#!/bin/python

import time
import uuid
import os
import sys
import signal
import traceback
import argparse
import logging
import datetime
import base64
import select
import random
import datetime
import json
import struct
import asyncio
import math

from . import DefaultDestEnvAction
from zmsclient.dst.client import ZmsDstClient
from zmsclient.dst.v1.models import (Observation, ObservationKind)
from zmsclient.zmc.client import ZmsZmcClient
from zmsclient.zmc.v1.types import UNSET
from zmsclient.zmc.v1.models import Subscription, EventFilter, Grant
from zmsclient.common.subscription import ZmsSubscriptionCallback

LOG = logging.getLogger(__name__)

# Only when a daemon
LOGFILE = "/local/logs/zms-monitor-sim.log"

class FreqRangeDesc:
    def __init__(self, min, max, power, power_factor=None, grant_id=None,
                 lat=None, lon=None, dist=None, enabled=True):
        self.min = min
        self.max = max
        self.power = power
        self.power_factor = power_factor
        self.grant_id = grant_id
        self.lat = lat
        self.lon = lon
        self.dist = dist
        self.enabled = enabled

    @classmethod
    def from_string(cls, s):
        sa = s.split(",")
        if len(sa) != 4:
            raise ValueError("invalid frequency range descriptor")
        return FreqRangeDesc(*[float(x) for x in sa])

    def __repr__(self):
        return (
            f"FreqRangeDesc(min={self.min},max={self.max},power={self.power},"
            f"power_factor={self.power_factor},grant_id={self.grant_id},"
            f"lat={self.lat},lon={self.lon},dist={self.dist},"
            f"enabled={self.enabled})")

    def violate(self, shift, gain):
        return FreqRangeDesc(
            self.min + shift, self.max + shift, self.power + gain,
            power_factor=self.power_factor, grant_id=self.grant_id,
            lat=self.lat, lon=self.lon, dist=self.dist, enabled=self.enabled)

class Simulator:
    def __init__(self, dst_client, zmc_client, monitor_id, element_id,
                 band, step_size, grants=[], incumbents=[],
                 type="openzms.monitorsim.v1.psd", format="sigmf",
                 labels=None, description="", lat=None, lon=None,
                 violation_shift=None, violation_gain=None,
                 interference_power=None,
                 interval=10, impotent=False):
        self.dst_client = dst_client
        self.zmc_client = zmc_client
        self.monitor_id = monitor_id
        self.element_id = element_id
        self.band = band
        self.step_size = step_size
        self.grants = grants
        self.incumbents = incumbents
        self.type = type
        self.format = format
        self.labels = labels
        self.description = description
        self.violation_shift = violation_shift
        self.violation_gain = violation_gain
        self.interference_power = interference_power
        self.interval = interval
        self.lat = lat
        self.lon = lon
        self.impotent = impotent

        self._violations_enabled = False
        self.dynamic_grants = {}
        self.locations = []

    async def run(self):
        while True:
            self.create_observation()
            await asyncio.sleep(self.interval)

    def start_grant(self, grant: Grant):
        if not grant.constraints or len(grant.constraints) != 1:
            LOG.warning("start_grant: only support single constraint grants, ignoring: %r", grant)
            return
        if not grant.radio_ports or len(grant.radio_ports) != 1:
            LOG.warning("start_grant: only support single radio port grants, ignoring: %r", grant)
            return
        if grant.constraints[0].constraint.max_eirp in [None, UNSET]:
            LOG.warning("start_grant: no max_eirp on grant constraint, ignoring: %r", grant)
            return
        (glat, glon, gdist) = (None, None, None)
        if self.lat or self.lon:
            loc_id = grant.radio_ports[0].radio_port.antenna_location_id
            loc = self.zmc_client.get_location(location_id=loc_id)
            LOG.debug("start_grant: grant location: %r", loc)
            if loc.srid != 4326:
                LOG.warning("start_grant: location CRS is not EPSG:4326, ignoring: %r (%r)", loc, grant.id)
                return
            (glat, glon) = (loc.y, loc.x)
            gdist = self.dist(self.lat, self.lon, glat, glon)

        LOG.debug("started grant %s: %r (dist %r)", grant.id, grant, gdist)
        frd = FreqRangeDesc(
            min=grant.constraints[0].constraint.min_freq,
            max=grant.constraints[0].constraint.max_freq,
            power=grant.constraints[0].constraint.max_eirp,
            power_factor=0.05,
            grant_id=grant.id, lat=glat, lon=glon, dist=gdist)
        LOG.debug("start_grant: frd: %r", frd)
        self.dynamic_grants[grant.id] = frd

    def stop_grant(self, grant: Grant):
        if grant.id in self.dynamic_grants:
            self.dynamic_grants[grant.id].enabled = False
            LOG.debug("stopped grant %s", grant.id)

    def remove_grant(self, grant: Grant):
        if grant.id in self.dynamic_grants:
            del self.dynamic_grants[grant.id]
            LOG.debug("removed grant %s", grant.id)

    def dist(self, lat1, lon1, lat2, lon2):
        """Haversine distance calculation."""
        R = 6371000.0
        phi1 = math.radians(lat1)
        phi2 = math.radians(lat2)
        delta_phi = math.radians(lat2 - lat1)
        delta_lambda = math.radians(lon2 - lon1)
        a = math.sin(delta_phi / 2.0) ** 2 + math.cos(phi1) * math.cos(phi2) * math.sin(delta_lambda / 2.0) ** 2
        c = 2 * math.atan2(math.sqrt(a), math.sqrt(1 - a))
        d = R * c
        return d

    def fspl(self, freq, dist):
        """Free-space path loss calculation."""
        if dist <= 0:
            # This should be an error, but just return 0 for now; this is
            # not a high fidelity simulation; we don't care about the near
            # field.j
            return 0.0
        c = 299792458.0
        fspl_db = 20.0 * math.log10(dist) + 20.0 * math.log10(freq) + 20.0 * math.log10(4.0 * math.pi / c)
        return fspl_db

    #
    # Create an observation.
    #
    def create_observation(self):
        o = Observation(
            element_id=self.element_id, monitor_id=self.monitor_id,
            kind=ObservationKind.PSD, types=self.type, format_=self.format,
            description=self.description, min_freq=int(self.band.min),
            max_freq=int(self.band.max),
            starts_at=datetime.datetime.now(datetime.timezone.utc),
            violation=self._violations_enabled)
        if self.labels:
            o.labels = self.labels

        LOG.info("Creating observation: %r", o)
        if self.impotent:
            return

        txlist = [*self.grants, *self.dynamic_grants.values()]
        if self._violations_enabled and (self.violation_shift or self.violation_gain):
            txlist = [ x.violate(self.violation_shift, self.violation_gain) for x in txlist ]

        powers = []

        for tx in txlist:
            if not tx.enabled:
                continue
            LOG.debug("including tx: %r", tx)
        for tx in self.incumbents:
            if not tx.enabled:
                continue
            LOG.debug("including tx (incumbent): %r", tx)

        center = self.band.min + (self.band.max - self.band.min) / 2
        is_interference = False
        for fi in range(int(self.band.min), int(self.band.max) + 1, int(self.step_size)):
            is_violation = 0
            power = self.band.power + (random.random() * 2 - 1) * self.band.power * self.band.power_factor
            for tx in txlist:
                if not tx.enabled:
                    continue
                if fi < tx.min or fi > tx.max:
                    continue
                fspl_db = 0.0
                if tx.dist:
                    fspl_db = self.fspl(fi, tx.dist)
                    #LOG.debug("fspl for freq %r dist %r: %r", fi, tx.dist, fspl_db)
                txpower = tx.power - fspl_db + (random.random() * 2 - 1) * tx.power * tx.power_factor
                if fi < (tx.min + self.step_size):
                    txpower -= 0.75 * (txpower - self.band.power)
                elif fi < (tx.min + self.step_size * 2):
                    txpower -= 0.25 * (txpower - self.band.power)
                elif fi > (tx.max - self.step_size):
                    txpower -= 0.75 * (txpower - self.band.power)
                elif fi > (tx.max - self.step_size * 2):
                    txpower -= 0.25 * (txpower - self.band.power)
                power = max(txpower, power)
                if self._violations_enabled:
                    is_violation = 1
            for tx in self.incumbents:
                if fi < tx.min or fi > tx.max:
                    continue
                txpower = tx.power + (random.random() * 2 - 1) * tx.power * tx.power_factor
                if fi < (tx.min + self.step_size):
                    txpower -= 0.75 * (txpower - self.band.power)
                elif fi < (tx.min + self.step_size * 2):
                    txpower -= 0.25 * (txpower - self.band.power)
                elif fi > (tx.max - self.step_size):
                    txpower -= 0.75 * (txpower - self.band.power)
                elif fi > (tx.max - self.step_size * 2):
                    txpower -= 0.25 * (txpower - self.band.power)
                power = max(txpower, power)
            power_is_interference = False
            if self.interference_power is not None and power >= self.interference_power:
                powder_is_interference = True
                is_interference = True
            above = power - self.band.power
            powers.append(dict(freq=fi, power=power, center=center, above=above,
                               violation=is_violation, interference=power_is_interference))

        if self.format == "csv":
            data = b"frequency,power,center_freq,abovefloor,violation\n"
            for p in powers:
                nl = f"{float(p['freq'])/1e6},{p['power']},{p['center']/1e6},{p['above']},{p['violation']}\n"
                data += nl.encode()
            o.data = base64.b64encode(data).decode()
        else:
            c = {
                "core:sample_start": 0,
                "core:frequency": float(powers[0]['freq']),
                "core:datetime": o.starts_at.isoformat(),
                "ntia-sensor:sigan_settings": {
                    "gain": int(10)
                }
            }
            dp = [{
                "processing": [f"openzms_monitorsim_psd"],
                "name": "openzms_monitorsim_psd",
                "description": "OpenZMS monitorsim simulatd PSD",
                "series": ["power"],
                "length": int(len(powers)),
                "x_units": "Hz",
                "x_start": [int(powers[0]['freq'])],
                "x_stop": [int(powers[-1]['freq'])],
                "x_step": [int(self.step_size)],
                "y_units": "dBx"
            }]
            metadata = {
                "global": {
                    "core:description": "monitorsim psd",
                    "core:datatype": "rf32_le",
                    "core:sample_rate": 1,
                    "core:version": "1.2.5",
                    "core:num_channels": 1,
                    "core:recorder": "monitorsim",
                    "core:extensions": [
                        {
                            "name": "ntia-algorithm",
                            "version": "v2.0.0",
                            "optional": False
                        },
                        {
                            "name": "ntia-sensor",
                            "version": "v2.0.0",
                            "optional": False
                        },
                        {
                            "name": "openzms-core",
                            "version": "v1.0.0",
                            "optional": True
                        }
                    ],
                    "openzms-core:kind": "psd",
                    "openzms-core:types": "openzms.monitorsim.v1.psd",
                    "openzms-core:labels": "simulated",
                    "openzms-core:min_freq": int(powers[0]['freq']),
                    "openzms-core:max_freq": int(powers[-1]['freq']),
                    "openzms-core:freq_step": int(self.step_size),
                    #"openzms-core:values": values,
                    "ntia-algorithm:data_products": dp,
                    "ntia-algorithm:processing_info": [{
                        "type": "openzms_monitorsim_psd",
                        "id": "openzms_monitorsim_psd",
                        "description": "A simulated PSD.",
                        #"samples": len(self.spectrogram),
                        #"dfts": 1
                    }]
                },
                "annotations": [],
                "captures": [c]
            }
            if self._violations_enabled:
                metadata["global"]["openzms-core:violation"] = True
            if is_interference:
                metadata["global"]["openzms-core:interference"] = True
            o.metadata = base64.b64encode(json.dumps(metadata).encode()).decode()
            # Encode the power values as little-endian float32
            data = bytearray()
            for p in powers:
                data += bytearray(struct.pack("<f", p['power']))
            o.data = base64.b64encode(data).decode()

        o.violation = self._violations_enabled
        o.interference = is_interference

        try:
            response = self.dst_client.create_observation(body=o)
            LOG.debug("created observation: %r", response)
            return response
        except:
            LOG.error("Error creating observation: %s", traceback.format_exc())
            return None

    def enable_violations(self):
        self._violations_enabled = True

    def disable_violations(self):
        self._violations_enabled = False

#
# Need to go in a common lib.
#
ET_DELETED       =  4
ET_REVOKED       =  5
ET_PENDING       = 23
ET_STARTED       = 18
ET_STOPPED       = 19
EST_ZMC          =  2
EST_DST          =  3
EC_GRANT         = 2006

class ZMCSubscriptionCallback(ZmsSubscriptionCallback):
    def __init__(self, zmc_client, simulator, grant_element_ids=[],
                 impotent=False, **kwargs):
        super(ZMCSubscriptionCallback, self).__init__(zmc_client, **kwargs)
        self.simulator = simulator
        self.grant_element_ids = grant_element_ids
        self.impotent = impotent

    def on_event(self, ws, evt, message):
        if evt.header.source_type != EST_ZMC:
            LOG.error("on_event: unexpected source type: %r (%r)",
                      evt.header.source_type, message)
            return

        # Since we are subscribed to all ZMC events for this element,
        # there will be chatter we do not care about.
        if (evt.header.code != EC_GRANT):
            return

        if evt.header.element_id not in self.grant_element_ids:
            return

        if evt.header.type == ET_STARTED:
            LOG.info("Grant started: %r", evt.object_)
            self.simulator.start_grant(evt.object_)
        elif evt.header.type == ET_STOPPED:
            LOG.info("Grant stopped: %r", evt.object_)
            self.simulator.stop_grant(evt.object_)
        elif evt.header.type == ET_PENDING:
            LOG.info("Grant pending: %r", evt.object_)
            self.simulator.stop_grant(evt.object_)
        elif evt.header.type in [ET_REVOKED, ET_DELETED]:
            LOG.info("Grant revoked/deleted: %r", evt.object_)
            self.simulator.remove_grant(evt.object_)

def init_main():
    parser = argparse.ArgumentParser(
        prog="monitorsim",
        description="Send simulated power spectral density data to an OpenZMS DST.")
    parser.register("type", "FreqRangeDesc", FreqRangeDesc.from_string)
    parser.add_argument(
        "-b", "--daemon", default=False, action="store_true",
        help="Daemonize")
    parser.add_argument(
        "-d", "--debug", default=0, action="count",
        help="Increase debug level: defaults to INFO; add once for zmsclient DEBUG; add twice to set the root logger level to DEBUG")
    parser.add_argument(
        "-n", "--impotent", default=False, action="store_true",
        help="Impotent: do not create observations in DST")
    parser.add_argument(
        "--logfile", default=LOGFILE, type=str,
        help="Redirect logging to a file when daemonizing.")
    parser.add_argument(
        "--token", action=DefaultDestEnvAction, type=str, required=True,
        help="OpenZMS token")
    parser.add_argument(
        "--monitor-id", action=DefaultDestEnvAction, type=str, required=True,
        help="Monitor ID to associate with observation")
    parser.add_argument(
        "--element-id", action=DefaultDestEnvAction, type=str, required=True,
        help="Element ID containing monitor")
    parser.add_argument(
        "--zmc-http", action=DefaultDestEnvAction, type=str, required=False,
        help="ZMC URL")
    parser.add_argument(
        "--dst-http", action=DefaultDestEnvAction, type=str, required=True,
        help="DST URL")
    parser.add_argument(
        "--dynamic-grants", default=False, action="store_true",
        help="Simulate grants dynamically via ZMC subscription")
    parser.add_argument(
        "--dynamic-fspl", default=False, action="store_true",
        help="Simulate path loss dynamically using free-space model via grant and monitor location from ZMC")
    parser.add_argument(
        "--grant-element-id", type=str, default=[], action="append",
        help="Element ID(s) to monitor for grants via ZMC subscription (only with --dynamic)")
    parser.add_argument(
        "--interference-power", type=float, default=None,
        help="Power level (dBx) at which interference should be flagged.")
    parser.add_argument(
        "--type", default="openzms.monitorsim.v1.psd", type=str,
        help="Specify the monitor observation type: e.g. openzms.monitorsim.v1.psd or powder.rfmonitor.v2.psd.")
    parser.add_argument(
        "--format", default="sigmf", type=str, choices=["sigmf", "csv"],
        help="Specify the monitor observation type: sigmf or csv .")
    parser.add_argument(
        "--labels", default=None, type=str,
        help="Specify the monitor observation labels as a comma-separated list: ota,sweep,inline... .")
    parser.add_argument(
        "--band", type='FreqRangeDesc', default="900e6,928e6,-120.0,0.005",
        help="A frequency range descriptor that defines the noise characteristics for the PSD: 'minfreq,maxfreq,noise-power,noise-power-rand-pct' .")
    parser.add_argument(
        "--grant", type='FreqRangeDesc', default=[], action="append",
        help="A frequency range descriptor whose range is in the band that defines power characteristics for a particular OpenZMS grant transmitter in the band: 'minfreq,maxfreq,power,power-rand-pct' .")
    parser.add_argument(
        "--incumbent", type='FreqRangeDesc', default=[], action="append",
        help="A frequency range descriptor whose range is in the band that defines power characteristics for a particular incumbent transmitter in the band: 'minfreq,maxfreq,power,power-rand-pct' .")
    parser.add_argument(
        "--step-size", type=float, default=50e3,
        help="Step size for power bins in PSD.")
    parser.add_argument(
        "--interval", type=float, default=15.0,
        help="Interval (seconds) between observations.")
    parser.add_argument(
        "--violation-shift", type=float, default=-2e6,
        help="Frequency shift (Hz) when violations are toggled on.")
    parser.add_argument(
        "--violation-gain", type=float, default=0.0,
        help="Power gain (dB) when violations are toggled on.")

    args = parser.parse_args(sys.argv[1:])

    if args.debug:
        LOG.setLevel(logging.DEBUG)
        logging.getLogger('zmsclient').setLevel(logging.DEBUG)
    else:
        LOG.setLevel(logging.INFO)
        logging.getLogger('zmsclient').setLevel(logging.INFO)
    if args.debug > 1:
        logging.getLogger().setLevel(logging.DEBUG)

    if args.band.max % args.step_size != 0:
        raise ValueError(f"band max ({args.band.max}) must be a multiple of step-size")
    for tx in [*args.grant, *args.incumbent]:
        if tx.min % args.step_size != 0:
            raise ValueError(f"tx min ({tx.min}) must be a multiple of step-size")
        if tx.max % args.step_size != 0:
            raise ValueError(f"tx max ({tx.max}) must be a multiple of step-size")

    zmc_client = None
    if args.dynamic_grants:
        if not args.zmc_http:
            raise ValueError("--zmc-http is required with --dynamic-grants")
        zmc_client = ZmsZmcClient(
            args.zmc_http, args.token,
            detailed=False, raise_on_unexpected_status=True)
    dst_client = ZmsDstClient(args.dst_http, args.token,
                              detailed=False, raise_on_unexpected_status=True)

    return dst_client, zmc_client, args

def sigh_exit(signalnum, frame):
    exit(0)

def sigh_enable_violate(*args):
    global simulator
    if simulator:
        simulator.enable_violations()
    LOG.info("Enabled violations")

def sigh_disable_violate(*args):
    global simulator
    if simulator:
        simulator.disable_violations()
    LOG.info("Disabled violations")

simulator = None

# The hander has to be outside the async main.
def set_signal_handler(signum, task_to_cancel):
    def handler(_signum, _frame):
        asyncio.get_running_loop().call_soon_threadsafe(task_to_cancel.cancel)
    signal.signal(signum, handler)

async def async_main(*args, tasks=[], subs=[]):
    this_task = asyncio.current_task()
    set_signal_handler(signal.SIGINT, this_task)
    set_signal_handler(signal.SIGHUP, this_task)
    set_signal_handler(signal.SIGTERM, this_task)

    try:
        runnable = [*[task.run() for task in tasks], *[sub.run_callbacks() for sub in subs]]
        await asyncio.gather(*runnable)
    except asyncio.CancelledError:
        for sub in subs:
            if sub.id:
                sub.unsubscribe()

def main():
    dst_client, zmc_client, args = init_main()

    if args.daemon:
        try:
            fp = open(args.logfile, "a");
            sys.stdout = fp
            sys.stderr = fp
            sys.stdin.close()
            logging.basicConfig(stream=fp)
        except:
            print(f"Could not open log file ({args.logfile}); aborting.")
            sys.exit(1)
        pid = os.fork()
        if pid:
            sys.exit(0)
        os.setsid()
    else:
        logging.basicConfig()

    signal.signal(signal.SIGINT, sigh_exit)
    signal.signal(signal.SIGTERM, sigh_exit)
    signal.signal(signal.SIGHUP, sigh_exit)
    signal.signal(signal.SIGUSR1, sigh_enable_violate)
    signal.signal(signal.SIGUSR2, sigh_disable_violate)

    LOG.debug(f"band: {args.band}")

    (mlat, mlon) = (None, None)
    if args.dynamic_grants and args.dynamic_fspl:
        monitor = zmc_client.get_monitor(monitor_id=args.monitor_id, elaborate=True)
        loc = zmc_client.get_location(location_id=monitor.radio_port.antenna_location_id)
        if loc.srid != 4326:
            raise ValueError("monitor location CRS is not EPSG:4326")
        (mlat, mlon) = (loc.y, loc.x)
        LOG.debug("monitor location: %r,%r", mlat, mlon)

    global simulator
    simulator = Simulator(
        dst_client, zmc_client, args.monitor_id, args.element_id,
        args.band, args.step_size, args.grant, args.incumbent,
        type=args.type, format=args.format, labels=args.labels,
        lat=mlat, lon=mlon, violation_shift=args.violation_shift,
        violation_gain=args.violation_gain, interval=args.interval,
        interference_power=args.interference_power,
        impotent=args.impotent)

    subs = []
    if args.dynamic_grants:
        filters = UNSET
        if args.grant_element_id:
            filters = [EventFilter(element_ids=args.grant_element_id)]
            #user_ids=[args.element_userid])
        subscription = Subscription(
            id=str(uuid.uuid4()), filters=filters)
        ZMCsubscription = ZMCSubscriptionCallback(
            zmc_client, simulator, args.grant_element_id, impotent=args.impotent,
            subscription=subscription, reconnect_on_error=True)
        subs.append(ZMCsubscription)
        
        # Load the current grants for the element(s)
        for element_id in args.grant_element_id:
            ret = zmc_client.list_grants(
                element_id=element_id, current=True, approved=True,
                denied=False, revoked=False, deleted=False, claim=False,
                elaborate=True)
            for grant in ret.grants:
                simulator.start_grant(grant)

    asyncio.run(async_main(tasks=[simulator], subs=subs))

if __name__ == "__main__":
    main()
