#!/usr/bin/env python3

# TODO can we watch the database for changes and update as they happen?
# TODO how filter it down to recently added nodes? Track last run time,
# just look at nodes from the last fixed time period?

import argparse
import json
import os
import pwd
import subprocess
import time

import caida_oidc_client
import pystemd
from requests_oauthlib import OAuth2Session

REALM = "CAIDA"
API_URL = "https://api.arkmon.caida.org"
AUTH_URL = f"https://auth.caida.org/realms/{REALM}/protocol/openid-connect"
CLIENT_ID = "arkmon-offline"
TOKEN_FILE = "/etc/ark/.arkmon-offline.token"

# 40000: transitional until arkmon-ui is used
# 44000: normal allocation range
CONTAINER_PORT_BASE = 40000


def load_token_info(token_file):
    with open(token_file, "r", encoding="ascii") as f:
        token_info = json.load(f)
    if "expires_at" not in token_info or token_info["expires_at"] < time.time():
        refresh_token = token_info['refresh_token']
        token_info.clear()
        token_info['refresh_token'] = refresh_token
        token_info['expires_in'] = -1
        token_info['access_token'] = 'dummy value for oauthlib'
    return token_info


def ark_dory_query(token_file):
    token_info = load_token_info(token_file)
    save_tokens = caida_oidc_client.make_save_tokens(token_file)

    # establish a new session, refreshing access token if necessary
    session = OAuth2Session(client_id=CLIENT_ID,
                token=token_info,
                auto_refresh_url=f"{AUTH_URL}/token",
                auto_refresh_kwargs={"client_id": CLIENT_ID},
                token_updater=save_tokens)

    if not session.authorized:
        print("Failed to authorize session, aborting")
        return None

    try:
        response = session.request("GET", f"{API_URL}/monitors/")
    except Exception as e:
        print(f"Failed to fetch {API_URL}/monitors: {e}")
        return None

    if response.status_code != 200:
        print(f"Got status code {response.status_code}, aborting")
        return None

    monitors = response.json()
    if len(monitors) == 0:
        print("Empty monitors list, aborting")
        return None

    return monitors


def is_passwd_set(name):
    try:
        pwd.getpwnam(name)
    except KeyError:
        return False
    return True


def is_sshd_set(name):
    if os.path.isfile(f"/etc/ark/ssh/sshd_config.d/{name}.conf"):
        return True
    return False


def set_passwd(name, port, dryrun):
    command = f"""
        useradd                             \
            --shell /bin/false              \
            --home-dir /nonexistent         \
            --no-create-home                \
            --uid "{port}"                  \
            --gid nogroup                   \
            --no-user-group                 \
            --comment "NAT reverse tunnel"  \
            --password '*'                  \
            "{name}"
        """

    print(f"Adding user {name} ({port})...")

    if dryrun:
        return True

    try:
        subprocess.run(command, shell=True, check=True, capture_output=True,
                encoding="ascii")
    except subprocess.CalledProcessError as e:
        print(e)
        print(e.stdout)
        print(e.stderr)
        return False
    return True


# update sshd server conf to allow port forwarding of one specific port
def set_sshd(name, port, dryrun):
    print(f"Configuring sshd for {name}...")
    filename = f"/etc/ark/ssh/sshd_config.d/{name}.conf"
    print(f"Writing sshd config to {filename}")

    if dryrun:
        return True

    try:
        with open(filename, "w", encoding="ascii") as f:
            f.write(f"Match User {name}\n\tPermitListen {port}\n")
    except (PermissionError, OSError) as e:
        print(e)
        return False

    return True


def main():
    parser = argparse.ArgumentParser(
            description="Add natp-ssh config for container-based ark monitors")

    parser.add_argument(
        "-s", "--seconds",
        default=60*60,
        type=int,
        help="time in seconds to search backwards for new nodes (default=3600)")

    parser.add_argument(
        "-d", "--dry-run",
        default=False,
        dest="dryrun",
        action="store_true",
        help="don't make any changes, just print what would happen")

    parser.add_argument(
        "-t", "--token-file",
        default=TOKEN_FILE,
        help=f"file containing offline token (default: {TOKEN_FILE})")

    args = parser.parse_args()

    # fetch the list of monitors
    monitors = ark_dory_query(args.token_file)
    if monitors is None or len(monitors) == 0:
        return

    changed = False
    threshold = time.time() - args.seconds

    # make sure they all have accounts and ssh config
    for monitor in monitors:
        # only deal with active container-based nodes
        if monitor.get("hwtype") != "Container":
            continue
        if monitor.get("status") != "Active":
            continue

        # ignore any nodes that are too old
        if monitor.get("activation", 0) < threshold:
            continue

        # a node needs a name and a natportal port number
        try:
            name = monitor["node"]
            port = int(monitor["natpport"])
        except KeyError:
            continue

        # some older containers might actually be VMs and use the regular
        # natportal configuration, leave them alone
        if port < CONTAINER_PORT_BASE:
            continue

        # create the user account if it doesn't exist
        if not is_passwd_set(name):
            if not set_passwd(name, port, args.dryrun):
                continue

        # create the sshd config if it doesn't exist
        if not is_sshd_set(name):
            if set_sshd(name, port, args.dryrun):
                changed = True

    # reload ssh to pick up newly allowed forwarding ports
    if changed and not args.dryrun:
        with pystemd.systemd1.Unit("natp-sshd-container.service") as sshd:
            print("Reloading natp-sshd-container.service")
            sshd.Reload(b"replace")


if __name__ == "__main__":
    main()
