#! /usr/bin/python3

# Stripped down version of
# https://gitlab.caida.org/CAIDA/ark/ark-ssh-aliases/-/blob/main/bin/ark-ssh.py

import argparse
import json
import os
import sys
import time

import caida_oidc_client
from requests_oauthlib import OAuth2Session

#+
# NAME:
#	ark-ssh-aliases
# PURPOSE:
#	Provides config and known_host info for Ark node, incl.
#	the aliases used to log into the nodes.
# CALLING SEQUENCE:
#  ark-ssh-aliases --config-file CONFIG_FILE --known-hosts-file KNOWN_HOSTS_FILE
# OPTIONAL INPUTS:
#	--config-file CONFIG_FILE
#			file to receive config info (the ssh alias definitions)
#	--known-hosts-file KNOWN_HOSTS_FILE
#			file to receive the public RSA keys of the nodes
#	--options Option=value
#			config options (can be used multiple times)

REALM = "CAIDA"
API_URL = "https://api.arkmon.caida.org"
AUTH_URL = f"https://auth.caida.org/realms/{REALM}/protocol/openid-connect"
CLIENT_ID = "arkmon-offline"


def load_token_info(token_file):
    with open(token_file, "r", encoding="ascii") as f:
        token_info = json.load(f)
    if "expires_at" not in token_info or token_info["expires_at"] < time.time():
        refresh_token = token_info['refresh_token']
        token_info.clear()
        token_info['refresh_token'] = refresh_token
        token_info['expires_in'] = -1
        token_info['access_token'] = 'dummy value for oauthlib'
    return token_info


# fetch all the monitor data from the api
def get_all_monitors(token_file):
    token_info = load_token_info(token_file)
    save_tokens = caida_oidc_client.make_save_tokens(token_file)

    # establish a new session, refreshing access token if necessary
    session = OAuth2Session(client_id=CLIENT_ID,
                token=token_info,
                auto_refresh_url=f"{AUTH_URL}/token",
                auto_refresh_kwargs={"client_id": CLIENT_ID},
                token_updater=save_tokens)

    if not session.authorized:
        print("Failed to authorize session, aborting")
        sys.exit(1)

    try:
        response = session.request("GET", f"{API_URL}/monitors/")
    except Exception as e:
        print(f"Failed to fetch {API_URL}/monitors: {e}")
        sys.exit(1)

    if response.status_code != 200:
        print(f"Got status code {response.status_code}, aborting")
        sys.exit(1)

    monitors = response.json()
    if len(monitors) == 0:
        print("Empty monitors list, aborting")
        sys.exit(1)

    return monitors


# extract just the information we need, only from valid monitors
def filter_monitor(monitor):
    if monitor.get("status") not in ["Active", "Inactive"]:
        return None
    try:
        return {
            "name": monitor["node"],
            "ssh_host_key_pub": monitor["sshpubkey"],
            "ostype": monitor.get("ostype", None),
            "natpport": monitor.get("natpport", None),
            "ipv4global": monitor.get("ipv4global", None),
            "ipv6global": monitor.get("ipv6global", None),
        }
    except KeyError:
        return None


def format_known_hosts(monitor):
    if monitor['ssh_host_key_pub'].startswith("@cert-authority"):
        return monitor['ssh_host_key_pub']
    return f"{monitor['name']} {monitor['ssh_host_key_pub']}"


def format_base_ssh_config(options):
    banner = ("#----------------------------------------------#\n"
              "# WARNING: Autogenerated. Do not edit by hand. #\n"
              "#----------------------------------------------#\n")

    if options:
        config = "Host *\n"
        for option in options:
            key, val = option.split("=")
            config += f"  {key} {val}\n"

    return [banner, config]


# generate ssh config for one monitor, preferring only natportal if possible
def format_ssh_config(monitor):
    label = alias = monitor["name"]
    # TODO is this distinction still required when there is only a single
    # entry in the config per host?
    if monitor['ssh_host_key_pub'].startswith("@cert-authority"):
        alias = f"{label}.natp.caida.org"

    # ideally, every monitor should be using nat portal
    if port := monitor["natpport"]:
        return (f"Host {label}\n"
                f"  HostKeyAlias {alias}\n"
                f"  HostName {label}.natp.caida.org\n"
                f"  Port {port}\n"
                 "  KeepAlive yes\n"
                 "  ServerAliveInterval 90\n")

    # otherwise we assume they have direct ipv4 or ipv6 connectivity for now
    if monitor["ostype"] == "FreeBSD":
        extra_options = "  PubkeyAcceptedKeyTypes +ssh-rsa\n"
    else:
        extra_options = ""

    # build deprecated ipv4 and ipv6 configuration if addresses exist
    deprecated = []
    if monitor["ipv4global"]:
        address = monitor["ipv4global"]
        deprecated.append(f"Host {label}\n"
                          f"  HostKeyAlias {alias}\n"
                          f"  HostName {address}\n"
                          f"{extra_options}")

    if monitor["ipv6global"]:
        address = monitor["ipv6global"]
        # prefer ipv4 as the primary if both ipv4 and ipv6 are available
        label = f"{label}-ipv6" if monitor["ipv4global"] else label
        deprecated.append(f"Host {label}\n"
                          f"  HostKeyAlias {alias}\n"
                          f"  HostName {address}\n"
                          f"{extra_options}")

    return "\n".join(deprecated)


def main():
    parser = argparse.ArgumentParser(
            description="Export SSH config for ark nodes")

    parser.add_argument("-t", "--token-file",
            default="/etc/ark/.arkmon-offline.token",
            help="name of file containing offline token (default: /etc/ark/.arkmon-offline.token)")

    parser.add_argument("--config-file",
            dest="config_file",
            action="store",
            default=None,
            help="Ark ssh_config file, default: stdout",
            )

    parser.add_argument("--known-hosts-file",
            dest="known_hosts_file",
            action="store",
            default=None,
            help="Ark ssh_known_hosts file, default: stdout",
            )

    parser.add_argument("-o","--option",
            dest="option",
            action="append",
            default=[
                "ForwardX11Trusted=yes",
                "ForwardAgent=yes",
                "HostKeyAlgorithms=+ssh-rsa,ssh-dss",
                "StrictHostKeyChecking=yes",
                "CheckHostIP=no"
                ],
            help="ssh-config options",
            )

    args = parser.parse_args()

    # get the list of valid monitors that need ssh configuration
    all_monitors = get_all_monitors(args.token_file)
    monitors = [y for x in all_monitors if (y := filter_monitor(x))]
    #monitors.sort()

    # extract the ssh host keys for each host
    known_hosts = [format_known_hosts(x) for x in monitors]
    if len(known_hosts) == 0:
        print("Empty known hosts data, aborting")
        sys.exit(1)

    # generate host specific config stanzas
    monitor_ssh_config = [format_ssh_config(x) for x in monitors]
    if len(monitor_ssh_config) == 0:
        print("Empty ssh config file, aborting")
        sys.exit(1)

    # data is all generated ok, so write out the known hosts file
    if args.known_hosts_file:
        with open(args.known_hosts_file, "w", encoding="ascii") as hostsfile:
            hostsfile.write("\n".join(known_hosts) + "\n")
    else:
        print("=== ark_known_hosts\n" + "\n".join(known_hosts))

    # add config boilerplate and write out the ssh config too
    config = format_base_ssh_config(args.option) + monitor_ssh_config
    if args.config_file:
        with open(args.config_file, "w", encoding="ascii") as conffile:
            conffile.write("\n".join(config) + "\n")
    else:
        print("=== ark_ssh_config\n" + "\n".join(config))


if __name__ == "__main__":
    main()
