import numpy as np
import os
import datetime
from pathlib import Path

# Check for display before importing matplotlib
if 'DISPLAY' not in os.environ or not os.environ['DISPLAY']:
    import matplotlib
    matplotlib.use('Agg')  # Must be before importing pyplot
import matplotlib.pyplot as plt
import scipy.signal
import math


class PlotterMixin:
    """
    Mixin class providing common plotting utilities
    """
    
    def _check_display(self):
        """
        Check if a GUI display is available
        """
        # Check for common display indicators
        if os.name == 'nt':  # Windows
            return True
        
        # Unix/Linux/Mac
        # Check for X11 display
        if 'DISPLAY' in os.environ and os.environ['DISPLAY']:
            return True
        # Check for Wayland
        if 'WAYLAND_DISPLAY' in os.environ and os.environ['WAYLAND_DISPLAY']:
            return True
        # Check for macOS
        try:
            if 'Darwin' in os.uname().sysname:
                return True
        except:
            pass
        return False
    
    def _format_title_with_timestamp(self, base_title, timestamp=None):
        """
        Format a plot title with timestamp information
        
        Args:
            base_title: The base title for the plot
            timestamp: datetime object, Unix timestamp (int/float), or None to use current time
        
        Returns:
            Formatted title string with timestamp
        """
        if timestamp is None:
            timestamp = datetime.datetime.now()
        elif not isinstance(timestamp, datetime.datetime):
            # Convert Unix timestamp to datetime
            try:
                timestamp = datetime.datetime.fromtimestamp(float(timestamp))
            except (ValueError, TypeError):
                # Fallback to current time if conversion fails
                timestamp = datetime.datetime.now()
        
        # Format timestamp as readable date/time (local time)
        time_str = timestamp.strftime("%Y-%m-%d %H:%M:%S")
        return f"{base_title} - {time_str}"
    
    def _save_figure(self, fig, prefix="spectrum", plot_count=None):
        """
        Save figure to file with timestamp
        """
        # Create .tmpimg directory if it doesn't exist
        tmpimg_dir = Path(".tmpimg")
        tmpimg_dir.mkdir(exist_ok=True)
        
        timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
        if plot_count is not None:
            if prefix == "individual_captures":
                filename = f"{prefix}_{timestamp}_{plot_count}.jpg"
                filepath = tmpimg_dir / filename
                fig.savefig(filepath, dpi=75, bbox_inches='tight', format='jpg', quality=85)
            else:
                filename = f"{prefix}_{timestamp}_{plot_count}.png"
                filepath = tmpimg_dir / filename
                fig.savefig(filepath, dpi=150, bbox_inches='tight')
        else:
            if prefix == "individual_captures":
                filename = f"{prefix}_{timestamp}.jpg"
                filepath = tmpimg_dir / filename
                fig.savefig(filepath, dpi=75, bbox_inches='tight', format='jpg', quality=85)
            else:
                filename = f"{prefix}_{timestamp}.png"
                filepath = tmpimg_dir / filename
                fig.savefig(filepath, dpi=150, bbox_inches='tight')
        print(f"Plot saved to: {filepath}")
        return str(filepath)


class SpectrumPlotter(PlotterMixin):
    """
    Handles all spectrum visualization and plotting functionality
    """
    
    def __init__(self, force_save=False):
        self.force_save = force_save
        self.has_display = self._check_display() and not force_save
        print(f"[SpectrumPlotter] Display detected: {self._check_display()}, Force save: {force_save}, Will display: {self.has_display}")
        # Backend already set at import time based on DISPLAY env var
    
    def plot_combined_spectrum(self, psd_data):
        """
        Plot the combined spectrum from PSD data
        """
        plt.figure(figsize=(10, 6))
        plt.plot(psd_data.frequencies, psd_data.powers)
        plt.xlabel('Frequency (MHz)')
        plt.ylabel('Power (dB)')
        
        # Use first timestamp from PSD data if available, otherwise current time
        timestamp = None
        if hasattr(psd_data, 'timestamps') and len(psd_data.timestamps) > 0:
            timestamp = psd_data.timestamps[0]
        
        title = self._format_title_with_timestamp('RF Spectrum Analysis', timestamp)
        plt.title(title)
        plt.grid(True)
        if self.has_display:
            print("[SpectrumPlotter] Calling plt.show() for combined spectrum")
            plt.show(block=False)  # Non-blocking show
            plt.pause(0.1)  # Brief pause to display
            plt.close()  # Close after display
        else:
            self._save_figure(plt.gcf(), "combined_spectrum")
            plt.close()
    
    def plot_individual_captures(self, psd_data):
        """
        Plot individual PSD captures based on original capture data.
        This creates a single figure with subplots for each center frequency.
        """
        # Use original individual captures
        if not hasattr(psd_data, 'individual_captures') or len(psd_data.individual_captures) == 0:
            print("No individual capture data available for plotting")
            return
        
        individual_captures = psd_data.individual_captures
        num_plots = len(individual_captures)
            
        # Determine grid dimensions
        # Try to make the grid as square as possible
        grid_size = math.ceil(math.sqrt(num_plots))
        rows = grid_size
        cols = math.ceil(num_plots / rows)
        
        # Create a single figure with a grid of subplots
        fig, axes = plt.subplots(rows, cols, figsize=(cols*2, rows*1.5))
        
        # Use first timestamp from PSD data if available, otherwise current time
        timestamp = None
        if hasattr(psd_data, 'timestamps') and len(psd_data.timestamps) > 0:
            timestamp = psd_data.timestamps[0]
        
        suptitle = self._format_title_with_timestamp('Individual PSD Captures by Center Frequency', timestamp)
        fig.suptitle(suptitle, fontsize=16)
        
        # Make axes a 2D array even if there's only one row or column
        if num_plots == 1:
            axes = np.array([[axes]])
        elif rows == 1:
            axes = axes.reshape(1, -1)
        elif cols == 1:
            axes = axes.reshape(-1, 1)
        
        # Plot each individual capture
        for i, capture in enumerate(individual_captures):
            row_idx = i // cols
            col_idx = i % cols
            ax = axes[row_idx, col_idx]
            
            # Plot the original capture data
            ax.plot(capture['frequencies'], capture['powers'])
            ax.set_xlabel('Frequency (MHz)')
            ax.set_ylabel('Power (dB)')
            ax.set_title(f'{capture["center_freq"]:.2f} MHz')
            ax.grid(True)
            
            # Set x-axis limits based on actual frequency range
            freq_min = np.min(capture['frequencies'])
            freq_max = np.max(capture['frequencies'])
            margin = (freq_max - freq_min) * 0.05  # 5% margin
            ax.set_xlim(freq_min - margin, freq_max + margin)
            
        # Hide any unused subplots
        for i in range(num_plots, rows * cols):
            row_idx = i // cols
            col_idx = i % cols
            fig.delaxes(axes[row_idx, col_idx])
            
        # Adjust layout to prevent overlap
        plt.tight_layout(rect=[0, 0, 1, 0.97])  # Make room for the suptitle
        
        # Show the figure or save it
        if self.has_display:
            plt.show()
        else:
            self._save_figure(fig, "individual_captures")
            plt.close(fig)
    
    def plot_spectrogram_from_psd(self, psd_data):
        """
        Simple spectrogram-style plot from PSD data
        """
        plt.plot(psd_data.frequencies, psd_data.powers)
        plt.grid(True)
        if self.has_display:
            plt.show()
        else:
            self._save_figure(plt.gcf(), "spectrogram_psd")
            plt.close()
    
    def plot_iq_spectrogram(self, samples, config):
        """
        Generate spectrogram from I/Q samples
        """
        frequencies, times, spectrogram = scipy.signal.spectrogram(
            samples, config.SAMPLE_RATE, nfft=config.FFT_SIZE,
            nperseg=config.FFT_SIZE)

        plt.pcolormesh(times, frequencies, 10*np.log10(spectrogram))
        plt.ylabel('Frequency [Hz]')
        plt.xlabel('Time [sec]')
        if self.has_display:
            plt.show()
        else:
            self._save_figure(plt.gcf(), "iq_spectrogram")
            plt.close()


class Plot(PlotterMixin):
    """
    Legacy plotting class for real-time plot updates
    Maintains compatibility with existing monitor.py code
    """
    
    def __init__(self, interactive):
        self.interactive = interactive
        self.has_display = self._check_display()
        # Backend already set at import time based on DISPLAY env var
        #if self.interactive:
        #    plt.ion()
        #plt.grid(True)
        # Marker for first plot
        self.graph = None
        self.axes = None
        self.plot_count = 0

    def plotPSD(self, psd):
        IQ = None
        if psd.individual_captures and len(psd.individual_captures) == 1:
            IQ = psd.individual_captures[0].get('IQ', None)
        if not self.graph:
            # First plot - create the figure
            if self.has_display and self.interactive:
                plt.ion()  # Turn on interactive mode for non-blocking plots
            
            # Add title with timestamp
            # Use timestamp from PSD data if available, otherwise current time
            timestamp = None
            if hasattr(psd, 'timestamps') and len(psd.timestamps) > 0:
                timestamp = psd.timestamps[0]
            
            title = self._format_title_with_timestamp('Real-time RF Spectrum Monitor', timestamp)

            if IQ:
                spectrogram = IQ.compute_spectrogram()
                self.graph, self.axes = plt.subplots(nrows=1, ncols=2, figsize=(12, 5))
                self.axes[0].plot(psd.frequencies, psd.powers)
                self.axes[0].set_xlabel('Frequency (MHz)')
                self.axes[0].set_ylabel('Power (dB)')
                #self.axes[0].grid(True)
                self.axes[1].pcolormesh(
                    spectrogram.frequencies, spectrogram.times, spectrogram.spectrogram,
                    cmap=plt.colormaps["viridis"], norm=None,
                    vmin=40 + np.min(spectrogram.spectrogram),
                    vmax=np.max(spectrogram.spectrogram) - 10)
                self.axes[1].set_xlabel('Frequency (MHz)')
                self.axes[1].set_ylabel('Time')
            else:
                self.graph, = plt.plot(psd.frequencies, psd.powers)
                plt.xlabel('Frequency (MHz)')
                plt.ylabel('Power (dB)')
                plt.grid(True)
                plt.title(title)
            
            if self.has_display and self.interactive:
                plt.pause(2.0)  # Small pause to render the plot
            else:
                self._save_figure(plt.gcf(), "realtime_spectrum", self.plot_count)
                self.plot_count += 1
                plt.close()
        else:
            # Subsequent plots - update existing plot or create new figure
            if self.has_display and self.interactive:
                # Update existing plot with new data
                if IQ:
                    self.axes[0].clear()
                    self.axes[0].plot(psd.frequencies, psd.powers)
                    self.axes[0].set_xlabel('Frequency (MHz)')
                    self.axes[0].set_ylabel('Power (dB)')
                    self.axes[1].clear()
                    if IQ:
                        spectrogram = IQ.compute_spectrogram()
                        self.axes[1].pcolormesh(
                            spectrogram.frequencies, spectrogram.times, spectrogram.spectrogram,
                            cmap=plt.colormaps["viridis"], norm=None,
                        vmin=40 + np.min(spectrogram.spectrogram),
                        vmax=np.max(spectrogram.spectrogram) - 10)
                else:
                    self.graph.set_ydata(psd.powers)
                
                # Update timestamp in title
                timestamp = None
                if hasattr(psd, 'timestamps') and len(psd.timestamps) > 0:
                    timestamp = psd.timestamps[0]
                
                title = self._format_title_with_timestamp('Real-time RF Spectrum Monitor', timestamp)
                plt.title(title)
                plt.draw()
                plt.pause(0.1)  # Small pause to update the display
            else:
                # Create new figure for headless mode
                plt.figure()
                plt.plot(psd.frequencies, psd.powers)
                
                # Add title with timestamp
                timestamp = None
                if hasattr(psd, 'timestamps') and len(psd.timestamps) > 0:
                    timestamp = psd.timestamps[0]
                
                title = self._format_title_with_timestamp('Real-time RF Spectrum Monitor', timestamp)
                plt.title(title)
                plt.xlabel('Frequency (MHz)')
                plt.ylabel('Power (dB)')
                plt.grid(True)
                
                self._save_figure(plt.gcf(), "realtime_spectrum", self.plot_count)
                self.plot_count += 1
                plt.close()
    
    def wait_for_close(self):
        """Wait for user input before closing the plot"""
        if self.has_display and self.interactive and self.graph:
            print("\nPress Enter to close the plot window...")
            input()
            plt.close('all')
