#!/usr/bin/env python3

from datetime import datetime, timedelta
import io
import json
import re
import time

import gzip
import radix
import requests
import yaml

import caida_oidc_client
from requests_oauthlib import OAuth2Session

REALM = "CAIDA"
API_URL = "https://api.arkmon.caida.org"
AUTH_URL = f"https://auth.caida.org/realms/{REALM}/protocol/openid-connect"
CLIENT_ID = "arkmon-offline"
TOKEN_FILE = "/etc/ark/.arkmon-offline.token"
EXCEPTIONS_FILE = "/etc/ark/metadata-exceptions.yaml"

def load_exceptions(path):
    try:
        with open(path, "r") as f:
            return yaml.safe_load(f) or {}
    except FileNotFoundError:
        return {}
    except Exception as e:
        print(f"warning: could not load exceptions file: {e}")
        return {}

def _are_asn_siblings(asn, origin, siblings_groups):
    for group in siblings_groups:
        if asn in group and origin in group:
            return True
    return False

def _read_prefix2as():
    url = 'https://publicdata.caida.org/datasets/routing/routeviews-prefix2as'
    tree = radix.Radix()

    most_recent = None
    try:
        response = requests.get(f'{url}/pfx2as-creation.log', timeout=30)
        response.raise_for_status()
        for line in response.text.splitlines():
            if line.startswith('#'):
                continue
            row = line.split()
            most_recent = row[2]
    except Exception as e:
        print(f'could not read pfx2as-creation.log: {e}')
        return tree

    if not most_recent:
        print('no most recent file')
        return tree

    try:
        response = requests.get(f'{url}/{most_recent}', timeout=30)
        response.raise_for_status()
        with gzip.open(io.BytesIO(response.content), mode='rt') as inf:
            for line in inf:
                if line.startswith('#'):
                    continue
                row = line.split()
                (pfx, pfx_len) = (row[0], int(row[1]))
                pfx_str = f'{pfx}/{pfx_len}'
                node = tree.add(pfx_str)
                node.data['prefix'] = pfx_str
                node.data['origin'] = row[2]
    except Exception as e:
        print(f'could not read most recent: {e}')

    return tree


def load_token_info(token_file):
    with open(token_file, "r", encoding="ascii") as f:
        token_info = json.load(f)
    if "expires_at" not in token_info or token_info["expires_at"] < time.time():
        refresh_token = token_info['refresh_token']
        token_info.clear()
        token_info['refresh_token'] = refresh_token
        token_info['expires_in'] = -1
        token_info['access_token'] = 'dummy value for oauthlib'
    return token_info

def ark_dory_query():
    token_info = load_token_info(TOKEN_FILE)
    save_tokens = caida_oidc_client.make_save_tokens(TOKEN_FILE)

    # establish a new session, refreshing access token if necessary
    session = OAuth2Session(client_id=CLIENT_ID,
                token=token_info,
                auto_refresh_url=f"{AUTH_URL}/token",
                auto_refresh_kwargs={"client_id": CLIENT_ID},
                token_updater=save_tokens)

    if not session.authorized:
        print("Failed to authorize session, aborting")
        return None

    try:
        response = session.request("GET", f"{API_URL}/monitors/")
    except Exception as e:
        print(f"Failed to fetch {API_URL}/monitors: {e}")
        return None

    if response.status_code != 200:
        print(f"Got status code {response.status_code}, aborting")
        return None

    monitors = response.json()
    if len(monitors) == 0:
        print("Empty monitors list, aborting")
        return None

    #print(monitors)
    return monitors


def check_os_labels(monitor, ignore=()) -> None:
    if osname := monitor.get("osname"):
        osname = osname.lower()
        if "freebsd" in osname:
            # e.g. FreeBSD 10.1
            pass
        elif "debian" in osname or "raspbian" in osname:
            # e.g. Debian GNU/Linux 12 (bookworm)
            release = re.search(r".*\(([a-z0-9]+)\)", osname)
            if not release:
                print(f"{monitor['node']}: unknown release for {osname}")
            codename = release.group(1)
            if codename != "trixie" and 'os_release' not in ignore:
                print(f"{monitor['node']}: unexpected debian release {codename}")
        elif "ubuntu" in osname:
            # e.g. Ubuntu 22.04.3 LTS
            release = re.search(r"[a-z]+ ([0-9]+\.[0-9]+).*", osname)
            if not release:
                print(f"{monitor['node']}: unknown release for {osname}")
            codename = release.group(1)
            if codename != "24.04" and 'os_release' not in ignore:
                print(f"{monitor['node']}: unexpected ubuntu release {codename}")
        else:
            print(f"{monitor['node']}: unknown osname {osname}")
    else:
        print(f"{monitor['node']}: missing osname")

def check_hardware_labels(monitor) -> None:
    if monitor.get("hwtype"):
        if monitor.get("cpuarch") is None:
            print(f"{monitor['node']}: missing cpuarch")
    else:
        print(f"{monitor['node']}: missing hwtype")

def check_state(monitor, cc, st) -> None:
    # there are probably errors in this table, and its contents should
    # not be taken as ground truth.
    cc_st = {
        "US": {2}, "CA": {2}, "MX": {3}, "CR": {1, 2}, "PA": {1, 2},
        "BR": {2}, "AR": {1}, "CL": {2}, "CO": {2, 3}, "PE": {3},
        "VE": {1}, "EC": {1, 2}, "UY": {2}, "PY": {1,2,3}, "BO": {1},
        "GB": {2, 3}, "FR": {2, 3}, "DE": {2}, "ES": {2}, "IT": {2, 3},
        "NL": {2}, "BE": {3}, "CH": {2}, "SE": {2}, "NO": {2},
        "FI": {2}, "DK": {2}, "PL": {2}, "CZ": {2}, "SK": {2}, "HU": {2},
        "RO": {1,2}, "BG": {2}, "PT": {2}, "GR": {1,2}, "IE": {1, 2}, "IS": {1},
        "LU": {2}, "JP": {2}, "CN": {2}, "IN": {2}, "KR": {2}, "ID": {2},
        "MY": {2}, "PH": {2}, "TH": {2}, "VN": {2}, "IR": {2}, "TR": {2},
        "RU": {2,3}, "KZ": {2}, "AU": {2, 3}, "NZ": {3}, "ZA": {2}, "NG": {2},
        "KE": {2}, "GH": {2}, "EG": {2}, "MA": {2}, "DZ": {2}, "TN": {2},
        "HK": {2}, "UK": {2, 3}, "AT": {1}, "LV": {3}, "AE": {2}, "BH": {2},
        "TW": {3}, "QA": {2}, "PK": {2}, "SG": {2}, "IL": {1, 2}, "LT": {2},
        "RS": {2}, "GM": {1}, "BD": {1, 2}, "TZ": {2}, "UG": {1, 3},
        "UA": {2}, "RW": {2}, "NP": {2}, "ZM": {2}, "MU": {2}, "CY": {2},
        "BT": {2}, "BA": {3}, "MG": {1}, "MN": {1, 3}, "BF": {3},
    }
    cc_nost = set(["GU", "XK", "VI"])

    if cc not in cc_st:
        if st:
            print(f"{monitor['node']}: unlikely ISO state code {st} for {cc}")
        elif cc not in cc_nost:
            print(f"{monitor['node']}: missing ISO state?")
        return

    if not st:
        print(f"{monitor['node']}: missing ISO state")
    elif len(st) not in cc_st[cc]:
        print(f"{monitor['node']}: unexpected ISO state {st} for country {cc}")

def check_continent(monitor, continent, cc) -> None:
    # there are probably errors in this table, and its contents should
    # not be taken as ground truth.
    cc_to_continent = {
        "DZ": ["Africa"], "AO": ["Africa"], "BJ": ["Africa"], "BW": ["Africa"],
        "BF": ["Africa"], "BI": ["Africa"], "CV": ["Africa"], "CM": ["Africa"],
        "CF": ["Africa"], "TD": ["Africa"], "KM": ["Africa"], "CG": ["Africa"],
        "CD": ["Africa"], "CI": ["Africa"], "DJ": ["Africa"], "GQ": ["Africa"],
        "ER": ["Africa"], "SZ": ["Africa"], "ET": ["Africa"], "GA": ["Africa"],
        "GM": ["Africa"], "GH": ["Africa"], "GN": ["Africa"], "GW": ["Africa"],
        "KE": ["Africa"], "LS": ["Africa"], "LR": ["Africa"], "LY": ["Africa"],
        "MG": ["Africa"], "MW": ["Africa"], "ML": ["Africa"], "MR": ["Africa"],
        "MU": ["Africa"], "MA": ["Africa"], "MZ": ["Africa"], "NA": ["Africa"],
        "NE": ["Africa"], "NG": ["Africa"], "RW": ["Africa"], "ST": ["Africa"],
        "SN": ["Africa"], "SC": ["Africa"], "SL": ["Africa"], "SO": ["Africa"],
        "ZA": ["Africa"], "SS": ["Africa"], "SD": ["Africa"], "TZ": ["Africa"],
        "TG": ["Africa"], "TN": ["Africa"], "UG": ["Africa"], "ZM": ["Africa"],
        "ZW": ["Africa"],
        "AL": ["Europe"], "AD": ["Europe"], "AT": ["Europe"], "BE": ["Europe"],
        "BA": ["Europe"], "BG": ["Europe"], "HR": ["Europe"], "CY": ["Europe"],
        "CZ": ["Europe"], "DK": ["Europe"], "EE": ["Europe"], "FI": ["Europe"],
        "FR": ["Europe"], "DE": ["Europe"], "GR": ["Europe"], "HU": ["Europe"],
        "IS": ["Europe"], "IE": ["Europe"], "IT": ["Europe"], "LV": ["Europe"],
        "LI": ["Europe"], "LT": ["Europe"], "LU": ["Europe"], "MT": ["Europe"],
        "MD": ["Europe"], "MC": ["Europe"], "ME": ["Europe"], "NL": ["Europe"],
        "MK": ["Europe"], "NO": ["Europe"], "PL": ["Europe"], "PT": ["Europe"],
        "RO": ["Europe"], "SM": ["Europe"], "RS": ["Europe"], "SK": ["Europe"],
        "SI": ["Europe"], "ES": ["Europe"], "SE": ["Europe"], "CH": ["Europe"],
        "UA": ["Europe"], "GB": ["Europe"], "UK": ["Europe"], "VA": ["Europe"],
        "XK": ["Europe"],
        "AF": ["Asia"], "BH": ["Asia"], "BD": ["Asia"], "BT": ["Asia"],
        "BN": ["Asia"], "KH": ["Asia"], "CN": ["Asia"], "IN": ["Asia"],
        "IR": ["Asia"], "IQ": ["Asia"], "IL": ["Asia"], "JP": ["Asia"],
        "JO": ["Asia"], "KW": ["Asia"], "KG": ["Asia"], "LA": ["Asia"],
        "LB": ["Asia"], "MY": ["Asia"], "MV": ["Asia"], "MN": ["Asia"],
        "MM": ["Asia"], "NP": ["Asia"], "KP": ["Asia"], "OM": ["Asia"],
        "PK": ["Asia"], "PH": ["Asia"], "QA": ["Asia"], "SA": ["Asia"],
        "SG": ["Asia"], "KR": ["Asia"], "LK": ["Asia"], "SY": ["Asia"],
        "TW": ["Asia"], "TJ": ["Asia"], "TH": ["Asia"], "TL": ["Asia"], 
        "TM": ["Asia"], "AE": ["Asia"], "UZ": ["Asia"], "VN": ["Asia"],
        "YE": ["Asia"], "HK": ["Asia"],
        "AG": ["North America"], "BS": ["North America"],
        "BB": ["North America"], "BZ": ["North America"],
        "CA": ["North America"], "CR": ["North America"],
        "CU": ["North America"], "DM": ["North America"],
        "DO": ["North America"], "SV": ["North America"],
        "GD": ["North America"], "GT": ["North America"],
        "HT": ["North America"], "HN": ["North America"],
        "JM": ["North America"], "MX": ["North America"],
        "NI": ["North America"], "KN": ["North America"],
        "LC": ["North America"], "VC": ["North America"],
        "TT": ["North America"], "US": ["North America"],
        "VI": ["North America"],
        "AR": ["South America"], "BO": ["South America"],
        "BR": ["South America"], "CL": ["South America"],
        "CO": ["South America"], "EC": ["South America"],
        "GY": ["South America"], "PY": ["South America"],
        "PE": ["South America"], "SR": ["South America"],
        "UY": ["South America"], "VE": ["South America"],
        "AU": ["Oceania"], "FJ": ["Oceania"], "KI": ["Oceania"],
        "MH": ["Oceania"], "FM": ["Oceania"], "NR": ["Oceania"],
        "NZ": ["Oceania"], "PW": ["Oceania"], "PG": ["Oceania"],
        "WS": ["Oceania"], "SB": ["Oceania"], "TO": ["Oceania"],
        "TV": ["Oceania"], "VU": ["Oceania"], "GU": ["Oceania"],
        "AQ": ["Antarctica"],
        "AM": ["Europe", "Asia"], "AZ": ["Europe", "Asia"],
        "PA": ["North America", "South America"],
        "EG": ["Africa", "Asia"],
        "ID": ["Asia", "Oceania"],
        "RU": ["Europe", "Asia"],
        "GE": ["Europe", "Asia"],
        "KZ": ["Europe", "Asia"],
        "TR": ["Europe", "Asia"],
    }

    if cc not in cc_to_continent:
        print(f"{monitor['node']}: unexpected ISO CC {cc}")
    elif continent not in set(cc_to_continent[cc]):
        print(f"{monitor['node']}: unexpected continent {continent} for {cc}")

def check_kernel(monitor, ignore=()) -> None:
    if 'kernel' in ignore:
        return
    if kernel := monitor.get("oskernel"):
        hwtype = monitor.get("hwtype")
        if "Raspberry Pi" in hwtype:
            # e.g. Linux 6.12.47+rpt-rpi-v7, Linux 6.12.62+rpt-rpi-2712
            version = re.search(r"Linux 6\..*\+rpt-rpi-(v[78]|2712)", kernel)
            if not version:
                print(f"{monitor['node']}: unexpected kernel {kernel}")
        elif hwtype in ["BMAX", "Rack Server", "VM"]:
            osname = monitor.get("osname", "").lower()
            if "freebsd" in osname:
                # e.g. FreeBSD 10.1-RELEASE-p35
                pass
            else:
                # e.g. Linux 6.12.57+deb13-amd64, Linux 6.8.0-87-generic
                version = re.search(r"Linux 6\..*", kernel)
                if not version:
                    print(f"{monitor['node']}: unexpected kernel {kernel}")
        else:
            # we can't upgrade kernels on container hosts
            pass
    else:
        print(f"{monitor['node']}: missing kernel")

def check_activities(monitor, cutoff, ignore=()) -> None:
    if 'activities' in ignore:
        return
    activities = monitor.get("activities")
    if not activities or len(activities) == 0:
        print(f"{monitor['node']}: missing activities")
        return
    if 'Fireball' not in set(activities) and 'fireball' not in ignore:
        print(f"{monitor['node']}: fireball not in activities")

    if 'activity_ipv4_probing' not in ignore:
        lastping4 = monitor.get("lastping4", 0)
        if lastping4 > cutoff and 'IPv4 team probing' not in activities:
            print(f"{monitor['node']}: IPv4 capable but IPv4 team probing not in activities")
        elif lastping4 < cutoff and 'IPv4 team probing' in activities:
            print(f"{monitor['node']}: Not IPv4 capable but IPv4 team probing in activities")

    if 'activity_ipv6_probing' not in ignore:
        lastping6 = monitor.get("lastping6", 0)
        if lastping6 > cutoff and 'IPv6 probing' not in activities:
            print(f"{monitor['node']}: IPv6 capable but IPv6 probing not in activities")
        elif lastping6 < cutoff and 'IPv6 probing' in activities:
            print(f"{monitor['node']}: Not IPv6 capable but IPv6 probing in activities")

def check_monitor_metadata(monitor, cutoff, tree, exceptions, siblings_groups) -> None:
    """ Extract and format the metadata from a single monitor dictionary """
    if monitor.get("node") is None:
        return
    name = monitor.get("node")
    if name.endswith("-zz") or name.startswith("xxx"):
        return

    if monitor.get("lastping", 0) < cutoff:
        return

    node_exc = exceptions.get('nodes', {}).get(name, {})
    raw_ignore = node_exc.get('ignore', [])
    ignore = {'all'} if raw_ignore == 'all' else set(raw_ignore)
    if 'all' in ignore:
        return
    allowed_asns = set(str(asn) for asn in node_exc.get('asns', []))

    check_os_labels(monitor, ignore)
    check_hardware_labels(monitor)
    check_kernel(monitor, ignore)

    freebsd = bool((osname := monitor.get("osname")) and ("freebsd" in osname.lower()))

    country = monitor.get("country")
    if not country:
        print(f"{name}: unknown cc")
    else:
        if len(country) != 2:
            print(f"{name}: unlikely ISO country code {country}")
        state = monitor.get("state")
        check_state(monitor, country, state)

    if not monitor.get("city"):
        print(f"{name}: missing city")

    continent = monitor.get("continent")
    if not continent:
        print(f"{name}: missing continent")
    elif country:
        check_continent(monitor, continent, country)

    if not monitor.get("iata"):
        print(f"{name}: missing iata")

    if not monitor.get("latitude"):
        print(f"{name}: missing lat")

    if not monitor.get("longitude"):
        print(f"{name}: missing long")

    if not monitor.get("orgtype"):
        print(f"{name}: missing orgtype")

    if not freebsd:
        check_activities(monitor, cutoff, ignore)

    # ip addresses never get removed from dory, so instead check we've
    # received a recent ping for each address family before adding them
    if monitor.get("lastping4", 0) > cutoff:
        ipv4 = monitor.get("ipv4global")
        if not ipv4:
            print(f"{name}: missing ipv4global")
        asn = monitor.get("asn")
        if not asn:
            if ipv4 and (node := tree.search_best(ipv4)):
                print(f"{name}: missing asn, suggest {node.data['origin']}")
            else:
                print(f"{name}: missing asn")
        else:
            if (ipv4 and (node := tree.search_best(ipv4)) and
                node.data['origin'] != asn and
                not _are_asn_siblings(asn, node.data['origin'], siblings_groups) and
                node.data['origin'] not in allowed_asns):
                print(f"{name}: asn {asn} != origin {node.data['origin']}")
            if isinstance(asn, str) and not asn.isdecimal():
                print(f"{name}: asn {asn} contains non-digits")

    if not freebsd:
        if not monitor.get("natpport"):
            sshall = monitor.get("sshall")
            if not sshall:
                print(f"{name}: missing natpport, missing sshall")
            else:
                natp = False
                for ssh in sshall.split():
                    if "natp" in ssh:
                        natp = True
                        break
                if not natp:
                    print(f"{name}: missing natpport, natp in sshall")
                else:
                    print(f"{name}: missing natpport, no natp in sshall")

def main() -> None:
    """ Get a list of metadata for all valid monitors """
    monitors = ark_dory_query()
    if monitors is None or len(monitors) == 0:
        print('no monitors')
        return

    tree = _read_prefix2as()
    exceptions = load_exceptions(EXCEPTIONS_FILE)
    siblings_groups = [
        set(str(asn) for asn in entry['asns'])
        for entry in exceptions.get('asn_siblings', [])
    ]

    # need a unix timestamp to compare to the lastping values the api returns
    cutoff = (datetime.now() - timedelta(days=7)).timestamp()
    for x in monitors:
        check_monitor_metadata(x, cutoff, tree, exceptions, siblings_groups)

if __name__ == "__main__":
    main()
