import numpy as np
import scipy.signal
import json
import time
import datetime
import os
import sys
import shutil
import tempfile
import matplotlib.pyplot as plt

from .config import Configuration
from .receiver import Receiver

#
# For legacy mode when running in Powder
#
EMULAB_BOOTDIR = "/var/emulab/boot"
EMULAB_SAVEDIR = "/var/emulab/save"

class Monitor:
    def __init__(self, args, logger):
        self.args = args
        self.logger = logger
        self.legacy = False
        self.devname = None

        if self.args.antenna:
            if self.args.device:
                devname = self.args.device
            else:
                devname = self.args.antenna
                pass

            self.device_config = {
                "devices" : {
                    devname : {
                        "name"     : devname,
                        "channels" :  {"0" : self.args.antenna}
                    }
                }
            }
        elif self.args.configfile and os.path.exists(self.args.configfile):
            with open(self.args.configfile, 'r') as f:
                self.device_config = json.load(f)
                pass
        else:
            raise Exception(
                "Must supply a device config file or --antenna/--device")

        #
        # Legacy mode write files to the WB store directory
        #
        if os.path.exists(EMULAB_BOOTDIR):
            with open(EMULAB_BOOTDIR + "/mydomain", 'r') as f:
                domain = f.readline().rstrip("\n")
                self.domain = domain
                pass

            with open(EMULAB_BOOTDIR + "/nickname", 'r') as f:
                nickname = f.readline().rstrip("\n")
                tokens = nickname.split(".")
                self.eid = tokens[1]
                self.pid = tokens[2]
                pass
            
            self.legacy = True
            self.devname = list(self.device_config["devices"].keys())[0]
            pass

        if self.args.www_mode:
            if not (self.args.device or self.devname):
                raise Exception("Must provide --device (ex: 'nuc1:rf0') with --www-mode") 
            if not os.path.exists(self.args.www_dir):
                raise Exception("WWW directory (--www-dir) does not exist")
            if not os.access(self.args.www_dir, os.W_OK):
                raise Exception("WWW directory (--www-dir) is not writable")
            pass

        self.logger.info("Device Config: " + str(self.device_config))
        self.config = Configuration(args, self.device_config)
        self.logger.info("Config: " + str(vars(self.config)))

        #
        # In ZMS mode, hand off Receiver initialization to the ZMS module
        #
        if args.zms_mode:
            from .zms import ZMS
            self.zms = ZMS(self.args, self.logger, self.config)
            return

        self.receiver = Receiver(self.args, logger, self.config)
        pass

    def start(self):
        if self.args.zms_mode:
            return self.zms.start()

        loops = self.args.repeat
        
        #
        # If plotting, create a plot that will be updated when looping
        #
        if self.args.plot:
            self.plotter = Plot(self.args.repeat != 1, self.logger)
            pass

        while True:
            self.logger.info("Running a scan at " +
                             datetime.datetime.now().strftime("%m/%d/%Y %H:%M:%S"))
            
            psd = self.get_full_spectrum_psd()
            if self.args.plot:
                self.plotter.plotPSD(psd)
            elif self.args.csv:
                return self.writecsv(psd)
            elif self.args.www_mode:
                self.wwwsave(psd)
            elif self.legacy:
                self.wbsave(psd)
            else:
                psd.dump()
                pass
            
            if self.args.repeat:
                loops -= 1
                if loops == 0:
                    break
                pass
            
            if self.args.interval:
                time.sleep(self.args.interval)
                pass
            pass

        #
        # If interactively plotting, hold here
        #
        if self.args.plot and self.args.repeat:
            plt.show(block=True)
            pass
        pass

    # Get samples for one center frequency.
    def get_samples(self, frequency):
        return self.receiver.get_samples(frequency)

    # Get PSD for all center frequencies.
    def get_full_spectrum_psd(self):
        return self.receiver.get_full_spectrum_psd()

    # Get I/Q samples for all center frequencies.
    def get_full_spectrum(self):
        return self.receiver.get_full_spectrum()

    def writecsv(self, psd):
        if self.args.output:
            with open(self.args.output, "w") as f:
                psd.writecsv(f)
                pass
        else:
            psd.writecsv(sys.stdout)
            pass
        pass

    def wwwsave(self, psd):
        import gzip
        
        #
        # Files go in a subdir named by the day, This avoids ending up
        # with a giant directory that takes forever to load and has a
        # menu with 100s of entries.
        #
        now    = datetime.datetime.now()
        utcnow = int(now.timestamp())
        subdir = now.strftime("20%y-%m-%d")
        wwwdir = self.args.www_dir + "/" + subdir
        if not os.path.exists(wwwdir):
            os.mkdir(wwwdir)
            os.chmod(wwwdir, 0o0775)
            pass
        csvname = "%s-%s.csv.gz" % (self.args.device, utcnow)
        csvfile = wwwdir + "/" + csvname

        # Create temp file, gzip, move into place.
        with tempfile.NamedTemporaryFile(mode="w", delete=False) as fp:
            psd.writecsv(fp)
            fp.close()
            with open(fp.name, mode='rb') as f_in:
                with gzip.open(csvfile, 'wb') as f_out:
                    shutil.copyfileobj(f_in, f_out)
                    pass
                pass
            os.remove(fp.name)
            pass
        pass

    #
    # Legacy mode is to write into the wbstore directory. This
    # only makes sense when running on Powder.
    #
    def wbsave(self, psd):
        import gzip

        #
        # Files go into the local wbstore directory, unless running on the
        # Mothership, in which case it goes to the project directory
        #
        if self.domain == "emulab.net":
            savedir = "/proj/%s/exp/%s" % (self.pid, self.eid)
        else:
            savedir = EMULAB_SAVEDIR
            pass

        now     = datetime.datetime.now()
        utcnow  = int(now.timestamp())
        csvname = "%s-%s.csv.gz" % (self.devname, utcnow)
        csvfile = savedir + "/" + csvname

        # Create temp file, gzip, move into place.
        with tempfile.NamedTemporaryFile(mode="w", delete=False) as fp:
            psd.writecsv(fp)
            fp.close()
            with open(fp.name, mode='rb') as f_in:
                with gzip.open(csvfile, 'wb') as f_out:
                    shutil.copyfileobj(f_in, f_out)
                    pass
                pass
            os.remove(fp.name)
            pass

        # An extra step for the Mothership.
        if self.domain != "emulab.net":
            return
        
        mdir = "/proj/%s/monitor" % (self.pid,) 
        if not os.path.exists(mdir):
            os.mkdir(mdir)
            os.chmod(mdir, 0o0775)
            pass
        tfile = mdir + "/" + self.eid + ".gz"
        if os.path.exists(tfile):
            os.remove(tfile)
            pass
        cmd = "/bin/tar -cf %s.tmp -C /proj %s/exp/%s/%s" % (tfile, self.pid, self.eid, csvname)
        self.logger.info(cmd)
        if os.system(cmd):
            raise Exception("cmd failed")
        cmd = "/bin/mv %s.tmp %s" % (tfile, tfile)
        if os.system(cmd):
            raise Exception("cmd failed")
        pass
        
    pass

class Plot:
    def __init__(self, interactive, logger):
        self.interactive = interactive
        self.logger = logger
        self.logger.info("Initializing the plotter");
        if interactive:
            plt.ion()
            pass
        #plt.grid(True)
        # Marker for first plot
        self.fig = None
        self.logger.info("Plotter initialized");
        pass

    def plotPSD(self, psd):
        if not self.fig:
            self.fig, self.axes = plt.subplots(nrows=1, ncols=2, figsize=(12, 5))
            psd.plot(self.axes[0])
            spectrogram = psd.IQ.compute_spectrogram()
            spectrogram.plot(self.axes[1])
            plt.show()
            if self.interactive:
                plt.pause(5)
                pass
            return

        self.axes[0].clear()
        self.axes[1].clear()
        psd.plot(self.axes[0])
        spectrogram = psd.IQ.compute_spectrogram()
        spectrogram.plot(self.axes[1])
        plt.pause(1)
        pass

    pass
