
import datetime
import logging
import time
import numpy as np
import matplotlib.pyplot as plt

from .device import Device
from .spectrum_analyzer import SpectrumAnalyzer
from .frequency_manager import FrequencyManager
from .spectrum_plotter import SpectrumPlotter
from .spectrum_exporter import SpectrumExporter
from .psd_stitcher import PSDStitcher
from .sigmf import make_archive


class Receiver:
    def __init__(self, args, logger, config):
        self.args = args
        self.logger = logger
        self.config = config
        self.device_config = self.config.device_config
        self.rf_port = list(self.device_config["devices"])[0]
        self.device = Device(args, logger, config)
        
        self.spectrum_analyzer = SpectrumAnalyzer(config, logger)
        self.frequency_manager = FrequencyManager(config, logger)
        
        # Get optimal PSD stitcher configuration from frequency manager
        stitcher_config = self.frequency_manager.calculate_psd_stitcher_config()
        self.logger.info(f"PSD Stitcher auto-configured: {stitcher_config['reasoning']}")
        
        # Initialize PSD stitcher with overlap-aware configuration
        self.psd_stitcher = PSDStitcher(
            edge_bin_divisor=stitcher_config['edge_bin_divisor'],
            mid_bin_divisor=stitcher_config['mid_bin_divisor'],
            no_masking=stitcher_config['no_masking'],
            overlap_ratio=stitcher_config['overlap_ratio']
        )

    #
    # For ZMS dynamic mode, some changes require re-initializing the device
    # so we might return a new receiver.
    #
    # At the moment, none of these changes require a new receiver.
    #
    def updateConfig(self, range=None, gain=None, dwell_time=None):
        self.config.update(range=range, gain=gain, dwell_time=dwell_time)
        self.device.setGain(self.config.GAIN)
        self.config.dump()
        return self

    def get_samples(self, frequency, num_samples = 0):
        if not num_samples:
            num_samples = self.config.NUM_FFTS * self.config.FFT_SIZE

        return self.device.get_N_samples(int(frequency), int(num_samples))

    # Get the PSD of samples for a single center frequency.
    def compute_psd_welch(self, samples, frequency, noverlap=None, window="hann", return_linear=False):
        return self.spectrum_analyzer.compute_psd_welch(samples, frequency, noverlap, window, return_linear)

    def compute_psd_bartlett(self, samples, frequency, return_linear=False):
        return self.spectrum_analyzer.compute_psd_bartlett(samples, frequency, return_linear)

    def compute_psd_bartlettX(self, samples, frequency, return_linear=False):
        return self.spectrum_analyzer.compute_psd_bartlettX(samples, frequency, return_linear)

    def get_full_spectrum(self):
        samples_per_center = self.config.NUM_FFTS * self.config.FFT_SIZE
        
        #
        # Only one channel for the moment.
        #
        all_samples = np.empty(
            (int(self.config.NUM_STEPS), int(samples_per_center)), dtype=np.complex64)

        sdt = datetime.datetime.utcnow()
        for ii, center_freq in enumerate(self.config.CENTER_FREQS):
            samples = self.device.get_N_samples(center_freq, samples_per_center)

            all_samples[ii,:] = samples[0:]
            pass
        edt = datetime.datetime.utcnow()
        
        return IQ(self.config, np.ravel(all_samples), center_freq, start=sdt, end=edt)
    
    
    def get_full_spectrum_psd(self, include_samples=False):
        """
        Get PSD for all center frequencies and stitch them together using PSDStitcher
        """
        # Debug: Print overall PSD generation parameters
        self.frequency_manager.log_frequency_plan()
        
        if len(self.config.CENTER_FREQS) == 0:
            self.logger.debug("No center frequencies configured - returning empty PSD")
            return PSD(self.config, np.array([]), np.array([]), np.array([]), np.array([]), np.array([]))
        
        # Collect all PSDs from each center frequency (without masking)
        self.logger.debug("Collecting PSDs from all center frequencies...")
        psd_array = []
        
        # Store individual captures for plotting
        individual_captures = []
        
        # Track PSD dimensions for logging
        psd_bins = None
        freq_resolution = None
        
        all_sample_start_dt = datetime.datetime.utcnow()
        for ii, center_freq in enumerate(self.config.CENTER_FREQS):
            if ii % 10 == 0 or ii < 5 or ii == len(self.config.CENTER_FREQS) - 1:
                self.logger.debug(f"Processing center freq {ii+1}/{len(self.config.CENTER_FREQS)}: {center_freq/1e6:.3f} MHz")
            
            sample_start_dt = datetime.datetime.utcnow()
            samples = self.device.get_N_samples(
                center_freq, self.config.NUM_FFTS * self.config.FFT_SIZE)
            sample_end_dt = datetime.datetime.utcnow()
            # Get LINEAR PSD values for proper stitching
            psd, freqs, freq_min, freq_max = self.compute_psd_welch(
                samples[0], center_freq, noverlap=self.args.psd_fft_overlap, window=self.args.psd_window, return_linear=True)
            
            psd_array.append(psd)
            
            # Store individual capture data (convert to dB for display)
            psd_db = np.nan_to_num(10.0 * np.log10(psd))
            individual_captures.append({
                'center_freq': center_freq / 1e6,  # Store in MHz
                'frequencies': freqs / 1e6,      # Store in MHz
                'powers': psd_db,  # Store in dB for plotting
                'freq_min': freq_min,
                'freq_max': freq_max,
                'step_size': freqs[1] - freqs[0],
                'start': sample_start_dt,
                'end': sample_end_dt
            })
            if include_samples:
                individual_captures[-1]['IQ'] = IQ(
                    self.config, samples, center_freq, start=sample_start_dt, end=sample_end_dt)
            
            # Store PSD dimensions from first capture for logging
            if psd_bins is None:
                psd_bins = len(psd)
                freq_resolution = (freqs[1] - freqs[0]) if len(freqs) > 1 else 0
                self.logger.debug(f"PSD frequency resolution: {freq_resolution/1e3:.3f} kHz")
                self.logger.debug(f"PSD bins per center freq: {psd_bins}")
        all_sample_end_dt = datetime.datetime.utcnow()
        
        # Convert to 2D numpy array for stitching
        psd_2d = np.array(psd_array, dtype=np.float32)
        self.logger.debug(f"PSD array shape: {psd_2d.shape}")
        
        # Use PSDStitcher to combine PSDs (in linear scale)
        self.logger.debug("Stitching PSDs together (in linear scale)...")
        stitched_psd_linear = self.psd_stitcher.stitch(psd_2d, method='mean')
        
        # Convert to dB AFTER stitching
        stitched_psd = np.nan_to_num(10.0 * np.log10(stitched_psd_linear))
        self.logger.debug(f"Stitched PSD length: {len(stitched_psd)}")
        
        # Generate corresponding frequency array for stitched result
        # Use the actual frequency range from the individual captures
        if len(individual_captures) > 0:
            # Get the actual frequency range from first and last captures
            actual_freq_start = individual_captures[0]['frequencies'][0] * 1e6  # Convert back from MHz to Hz
            actual_freq_end = individual_captures[-1]['frequencies'][-1] * 1e6  # Convert back from MHz to Hz
            actual_range = actual_freq_end - actual_freq_start
            
            # Calculate frequency resolution for the stitched result
            stitched_freq_resolution = actual_range / len(stitched_psd)
            stitched_frequencies = actual_freq_start + np.arange(len(stitched_psd)) * stitched_freq_resolution
            
            self.logger.debug(f"Frequency mapping - actual range: {actual_freq_start/1e6:.3f}-{actual_freq_end/1e6:.3f} MHz, resolution: {stitched_freq_resolution/1e3:.1f} kHz, length: {len(stitched_psd)}")
        else:
            # Fallback case
            stitched_frequencies = np.arange(len(stitched_psd))
        
        # Create timestamps and device info arrays
        current_time = int(time.time())
        all_times = np.full(len(stitched_psd), current_time, dtype=int)
        all_devs = np.full(len(stitched_psd), self.rf_port, dtype=object)
        
        # For center frequencies, we need to map each stitched bin back to its source center frequency
        # This is approximate since the stitching process combines overlapping regions
        all_center_freqs = np.zeros(len(stitched_psd), dtype=np.float32)
        
        # Use the actual shift from the initialized stitcher
        if psd_bins is not None and hasattr(self.psd_stitcher, 'shift'):
            shift = self.psd_stitcher.shift
            
            for ii, center_freq in enumerate(self.config.CENTER_FREQS):
                start_idx = ii * shift
                end_idx = min(start_idx + psd_bins, len(stitched_psd))
                all_center_freqs[start_idx:end_idx] = center_freq / 1e6
        else:
            # Fallback: distribute center frequencies evenly across the result
            bins_per_center = len(stitched_psd) // len(self.config.CENTER_FREQS)
            for ii, center_freq in enumerate(self.config.CENTER_FREQS):
                start_idx = ii * bins_per_center
                end_idx = min(start_idx + bins_per_center, len(stitched_psd))
                all_center_freqs[start_idx:end_idx] = center_freq / 1e6
        
        # Convert frequencies to MHz
        stitched_frequencies_mhz = stitched_frequencies / 1e6
        
        # Trim to requested frequency range
        start_mhz = self.config.START_FREQ / 1e6
        end_mhz = self.config.END_FREQ / 1e6
        
        valid_indices = np.where((stitched_frequencies_mhz >= start_mhz) & (stitched_frequencies_mhz <= end_mhz))[0]
        
        if len(valid_indices) > 0:
            # Trim all arrays to valid range
            stitched_psd = stitched_psd[valid_indices]
            stitched_frequencies_mhz = stitched_frequencies_mhz[valid_indices]
            all_times = all_times[valid_indices]
            all_devs = all_devs[valid_indices]
            all_center_freqs = all_center_freqs[valid_indices]
            self.logger.debug(f"Trimmed to requested range: {len(valid_indices)} points")
        
        # Debug: Print final PSD summary
        self.logger.debug(f"=== Final Stitched PSD Summary ===")
        self.logger.debug(f"Total frequency points: {len(stitched_frequencies_mhz)}")
        self.logger.debug(f"Frequency range: {np.min(stitched_frequencies_mhz):.3f} - {np.max(stitched_frequencies_mhz):.3f} MHz")
        self.logger.debug(f"Power range: {np.min(stitched_psd):.1f} - {np.max(stitched_psd):.1f} dB")
        self.logger.debug(f"Mean power: {np.mean(stitched_psd):.1f} dB")
        self.logger.debug(f"Unique center frequencies: {len(np.unique(all_center_freqs))}")
        self.logger.debug(f"=== End PSD Debug ===")
        
        return PSD(self.config, stitched_psd, stitched_frequencies_mhz, all_times, all_devs, all_center_freqs, individual_captures, start=all_sample_start_dt, end=all_sample_end_dt)

    pass

#
# Wrapper class for all of the PSD data we need.
#
class PSD:
    def __init__(self, config, all_spectrum, all_frequencies,
                 all_times, all_devs, all_center_freqs, individual_captures=None,
                 start=None, end=None):
        self.config = config
        self.powers = all_spectrum
        self.frequencies = all_frequencies
        self.timestamps = all_times
        self.devices = all_devs
        self.center_freqs = all_center_freqs
        self.individual_captures = individual_captures or []
        
        # Initialize helper objects
        self._plotter = SpectrumPlotter()
        self._exporter = SpectrumExporter()
        
        # Debug: Print PSD object creation summary
        print(f"[DEBUG] PSD object created with {len(all_spectrum)} data points")
        if len(all_spectrum) > 0:
            print(f"[DEBUG] PSD frequency range: {np.min(all_frequencies):.3f} - {np.max(all_frequencies):.3f} MHz")
            print(f"[DEBUG] PSD power range: {np.min(all_spectrum):.1f} - {np.max(all_spectrum):.1f} dB")

    def dump(self):
        """Dump PSD data to console - delegates to exporter"""
        self._exporter.dump_to_console(self)

    def writecsv(self, fd):
        """Write CSV data to file handle - delegates to exporter"""
        self._exporter.write_csv(self, fd)

    def plot(self):
        """Plot combined spectrum - delegates to plotter"""
        self._plotter.plot_combined_spectrum(self)
        
    def plot_individual(self):
        """Plot individual captures - delegates to plotter"""
        self._plotter.plot_individual_captures(self)

    def spectrogram(self):
        """Plot spectrogram - delegates to plotter"""
        self._plotter.plot_spectrogram_from_psd(self)

    class OneFreq:
        def __init__(self, dev, stamp, freq, power, center):
            self.center_freq = center
            self.frequency   = freq
            self.device      = dev
            self.timestamp   = stamp
            self.power       = power
            pass

    def __iter__(self):
        self.iter_index  = 0
        self.iter_length = self.frequencies.shape[0]
        return self
    
    def __next__(self):
        if self.iter_index >= self.iter_length:
            raise StopIteration

        ii = self.iter_index
        self.iter_index += 1

        return self.OneFreq(self.devices[ii],
                            self.timestamps[ii],
                            self.frequencies[ii],
                            self.powers[ii],
                            self.center_freqs[ii])

    def make_sigmf_metadata(self):
        values = []
        dp = [{
            "processing": [f"powder_rfmonitor_aggregate_psd"],
            "name": "powder_rfmonitor_aggregate_psd",
            "description": "POWDER rfmonitor aggregate sweep psd",
            "series": ["power"],
            "length": int(len(self.powers)),
            "x_units": "Hz",
            "x_start": [int(self.frequencies[0] * 1e6)],
            "x_stop": [int(self.frequencies[-1] * 1e6)],
            "x_step": [int(self.frequencies[1] * 1e6 - self.frequencies[0] * 1e6)],
            "y_units": "dBx"
        }]
        metadata = {
            "global": {
                "core:description": "rfmonitor sweep psd",
                "core:datatype": "rf32_le",
                "core:sample_rate": int(self.config.SAMPLE_RATE),
                "core:version": "1.2.5",
                "core:num_channels": 1,
                "core:recorder": "rfmonitor",
                "core:extensions": [
                    {
                        "name": "ntia-algorithm",
                        "version": "v2.0.0",
                        "optional": False
                    },
                    {
                        "name": "ntia-sensor",
                        "version": "v2.0.0",
                        "optional": False
                    },
                    {
                        "name": "openzms-core",
                        "version": "v1.0.0",
                        "optional": True
                    }
                ],
                "openzms-core:kind": "psd",
                "openzms-core:types": "powder.rfmonitor.v2.psd",
                "openzms-core:labels": "ota,sweep",
                "openzms-core:min_freq": int(self.frequencies[0] * 1e6),
                "openzms-core:max_freq": int(self.frequencies[-1] * 1e6),
                "openzms-core:freq_step": int(self.frequencies[1] * 1e6 - self.frequencies[0] * 1e6),
                "openzms-core:values": values,
                "ntia-algorithm:data_products": dp,
                "ntia-algorithm:processing_info": [{
                    "type": "powder_rfmonitor_aggregate_psd",
                    "id": "powder_rfmonitor_aggregate_psd",
                    "description": "An aggregate PSD.",
                    #"samples": len(self.spectrogram),
                    #"dfts": 1
                }]
            },
            "annotations": [],
            "captures": []
        }

        i = 0
        for m in self.individual_captures:
            #
            # XXX: this is not really compliant with the NTIA SigMF extensions,
            # so commented, but leaving it here for reference for now.
            #
            # To say more: we cannot call it a data_product because that implies
            # that the data is in the data file, which it is not.  This is
            # more of a metadata-only product.  We would need a custom extension
            # to describe it properly.
            #
            if False:
                plen = len(m["powers"])
                dp = {
                    "processing": [f"powder_rfmonitor_psd_{i}"],
                    "name": "powder_rfmonitor_psd",
                    "description": "POWDER rfmonitor aggregate sweep psd",
                    "series": ["psd"],
                    "length": plen,
                    "x_units": "Hz",
                    "x_start": [float(m["freq_min"]) * 1e6],
                    "x_stop": [float(m["freq_max"]) * 1e6],
                    "x_step": [float(m["step_size"]) * 1e6],
                    "y_units": "dBX/Hz"
                }
                metadata["global"]["ntia-algorithm:data_products"].append(dp)

                pi = {
                    "type": "powder_rfmonitor_psd",
                    "id": f"powder_rfmonitor_psd_{i}",
                    "description": "An aggregate multi-center sweep power spectral density calculation",
                    "samples": plen,
                    #"dfts": 1
                }
                metadata["global"]["ntia-algorithm:processing_info"].append(pi)

            c = {
                "core:sample_start": 0, # i * plen,
                "core:frequency": float(m["center_freq"]),
                "core:datetime": m["start"].isoformat() + "Z",
                "ntia-sensor:sigan_settings": {
                    "gain": int(self.config.GAIN)
                }
            }
            metadata["captures"].append(c)

            i += 1

        return metadata

    def make_sigmf_archive(self, gzip=False):
        metadata = self.make_sigmf_metadata()
        return make_archive(
            metadata, self.powers.tobytes(),
            basename="powder-rfmonitor-psd",
            monitor_id=self.config.args.monitor_id, gzip=gzip)

    pass

class Spectrogram:
    def __init__(self, config, spectrogram, frequencies,
                 times, IQ, start=None, end=None):
        self.config = config
        self.spectrogram = spectrogram
        self.frequencies = frequencies
        self.times = times
        self.start = start
        self.end = end
        self.IQ = IQ

    def plot(self, axis):
        cmap = plt.colormaps["viridis"]
        axis.pcolormesh(self.frequencies,
                        self.times,
                        self.spectrogram,
                        cmap=cmap, norm=None,
                        vmin=40 + np.min(self.spectrogram),
                        vmax=np.max(self.spectrogram) - 10)
        return

    def make_sigmf_metadata(self):
        metadata = {
            "global": {
                "core:description": "rfmonitor spectrogram",
                "core:datatype": "rf32_le",
                "core:sample_rate": int(self.config.SAMPLE_RATE),
                "core:version": "1.2.5",
                "core:num_channels": 1,
                "core:recorder": "rfmonitor",
                "core:extensions": [
                    {
                        "name": "ntia-algorithm",
                        "version": "v2.0.0",
                        "optional": False
                    },
                    {
                        "name": "ntia-sensor",
                        "version": "v2.0.0",
                        "optional": False
                    },
                    {
                        "name": "openzms-core",
                        "version": "v1.0.0",
                        "optional": True
                    }
                ],
                "openzms-core:kind": "spectrogram",
                "openzms-core:types": "powder.rfmonitor.v2.spectrogram",
                "openzms-core:labels": "ota",
                "openzms-core:min_freq": int(self.frequencies[0]),
                "openzms-core:max_freq": int(self.frequencies[-1]),
                "ntia-algorithm:data_products": [{
                    "processing": [f"powder_rfmonitor_spectrogram"],
                    "name": "powder_rfmonitor_spectrogram",
                    "description": "POWDER rfmonitor spectrogram",
                    "series": ["spectrogram"],
                    "length": len(self.spectrogram[0]),
                    "x_units": "Hz",
                    "x_start": [self.frequencies[0]],
                    "x_stop": [self.frequencies[-1]],
                    "x_step": [self.frequencies[1] - self.frequencies[0]],
                    "y_units": "Time",
                    "y_start": [self.times[0]],
                    "y_stop": [self.times[-1]],
                    "y_step": [self.times[1]-self.times[0]],
                }],
                "ntia-algorithm:processing_info": [{
                    "type": "powder_rfmonitor_spectrogram",
                    "id": "powder_rfmonitor_spectrogram",
                    "description": "A spectrogram.",
                    "samples": len(self.spectrogram),
                    #"dfts": 1
                }]
            },
            "annotations": [],
            "captures": [{
                "core:sample_start": 0,
                "core:frequency": float(self.config.args.center) * 1e6,
                "core:datetime": self.IQ.start.isoformat() + "Z",
                "ntia-sensor:sigan_settings": {
                    "gain": int(self.config.GAIN)
                }
            }]
        }

        return metadata

    def make_sigmf_archive(self, gzip=False):
        metadata = self.make_sigmf_metadata()
        return make_archive(
            metadata, self.spectrogram.tobytes(),
            basename="powder-rfmonitor-spectrogram",
            monitor_id=self.config.args.monitor_id, gzip=gzip)

#
# And another for I/Q samples.
#
class IQ:
    def __init__(self, config, all_samples, frequency, start=None, end=None):
        self.config  = config
        self.samples = all_samples
        self.frequency = frequency
        self.start = start
        self.end = end
        self._plotter = SpectrumPlotter()
        self._analyzer = SpectrumAnalyzer(self.config, logging.getLogger(__name__))

    def dump(self):
        print(str(self.samples))
              
    def plot_spectrogram(self):
        """Generate spectrogram from I/Q samples - delegates to plotter"""
        self._plotter.plot_iq_spectrogram(self.samples, self.config)

    def compute_spectrogram(self):
        spectrogram, frequencies, times = self._analyzer.compute_spectrogram(
            self.samples, self.frequency)

        return Spectrogram(self.config, spectrogram, frequencies, times,
                           self, start=self.start, end=self.end)
    pass
