import numpy as np
import scipy.signal
import scipy.fft


class SpectrumAnalyzer:
    """
    Handles all spectral analysis computations including PSD generation
    using various methods (Welch, Bartlett, etc.)
    """
    
    def __init__(self, config, logger):
        self.config = config
        self.logger = logger
    
    def compute_psd_welch(self, samples, frequency, noverlap=None, window="hann", return_linear=False):
        """
        Compute PSD using Welch's method with configurable overlap and windowing
        
        Args:
            samples: Input samples
            frequency: Center frequency
            noverlap: Overlap fraction
            window: Window function
            return_linear: If True, return linear PSD values instead of dB
        """
        # Debug: Print Welch PSD computation parameters
        actual_noverlap = int(noverlap * self.config.FFT_SIZE) if noverlap else None
        self.logger.debug(f"  Welch PSD params - Center: {frequency/1e6:.3f} MHz, "
                          f"FFT size: {self.config.FFT_SIZE}, "
                          f"Window: {window}, "
                          f"Overlap: {noverlap} ({actual_noverlap} samples), "
                          f"Sample count: {len(samples)}")
        
        freqs, psdlin = scipy.signal.welch(
            samples,
            fs=self.config.SAMPLE_RATE,
            window=window,
            nperseg=self.config.FFT_SIZE,
            noverlap=actual_noverlap,
            detrend=False,
            return_onesided=False
        )

        # Return linear or dB based on parameter
        if return_linear:
            psd = psdlin
        else:
            psd = np.nan_to_num(10.0 * np.log10(psdlin))
        
        freqs = freqs + float(frequency)
        
        # put DC in the center
        psd = np.fft.fftshift(psd)
        freqs = np.fft.fftshift(freqs)
        center_freq_mhz = frequency / 1e6
        freq_min = freqs[0] / 1e6
        freq_max = freqs[-1] / 1e6
        
        # Debug: Print final PSD statistics
        scale_str = "linear" if return_linear else "dB"
        self.logger.debug(f"  Final PSD - Freq range: {np.min(freq_min):.3f}-{np.max(freq_max):.3f} MHz, "
                         f"Power range: {np.min(psd):.1f}-{np.max(psd):.1f} {scale_str}, "
                         f"Mean power: {np.mean(psd):.1f} {scale_str}")
        
        return psd, freqs, freq_min, freq_max

    def compute_psd_bartlett(self, samples, frequency, return_linear=False):
        """
        Compute PSD using Bartlett's method (Welch with no overlap)
        """
        psd, labels, freq_min, freq_max = self.compute_psd_welch(
            samples, frequency, noverlap=0, window="bartlett", return_linear=return_linear)
        return psd, labels, freq_min, freq_max

    def compute_psd_bartlettX(self, samples, frequency, return_linear=False):
        """
        Alternative Bartlett implementation using periodogram
        """
        freqs, psdlin = scipy.signal.periodogram(
            samples, self.config.SAMPLE_RATE,
            window="hann",
            nfft=self.config.FFT_SIZE, return_onesided=False)

        # Return linear or dB based on parameter
        if return_linear:
            psd = psdlin
        else:
            psd = np.nan_to_num(10.0 * np.log10(psdlin))
        
        freqs = freqs + float(frequency)
        
        # put DC in the center
        psd = np.fft.fftshift(psd)
        freqs = np.fft.fftshift(freqs)

        # No longer filtering bins - using all FFT bins

        # Calculate frequency boundaries for this capture
        freq_min = freqs[0] / 1e6
        freq_max = freqs[-1] / 1e6

        return psd, freqs, freq_min, freq_max

    def compute_spectrogram(self, samples, frequency, noverlap=None, window="hann"):
        """
        Generates a spectrogram.
        """
        frequencies, times, spectrogram = scipy.signal.spectrogram(
            samples, fs=self.config.SAMPLE_RATE, nfft=self.config.FFT_SIZE,
            nperseg=self.config.FFT_SIZE, noverlap=noverlap, window=window,
            # To avoid a warning
            return_onesided=False)

        shifted = scipy.fft.fftshift(spectrogram, axes=0)
        shifted = scipy.fft.fftshift(shifted, axes=1)
        spectrogram = 10*np.log10(np.transpose(shifted))
        spectrogram = spectrogram.reshape(spectrogram.shape[0], spectrogram.shape[1])
        frequencies = scipy.fft.fftshift(frequencies) + float(frequency) / 1e6

        return spectrogram, frequencies, times
