#! /usr/bin/python3

import argparse
from email.message import EmailMessage
from datetime import datetime, timedelta
import json
import smtplib
import sys
import time

import caida_oidc_client
import psutil
from requests_oauthlib import OAuth2Session

FROM = 'monitor-info@caida.org'
SUBJECT = 'non-responsive ark monitors'
MAILSERVER = 'rommie.caida.org'
API_URL = "https://api.arkmon.caida.org"
CLIENT_ID = "arkmon-offline"

TOKEN_FILE = "/etc/ark/.arkmon-offline.token"

def load_token_info(token_file):
    with open(token_file, "r", encoding="ascii") as f:
        token_info = json.load(f)
    if "expires_at" not in token_info or token_info["expires_at"] < time.time():
        refresh_token = token_info['refresh_token']
        token_info.clear()
        token_info['refresh_token'] = refresh_token
        token_info['expires_in'] = -1
        token_info['access_token'] = 'dummy value for oauthlib'
    return token_info


def get_listening_tcp_sockets():
    for conn in psutil.net_connections(kind='tcp'):
        if conn.status == 'LISTEN':
            yield conn


def get_monitors(client_id: str, url: str, token_file: str,
                 realm: str = 'CAIDA'):
    token_info = load_token_info(token_file)
    save_tokens = caida_oidc_client.make_save_tokens(token_file)
    refresh_data = caida_oidc_client.jwt_decode(token_info['refresh_token'])
    token_url = refresh_data["iss"] + '/protocol/openid-connect/token'

    # establish a new session, refreshing access token if necessary
    session = OAuth2Session(client_id=client_id,
                token=token_info,
                auto_refresh_url=token_url,
                auto_refresh_kwargs={"client_id": client_id},
                token_updater=save_tokens)
    if not session.authorized:
        print("Failed to authorize session, aborting")
        return None

    try:
        response = session.request("GET", f"{url}/monitors/")
    except Exception as e:
        print(f"Failed to fetch {url}/monitors/: {e}")
        return None

    if response.status_code != 200:
        print(f"Got status {response.status_code}, aborting", file=sys.stderr)
        return None

    monitors = response.json()
    if len(monitors) == 0:
        print("Empty monitors list, aborting")
        return None
    return monitors


def send_email(mon_info: dict, recipient: str):
    html_rows = []
    txt_rows = []
    items = sorted(mon_info.items(), key=lambda kv: kv[1]['lastping'],
                   reverse=True)
    for mon_id,info in items:
        ping = info['lastping']
        now_sec = datetime.now().timestamp()
        td = timedelta(seconds=now_sec-ping)
        ago = f'{td.days:2} days {td.seconds / 3600:5.02f} hours'
        ssh_info = ' (no SSH)' if info.get('no_ssh') else ''
        ping_url = f'https://dory.caida.org/node/ping/{mon_id}'
        view_url = f'https://dory.caida.org/node/view/{mon_id}'
        html_rows.append(f'<a href="{view_url}">{mon_id}</a> '
                         f'<a href="{ping_url}">last pinged</a> '
                         f'{ago} ago{ssh_info}')
        txt_rows.append(f'{mon_id:7} last pinged {ago} ago{ssh_info}\n  {ping_url}')
    msg = EmailMessage()
    msg['Subject'] = SUBJECT
    msg['From'] = FROM
    msg['To'] = recipient
    msg.set_content('\n'.join(txt_rows))
    br = '<br/>\n  '
    msg.add_alternative(f'''
<html>
  <head></head>
  <body>
  {br.join(html_rows)}
  </body>
</html>
    ''', subtype='html')
    s = smtplib.SMTP(MAILSERVER)
    s.send_message(msg)
    s.quit()


#====================================================================

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--token-file', default=TOKEN_FILE)
    parser.add_argument('--recipient', required=True)
    args = parser.parse_args()

    month_sec = timedelta(days=30).total_seconds()
    day_sec = timedelta(days=1).total_seconds()
    now_sec = datetime.now().timestamp()

    all_mons = get_monitors(client_id=CLIENT_ID, url=API_URL,
                            token_file=args.token_file)
    if all_mons is None:
        sys.exit()
    # Only monitors seen in the last month are active
    active_mons = {m['node']: m for m in all_mons
                   if m.get('lastping') is not None
                   and m['lastping'] > now_sec - month_sec}
    # Warn if active monitor hasn't been seen in last 24 hours
    warn_mons = {node: m for node,m in active_mons.items()
                 if m['lastping'] < now_sec - day_sec}

    listening_ports = set()
    for conn in get_listening_tcp_sockets():
        listening_ports.add(int(conn.laddr.port))

    # Every monitor should have a natpport assigned and listed in the database.
    # Check that there is a matching listening port for each natpport.
    for mon_id,info in active_mons.items():
        port = info.get("natpport")
        if port and port not in listening_ports:
            warn_mons[mon_id] = info
            warn_mons[mon_id]['no_ssh'] = True

    if warn_mons:
        send_email(warn_mons, args.recipient)


if __name__ == "__main__":
    main()
