#!/usr/bin/env python3

# crypto libraries look to do this, but only newer versions
# GPLv3 references:
# https://github.com/scheiblingco/sshkey-tools/blob/main/src/sshkey_tools/fields.py
# https://github.com/scheiblingco/sshkey-tools/blob/main/src/sshkey_tools/cert.py

from collections import defaultdict
from cryptography import x509
from cryptography.hazmat.backends import default_backend
from datetime import datetime, timedelta
from struct import unpack
# python3-psycopg (psycopg3) isn't available in jammy
import psycopg2


# XXX could there be so many certs we can't return them all, and should
# instead print directly from the cursor? or return the cursor?
# fetch all ssh certificates and print them in a list

def get_int32(data):
    value = int(unpack(">I", data[:4])[0])
    return value, data[4:]

def get_int64(data):
    value = int(unpack(">Q", data[:8])[0])
    return value, data[8:]

# do I care about the different encodings, bytestring vs string? assume
# all my data is well formed for now
def get_string(data):
    length = unpack(">I", data[:4])[0] + 4
    value = data[4:length]
    return value.decode("utf-8"), data[length:]

def get_bytestring(data):
    length = unpack(">I", data[:4])[0] + 4
    value = data[4:length]
    return value, data[length:]

def get_datetime(data):
    timestamp, data = get_int64(data)
    return datetime.fromtimestamp(timestamp), data

def get_list(data):
    raw, data = get_bytestring(data)
    items = []
    while len(raw) > 0:
        item, raw = get_string(raw)
        items.append(item)
    return items, data

def get_keyvalue(data):
    raw, data = get_bytestring(data)
    result = {}
    while len(raw) > 0:
        key, raw = get_string(raw)
        value, raw = get_bytestring(raw)
        if value != b"":
            value = get_string(value)[0]
        else:
            value = ""
        result[key] = value

    if "".join(result.values()) == "":
        return list(result.keys()), data

    return result, data


# TODO possibly worth trying to validate fields, like the sshkey-tools does?
class SSLCertificate:
    def __init__(self, data, now=datetime.now()):
        self.key_type, data = get_string(data)
        self.nonce, data = get_bytestring(data)
        self.curve, data = get_string(data)
        self.key, data = get_bytestring(data)
        self.serial, data = get_int64(data)
        self._cert_type, data = get_int32(data)
        self.key_id, data = get_string(data)
        self.principals, data = get_list(data)
        self.validafter, data = get_datetime(data)
        self.validbefore, data = get_datetime(data)
        self.critical, data = get_keyvalue(data)
        self.extensions, data = get_keyvalue(data)
        self.reserved, data = get_string(data)
        self.pubkey, data = get_bytestring(data)
        self.pubkey_type, _ = get_string(self.pubkey)
        self.signature, data = get_bytestring(data)
        self.signature_type, _ = get_string(self.signature)
        self._now = now.replace(second=0, microsecond=0)

    def __str__(self):
        #return f"{self.cert_type} {self.serial} {self.key_id} {self.expiry}"
        return f"{self.cert_type} {self.key_id} {self.expiry}"

    @property
    def expiry(self):
        seconds = seconds=(self.validbefore - self._now).total_seconds()
        delta = timedelta(seconds=abs(seconds))
        return f"{delta}" if seconds > 0 else f"-{delta}"

    @property
    def cert_type(self):
        return "USER" if self._cert_type == 1 else "HOST"


# TODO make a common certificate class? or at least make expiry fields similar
class X509Certificate():
    def __init__(self, base, now=datetime.now()):
        self._base = base
        self.cert_type = "X509"
        self._now = now.replace(second=0, microsecond=0)

    def __getattr__(self, attr):
        if hasattr(self._base, attr):
            return getattr(self._base, attr)
        raise AttributeError(attr)

    def __str__(self):
        #return f"{self.cert_type} {self.serial_number} {self.key_id} {self.expiry}"
        return f"{self.cert_type} {self.key_id} {self.expiry}"

    @property
    def expiry(self):
        seconds = seconds=(self.not_valid_after - self._now).total_seconds()
        delta = timedelta(seconds=abs(seconds))
        return f"{delta}" if seconds > 0 else f"-{delta}"

    @property
    def key_id(self):
        return self.subject.get_attributes_for_oid(x509.NameOID.COMMON_NAME)[0].value

    @property
    def validbefore(self):
        return self.not_valid_after

    @property
    def validafter(self):
        return self.not_valid_before




def get_ssh_certs():
    certs = []
    with psycopg2.connect("dbname=stepdb") as conn:
        with conn.cursor() as cursor:
            #cursor.execute("""
            #    SELECT encode(nvalue, 'base64') from ssh_certs
            #    """)
            cursor.execute("""
                SELECT nkey, nvalue from ssh_certs
                WHERE pg_xact_commit_timestamp(xmin) > now()-interval '90 days'
                """)
            for base64 in cursor:
                data = bytes(base64[1])
                cert = SSLCertificate(data)
                certs.append(cert)
    return certs


# TODO I've enabled commit timestamps, so use that to restrict to the last
# 60-90 days or so once all certs have been renewed with timestamps
def get_x509_certs():
    certs = []
    with psycopg2.connect("dbname=stepdb") as conn:
        with conn.cursor() as cursor:
            cursor.execute("""
                SELECT nkey, encode(nvalue, 'base64') from x509_certs
                WHERE pg_xact_commit_timestamp(xmin) > now()-interval '90 days'
                """)
            for base64 in cursor:
                data = ("-----BEGIN CERTIFICATE-----\n" +
                        base64[1] +
                        "\n-----END CERTIFICATE-----")
                cert = x509.load_pem_x509_certificate(data.encode("utf-8"), default_backend())
                certs.append(X509Certificate(cert))
    return certs


# print - certs about to expire that should have been renewed already
#       - certs that did expire
#       - certs that were renewed recently
#       - all cert changes for the last month
def print_certs(certs):
    now = datetime.now()
    later = now + timedelta(days=7)
    #later = now + timedelta(days=14)
    then = now - timedelta(days=7)
    #then = now - timedelta(days=14)

    certs.sort(key=lambda x: x.validbefore, reverse=True)

    #grouped = {(x.cert_type, x.key_id): value for cert in certs}
    grouped = defaultdict(list)
    for cert in certs:
        key = (cert.cert_type, cert.key_id)
        grouped[key].append(cert)

    print(f"========== {datetime.now().strftime('%a %d %b %Y %Z')} ==========")

    expired = [x[0] for x in grouped.values() if x[0].validbefore < now and x[0].validbefore > then]
    print(f"\nExpired and not renewed: {len(expired)} certificates")
    for cert in expired:
        print(f"\t{cert}")

    soon = [x[0] for x in grouped.values() if x[0].validbefore > now and x[0].validbefore < later]
    print(f"\nExpiring soon: {len(soon)} certificates")
    for cert in reversed(soon):
        print(f"\t{cert}")

    issued = [x[0] for x in grouped.values() if len(x) == 1 and x[0].validafter < now and x[0].validafter > then and x[0].validbefore > now]
    print(f"\nNewly issued: {len(issued)} certificates")
    for cert in issued:
        print(f"\t{cert}")

    renewed = [x[0] for x in grouped.values() if len(x) > 1 and x[0].validafter < now and x[0].validafter > then and x[0].validbefore > now]
    print(f"\nNewly renewed: {len(renewed)} certificates")
    for cert in renewed:
        print(f"\t{cert}")

    print("\n\nAll issued or expired certificates:")
    for group in grouped.values():
        for cert in group:
            print(f"\t{cert}")
        print("\t.")


# TODO filter by hostname, cert type
# TODO don't print IDs in summary?
# TODO summary report vs all certs in last year (configurable time period)
def main():
    #certs = get_ssh_certs()
    #print_certs(certs)

    #certs = get_x509_certs()
    #print_certs(certs)

    print_certs(get_ssh_certs() + get_x509_certs())

if __name__ == "__main__":
    main()
