
import datetime
import logging
import time
import numpy as np
import scipy.signal
import scipy.fft
import matplotlib.pyplot as plt

from .device import Device
from .sigmf import make_archive

LOG = logging.getLogger(__name__)

class Receiver:
    def __init__(self, args, logger, config):
        self.args = args
        self.logger = logger
        self.config = config
        self.device_config = self.config.device_config
        self.rf_port = list(self.device_config["devices"])[0]

        self.device = Device(args, logger, config)
        pass

    #
    # For ZMS dynamic mode, some changes require re-initializing the device
    # so we might return a new receiver.
    #
    # At the moment, none of these changes require a new receiver.
    #
    def updateConfig(self, range=None, gain=None, dwelltime=None):
        self.config.update(range=range, gain=gain, dwelltime=dwelltime)
        self.device.setGain(self.config.GAIN)
        self.config.dump()
        return self

    def get_samples(self, frequency, num_samples = 0):
        if not num_samples:
            num_samples = self.config.NUM_FFTS * self.config.FFT_SIZE
            pass
        return self.device.get_N_samples(int(frequency), int(num_samples))

    # Get the PSD of samples for a single center frequency.
    def compute_psd_welch(self, samples, frequency, noverlap=None, window="hann"):
        labels,psdlin = scipy.signal.welch(
            samples, self.config.MAX_INST_BW,
            noverlap=noverlap, window=window,
            nperseg=self.config.FFT_SIZE, return_onesided=False)
        #LOG.debug(f"compute_psd_welch: {frequency}")
        #LOG.debug(labels)

        psd = np.nan_to_num(10.0 * np.log10(psdlin))
        #LOG.debug(f"compute_psd_welch (lin): {frequency}")
        #LOG.debug(labels)
        
        # put DC in the center
        psd = np.fft.fftshift(psd)
        labels = np.fft.fftshift(labels)
        #LOG.debug(f"compute_psd_welch (fftshift): {frequency}")
        #LOG.debug(labels)

        # reject samples above cutoff
        # * (high frequencies are near fn=f/Fs=-0.5 and 0.5
        #    or rather the beginning and end of an FFT-shifted array)
        n_remove = self.config.N_REMOVE
        if n_remove:
            psd = psd[n_remove:-n_remove] 
            labels = labels[n_remove:-n_remove]
            pass
        labels = labels + frequency
        return psd, labels

    def compute_psd_bartlett(self, samples, frequency):
        psd, labels = self.compute_psd_welch(
            samples, frequency, noverlap=0, window="bartlett")
        return psd, labels

    def compute_psd_bartlettX(self, samples, frequency):
        labels,psdlin = scipy.signal.periodogram(
            samples, self.config.MAX_INST_BW,
            window="hann",
            nfft=self.config.FFT_SIZE, return_onesided=False)

        psd = np.nan_to_num(10.0 * np.log10(psdlin))
        labels = labels + float(frequency)
        
        # put DC in the center
        psd = np.fft.fftshift(psd)
        labels = np.fft.fftshift(labels)

        # reject samples above cutoff
        # * (high frequencies are near fn=f/Fs=-0.5 and 0.5
        #    or rather the beginning and end of an FFT-shifted array)
        n_remove = self.config.N_REMOVE
        psd = psd[n_remove:-n_remove] 
        labels = labels[n_remove:-n_remove]

        return psd, labels

    def get_full_spectrum(self):
        samples_per_center = self.config.NUM_FFTS * self.config.FFT_SIZE
        
        #
        # Only one channel for the moment.
        #
        all_samples = np.empty(
            (self.config.NUM_STEPS, samples_per_center), dtype=np.complex64)

        sdt = datetime.datetime.utcnow()
        for ii, center_freq in enumerate(self.config.CENTER_FREQS):
            samples = self.device.get_N_samples(center_freq, samples_per_center)

            all_samples[ii,:] = samples[0:]
            pass
        edt = datetime.datetime.utcnow()
        
        return IQ(self.config, np.ravel(all_samples), start=sdt, end=edt)
    
    
    def get_full_spectrum_psd(self):
        samples_per_center = self.config.NUM_FFTS * self.config.FFT_SIZE
        #
        # Only one channel for the moment.
        #
        all_samples = np.empty(
            (self.config.NUM_STEPS, samples_per_center), dtype=np.complex64)
        all_spectrum = np.empty(
            (self.config.NUM_STEPS, self.config.N), dtype=np.float32)
        all_frequencies = np.empty(
            (self.config.NUM_STEPS, self.config.N), dtype=np.float64)
        all_center_freqs = np.array([], dtype=np.float32)
        all_times = np.array([], dtype=int)
        all_devs = np.array([])
        meta = []

        for ii, center_freq in enumerate(self.config.CENTER_FREQS):
            t = time.time()
            sdt = datetime.datetime.utcnow()
            samples = self.device.get_N_samples(
                center_freq, self.config.NUM_FFTS * self.config.FFT_SIZE)
            edt = datetime.datetime.utcnow()
            if self.args.psd_method == "welch":
                psd,labels = self.compute_psd_welch(samples[0], center_freq)
            else:
                psd,labels = self.compute_psd_bartlett(samples[0], center_freq)
                pass

            m = dict(start_freq=labels[0], end_freq=labels[-1],
                     center_freq=center_freq, step_size=labels[1]-labels[0],
                     start_time=sdt, end_time=edt, len=len(psd))

            # extend devs, times to match the number of lines found per iteration.
            trep = np.tile(int(t), self.config.N_BINS_OUT_PER_STEP)
            devrep = np.tile(self.rf_port, self.config.N_BINS_OUT_PER_STEP)
            cfrep = np.tile(center_freq / 1e6, self.config.N_BINS_OUT_PER_STEP)

            all_samples[ii,:] = samples[0:]
            all_spectrum[ii,:] = psd[0:]
            all_frequencies[ii,:] = labels[0:]
            all_times = np.append(all_times, trep)
            all_devs = np.append(all_devs, devrep)
            all_center_freqs = np.append(all_center_freqs, cfrep)

            if ii % 10 == 0:
                self.logger.debug(
                    "spectrum collection {:.1f}% complete".format(
                        float(ii)/self.config.NUM_STEPS*100))
                pass

            meta.append(m)
            pass

        return PSD(self.config, np.ravel(all_spectrum),
                   np.ravel(all_frequencies),
                   all_times, all_devs, all_center_freqs,
                   IQ(self.config, np.ravel(all_samples)), meta)

    pass

#
# Wrapper class for all of the PSD data we need.
#
class PSD:
    def __init__(self, config, all_spectrum, all_frequencies,
                 all_times, all_devs, all_center_freqs, IQ, meta):
        self.config = config
        self.powers = all_spectrum
        self.frequencies = all_frequencies
        self.timestamps = all_times
        self.devices = all_devs
        self.center_freqs = all_center_freqs
        self.IQ = IQ
        self.meta = meta
        pass

    def dump(self):
        for ii in range(self.frequencies.shape[0]):
            strrow = "{},{},{},{},{}".format(self.devices[ii],
                                             self.timestamps[ii],
                                             self.frequencies[ii] / 1e6,
                                             self.powers[ii],
                                             self.center_freqs[ii])
            print(strrow)
            pass
        pass

    def writecsv(self, fd):
        fd.write("frequency,power,center_freq\n")
        for ii in range(self.frequencies.shape[0]):
            fd.write("%.3f,%.3f,%.3f\n" %
                     (self.frequencies[ii] / 1e6,
                      self.powers[ii],self.center_freqs[ii]))
            pass

    def plot(self, axe):
        axe.plot(self.frequencies, self.powers)
        pass

    class OneFreq:
        def __init__(self, dev, stamp, freq, power, center):
            self.center_freq = center
            self.frequency   = freq
            self.device      = dev
            self.timestamp   = stamp
            self.power       = power
            pass

    def __iter__(self):
        self.iter_index  = 0
        self.iter_length = self.frequencies.shape[0]
        return self
    
    def __next__(self):
        if self.iter_index >= self.iter_length:
            raise StopIteration

        ii = self.iter_index
        self.iter_index += 1

        return self.OneFreq(self.devices[ii],
                            self.timestamps[ii],
                            self.frequencies[ii],
                            self.powers[ii],
                            self.center_freqs[ii])

    def make_sigmf_metadata(self):
        metadata = {
            "global": {
                "core:description": "rfmonitor sweep psd",
                "core:datatype": "rf32le",
                "core:sample_rate": int(self.config.ANALOG_BW),
                "core:version": "1.2.5",
                "core:num_channels": 1,
                "core:recorder": "rfmonitor",
                "core:extensions": [
                    {
                        "name": "ntia-algorithm",
                        "version": "v2.0.0",
                        "optional": False
                    },
                    {
                        "name": "ntia-sensor",
                        "version": "v2.0.0",
                        "optional": False
                    },
                    {
                        "name": "openzms-core",
                        "version": "v1.0.0",
                        "optional": True
                    }
                ],
                "openzms-core:kind": "psd",
                "openzms-core:types": "powder.rfmonitor.v2.psd",
                "openzms-core:labels": "ota,sweep",
                "openzms-core:min_freq": int(self.frequencies[0]),
                "openzms-core:max_freq": int(self.frequencies[-1]),
                "openzms-core:freq_step": int(self.meta[0]["step_size"]),
                "ntia-algorithm:data_products": [],
                "ntia-algorithm:processing_info": []
            },
            "annotations": [],
            "captures": []
        }

        i = 0
        for m in self.meta:
            dp = {
                "processing": [f"powder_rfmonitor_psd_{i}"],
                "name": "powder_rfmonitor_psd",
                "description": "POWDER rfmonitor aggregate sweep psd",
                "series": ["psd"],
                "length": m["len"],
                "x_units": "Hz",
                "x_start": [float(m["start_freq"])],
                "x_stop": [float(m["end_freq"])],
                "x_step": [float(m["step_size"])],
                "y_units": "dBX/Hz"
            }
            metadata["global"]["ntia-algorithm:data_products"].append(dp)

            pi = {
                "type": "powder_rfmonitor_psd",
                "id": f"powder_rfmonitor_psd_{i}",
                "description": "An aggregate multi-center sweep power spectral density calculation",
                "samples": m["len"],
                #"dfts": 1
            }
            metadata["global"]["ntia-algorithm:processing_info"].append(pi)

            c = {
                "core:sample_start": i * m["len"],
                "core:frequency": float(m["center_freq"]),
                "core:datetime": m["start_time"].isoformat() + "Z",
                "ntia-sensor:sigan_settings": {
                    "gain": int(self.config.GAIN)
                }
            }
            metadata["captures"].append(c)

            i += 1

        return metadata

    def make_sigmf_archive(self, gzip=False):
        metadata = self.make_sigmf_metadata()
        return make_archive(
            metadata, self.powers.tobytes(),
            basename="powder-rfmonitor-psd",
            monitor_id=self.config.args.monitor_id, gzip=gzip)

    pass

class Spectrogram:
    def __init__(self, config, spectrogram, frequencies_,
                 times_, IQ, meta):
        self.config = config
        self.spectrogram = spectrogram
        self.frequencies = frequencies_
        self.times = times_
        self.IQ = IQ
        self.meta = meta

    def plot(self, axis):
        cmap = plt.colormaps["viridis"]
        axis.pcolormesh(self.frequencies,
                        self.times,
                        self.spectrogram,
                        cmap=cmap, norm=None,
                        vmin=40 + np.min(self.spectrogram),
                        vmax=np.max(self.spectrogram) - 10)
        return

    def make_sigmf_metadata(self):
        metadata = {
            "global": {
                "core:description": "rfmonitor spectrogram",
                "core:datatype": "rf32le",
                "core:sample_rate": int(self.config.ANALOG_BW),
                "core:version": "1.2.5",
                "core:num_channels": 1,
                "core:recorder": "rfmonitor",
                "core:extensions": [
                    {
                        "name": "ntia-algorithm",
                        "version": "v2.0.0",
                        "optional": False
                    },
                    {
                        "name": "ntia-sensor",
                        "version": "v2.0.0",
                        "optional": False
                    },
                    {
                        "name": "openzms-core",
                        "version": "v1.0.0",
                        "optional": True
                    }
                ],
                "openzms-core:kind": "spectrogram",
                "openzms-core:types": "powder.rfmonitor.v2.spectrogram",
                "openzms-core:labels": "ota",
                "openzms-core:min_freq": int(self.frequencies[0]),
                "openzms-core:max_freq": int(self.frequencies[-1]),
                "ntia-algorithm:data_products": [{
                    "processing": [f"powder_rfmonitor_spectrogram"],
                    "name": "powder_rfmonitor_spectrogram",
                    "description": "POWDER rfmonitor spectrogram",
                    "series": ["spectrogram"],
                    "length": len(self.spectrogram[0]),
                    "x_units": "Hz",
                    "x_start": [self.frequencies[0]],
                    "x_stop": [self.frequencies[-1]],
                    "x_step": [self.frequencies[1] - self.frequencies[0]],
                    "y_units": "Time",
                    "y_start": [self.times[0]],
                    "y_stop": [self.times[-1]],
                    "y_step": [self.times[1]-self.times[0]],
                }],
                "ntia-algorithm:processing_info": [{
                    "type": "powder_rfmonitor_spectrogram",
                    "id": "powder_rfmonitor_spectrogram",
                    "description": "A spectrogram.",
                    "samples": len(self.spectrogram),
                    #"dfts": 1
                }]
            },
            "annotations": [],
            "captures": [{
                "core:sample_start": 0,
                "core:frequency": float(self.config.args.center),
                "core:datetime": self.IQ.start.isoformat() + "Z",
                "ntia-sensor:sigan_settings": {
                    "gain": int(self.config.GAIN)
                }
            }]
        }

        return metadata

    def make_sigmf_archive(self, gzip=False):
        metadata = self.make_sigmf_metadata()
        return make_archive(
            metadata, self.spectrogram.tobytes(),
            basename="powder-rfmonitor-spectrogram",
            monitor_id=self.config.args.monitor_id, gzip=gzip)

#
# And another for I/Q samples.
#
class IQ:
    def __init__(self, config, all_samples, start=None, end=None):
        self.samples = all_samples
        self.config  = config
        self.start = start
        self.end = end
        pass

    def dump(self):
        print(str(self.samples))
        pass

    def plot(self):
        plt.plot(np.abs(self.samples))
        plt.show()
        pass

    def compute_spectrogram(self):
        frequencies, times, spectrogram = scipy.signal.spectrogram(
            self.samples, fs=self.config.ANALOG_BW, nfft=self.config.FFT_SIZE,
            nperseg=self.config.FFT_SIZE, noverlap=0, window="hann",
            # To avoid a warning
            return_onesided=False)

        shifted = scipy.fft.fftshift(spectrogram, axes=0)
        shifted = scipy.fft.fftshift(shifted, axes=1)
        spectrogram = 10*np.log10(np.transpose(shifted))
        frequencies = scipy.fft.fftshift(frequencies) + self.config.CENTER_FREQS[0]

        meta = {}

        return Spectrogram(self.config, spectrogram, frequencies, times,
                           self, meta)

    pass
