import abc
import argparse
import copy
import logging
import time
from threading import (RLock, Condition)
from typing import (Annotated, Any, List, Optional, Union)

from zmsclient.dst.client import ZmsDstClient
from zmsclient.dst.v1.models import (Metric, MetricList, Value)
from zmsclient.dst.v1.models import Error as DstError

LOG = logging.getLogger(__name__)

class ModemBase(abc.ABC):
    """Base class for modem types."""
    name: str = "modem"

    def __init__(self, args: argparse.Namespace):
        self.args = args

        self.device = args.device
        self.file = args.file
        self.impotent = args.impotent
        self.when_unchanged = args.when_unchanged

    @classmethod
    @abc.abstractmethod
    def add_args(kls, parser: argparse.ArgumentParser):
        pass
            
    def init(self):
        """Initialize the modem, if necessary."""
        pass
    
    def open(self):
        """Open the modem for reading, if necessary."""
        pass

    def close(self):
        """Close the modem, if necessary."""
        pass

    @abc.abstractmethod
    def get_metric_definitions(self) -> List[Metric]:
        """Define the metrics that will be collected."""
        return []

    @abc.abstractmethod
    def get_values(self, metric_name: str) -> List[Value]:
        """Get the current values for a defined metric.  The metric_id,
        monitor_id, and element_id will be filled in by the monitor."""
        return []
    
class Monitor:
    """A modem monitor."""
    def __init__(self, client: ZmsDstClient, modem: ModemBase,
                 monitor_id: str, element_id: str,
                 interval: float = 10.0, impotent: bool = False):
        self.client = client
        self.modem = modem
        self.metrics = dict()  # name -> Metric
        self.monitor_id = monitor_id
        self.element_id = element_id
        self.interval = interval
        self.impotent = impotent
        self._running = False
        self._lock = None
        self._cv = None

    def _upload_value(self, v: Value) -> Optional[DstError]:
        v.monitor_id = self.monitor_id
        if self.impotent:
            LOG.info("Impotent mode: not uploading value: %r", v)
            return None
        try:
            response = self.client.create_value(body=v)
            LOG.debug("created value: %r", response)
            return None
        except Exception as e:
            LOG.error("Could not create value %r: %r", v, e)
            raise e

    def get_metric_by_name(self, name: str) -> Optional[Metric]:
        return self.metrics.get(name, None)

    def _init_metric(self, metric: Metric, create_if_missing: bool = True) -> Union[Metric,DstError]:
        # Look for a public metric with this name first.
        try:
            response = self.client.list_metrics(metric=metric.name, is_public=True)
            if isinstance(response, MetricList) and len(response.metrics) > 0:
                return response.metrics[0]
            else:
                LOG.debug("No public metric found for %s (%s)", metric.name, response)
        except Exception as e:
            LOG.error("Could not list metric %s: %s", metric.name, e)
            raise e
        # Look for a per-element metric next.
        try:
            response = self.client.list_metrics(metric=metric.name, element_id=self.element_id)
            if isinstance(response, MetricList) and len(response.metrics) > 0:
                return response.metrics[0]
            else:
                LOG.debug("No private metric found for %s (%s)", metric.name, response)
        except Exception as e:
            LOG.error("Could not list metric %s: %s", metric.name, e)
            raise e
        if create_if_missing:
            # Create a per-element metric.
            try:
                metric.element_id = self.element_id
                LOG.info("Creating metric %s: %s", metric.name, metric)
                response = self.client.create_metric(body=metric)
                return response
            except Exception as e:
                LOG.error("Could not create metric %s: %s", metric.name, e)
                raise e
        return None

    def init_metrics(self, create_if_missing: bool = True):
        """Load (or create) the metrics we need in the DST."""
        mlist = self.modem.get_metric_definitions()
        for m in mlist:
            LOG.debug("Metric definition: %s", m)
            # NB: just do a shallow copy; we will be changing the element_id
            # and possibly other fields.  We cannot do a deepcopy because the
            # generated models check to see if fields are UNSET when JSONifying,
            # and a deepcopy will change UNSET to a new Unset instance!
            nm = copy.copy(m)
            nm.element_id = self.element_id
            if self.impotent:
                LOG.info("Impotent mode: not initializing metric %s: %s", m.name, nm)
                self.metrics[m.name] = nm
            else:
                ret = self._init_metric(nm, create_if_missing)
                if isinstance(ret, Metric):
                    LOG.debug("Initialized metric %s: %s", nm.name, ret)
                    self.metrics[nm.name] = ret
                else:
                    LOG.error("Could not initialize metric %s: %s", nm.name, ret)
                    raise RuntimeError(f"Could not initialize metric {nm.name}: {ret}")

    def run(self):
        """Run the monitor."""
        if self._running:
            raise RuntimeError("Monitor is already running")
        self._running = True
        self._lock = RLock()
        self._cv = Condition(lock=self._lock)
        LOG.info("Starting monitor for modem %s", self.modem.name)
        self.modem.init()
        self.modem.open()
        while True:
            with self._cv:
                values = []
                for mname in self.metrics.keys():
                    mvalues = self.modem.get_values(mname)
                    if mvalues:
                        for mv in mvalues:
                            mv.metric_id = self.metrics[mname].id
                            mv.monitor_id = self.monitor_id
                        values.extend(mvalues)
                for v in values:
                    if self.impotent:
                        LOG.info("Impotent mode: not uploading value: %s", v)
                    else:
                        self._upload_value(v)
                # Return value is False if timeout was reached, which is our
                # signal to stop.
                ret = self._cv.wait(timeout=self.interval)
                if ret:
                    LOG.info("Stopping monitor for modem %s", self.modem.name)
                    self.modem.close()
                    return
                time.sleep(self.interval)

    def stop(self):
        if self._running and self._cv:
            with self._cv:
                self._running = False
                self._cv.notify()
