#!/usr/bin/env -S python3 -u

# curl https://ca.caida.org/autoconf
# curl https://ca.caida.org/activity/fireball

from http import HTTPStatus
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
import ipaddress
import json
import os
import re
import subprocess
import sys
from threading import Lock
import time

import caida_oidc_client
from requests_oauthlib import OAuth2Session

# Check if a request comes from an ip address of a monitor in dory, and
# return the matching hostname, ssh port, and a certificate token so that
# a container-based node can configure itself.
# {
#    "hostname": "foo-zz.ark.caida.org",
#    "port": "44001"
#    "token": "abcdefg...",
# }

HOSTNAME = "127.0.0.1"
PORT = 8083
KEY = "/etc/step-ca/secrets/jwk-provisioner_ca.ark.caida.org.key"
PROVISIONER = "jwk@ca.ark.caida.org"

REALM = "CAIDA"
API_URL = "https://api.arkmon.caida.org"
AUTH_URL = f"https://auth.caida.org/realms/{REALM}/protocol/openid-connect"
CLIENT_ID = "arkmon-offline"
TOKEN_FILE = "/etc/ark/.arkmon-offline.token"

SUBJECT_REGEX = re.compile(r"CN=([a-zA-Z0-9-]+)\.ark\.caida\.org")
DEFAULT_TEAM_PROBING_PPS = 100

def load_token_info(token_file):
    with open(token_file, "r", encoding="ascii") as f:
        token_info = json.load(f)
    if "expires_at" not in token_info or token_info["expires_at"] < time.time():
        refresh_token = token_info['refresh_token']
        token_info.clear()
        token_info['refresh_token'] = refresh_token
        token_info['expires_in'] = -1
        token_info['access_token'] = 'dummy value for oauthlib'
    return token_info


def ark_dory_query():
    token_info = load_token_info(TOKEN_FILE)
    save_tokens = caida_oidc_client.make_save_tokens(TOKEN_FILE)

    # establish a new session, refreshing access token if necessary
    with OAuth2Session(client_id=CLIENT_ID,
                token=token_info,
                auto_refresh_url=f"{AUTH_URL}/token",
                auto_refresh_kwargs={"client_id": CLIENT_ID},
                token_updater=save_tokens) as session:

        if not session.authorized:
            print("Failed to authorize session, aborting")
            return None

        try:
            response = session.request("GET", f"{API_URL}/monitors/", timeout=3)
        except Exception as e:
            print(f"Failed to fetch {API_URL}/monitors: {e}")
            return None

        if response.status_code != 200:
            print(f"Got status code {response.status_code}, aborting")
            return None

        monitors = response.json()
        if len(monitors) == 0:
            print("Empty monitors list, aborting")
            return None

    return monitors


class ConfHandler(BaseHTTPRequestHandler):
    # replace the log function with one that understands X-Forwarded-For
    def log_message(self, format, *args):
        address = self.headers.get("X-Forwarded-For", self.address_string())
        message = format % args
        sys.stderr.write("%s - - [%s] %s\n" %
                         (address,
                          self.log_date_time_string(),
                          message.translate(self._control_char_table)))


    def do_GET(self):
        if self.path == "/autoconf":
            self.autoconf()
        elif self.path.startswith("/activity/"):
            try:
                activity = self.path.strip("/ ").split("/")[1]
            except IndexError:
                self.send_error(HTTPStatus.BAD_REQUEST)
                print(f"Bad path: {self.path}")
                return

            if activity == "fireball":
                self.fireball_config()
            elif activity == "team-probing":
                self.team_probing_config()
            else:
                self.send_error(HTTPStatus.BAD_REQUEST,
                        message=f"Bad activity: {activity}")
                print(f"Bad activity: {activity}")
        else:
            self.send_error(HTTPStatus.NOT_FOUND)


    def get_name(self):
        subject = self.headers.get("X-SSL-Client-Cert-Subject-DN", None)
        if subject is None:
            self.send_error(HTTPStatus.BAD_REQUEST,
                    message="Missing X-SSL-Client-Cert-Subject-DN header")
            return None
        match = SUBJECT_REGEX.search(subject)
        if match is None:
            self.send_error(HTTPStatus.BAD_REQUEST,
                    message=f"Badly formed certificate subject: {subject}")
            return None
        name = match.group(1)
        return name


    # XXX if we were willing to make primitives public then this could use
    # the https://api.arkmon.caida.org/public/monitors endpoint. There's almost
    # certainly other things we want to configure that are best kept private
    # however, so probably still need a configuration server like this anyway
    def fireball_config(self):
        name = self.get_name()
        if name is None:
            return

        # XXX should get_monitors send an error? Or can we do that here?
        monitors = self.get_monitors()
        if monitors is None or len(monitors) == 0:
            return

        # check we have configuration for this node
        monitor = self.find_monitor("node", name, monitors)
        if monitor is None:
            self.send_error(HTTPStatus.NOT_FOUND, message="Monitor not found")
            return

        # check that fireball is enabled on this node
        if "Fireball" not in monitor.get("activities", []):
            self.send_error(HTTPStatus.NOT_FOUND, message="Fireball disabled")
            return

        # list the availability of all the specific measurement primitives
        primitives = monitor.get("scamper_primitives", [])
        config = (
                f"# ark-activity-fireball configuration for {name}\n"
                "# automatically generated, any changes will be overwritten\n"
                "dealias.enable=1\n"
                f"host.enable={1 if 'dns' in primitives else 0}\n"
                f"http.enable={1 if 'http' in primitives else 0}\n"
                "neighbourdisc.enable=0\n"
                f"owamp.enable={1 if 'owamp' in primitives else 0}\n"
                "ping.enable=1\n"
                f"quic.enable={1 if 'quic' in primitives else 0}\n"
                "sniff.enable=0\n"
                "sting.enable=0\n"
                "tbit.enable=0\n"
                "trace.enable=1\n"
                "tracelb.enable=1\n"
                f"udpprobe.enable={1 if 'udp' in primitives else 0}\n")

        # send it back as plain text for the client to save as a config file
        self.send_response(HTTPStatus.OK)
        self.send_header("Content-Type", "text/plain")
        self.end_headers()
        self.wfile.write(config.encode("utf-8"))


    # TODO lots of this code will be repeated across activities, tidy that up
    def team_probing_config(self):
        name = self.get_name()
        if name is None:
            return

        monitors = self.get_monitors()
        if monitors is None or len(monitors) == 0:
            self.send_error(HTTPStatus.INTERNAL_SERVER_ERROR,
                    message="No monitors found")
            return

        # check we have configuration for this node
        monitor = self.find_monitor("node", name, monitors)
        if monitor is None:
            self.send_error(HTTPStatus.NOT_FOUND, message="Monitor not found")
            return

        # check that team probing is enabled on this node
        if "IPv4 team probing" not in monitor.get("activities", []):
            self.send_error(HTTPStatus.NOT_FOUND, message="Team probing disabled")
            return

        # list all the scamper configuration options
        config = monitor.get("scamper_config", [])
        relevant = [parts[1] for x in config if (parts := x.split(":"))[0] == "team-probing"]
        # add PPS if it isn't explicitly set
        if not any("PPS" in s for s in relevant):
            relevant += [f"PPS={DEFAULT_TEAM_PROBING_PPS}"]
        relevant.sort()

        header = [
            f"# ark-activity-team-probing configuration for {name}",
            "# automatically generated, any changes will be overwritten"
        ]

        # send it back as plain text for the client to save as a config file
        self.send_response(HTTPStatus.OK)
        self.send_header("Content-Type", "text/plain")
        self.end_headers()
        self.wfile.write("\n".join(header + relevant).encode("utf-8"))


    def autoconf(self):
        # there should always be an address
        address = self.headers.get("X-Forwarded-For")
        if address is None:
            self.send_response(HTTPStatus.FORBIDDEN)
            self.end_headers()
            print(f"Unknown ip address: {address}")
            return

        # there might be a token if the monitor thinks it can do autoconf
        token = self.headers.get("X-Auth-Token", None)

        # find the configuration matching the address or token
        config = self.get_node_config(address, token)
        if config is None:
            self.send_response(HTTPStatus.NOT_FOUND)
            self.end_headers()
            print(f"No config for address: {address}")
            return

        # check we've got the bits that we need
        print(config)
        if "hostname" not in config or "port" not in config:
            self.send_response(HTTPStatus.INTERNAL_SERVER_ERROR,
                    message="Missing configuration items")
            self.end_headers()
            print(f"Missing configuration items for: {address}")
            return

        # there is a config, generate and return a token to add to it
        try:
            token = subprocess.run(
                [
                    "step", "ca", "token",
                    "--key", KEY,
                    "--provisioner", PROVISIONER,
                    f"{config['hostname']}.ark.caida.org"
                ],
                capture_output=True,
                check=True
            )
        except subprocess.CalledProcessError:
            self.send_response(HTTPStatus.INTERNAL_SERVER_ERROR,
                    message="Failed to generate token")
            self.end_headers()
            print("Failed to generate token")
            return

        print(f"Token generated for {config['hostname']} ({address})")

        # send success headers
        self.send_response(HTTPStatus.OK)
        self.send_header("Content-Type", "text/plain")
        self.end_headers()

        # add the token to config and send it all off as json
        config["token"] = token.stdout.decode("utf-8").rstrip()
        self.wfile.write(json.dumps(config).encode("utf-8"))


    def find_monitor(self, field, value, haystack):
        matches = [x for x in haystack if field in x and x[field] == value]
        if len(matches) != 1:
            return None
        return matches[0]


    def get_monitors(self):
        now = time.time()
        # TODO is it worth having finer grained locking? Or being able to
        # return old data while the new data is being fetched?
        with self.server.lock:
            # update cache if required
            if self.server.lastfetch < now - self.server.cachetime:
                self.server.lastfetch = now
                monitors = ark_dory_query()
                # don't update the cache with empty data, so we can continue
                # serving if the backend goes away
                if monitors:
                    self.server.cached_monitors = monitors
            # use a copy so it doesn't matter if the monitors list gets
            # updated while a thread is iterating over it
            return self.server.cached_monitors.copy()


    def get_node_config(self, address=None, auth=None):
        # XXX is it worth limiting the query to PENDING monitors, and setting
        # them to ACTIVE afterwards, so the config is single use?
        # XXX what about only allowing auth tokens for pending monitors (then
        # set to active), and always allowing ip address configuration in any
        # monitor state. The ip address will get updated by pinghome once the
        # monitor has successfully booted, just might be an issue if people
        # with dynamic ip addresses aren't using persistent storage
        monitors = self.get_monitors()
        if monitors is None or len(monitors) == 0:
            return None

        monitor = None

        # TODO should auth be AND'd or OR'd with address if both are present?
        # check auth token first if it is present
        if auth:
            monitor = self.find_monitor("auth_token", auth, monitors)

        # if auth token failed or is missing, try using ip address
        if monitor is None:
            if ipaddress.ip_address(address).version == 4:
                monitor = self.find_monitor("ipv4global", address, monitors)
            else:
                monitor = self.find_monitor("ipv6global", address, monitors)

        # no match for ip address or auth token, we can't auto configure
        if monitor is None:
            # return an interesting header too?
            return None

        # a node needs a name and a natportal port number
        try:
            name = monitor["node"]
            port = int(monitor["natpport"])
        except KeyError:
            return None

        return {
            "hostname": name,
            "port": port,
        }


def main():
    # the sessions should be quick enough not to need threading here, which
    # avoids having to deal with race conditions in the cache for now
    server = ThreadingHTTPServer((HOSTNAME, PORT), ConfHandler)
    server.allow_reuse_address = True
    server.lastfetch = 0
    server.cachetime = 60
    server.cached_monitors = []
    server.lock = Lock()

    try:
        server.serve_forever()
    except KeyboardInterrupt:
        server.server_close()


if __name__ == "__main__":
    main()
