import numpy as np
import time
import os,sys
import uhd

# == DEFAULTS == #
# FC32 - float (float), complex 32 -> host data format
# SC16 - integer (int16_t), complex 16 -> USRP DSP data format (auto conversion happens in the streamer)
STREAM_ARGS = uhd.libpyuhd.usrp.stream_args("fc32", "sc16")

# delay before rx to allow for time sync
INIT_DELAY = 0.05 

# indicates the number of times we will allow recv to return zero samples before restarting the stream
ZERO_RX_SAMP_THRESHOLD = 50

# Notes on stream commands:
# * `stream_now = True` for multiple channels on a single streamer will fail to time align! 
# * too many consecutive stream_cmd calls may cause the device to lock up

class Device:
    def __init__(self, args, logger, config):
        self.args = args
        self.logger = logger
        self.config = config
        self.device_config = self.config.device_config
        self.is_bs = False
        self.txrx_receive_hack = None
        self.recv_frame_size = 8176
        self.num_recv_frames = 50
        self.num_samps_req = config.FFT_SIZE
        self.rate = config.MAX_INST_BW
        self.gain = config.GAIN
        self.debug = self.args.debug

        self.logger.info(str(self.device_config))
        
        #
        # devices looks like this:
        #
        # "devices" : {
        #    "nuc1:rf0" : {
        #       "name"     : "nuc1:rf0",
        #       "channels" : {"0" : "RX2", "1" : "RX2"},
        #    },
        # }
        #
        # Note though, we are not using more then a single channel.
        #
        self.devs     =  self.device_config["devices"]
        self.channels = []
        self.antennas = []
        for devname in self.devs:
            channels = self.devs[devname]["channels"]
            if len(self.channels) == 0:
                for chan in channels:
                    self.channels.append(int(chan))
                    self.antennas.append(channels[chan])
                    pass
            else:
                if len(channels) != len(self.channels):
                    raise Exception("Channel length mismatch in configuration")
                pass
            pass

        if "txrx_receive_hack" in self.device_config:
            self.txrx_receive_hack = self.device_config["txrx_receive_hack"]
            pass

        self.device_args  = "recv_frame_size=" + str(self.recv_frame_size)
        self.device_args += ",num_recv_frames=" + str(self.num_recv_frames)
        self.logger.info("device args: " + self.device_args)
        
        # setup USRP radio device
        self.usrp = uhd.usrp.MultiUSRP(self.device_args)

        # Watch for X310, gets different handling then a B210
        mboard = self.usrp.get_mboard_name(0)
        if mboard == "X310":
            self.logger.info("This is an X310, " +
                             "restarting with correct master clock rate")
            self.is_bs = True
            self.usrp = None
            self.device_args += ",master_clock_rate=184.32e6"
            self.usrp = uhd.usrp.MultiUSRP(self.device_args)
            pass            

        # set up time and frequency ref signals based on available sources
        try:
            if self.is_bs:
                self.usrp.set_clock_source("internal")
                self.usrp.set_time_source("internal")
                self.logger.info("using external reference source")
            else:
                self.usrp.set_clock_source("gpsdo")
                self.usrp.set_time_source("gpsdo")
                self.logger.info("using GPSDO reference source")
        except:
            self.usrp.set_clock_source("internal")
            self.usrp.set_clock_source("internal")
            self.logger.info("using internal reference source")

        # set to a default value to initialize the radio
        for chan in self.channels:
            self.usrp.set_rx_rate(self.rate, chan)
            self.usrp.set_rx_freq(uhd.libpyuhd.types.tune_request(2400e6), chan)
            self.usrp.set_rx_gain(self.gain, chan)

            self.usrp.set_rx_antenna(self.antennas[chan], chan)
        
            # Analog band width is not properly set when using two channels currently (likely a UHD bug)
            # reduce the analog bandwidth to ~90% of the sample rate. See radio/constants.py for more details.
            self.usrp.set_rx_bandwidth(config.ANALOG_BW, 0)
            self.usrp.set_rx_bandwidth(config.ANALOG_BW, 1)
            pass

        if self.txrx_receive_hack:
            self.logger.info("Issuing set_gpio_attr calls to the RF frontend")
            self.usrp.set_gpio_attr("FP0", "CTRL", 0)
            self.usrp.set_gpio_attr("FP0", "DDR", 0x10)
            self.usrp.set_gpio_attr("FP0", "OUT", 0x10)
            pass
    
        # initial set up of data streamer
        self.stream_args = STREAM_ARGS
        self.stream_args.channels = self.channels
        self.streamer = self.usrp.get_rx_stream(self.stream_args)
        self.dev_buffer_size = self.streamer.get_max_num_samps() # B210, USB3.0 with default buffer size -> 2040 samples per call to recv
    
        # The garbage buffer holds data immediately collected after a re-tune.
        # data here isn't trusted so we just throw it away
        self.garbage_buffer = np.empty((len(self.channels), self.dev_buffer_size), dtype=np.complex64)

        # this buffer holds all of the samples of interest
        self.default_result = np.empty((len(self.channels), self.num_samps_req), dtype=np.complex64) 
        
        # synchronize channels
        # TODO update this after getting GPSDO installed. Sync to GPS clock if possible
        self.usrp.set_time_unknown_pps(uhd.types.TimeSpec(0.0))
        #gps_time = self.get_gps_time()
        #self.usrp.set_time_next_pps(uhd.uhd.types.TimeSpec(gps_time+1))

        if self.args.no_restart_streamer:
            stream_cmd = self._get_stream_start_cmd()
            self.streamer.issue_stream_cmd(stream_cmd)
            pass

        # report the configured parameters
        self.logger.info("Actual gain: {}".format(self.usrp.get_rx_gain(0)))
        self.logger.info("Actual rate: {:.2f} MSps".format(self.usrp.get_rx_rate(0)/1e6))
        self.logger.info("Analog bandwidth set to {:.2f} MHz".format(self.usrp.get_rx_bandwidth()/1e6))
        self.logger.info("Keeping {} samples from FFT size {} ".format(self.config.N, self.config.FFT_SIZE))
        self.logger.info("USRP radio initialized successfully")        
        
        # handling tunes
        self.times = []
        self.end_collect_time = 0
        self.end_tune_time = 0

        # 50 packets (usually 2040*50=102000 samples or a millisecond worth of data (3.3ms)) typical time for a re-tune
        self.default_drop = 50*self.dev_buffer_size
        self.drops = []

        pass

    def _get_stream_start_cmd(self):
        """ Create a start stream cmd based on the current time
        """
        #self.usrp.set_time_unknown_pps(uhd.types.TimeSpec(0.0))
        stream_cmd = uhd.types.StreamCMD(uhd.types.StreamMode.start_cont)
        # TODO test this
        # indicate that we want N samples plus more later
        #stream_cmd = uhd.types.StreamCMD(uhd.types.StreamMode.num_more)
        
        # TODO test this
        stream_cmd.stream_now = True
        stream_cmd.time_spec = uhd.types.TimeSpec(self.usrp.get_time_now().get_real_secs()+INIT_DELAY)
        return stream_cmd

    def get_N_samples(self, frequency, N):
        """ Default method
        """
        N = int(N)
        # tune to the indicated center frequency
        for chan in self.channels:
            b = time.time()
            self.usrp.set_rx_freq(
                uhd.libpyuhd.types.tune_request(frequency, self.config.LO_OFFSET), chan)
            retune_count = 0
            while not self.usrp.get_rx_sensor("lo_locked"):
                time.sleep(0.01)
                retune_count += 1
                pass
            a = time.time()
            if self.debug:
                self.logger.debug("Tuning time: %f: %f (%d)" %
                                  (frequency, (a - b), retune_count))
                pass

        if self.args.no_restart_streamer:
            self.flush_stream()
        else:
            self.start_stream()
            pass

        rx_metadata = uhd.types.RXMetadata()
        num_accum_samps = 0
        result = np.empty((len(self.channels), N), dtype=np.complex64)
        recv_buffer = np.empty((len(self.channels), self.dev_buffer_size), dtype=np.complex64)
        
        received_zero_samples_x_times_in_a_row = 0
        prev_num_rx_samps = 0

        while num_accum_samps < N:
            # copy from internal buffer to local recv buffer
            num_rx_samps = self.streamer.recv(recv_buffer, rx_metadata)

            # catch hang condition where no samples are received
            if num_rx_samps == 0 and prev_num_rx_samps == 0: 
                received_zero_samples_x_times_in_a_row += 1
            elif num_rx_samps == 0 and prev_num_rx_samps != 0:
                received_zero_samples_x_times_in_a_row = 1
            if received_zero_samples_x_times_in_a_row >= ZERO_RX_SAMP_THRESHOLD:
                self.logger.info("received zero samples {} times in a row. Restarting the stream".format(received_zero_samples_x_times_in_a_row))
                received_zero_samples_x_times_in_a_row = 0
                self.restart_stream()
            # sometimes we timeout and need to restart an idle radio core
            if rx_metadata.error_code == uhd.types.RXMetadataErrorCode.late:
                self.logger.info("received a LATE error code. Restarting stream")
                self.restart_stream()

            # if we received > 0 samples, store them to the result array        
            if num_rx_samps:
                real_samps = min(N - num_accum_samps, num_rx_samps)
                result[:, num_accum_samps:num_accum_samps + real_samps] = recv_buffer[:, 0:real_samps]
            num_accum_samps += num_rx_samps

            prev_num_rx_samps = num_rx_samps

        if not self.args.no_restart_streamer:
            self.stop_stream()
            pass
        self.end_collect_time = time.time()
        return result

    def flush_stream(self):
        self.end_tune_time = time.time()

        # data in the buffer is invalid (we didn't stop streaming during the re-tune), so empty
        # this could be avoided if the frequency didn't change. (TODO)
        recvd = 0
        received_zero_samples_x_times_in_a_row = 0
        prev_num_rx_samps = 0
        rx_metadata = uhd.types.RXMetadata()

        # drop packets that are invalid due to re-tune. Number of samples are proportional to re-tune time
        #import pdb; pdb.set_trace()
        tdiff = self.end_tune_time-self.end_collect_time
        #if tdiff > 0 and tdiff < 1:
        #    self.times.append(tdiff)
        num_samples_to_drop = self.rate * (tdiff) 
        if num_samples_to_drop < 0 or num_samples_to_drop > self.rate: 
            num_samples_to_drop = self.default_drop
        #self.drops.append(num_samples_to_drop)

        while recvd < (num_samples_to_drop):
            num_recvd = self.streamer.recv(self.garbage_buffer, rx_metadata)
            recvd += num_recvd

            # catch hang condition where no samples are received
            if num_recvd == 0 and prev_num_rx_samps == 0: 
                received_zero_samples_x_times_in_a_row += 1
            elif num_recvd == 0 and prev_num_rx_samps != 0:
                received_zero_samples_x_times_in_a_row = 1
            if received_zero_samples_x_times_in_a_row >= ZERO_RX_SAMP_THRESHOLD:
                self.logger.info("received zero samples {} times in a row. Restarting the stream".format(received_zero_samples_x_times_in_a_row))
                received_zero_samples_x_times_in_a_row = 0
                self.restart_stream()
            # sometimes we timeout and need to restart an idle radio core
            if rx_metadata.error_code == uhd.types.RXMetadataErrorCode.late: 
                self.logger.info("received a LATE error code. Restarting stream")
                self.restart_stream()
        pass

    def stop_stream(self):
        """ Sends a stop command to the FPGA. Subsequent calls to recv will return 0 bytes.
        """
        stream_stop_cmd = uhd.types.StreamCMD(uhd.types.StreamMode.stop_cont)
        self.streamer.issue_stream_cmd(stream_stop_cmd)

    def start_stream(self):
        """ Sends a start command to the FPGA.
        """
        start_cmd = self._get_stream_start_cmd()
        self.streamer.issue_stream_cmd(start_cmd)

    def restart_stream(self):
        """ Stops, then restarts streaming on the FPGA

        # TODO: test calling recv until timeout occurs. Should empty internal buffers.
        """
        self.stop_stream()
        self._recreate_streamer()
        start_cmd = self._get_stream_start_cmd()
        self.streamer.issue_stream_cmd(start_cmd)

    def _recreate_streamer(self):
        """ using the default stream args, re-create the streamer
        This should force a clear of the buffer
        """
        self.streamer = self.usrp.get_rx_stream(self.stream_args)

    def setGain(self, gain):
        self.gain = gain
        for chan in self.channels:
            self.usrp.set_rx_gain(self.gain, chan)
            self.logger.info("Actual gain now: {}".format(self.usrp.get_rx_gain(chan)))
            pass
        pass

    pass

