#! /usr/bin/python3
# pylint: disable=R0903
from __future__ import annotations
import argparse
import codecs
import json
import ssl
import sys
import traceback
from typing import Any, Iterator, Optional, List, NoReturn
import urllib.request


# Data Representation
# -------------------

class InfoParseError(ValueError):
    def __init__(self, msg: str, data: Any):
        super().__init__(msg)
        self.data = data

class ComputingShare:

    id: str
    name: str
    associated_endpoints: List[ComputingEndpoint]

    def __init__(self, data: Any):
        self.id = data["ID"] # pylint: disable=C0103
        self.name = data["Name"]
        self.associated_endpoints = []

    def add_endpoint(self, endpoint: ComputingEndpoint):
        self.associated_endpoints.append(endpoint)

class ComputingEndpoint:

    health_state: str
    health_state_info: Optional[str]
    interface_name: str

    def __init__(self, data: Any):
        self.id = data["ID"] # pylint: disable=C0103
        self.name = data["Name"]
        self.health_state = data["HealthState"]
        self.health_state_info = data.get("HealthStateInfo")
        self.interface_name = data["InterfaceName"]

    def __str__(self) -> str:
        return self.id

# This must be applied when extracting elements mentioned in info_json_arrays in
# src/services/a-rex/rest/rest.cpp from the ARC source for backwards
# compatibility.
def _fixup_list(xs: Any) -> List[Any]:
    if isinstance(xs, list):
        return xs
    return [xs]

class ComputingService:

    def __init__(self, data: Any):
        # pylint: disable=W0707
        self.endpoints_by_id = {}
        self.shares_by_id = {}

        shares_data = _fixup_list(data.get("ComputingShare", []))
        endpoints_data = _fixup_list(data.get("ComputingEndpoint", []))

        for share_data in shares_data:
            try:
                share = ComputingShare(share_data)
                self.shares_by_id[share.id] = share
            except (KeyError, ValueError, TypeError) as exn:
                raise InfoParseError(
                        f"Failed to parse share: {exn}", share_data)

        for endpoint_data in endpoints_data:
            try:
                endpoint = ComputingEndpoint(endpoint_data)
                self.endpoints_by_id[endpoint.id] = endpoint
                for share_id in _fixup_list(
                        endpoint_data.get("Associations", {})
                                     .get("ComputingShareID", [])):
                    self.shares_by_id[share_id].add_endpoint(endpoint)
            except (KeyError, ValueError, TypeError) as exn:
                raise InfoParseError(
                        f"Failed to parse endpoint: {exn}", endpoint_data)

    def dump(self) -> None:
        print("Found endpoints:")
        for endpoint in self.endpoints_by_id.values():
            print(f"- {endpoint}")
        for share in self.shares_by_id.values():
            print(f"Found share {share.id} using endpoints:")
            for endpoint in share.associated_endpoints:
                print(f"- {endpoint}")


# Checks
# ------

class ServiceError:
    def __init__(self, brief_message: str, full_message: Optional[str] = None):
        self.brief_message = brief_message
        self.full_message = full_message

def check_service(
        service: ComputingService, *,
        min_endpoint_count: int,
        min_share_count: int,
        required_interfaces: List[int]):

    errors = []
    def critical(*args, **kwargs):
        errors.append(ServiceError(*args, **kwargs))

    # Check that we have at least one share and at least one endpoint.
    if len(service.shares_by_id) < min_share_count:
        critical(f"Only {len(service.shares_by_id)} "
                 f"of {min_share_count} ComputingShare(s) found.")
    if len(service.endpoints_by_id) < min_endpoint_count:
        critical(f"Only {len(service.endpoints_by_id)} "
                 f"of {min_endpoint_count} ComputingEndpoint(s) found.")

    # Check that all endpoints are active.
    for endpoint in service.endpoints_by_id.values():
        if endpoint.health_state != "ok":
            msg = f"Endpoint {endpoint.id} is {endpoint.health_state}"
            if endpoint.health_state_info:
                critical(msg, msg + ": " + endpoint.health_state_info)
            else:
                critical(msg)

    # Check that all shares have at least one active endpoint.
    for share in service.shares_by_id.values():
        endpoints = [
            endpoint for endpoint in share.associated_endpoints
            if endpoint.health_state == "ok"
        ]
        if not endpoints:
            critical(f"Share {share.id} has no working endpoint.")
        interfaces = set(endpoint.interface_name for endpoint in endpoints)
        for interface in set(required_interfaces).difference(interfaces):
            critical(f"Interface {interface} missing for share {share.id}.")

    # Report errors and exit.
    if errors:
        if len(errors) == 1:
            print(errors[0].brief_message)
            if errors[0].full_message:
                print(errors[0].full_message)
        else:
            print("Multiple errors found, see details.")
            for error in errors:
                print(error.full_message or error.brief_message)
        return 2
    print("No problems found.")
    return 0


# Main Program
# ------------

def report_uncaught_exception_and_exit(msg: str) -> NoReturn:
    print(msg)
    traceback.print_exc(file=sys.stdout)
    sys.exit(3)

def report_parse_error_and_exit(exn: InfoParseError, data: Any) -> NoReturn:
    print(exn)
    traceback.print_exc()
    print("The unparsed JSON fragment is:")
    print(json.dumps(exn.data, indent=4, sort_keys=True))
    print("The full JSON document is:")
    print(json.dumps(data, indent=4, sort_keys=True))
    sys.exit(2)

def fetch_services_or_exit(
        url: str,
        tls_ca_dir: Optional[str] = None,
        tls_key: Optional[str] = None,
        tls_cert: Optional[str] = None) -> Iterator[ComputingService]:

    try:
        context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
        if tls_ca_dir:
            context.load_verify_locations(capath=tls_ca_dir)
        if tls_key and tls_cert:
            context.load_cert_chain(keyfile=tls_key, certfile=tls_cert)
        headers = {"Accept": "application/json"}
        request = urllib.request.Request(url, headers=headers)
        with urllib.request.urlopen(request, context=context) as fh:
            data = json.load(codecs.getreader("utf-8")(fh))
    except (urllib.error.HTTPError, urllib.error.URLError, ssl.SSLError,
            json.decoder.JSONDecodeError) as exn:
        print(f"Failed to fetch from {url}: {exn}")
        sys.exit(2)
    try:
        for ad_data in _fixup_list(data["Domains"]["AdminDomain"]):
            for cs_data in _fixup_list(ad_data["Services"]["ComputingService"]):
                yield ComputingService(cs_data)
    except IndexError as exn:
        exnp = InfoParseError(f"{exn.args[0]} not found.", data)
        report_parse_error_and_exit(exnp, data)
    except TypeError as exn:
        exnp = InfoParseError(f"Unexpected shape of data. {exn}.", data)
        report_parse_error_and_exit(exnp, data)
    except InfoParseError as exn:
        report_parse_error_and_exit(exn, data)

def main() -> None:
    argp = argparse.ArgumentParser(
            description="""
                NAGIOS probe to check the status of an ARC CE using the
                org.nordugrid.arcrest interface.
            """)
    argp.add_argument("--host", "-H", type=str,
            help="host name of the CE to check")
    argp.add_argument("--port", "-P", type=int,
            help="port number of the information system endpoint")
    argp.add_argument("--endpoint", "-U",
            help="URL of the information system endpoint")
    argp.add_argument("--tls-ca-dir", type=str,
            default="/etc/grid-security/certificates",
            help="directory containing accepted X.509 CA certificates")
    argp.add_argument("--tls-cert", type=str,
            help="client certificate used to authenticate to the CE")
    argp.add_argument("--tls-key", type=str,
            help="client key used to authenticate to the CE")
    argp.add_argument("--require-min-share-count", type=int, default=1,
            help="require that there are at least this number of shares")
    argp.add_argument("--require-min-endpoint-count", type=int, default=1,
            help="require that there are at least this number of endpoints")
    argp.add_argument("--require-interface", type=str, nargs="*", default=[],
            help="require that there is an endpoint supporting the given "
                 "interface for each share")
    argp.add_argument("--dump", action='store_true',
            help="dump some of the gathered information at the end of the "
                 "output, for debugging or casual inspection")
    args = argp.parse_args()

    if not args.endpoint is None:
        endpoint = args.endpoint
    elif not args.host is None:
        if args.port is None:
            endpoint = f"https://{args.host}/arex/rest/1.0/info"
        else:
            endpoint = f"https://{args.host}:{args.port}/arex/rest/1.0/info"
    else:
        argp.error("Either --host/-H or --endpoint/-U is required.")

    try:
        services = fetch_services_or_exit(
                url=endpoint,
                tls_ca_dir=args.tls_ca_dir,
                tls_cert=args.tls_cert,
                tls_key=args.tls_key)
        exit_code = 0
        for service in services:
            exit_code_for_service = check_service(
                    service,
                    min_share_count=args.require_min_share_count,
                    min_endpoint_count=args.require_min_endpoint_count,
                    required_interfaces=args.require_interface)
            exit_code = max(exit_code_for_service, exit_code)
            if args.dump:
                print("")
                print("## Service Dump ##")
                print("")
                service.dump()
    except Exception: # pylint: disable=W0718
        report_uncaught_exception_and_exit("Uncaught exception.")
    sys.exit(exit_code)

if __name__ == "__main__":
    main()
