import numpy as np
from typing import Iterator, Literal, Optional
import warnings
import matplotlib.pyplot as plt


class PSDStitcher:
    """
    PSD stitcher
    
    Doesn't care about frequencies, just bins. Assumes bins are evenly spaced
    and frequency spacing remains constant.
    """
    
    @staticmethod
    def _is_power_of_two(n: int) -> bool:
        """Check if a number is a power of 2"""
        return n > 0 and (n & (n - 1)) == 0

    def __init__(
        self, 
        edge_bin_divisor: int = 8, 
        mid_bin_divisor: int = 8,
        dtype: type = np.float32,
        no_masking: bool = False,
        overlap_ratio: Optional[float] = None
    ):
        """
        Initialize the PSD stitcher

        Args:
            edge_bin_divisor: Divisor to determine number of edge bins (must be power of 2)
            mid_bin_divisor: Divisor to determine number of middle bins (must be power of 2)
            dtype: Data type for calculations (default: float32 for memory efficiency)
            no_masking: If True, disable all masking and simply concatenate PSDs
            overlap_ratio: Fraction of overlap between adjacent PSDs (0.0 = no overlap, 0.5 = 50% overlap)
                          If provided, will use general overlap stitching instead of divisor-based masking
            
        Raises:
            ValueError: If divisors aren't powers of 2 (when masking is enabled)
        """
        self.overlap_ratio = overlap_ratio
        self.use_overlap_stitching = overlap_ratio is not None
        self.no_masking = no_masking
        
        # If using overlap-based stitching, override no_masking
        if self.use_overlap_stitching:
            self.no_masking = overlap_ratio == 0.0
        
        if not self.no_masking and not self.use_overlap_stitching:
            if not self._is_power_of_two(edge_bin_divisor):
                raise ValueError(f"edge_bin_divisor ({edge_bin_divisor}) must be a power of 2")
                
            if not self._is_power_of_two(mid_bin_divisor):
                raise ValueError(f"mid_bin_divisor ({mid_bin_divisor}) must be a power of 2")
            
        self.edge_bin_divisor = edge_bin_divisor
        self.mid_bin_divisor = mid_bin_divisor
        self.dtype = dtype
        self._masks_initialized = False

    def _init_masks(self, num_bins: int) -> None:
        """Initialize masks based on input array dimensions"""
        self.num_bins = num_bins
        
        if self.no_masking:
            # No masking mode: just concatenate PSDs with no overlap
            self.edge_bins = 0
            self.mid_bins = 0
            self.shift = num_bins  # No overlap, each PSD is fully shifted
            
            # Create masks that keep everything (all ones)
            self.left_mask = np.ones(num_bins, dtype=np.uint8)
            self.right_mask = np.ones(num_bins, dtype=np.uint8)
            self.middle_mask = np.ones(num_bins, dtype=np.uint8)
            self.current_mask = np.ones(num_bins, dtype=np.uint8)
            self.full_mask = np.ones(num_bins, dtype=np.uint8)
            self.next_mask = np.ones(num_bins, dtype=np.uint8)
        else:
            # Normal masking mode
            if not self._is_power_of_two(num_bins):
                raise ValueError(f"Number of bins ({num_bins}) must be a power of 2")
                
            self.edge_bins = num_bins // self.edge_bin_divisor
            self.mid_bins = num_bins // self.mid_bin_divisor
            #TODO this is not correct
            self.shift = 7 * num_bins // 32
            
            # Create masks using uint8 for memory efficiency
            self.left_mask = np.ones(num_bins, dtype=np.uint8)
            self.left_mask[:self.edge_bins] = 0
            
            self.right_mask = self.left_mask[::-1].copy()
            
            self.middle_mask = np.ones(num_bins, dtype=np.uint8)
            mid_start = num_bins // 2 - self.mid_bins // 2
            self.middle_mask[mid_start:mid_start + self.mid_bins] = 0
            
            # Validate mask overlap
            self._validate_masks()
            
            # Pre-compute combined masks
            self.current_mask = self.right_mask & self.middle_mask
            self.full_mask = self.left_mask & self.right_mask & self.middle_mask
            self.next_mask = self.left_mask & self.middle_mask
        
        self._masks_initialized = True
    
    def _get_combine_function(self, method: Literal['mean', 'max', 'min']):
        """Get the combination function for the specified method"""
        valid_methods = {
            'mean': lambda a, b: np.nanmean([a, b], axis=0),
            'max': lambda a, b: np.nanmax([a, b], axis=0),
            'min': lambda a, b: np.nanmin([a, b], axis=0)
        }
        if method not in valid_methods:
            raise ValueError(f"Invalid method: {method}. Must be one of {list(valid_methods.keys())}")
        return valid_methods[method]
    
    def _process_overlap_region(self, result, row, current_start, current_end, 
                                overlap_start, overlap_end, next_mask, combine_func):
        """Process the overlapping region between current and next segments"""
        overlap_region_start = max(current_start, overlap_start)
        overlap_region_end = min(current_end, overlap_end)
        
        # Calculate relative indices for the overlap region
        current_rel_start = overlap_region_start - current_start
        current_rel_end = overlap_region_end - current_start
        next_rel_start = overlap_region_start - overlap_start
        next_rel_end = overlap_region_end - overlap_start
        
        # Extract values with masks
        current_vals = np.where(
            self.current_mask[current_rel_start:current_rel_end],
            result[overlap_region_start:overlap_region_end],
            np.nan
        )
        
        next_vals = np.where(
            next_mask[next_rel_start:next_rel_end],
            row[next_rel_start:next_rel_end],
            np.nan
        )
        
        # Combine and write
        result[overlap_region_start:overlap_region_end] = combine_func(current_vals, next_vals)
    
    def _process_non_overlap_region(self, result, row, current_end, overlap_start, 
                                    overlap_end, next_mask):
        """Process the non-overlapping region of the next segment"""
        if overlap_end < current_end + self.shift:
            non_overlap_start = overlap_end
            non_overlap_end = current_end + self.shift
            rel_start = non_overlap_start - overlap_start
            rel_end = rel_start + (non_overlap_end - non_overlap_start)
            
            result[non_overlap_start:non_overlap_end] = np.where(
                next_mask[rel_start:rel_end],
                row[rel_start:rel_end],
                np.nan
            )

    def _validate_masks(self) -> None:
        """Validate that masks will not create gaps in output"""
        # Skip validation when no masking is enabled
        if self.no_masking:
            return
            
        current_mask = self.right_mask & self.middle_mask
        next_mask = self.left_mask & self.middle_mask
        
        # Simulate shift
        next_mask_shifted = np.zeros(self.num_bins, dtype=np.uint8)
        if self.shift < self.num_bins:
            next_mask_shifted[self.shift:] = next_mask[:self.num_bins - self.shift]
        
        # Check for gaps
        combined = current_mask | next_mask_shifted
        if not np.all(combined):
            gap_indices = np.where(~combined)[0]
            raise ValueError(
                f"Divisors would create gaps at indices {gap_indices}. "
                f"Try smaller divisors to ensure mask overlap."
            )

    def stitch_streaming(
        self,
        psd_iterator: Iterator[np.ndarray],
        num_rows: int,
        method: Literal['mean', 'max', 'min'] = 'mean'
    ) -> np.ndarray:
        """
        Stream-process PSD data with minimal memory usage
        
        Args:
            psd_iterator: Iterator yielding PSD rows one at a time
            num_rows: Total number of rows to process
            method: Combination method ('mean', 'max', 'min')
            
        Returns:
            Stitched PSD array
        """
        # Initialize with first row
        first_row = next(psd_iterator)
        num_bins = first_row.shape[0]
        
        if not self._masks_initialized:
            self._init_masks(num_bins)
        
        # Allocate output array
        output_size = num_bins + self.shift * (num_rows - 1)
        result = np.full(output_size, np.nan, dtype=self.dtype)
        
        # Apply appropriate mask to first row (should use right_mask & middle_mask = current_mask)
        masked_first_row = np.where(self.current_mask, first_row.astype(self.dtype), np.nan)
        result[:num_bins] = masked_first_row
        
        # Get combination function
        combine_func = self._get_combine_function(method)
        
        # Process remaining rows
        for i, row in enumerate(psd_iterator, 1):
            if i >= num_rows:
                break
                
            row = row.astype(self.dtype)
            
            # Calculate positions for next row
            next_start = i * self.shift
            next_end = next_start + num_bins
            
            # Determine mask for this iteration
            next_mask = self.full_mask if i < num_rows - 1 else self.next_mask
            
            # Apply mask and place in result
            masked_row = np.where(next_mask, row, np.nan)
            
            # Handle overlap region
            if next_start < result.shape[0]:
                overlap_end = min(next_end, result.shape[0])
                overlap_size = overlap_end - next_start
                
                # Get existing values in overlap region
                existing = result[next_start:overlap_end]
                new_values = masked_row[:overlap_size]
                
                # Combine values where both exist
                combined = combine_func(existing, new_values)
                result[next_start:overlap_end] = combined
            
            # Handle non-overlap region (extension)
            if next_end > result.shape[0]:
                extension_start = result.shape[0]
                extension_size = next_end - extension_start
                row_start_idx = extension_start - next_start
                
                # Extend result array if needed
                if next_end > result.shape[0]:
                    new_result = np.full(next_end, np.nan, dtype=self.dtype)
                    new_result[:result.shape[0]] = result
                    result = new_result
                
                # Add new values
                result[extension_start:next_end] = masked_row[row_start_idx:row_start_idx + extension_size]
        
        return result

    def stitch(
        self,
        psd_array_2d: np.ndarray,
        method: Literal['mean', 'max', 'min'] = 'mean'
    ) -> np.ndarray:
        """
        Memory-efficient version of original stitch method
        
        Args:
            psd_array_2d: 2D array of shape (num_rows, num_bins)
            method: Combination method
            
        Returns:
            Stitched PSD output
        """
        if not isinstance(psd_array_2d, np.ndarray):
            raise TypeError("Input must be a numpy array")
            
        if len(psd_array_2d.shape) != 2:
            raise ValueError("Input must be a 2D array")
        
        num_rows, num_bins = psd_array_2d.shape
        
        # Convert to efficient dtype if needed
        if psd_array_2d.dtype != self.dtype:
            warnings.warn(f"Converting input from {psd_array_2d.dtype} to {self.dtype}")
        
        # Initialize masks if needed
        if not self._masks_initialized:
            self._init_masks(num_bins)
        
        # Simple concatenation for no_masking mode
        if self.no_masking:
            # Just concatenate all PSDs without overlap
            result = psd_array_2d.flatten().astype(self.dtype)
            return result
        
        # Create row iterator for normal masking mode
        def row_iterator():
            for i in range(num_rows):
                yield psd_array_2d[i]
        
        return self.stitch_streaming(
            row_iterator(),
            num_rows,
            method
        )

    def stitch_with_overlap_ratio(
        self,
        psd_array_2d: np.ndarray,
        overlap_ratio: float,
        method: Literal['mean', 'max', 'min'] = 'mean'
    ) -> np.ndarray:
        """
        Stitch PSDs using overlap ratio to determine bin ranges
        
        Args:
            psd_array_2d: 2D array of shape (num_rows, num_bins)
            overlap_ratio: Fraction of overlap between adjacent PSDs (0.0 to 1.0)
            method: Combination method for overlapping regions
            
        Returns:
            Stitched PSD array
        """
        if not isinstance(psd_array_2d, np.ndarray):
            raise TypeError("Input must be a numpy array")
            
        if len(psd_array_2d.shape) != 2:
            raise ValueError("Input must be a 2D array")
        
        if not 0.0 <= overlap_ratio < 1.0:
            raise ValueError("overlap_ratio must be between 0.0 and 1.0")
        
        num_rows, num_bins = psd_array_2d.shape
        
        # Handle trivial cases
        if num_rows == 0:
            return np.array([], dtype=self.dtype)
        if num_rows == 1:
            return psd_array_2d[0].astype(self.dtype)
        
        # For no overlap, simple concatenation
        if overlap_ratio == 0.0:
            result = psd_array_2d.flatten().astype(self.dtype)
            return result
        
        # Check for 50% overlap optimization
        if abs(overlap_ratio - 0.5) < 0.01:
            return self._stitch_50_percent_overlap(psd_array_2d, method)
        
        # General overlap case - calculate bin ranges
        combine_func = self._get_combine_function(method)
        
        # Calculate how many bins to take from each position
        overlap_bins = int(overlap_ratio * num_bins)
        step_bins = num_bins - overlap_bins
        
        # Calculate output size
        output_size = num_bins + step_bins * (num_rows - 1)
        result = np.full(output_size, np.nan, dtype=self.dtype)
        
        # Place first PSD completely
        result[:num_bins] = psd_array_2d[0].astype(self.dtype)
        
        # Process remaining PSDs with overlap
        for i in range(1, num_rows):
            start_pos = i * step_bins
            end_pos = start_pos + num_bins
            
            # Handle overlap region
            overlap_start = max(start_pos, 0)
            overlap_end = min(end_pos, result.shape[0])
            existing_end = min(start_pos + overlap_bins, result.shape[0])
            
            if overlap_start < existing_end:
                # Combine overlapping region
                psd_start_idx = overlap_start - start_pos
                psd_end_idx = existing_end - start_pos
                
                existing_vals = result[overlap_start:existing_end]
                new_vals = psd_array_2d[i, psd_start_idx:psd_end_idx].astype(self.dtype)
                result[overlap_start:existing_end] = combine_func(existing_vals, new_vals)
            
            # Handle non-overlapping region
            if existing_end < end_pos and existing_end < result.shape[0]:
                non_overlap_start = existing_end
                non_overlap_end = min(end_pos, result.shape[0])
                psd_start_idx = non_overlap_start - start_pos
                psd_end_idx = non_overlap_end - start_pos
                
                result[non_overlap_start:non_overlap_end] = psd_array_2d[i, psd_start_idx:psd_end_idx].astype(self.dtype)
        
        return result

    def _stitch_50_percent_overlap(
        self,
        psd_array_2d: np.ndarray, 
        method: Literal['mean', 'max', 'min'] = 'mean'
    ) -> np.ndarray:
        """
        Optimized stitching for 50% overlap case
        
        This is moved from receiver.py and optimized for the common 50% overlap scenario.
        For 50% overlap, we take:
        - First PSD: first 3/4 of bins (avoid overlap on right)
        - Middle PSDs: middle 1/2 of bins (avoid overlaps on both sides)  
        - Last PSD: last 3/4 of bins (avoid overlap on left)
        """
        num_rows, num_bins = psd_array_2d.shape
        
        result = []
        quarter_bins = num_bins // 4
        three_quarter_bins = 3 * quarter_bins
        
        if num_rows == 1:
            # Single PSD - use entire thing
            result.extend(psd_array_2d[0, :])
        else:
            # First PSD: take first 3/4 (avoid overlap on right)
            result.extend(psd_array_2d[0, :three_quarter_bins])
            
            # Middle PSDs: take middle half (avoid overlaps on both sides)
            for i in range(1, num_rows - 1):
                result.extend(psd_array_2d[i, quarter_bins:three_quarter_bins])
            
            # Last PSD: take last 3/4 (avoid overlap on left)
            result.extend(psd_array_2d[-1, quarter_bins:])
        
        return np.array(result, dtype=self.dtype)


def test_stitching_with_ones():
    """Test stitching algorithm with arrays of 1.0 to verify correctness"""
    print("Testing stitching with arrays of 1.0...")
    
    # Create test data - all values are 1.0
    num_rows, num_bins = 5, 64
    test_data = np.ones((num_rows, num_bins), dtype=np.float32)
    
    print(f"Input shape: {test_data.shape}")
    print(f"Input values: all 1.0")
    
    # Test with different methods and masking modes
    for no_masking in [False, True]:
        print(f"\n{'='*50}")
        print(f"Testing with no_masking={no_masking}")
        print(f"{'='*50}")
        
        stitcher = PSDStitcher(edge_bin_divisor=8, mid_bin_divisor=8, no_masking=no_masking)
        
        for method in ['mean', 'max', 'min']:
            print(f"\nTesting with method: {method}")
            
            # Run the stitcher
            result = stitcher.stitch(test_data, method=method)
            
            print(f"Output shape: {result.shape}")
            
            # Check for NaN values
            nan_count = np.isnan(result).sum()
            print(f"NaN values: {nan_count}")
            
            if nan_count > 0:
                print("WARNING: Found NaN values in output!")
                nan_indices = np.where(np.isnan(result))[0]
                print(f"NaN indices: {nan_indices[:10]}...")  # Show first 10
            
            # Check if values are close to 1.0 (excluding NaN)
            valid_mask = ~np.isnan(result)
            if valid_mask.any():
                valid_values = result[valid_mask]
                min_val = np.min(valid_values)
                max_val = np.max(valid_values)
                mean_val = np.mean(valid_values)
                
                print(f"Valid values - Min: {min_val:.6f}, Max: {max_val:.6f}, Mean: {mean_val:.6f}")
                
                # Check if values are reasonably close to 1.0
                tolerance = 1e-6
                if np.allclose(valid_values, 1.0, atol=tolerance):
                    print("✓ All valid values are close to 1.0")
                else:
                    print("✗ Some values are not close to 1.0")
                    # Show some examples of problematic values
                    diff_from_one = np.abs(valid_values - 1.0)
                    max_diff_idx = np.argmax(diff_from_one)
                    print(f"Max difference from 1.0: {diff_from_one[max_diff_idx]:.8f}")
            else:
                print("ERROR: No valid values found!")
    
    print("\nTest completed.")
    return result


def test_no_masking_simple():
    """Simple test to verify no_masking mode works correctly"""
    print("Testing no_masking mode...")
    
    # Use same data as other test - all values are 1.0
    num_rows, num_bins = 5, 64
    test_data = np.ones((num_rows, num_bins), dtype=np.float32)
    
    print(f"Input shape: {test_data.shape}")
    print(f"Input values: all 1.0")
    
    # Test no_masking mode
    stitcher_no_mask = PSDStitcher(no_masking=True)
    result_no_mask = stitcher_no_mask.stitch(test_data)
    print(f"No masking result shape: {result_no_mask.shape}")
    print(f"No masking result length: {len(result_no_mask)} (expected: {num_rows * num_bins})")
    print(f"All values are 1.0: {np.allclose(result_no_mask, 1.0)}")
    
    # Test normal masking mode
    stitcher_mask = PSDStitcher(edge_bin_divisor=8, mid_bin_divisor=8, no_masking=False)
    result_mask = stitcher_mask.stitch(test_data)
    print(f"With masking result shape: {result_mask.shape}")
    print(f"With masking result length: {len(result_mask)}")
    
    return result_no_mask, result_mask


def visualize_stitching_process():
    """Create a matplotlib heatmap showing the stitching process"""
    
    # Create simple test data with different values for each row
    num_rows, num_bins = 5, 64
    test_data = np.zeros((num_rows, num_bins), dtype=np.float32)
    
    # Fill each row with a different constant value to see the effect
    for i in range(num_rows):
        test_data[i, :] = i + 1  # Row 0 = 1.0, Row 1 = 2.0, etc.
    
    print(f"Input shape: {test_data.shape}")
    print("Input values: Row 0=1.0, Row 1=2.0, Row 2=3.0, Row 3=4.0, Row 4=5.0")
    
    # Create stitcher and get masks
    stitcher = PSDStitcher(edge_bin_divisor=8, mid_bin_divisor=8)
    
    # Initialize to get masks
    stitcher._init_masks(num_bins)
    
    # Get the final result
    result = stitcher.stitch(test_data, method='mean')
    
    # Create visualization
    fig, axes = plt.subplots(3, 2, figsize=(15, 12))
    fig.suptitle('PSD Stitching Process Visualization', fontsize=16)
    
    # Plot 1: Original input data
    ax = axes[0, 0]
    im1 = ax.imshow(test_data, aspect='auto', cmap='viridis', interpolation='nearest')
    ax.set_title('Original Input PSDs')
    ax.set_xlabel('Frequency Bins')
    ax.set_ylabel('PSD Index')
    plt.colorbar(im1, ax=ax)
    
    # Plot 2: Show masks
    ax = axes[0, 1]
    mask_data = np.zeros((4, num_bins))
    mask_data[0, :] = stitcher.left_mask
    mask_data[1, :] = stitcher.right_mask  
    mask_data[2, :] = stitcher.middle_mask
    mask_data[3, :] = stitcher.full_mask
    
    im2 = ax.imshow(mask_data, aspect='auto', cmap='RdYlBu', interpolation='nearest')
    ax.set_title('Masks (Blue=Keep, Red=Remove)')
    ax.set_xlabel('Frequency Bins')
    ax.set_yticks(range(4))
    ax.set_yticklabels(['Left', 'Right', 'Middle', 'Full'])
    plt.colorbar(im2, ax=ax)
    
    # Plot 3: Show positioned PSDs
    ax = axes[1, 0]
    output_len = len(result)
    positioned_data = np.full((num_rows, output_len), np.nan)
    
    # Place each PSD in its shifted position
    for i in range(num_rows):
        start_pos = i * stitcher.shift
        end_pos = start_pos + num_bins
        if end_pos <= output_len:
            positioned_data[i, start_pos:end_pos] = test_data[i, :]
    
    im3 = ax.imshow(positioned_data, aspect='auto', cmap='viridis', interpolation='nearest')
    ax.set_title('PSDs Positioned with Shifts')
    ax.set_xlabel('Output Position')
    ax.set_ylabel('PSD Index')
    plt.colorbar(im3, ax=ax)
    
    # Plot 4: Show masked PSDs
    ax = axes[1, 1]
    masked_data = np.full((num_rows, output_len), np.nan)
    
    for i in range(num_rows):
        start_pos = i * stitcher.shift
        end_pos = start_pos + num_bins
        if end_pos <= output_len:
            if i == 0:
                # First row uses right + middle mask
                mask = stitcher.right_mask & stitcher.middle_mask
            elif i == num_rows - 1:
                # Last row uses left + middle mask  
                mask = stitcher.left_mask & stitcher.middle_mask
            else:
                # Middle rows use full mask
                mask = stitcher.full_mask
            
            masked_values = np.where(mask, test_data[i, :], np.nan)
            masked_data[i, start_pos:end_pos] = masked_values
    
    im4 = ax.imshow(masked_data, aspect='auto', cmap='viridis', interpolation='nearest')
    ax.set_title('PSDs After Masking')
    ax.set_xlabel('Output Position')
    ax.set_ylabel('PSD Index')
    plt.colorbar(im4, ax=ax)
    
    # Plot 5: Final result
    ax = axes[2, 0]
    result_2d = result.reshape(1, -1)
    im5 = ax.imshow(result_2d, aspect='auto', cmap='viridis', interpolation='nearest')
    ax.set_title('Final Stitched Output')
    ax.set_xlabel('Output Position')
    ax.set_ylabel('Result')
    plt.colorbar(im5, ax=ax)
    
    # Plot 6: Show overlap regions
    ax = axes[2, 1]
    overlap_map = np.zeros(output_len)
    
    for i in range(num_rows - 1):
        current_start = i * stitcher.shift
        current_end = current_start + num_bins
        next_start = (i + 1) * stitcher.shift
        next_end = next_start + num_bins
        
        overlap_start = max(current_start, next_start)
        overlap_end = min(current_end, next_end)
        
        if overlap_start < overlap_end:
            overlap_map[overlap_start:overlap_end] += 1
    
    im6 = ax.plot(overlap_map, 'r-', linewidth=2)
    ax.set_title('Overlap Count (How Many PSDs Contribute)')
    ax.set_xlabel('Output Position')
    ax.set_ylabel('Number of Contributing PSDs')
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    
    # Create .tmpimg directory if it doesn't exist
    import os
    os.makedirs('.tmpimg', exist_ok=True)
    filepath = '.tmpimg/psd_stitching_visualization.png'
    plt.savefig(filepath, dpi=150, bbox_inches='tight')
    # plt.show()
    
    print(f"Visualization saved as '{filepath}'")
    print(f"Final result shape: {result.shape}")
    print(f"NaN count in result: {np.isnan(result).sum()}")
    
    return result




if __name__ == "__main__":
    # Run visualization first
    print("=" * 50)
    visualize_stitching_process()
    
    # Run test with ones
    print("\n" + "=" * 50)
    test_stitching_with_ones()
    
    print("\n" + "=" * 50)
    print("Running no masking test...")
    test_no_masking_simple()
    
