#!/usr/bin/env python3

# fetch metadata from dory, and format it for the scamper python api

import argparse
from datetime import datetime, timedelta
import gzip
import io
import json
import re
import sys
import time
from unidecode import unidecode

import caida_oidc_client
import radix
import requests
from requests_oauthlib import OAuth2Session

REALM = "CAIDA"
API_URL = "https://api.arkmon.caida.org"
AUTH_URL = f"https://auth.caida.org/realms/{REALM}/protocol/openid-connect"
CLIENT_ID = "arkmon-offline"
TOKEN_FILE = "/etc/ark/.arkmon-offline.token"
P2A_URL = "https://publicdata.caida.org/datasets/routing/routeviews-prefix2as"

# metadata is stored in scamper something like this:
#   char              *name;
#   char              *ipv4;
#   char              *asn4;
#   char              *cc;
#   char              *st;
#   char              *place;
#   char              *latlong;
#   char              *shortname;
#
# and stored in the metadata file like this:
#   hlz2-nz.ark.caida.org cc nz
#   hlz2-nz.ark.caida.org st wko
#   hlz2-nz.ark.caida.org shortname hlz2-nz
#   hlz2-nz.ark.caida.org place Hamilton
#
# notes:
#   * all ascii text, convert utf8 characters
#   * just pick the first asn for now, if there are multiple
#   * combine lat/long into a comma separated string
#   * don't set anything that doesn't have a value


def prefix2as_most_recent(url: str) -> str:
    response = requests.get(f"{url}/pfx2as-creation.log", timeout=10)
    response.raise_for_status()
    data = response.content

    for line in data.decode().splitlines():
        if line.startswith("#"):
            continue
        row = line.split()
        most_recent = row[2]

    return most_recent


def prefix2as_read() -> radix.Radix | None:
    tree = radix.Radix()

    try:
        most_recent = prefix2as_most_recent(P2A_URL)

        response = requests.get(f"{P2A_URL}/{most_recent}", timeout=10)
        response.raise_for_status()
        data = response.content

        with gzip.open(io.BytesIO(data), mode="rt") as inf:
            for line in inf:
                if line.startswith("#"):
                    continue
                row = line.split()
                (pfx, pfx_len) = (row[0], int(row[1]))
                if pfx_len < 8 or pfx_len > 24:
                    continue
                pfx_str = f"{pfx}/{pfx_len}"
                origin_str = row[2]

                if origin_str.isdecimal():
                    origin = [int(origin_str)]
                else:
                    origin = [int(a) for a in re.split(r"[,_]", origin_str)]

                node = tree.add(pfx_str)
                node.data["origin"] = sorted(origin)

    except Exception as exc:
        print(f"{exc}", file=sys.stderr, flush=True)
        return None

    return tree


def load_token_info(token_file):
    with open(token_file, "r", encoding="ascii") as f:
        token_info = json.load(f)
    if "expires_at" not in token_info or token_info["expires_at"] < time.time():
        refresh_token = token_info["refresh_token"]
        token_info.clear()
        token_info["refresh_token"] = refresh_token
        token_info["expires_in"] = -1
        token_info["access_token"] = "dummy value for oauthlib"
    return token_info


def ark_dory_query():
    token_info = load_token_info(TOKEN_FILE)
    save_tokens = caida_oidc_client.make_save_tokens(TOKEN_FILE)

    # establish a new session, refreshing access token if necessary
    session = OAuth2Session(
        client_id=CLIENT_ID,
        token=token_info,
        auto_refresh_url=f"{AUTH_URL}/token",
        auto_refresh_kwargs={"client_id": CLIENT_ID},
        token_updater=save_tokens,
    )

    if not session.authorized:
        print("Failed to authorize session, aborting")
        return None

    try:
        response = session.request("GET", f"{API_URL}/monitors/")
    except Exception as e:
        print(f"Failed to fetch {API_URL}/monitors: {e}")
        return None

    if response.status_code != 200:
        print(f"Got status code {response.status_code}, aborting")
        return None

    monitors = response.json()
    if len(monitors) == 0:
        print("Empty monitors list, aborting")
        return None

    return monitors


def get_os_labels(monitor):
    labels = []
    if osname := monitor.get("osname"):
        osname = osname.lower()
        if "freebsd" in osname:
            # e.g. FreeBSD 10.1
            labels.append("os:freebsd")
            labels.append(f"os:{osname.replace(' ', '_')}")
        elif "debian" in osname or "raspbian" in osname:
            # e.g. Debian GNU/Linux 12 (bookworm)
            labels.append("os:linux")
            labels.append("os:debian")
            if release := re.search(r".*\(([a-z0-9]+)\)", osname):
                labels.append(f"os:debian_{release.group(1)}")
        elif "ubuntu" in osname:
            # e.g. Ubuntu 22.04.3 LTS
            labels.append("os:linux")
            labels.append("os:ubuntu")
            if release := re.search(r"[a-z]+ ([0-9]+\.[0-9]+).*", osname):
                labels.append(f"os:ubuntu_{release.group(1)}")
        else:
            labels.append(f"os:{osname}")
    # else:
    #    labels.append("os:unknown")
    return labels


def get_hardware_labels(monitor):
    labels = []
    if hwtype := monitor.get("hwtype"):
        hwtype = hwtype.lower()
        if "container" in hwtype or "vm" in hwtype:
            labels.append("hardware:virtual")
            labels.append(f"hardware:{hwtype}")
        else:
            labels.append("hardware:physical")
        if arch := monitor.get("cpuarch"):
            labels.append(f"hardware:{arch.lower()}")
    # else:
    #    labels.append("hardware:unknown")
    return labels


def per_monitor_metadata(monitor, tree, cutoff):
    """Extract and format the metadata from a single monitor dictionary"""
    if monitor.get("node") is None:
        return None

    if monitor.get("lastping", 0) < cutoff:
        return None

    metadata = {
        "name": f"{monitor['node']}.ark.caida.org",
        "shortname": monitor["node"],
        "tag": get_os_labels(monitor) + get_hardware_labels(monitor),
    }

    if country := monitor.get("country"):
        metadata["cc"] = country
    if state := monitor.get("state"):
        metadata["st"] = state
    if city := monitor.get("city"):
        metadata["place"] = city
    if iata := monitor.get("iata"):
        metadata["iata"] = iata
    if (lat := monitor.get("latitude")) and (long := monitor.get("longitude")):
        metadata["latlong"] = f"{float(lat):.2f},{float(long):.2f}"

    # ip addresses never get removed from dory, so instead check we've
    # received a recent ping for each address family before adding them
    if monitor.get("lastping4", 0) > cutoff:
        metadata["tag"].append("network:ipv4")
        if monitor.get("nat"):
            metadata["tag"].append("network:nat4")
        if ipv4 := monitor.get("ipv4global"):
            metadata["ipv4"] = ipv4
        if ipv4 and tree and (node := tree.search_best(ipv4)):
            ases = node.data["origin"]
            metadata["asn4"] = ases[0]
            if (
                (asn4 := monitor.get("asn"))
                and asn4.isdecimal()
                and int(asn4) in node.data["origin"]
            ):
                metadata["asn4"] = int(asn4)
        elif asn4 := monitor.get("asn"):
            if isinstance(asn4, str):
                metadata["asn4"] = asn4.split("_")[0]
            else:
                metadata["asn4"] = asn4

    if monitor.get("lastping6", 0) > cutoff:
        metadata["tag"].append("network:ipv6")
        if monitor.get("nat6"):
            metadata["tag"].append("network:nat6")

    if primitives := monitor.get("scamper_primitives"):
        metadata["tag"] += [f"primitive:{x}" for x in primitives]

    return metadata


def get_metadata():
    """Get a list of metadata for all valid monitors"""
    monitors = ark_dory_query()
    if monitors is None or len(monitors) == 0:
        return None
    # need a unix timestamp to compare to the lastping values the api returns
    cutoff = (datetime.now() - timedelta(days=7)).timestamp()

    # need a prefix2as tree to fill out the ASN field.
    tree = prefix2as_read()

    return [y for x in monitors if (y := per_monitor_metadata(x, tree, cutoff))]


def write_metadata(metadata, filename):
    """Write all the metadata to a file in the fireball text format"""
    with open(filename, "w", encoding="ascii") if filename else sys.stdout as f:
        for monitor in metadata:
            for key, value in monitor.items():
                if key == "name":
                    continue
                if isinstance(value, list):
                    for item in value:
                        f.write(unidecode(f"{monitor['name']} {key} {item}\n"))
                else:
                    f.write(unidecode(f"{monitor['name']} {key} {value}\n"))


def main():
    parser = argparse.ArgumentParser(description="Fetch fireball metadata")
    parser.add_argument(
        "filename",
        nargs="?",
        default=None,
        help="File to write metadata to (default: stdout)",
    )
    args = parser.parse_args()

    metadata = get_metadata()
    if not metadata:
        print("Failed to fetch metadata")
        sys.exit(1)
    write_metadata(metadata, args.filename)


if __name__ == "__main__":
    main()
