#!/usr/bin/env python3
"""
ark-data-lint - Sanity-check CAIDA Ark warts files.

Per-VP aggregation across all provided warts files.  Flags VPs whose
TTL rewrite rate exceeds --ttl-rewrite-threshold (TTL) or whose
response rate falls below --response-threshold (RESP).  A VP can
carry either, both, or neither flag.

Run with --help for the full option list.

Usage:
    python3 ark-data-lint [options] <file> [<file> ...]
    python3 ark-data-lint [options] --ipv4-dir PATH --ipv6-dir PATH
"""

import argparse
import enum
import glob
import os
import subprocess
import sys
from collections import defaultdict
from dataclasses import dataclass, field
from datetime import date, timedelta
from typing import NamedTuple

try:
    from scamper import ScamperFile, ScamperTrace, ScamperTraceStop
except ImportError:
    print(
        "ERROR: scamper Python module not found.\n"
        "Build from https://www.caida.org/catalog/software/scamper/",
        file=sys.stderr,
    )
    sys.exit(1)


# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------

WARTS_EXTENSIONS = (".warts", ".warts.gz")

# Leading dot-component that identifies the IPv6 prefix-probing
# filename format; used to distinguish it from the IPv4 team-probing
# format when parsing filenames.
IPV6_FILENAME_PREFIX = "topo-v6"

# Stop reasons that indicate the destination produced a terminal
# reply.  Completed = probe reached destination; Unreach = ICMP
# unreachable from destination; Icmp = other ICMP error (typically
# from destination or very near it).  All other stop reasons
# (GapLimit, Loop, HopLimit, NoReason, GSS, Halted, Error,
# InProgress) indicate no response from the destination.
RESPONDING_STOP_REASONS = frozenset({
    ScamperTraceStop.Completed,
    ScamperTraceStop.Unreach,
    ScamperTraceStop.Icmp,
})

# Flag labels shown in the report and counted in the summary.
# Centralized so the help text, vp_flags(), and summary stay in sync
# if new checks are added later.
FLAG_TTL = "TTL"
FLAG_RESP = "RESP"


# ---------------------------------------------------------------------------
# Trace classification
# ---------------------------------------------------------------------------


class TtlCheck(enum.Enum):
    """Outcome of the TTL rewrite check on a single trace."""

    REWRITTEN = "rewritten"
    # A terminal ICMP error was found, and its quoted TTL exceeded
    # the probe TTL by more than the tolerance.

    NO_REWRITE = "no_rewrite"
    # A terminal ICMP error was found; its quoted TTL was within
    # tolerance.

    NO_TERMINAL_HOP = "no_terminal_hop"
    # No terminal ICMP error hop was found in this trace, so the
    # TTL check cannot be applied.  This is independent of whether
    # the destination responded: e.g. a stop_reason=Completed trace
    # typically has no terminal ICMP error hop but did reach the
    # destination.


class TraceResult(NamedTuple):
    """Independent per-trace signals extracted in a single pass."""
    ttl: TtlCheck
    responded: bool


# ---------------------------------------------------------------------------
# Data structures
# ---------------------------------------------------------------------------


@dataclass
class VpStats:
    vp: str
    traces: int = 0
    ttl_rewrites: int = 0
    ttl_no_terminal: int = 0  # traces with no terminal ICMP error hop
    responding: int = 0       # traces where destination replied
    errors: int = 0           # exceptions during processing

    @property
    def ttl_eligible(self) -> int:
        """Traces where the TTL check could produce a verdict."""
        return self.traces - self.ttl_no_terminal

    @property
    def ttl_rewrite_rate(self) -> float:
        """Fraction of TTL-eligible traces that were rewritten."""
        return (
            self.ttl_rewrites / self.ttl_eligible
            if self.ttl_eligible > 0 else 0.0
        )

    @property
    def response_rate(self) -> float:
        """Fraction of all traces where the destination replied."""
        return (
            self.responding / self.traces
            if self.traces > 0 else 0.0
        )


@dataclass
class DatasetResult:
    """Results for one dataset (ipv4, ipv6, or explicit files)."""
    label: str
    directory: str  # source directory, or "(command line)"
    target_date: str  # YYYY-MM-DD or ""
    file_count: int
    vp_stats: list[VpStats] = field(default_factory=list)


# ---------------------------------------------------------------------------
# Filename parsing and filtering
# ---------------------------------------------------------------------------


def is_warts_file(path: str) -> bool:
    """Return True if the path has a warts file extension."""
    return path.endswith(WARTS_EXTENSIONS)


def vp_from_filename(path: str) -> str:
    """
    Extract the VP name from a warts filename.

    IPv4 format: {vp}.team-probing.{cycle}.{date}.warts.gz
      -> first dot-separated component

    IPv6 format: topo-v6.l8.{date}.{timestamp}.{vp}.warts.gz
      -> last dot-separated component before .warts.gz
    """
    # removesuffix is a no-op when the suffix isn't present; chaining
    # handles both .warts and .warts.gz without an explicit branch.
    base = (
        os.path.basename(path)
        .removesuffix(".gz")
        .removesuffix(".warts")
    )
    parts = base.split(".")
    if parts[0] == IPV6_FILENAME_PREFIX:
        return parts[-1]
    return parts[0]


def date_from_ipv6_filename(path: str) -> str | None:
    """
    Extract the date string from an IPv6 warts filename.

    Format: topo-v6.l8.{YYYYMMDD}.{timestamp}.{vp}.warts.gz
    Returns the YYYYMMDD string, or None if parsing fails.
    """
    base = (
        os.path.basename(path)
        .removesuffix(".gz")
        .removesuffix(".warts")
    )
    parts = base.split(".")
    if len(parts) >= 3 and parts[0] == IPV6_FILENAME_PREFIX:
        return parts[2]
    return None


# ---------------------------------------------------------------------------
# Directory discovery
# ---------------------------------------------------------------------------


def discover_ipv4_files(
    base_dir: str, target_date: date
) -> list[str]:
    """
    Find warts files for a specific date under the IPv4
    team-probing directory structure.

    Structure: {base_dir}/YYYY/cycle-YYYYMMDD/*.warts.gz
    """
    date_str = target_date.strftime("%Y%m%d")
    year_str = target_date.strftime("%Y")
    cycle_dir = os.path.join(
        base_dir, year_str, f"cycle-{date_str}"
    )
    if not os.path.isdir(cycle_dir):
        print(
            f"WARNING: IPv4 cycle directory not found: "
            f"{cycle_dir}",
            file=sys.stderr,
        )
        return []
    files = sorted(
        f for f in glob.glob(os.path.join(cycle_dir, "*"))
        if is_warts_file(f)
    )
    if not files:
        print(
            f"WARNING: no warts files in {cycle_dir}",
            file=sys.stderr,
        )
    return files


def discover_ipv6_files(
    base_dir: str, target_date: date
) -> list[str]:
    """
    Find warts files for a specific date under the IPv6
    prefix-probing directory structure.

    Structure: {base_dir}/YYYY/MM/<files with date in name>
    Files: topo-v6.l8.YYYYMMDD.{timestamp}.{vp}.warts.gz
    """
    date_str = target_date.strftime("%Y%m%d")
    year_str = target_date.strftime("%Y")
    month_str = target_date.strftime("%m")
    month_dir = os.path.join(base_dir, year_str, month_str)
    if not os.path.isdir(month_dir):
        print(
            f"WARNING: IPv6 month directory not found: "
            f"{month_dir}",
            file=sys.stderr,
        )
        return []
    files = sorted(
        f for f in glob.glob(os.path.join(month_dir, "*"))
        if is_warts_file(f)
        and date_from_ipv6_filename(f) == date_str
    )
    if not files:
        print(
            f"WARNING: no IPv6 warts files for {date_str} in "
            f"{month_dir}",
            file=sys.stderr,
        )
    return files


# ---------------------------------------------------------------------------
# Detection logic
# ---------------------------------------------------------------------------


def _classify_ttl(
    trace: ScamperTrace, ttl_tolerance: int
) -> TtlCheck:
    """
    Classify a single trace for the TTL rewrite check.

    Walks the trace's hops looking for a terminal ICMP error (e.g.
    port/host/net unreachable) that carries a quoted packet.  If the
    quoted TTL exceeds the probe TTL by more than ttl_tolerance, a
    device on path must have increased the TTL: the trace is
    classified REWRITTEN.

    is_icmp_ttl_exp() hops are skipped even though they also carry
    a quoted packet: their icmp_q_ttl is near zero by design (the
    probe had just expired), so including them would mask real
    rewrites elsewhere in the trace.
    """
    found_terminal = False
    for hop in trace.hops():
        if hop is None:
            continue
        if not hop.is_icmp_q():
            continue
        if hop.is_icmp_ttl_exp():
            # Mid-path TTL-exceeded responses quote the probe too,
            # but icmp_q_ttl is ~0 by design there.  We need the
            # terminal error where icmp_q_ttl reflects the arrival
            # TTL.
            continue

        found_terminal = True
        probe_ttl = hop.probe_ttl
        icmp_q_ttl = hop.icmp_q_ttl
        if probe_ttl is None or icmp_q_ttl is None:
            continue
        if icmp_q_ttl > probe_ttl + ttl_tolerance:
            return TtlCheck.REWRITTEN

    return (
        TtlCheck.NO_REWRITE if found_terminal
        else TtlCheck.NO_TERMINAL_HOP
    )


def check_trace(
    trace: ScamperTrace, ttl_tolerance: int
) -> TraceResult:
    """
    Extract per-trace signals for all checks in a single pass.

    Returns a TraceResult with two independent fields:
      - ttl: classification for the TTL rewrite check
      - responded: whether the destination produced a terminal reply
    """
    ttl = _classify_ttl(trace, ttl_tolerance)
    responded = trace.stop_reason in RESPONDING_STOP_REASONS
    return TraceResult(ttl=ttl, responded=responded)


# ---------------------------------------------------------------------------
# File processing
# ---------------------------------------------------------------------------


def process_file(
    path: str, stats: VpStats, ttl_tolerance: int
) -> None:
    """
    Read one warts file and update stats in place.

    Only RuntimeError is caught: that is what the scamper Python
    bindings raise for file-open failures and for unknown object
    types encountered during iteration.  The per-trace accessors
    used below (hops(), is_icmp_q, is_icmp_ttl_exp, probe_ttl,
    icmp_q_ttl, stop_reason) null-check internally and do not raise
    under normal operation, so no inner try/except is needed.

    Corrupt or truncated warts files do not raise -- scamper's
    read() returns None for unrecoverable parse errors, which the
    iterator converts to StopIteration, ending the file cleanly.
    """
    try:
        with ScamperFile(path, filter_types=[ScamperTrace]) as f:
            for trace in f:
                stats.traces += 1
                result = check_trace(trace, ttl_tolerance)
                if result.ttl is TtlCheck.REWRITTEN:
                    stats.ttl_rewrites += 1
                elif result.ttl is TtlCheck.NO_TERMINAL_HOP:
                    stats.ttl_no_terminal += 1
                # NO_REWRITE: eligible trace with no rewrite seen.
                if result.responded:
                    stats.responding += 1
    except RuntimeError as exc:
        print(
            f"WARNING: could not process {path}: {exc}",
            file=sys.stderr,
        )
        stats.errors += 1


def analyze_files(
    files: list[str],
    ttl_tolerance: int,
) -> list[VpStats]:
    """
    Process a list of warts files and return per-VP stats.

    Only files with warts extensions are processed; others are
    silently skipped.
    """
    vp_files: dict[str, list[str]] = defaultdict(list)
    for path in files:
        if not is_warts_file(path):
            continue
        vp_files[vp_from_filename(path)].append(path)

    vp_stats_list: list[VpStats] = []
    for vp, paths in sorted(vp_files.items()):
        stats = VpStats(vp=vp)
        for path in sorted(paths):
            process_file(path, stats, ttl_tolerance)
        vp_stats_list.append(stats)
    return vp_stats_list


# ---------------------------------------------------------------------------
# Reporting
# ---------------------------------------------------------------------------


def vp_flags(
    s: VpStats,
    ttl_rewrite_threshold: float,
    response_threshold: float,
) -> list[str]:
    """
    Return the list of check labels triggered for this VP.

    A check only contributes a flag when it has enough data to make a
    determination: TTL needs at least one eligible trace, RESP needs
    at least one trace.
    """
    flags: list[str] = []
    if (
        s.ttl_eligible > 0
        and s.ttl_rewrite_rate > ttl_rewrite_threshold
    ):
        flags.append(FLAG_TTL)
    if (
        s.traces > 0
        and s.response_rate < response_threshold
    ):
        flags.append(FLAG_RESP)
    return flags


def format_report(
    vp_stats: list[VpStats],
    ttl_rewrite_threshold: float,
    response_threshold: float,
    verbose: bool,
    sort_by: str,
) -> tuple[str, list[VpStats]]:
    """
    Format a text report for a list of VP stats.

    If verbose is False, only flagged VPs are included in the table.
    The summary line always reports the total flagged count.

    Returns (report_text, flagged_vps).
    """
    sort_keys = {
        "vp":            lambda s: s.vp,
        "traces":        lambda s: -s.traces,
        "ttl-rewrites":  lambda s: -s.ttl_rewrites,
        "ttl":           lambda s: -s.ttl_rewrite_rate,
        "response":      lambda s: s.response_rate,
    }
    vp_stats = sorted(
        vp_stats, key=sort_keys.get(sort_by, sort_keys["vp"])
    )

    header = (
        f"{'VP':<20} {'TRACES':>9} {'TTL_REWRITES':>12}"
        f" {'TTL%':>7} {'RESP%':>7}  FLAGS"
    )
    sep = "-" * len(header)

    flag_pairs: list[tuple[VpStats, list[str]]] = [
        (s, vp_flags(s, ttl_rewrite_threshold, response_threshold))
        for s in vp_stats
    ]
    flagged = [s for s, flags in flag_pairs if flags]

    lines: list[str] = []
    lines.append(header)
    lines.append(sep)
    for s, flags in flag_pairs:
        if s.ttl_eligible == 0:
            ttl_str = "    N/A"
        else:
            ttl_str = f"{s.ttl_rewrite_rate * 100:6.2f}%"

        if s.traces == 0:
            resp_str = "    N/A"
        else:
            resp_str = f"{s.response_rate * 100:6.2f}%"

        if flags:
            flag_str = "*** " + "+".join(flags)
        else:
            flag_str = ""

        if not verbose and not flags:
            continue

        err_note = f" ({s.errors} err)" if s.errors else ""
        lines.append(
            f"{s.vp:<20} {s.traces:>9,} {s.ttl_rewrites:>12,}"
            f" {ttl_str} {resp_str}"
            f"  {flag_str}{err_note}"
        )

    lines.append(sep)
    lines.append(
        f"\n{len(flagged)} VP(s) flagged "
        f"(TTL rewrite > {ttl_rewrite_threshold * 100:.1f}%"
        f" or response rate < "
        f"{response_threshold * 100:.1f}%)."
    )

    return "\n".join(lines), flagged


def format_full_report(
    results: list[DatasetResult],
    ttl_rewrite_threshold: float,
    response_threshold: float,
    verbose: bool,
    sort_by: str,
) -> tuple[str, list[VpStats]]:
    """
    Format the complete report across all datasets.

    Returns (report_text, all_flagged_vps).
    """
    sections: list[str] = []
    all_flagged: list[VpStats] = []

    for result in results:
        if not result.vp_stats:
            continue

        section_lines: list[str] = []
        section_lines.append(f"=== {result.label} ===")
        section_lines.append(f"Source: {result.directory}")
        if result.target_date:
            section_lines.append(f"Date:   {result.target_date}")
        section_lines.append(f"Files:  {result.file_count}")
        section_lines.append("")

        report, flagged = format_report(
            result.vp_stats,
            ttl_rewrite_threshold=ttl_rewrite_threshold,
            response_threshold=response_threshold,
            verbose=verbose,
            sort_by=sort_by,
        )
        section_lines.append(report)
        sections.append("\n".join(section_lines))
        all_flagged.extend(flagged)

    return "\n\n".join(sections), all_flagged


# ---------------------------------------------------------------------------
# Email
# ---------------------------------------------------------------------------


def send_email(
    recipients: str, subject: str, body: str
) -> bool:
    """
    Send an email via the system mail command.

    Returns True on success, False on failure.
    """
    addrs = [r.strip() for r in recipients.split(",") if r.strip()]
    if not addrs:
        print(
            "ERROR: --mailto has no valid recipients.",
            file=sys.stderr,
        )
        return False
    try:
        proc = subprocess.run(
            ["mail", "-s", subject,
             "-a", "From: ark-status@caida.org", *addrs],
            input=body,
            text=True,
            capture_output=True,
            timeout=30,
        )
        if proc.returncode != 0:
            print(
                f"WARNING: mail command failed "
                f"(rc={proc.returncode}): {proc.stderr}",
                file=sys.stderr,
            )
            return False
        return True
    except FileNotFoundError:
        print(
            "ERROR: 'mail' command not found. Install mailutils "
            "or equivalent.",
            file=sys.stderr,
        )
        return False
    except Exception as exc:
        print(
            f"ERROR: could not send email: {exc}",
            file=sys.stderr,
        )
        return False


# ---------------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------------


def parse_args() -> argparse.Namespace:
    p = argparse.ArgumentParser(
        description=(
            "Sanity-check CAIDA Ark warts files for TTL rewrites "
            "and low response rates."
        ),
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    p.add_argument("files", nargs="*", metavar="FILE")
    p.add_argument(
        "--ttl-rewrite-threshold",
        type=float,
        default=0.5,
        help=(
            "TTL rewrite rate above which a VP is flagged TTL"
        ),
    )
    p.add_argument(
        "--response-threshold",
        type=float,
        default=0.10,
        help=(
            "Response rate below which a VP is flagged RESP"
        ),
    )
    p.add_argument(
        "--ttl-tolerance",
        type=int,
        default=2,
        help=(
            "Allowable excess of icmp_q_ttl over probe_ttl"
            " before counting as a TTL rewrite"
        ),
    )
    p.add_argument(
        "--sort-by",
        default="vp",
        choices=[
            "vp", "traces", "ttl-rewrites", "ttl", "response",
        ],
        help=(
            "Column to sort by.  'ttl' = TTL rewrite rate"
            " (descending); 'response' = response rate"
            " (ascending, worst first)"
        ),
    )
    p.add_argument(
        "--verbose",
        action="store_true",
        help="Show all VPs, not just flagged ones",
    )
    p.add_argument(
        "--ipv4-dir",
        metavar="PATH",
        help=(
            "Base directory for IPv4 team-probing data "
            "(e.g. .../team-probing/list-7.allpref24/"
            "team-1/daily)"
        ),
    )
    p.add_argument(
        "--ipv6-dir",
        metavar="PATH",
        help=(
            "Base directory for IPv6 prefix-probing data "
            "(e.g. .../topo-v6/list-8.ipv6.allpref)"
        ),
    )
    p.add_argument(
        "--date",
        metavar="YYYY-MM-DD",
        type=date.fromisoformat,
        help=(
            "Target date for directory discovery "
            "(default: yesterday)"
        ),
    )
    p.add_argument(
        "--mailto",
        metavar="ADDRESS",
        help=(
            "Send email alert if any VPs are flagged. "
            "Comma-separated for multiple recipients."
        ),
    )
    return p.parse_args()


def main() -> int:
    args = parse_args()

    for name, val in (
        ("--ttl-rewrite-threshold", args.ttl_rewrite_threshold),
        ("--response-threshold", args.response_threshold),
    ):
        if not 0.0 <= val <= 1.0:
            print(
                f"ERROR: {name} must be in [0.0, 1.0], "
                f"got {val}.",
                file=sys.stderr,
            )
            return 1

    has_dirs = args.ipv4_dir or args.ipv6_dir
    has_files = bool(args.files)

    if not has_dirs and not has_files:
        print(
            "ERROR: no input files or directories specified.\n"
            "Provide files as arguments, or use --ipv4-dir / "
            "--ipv6-dir for auto-discovery.",
            file=sys.stderr,
        )
        return 1

    target = args.date or (date.today() - timedelta(days=1))
    # ISO 8601 (YYYY-MM-DD) for user-facing display.  Filename
    # matching in discover_* uses YYYYMMDD internally from the same
    # `target` date object.
    target_display = target.isoformat()
    results: list[DatasetResult] = []

    # --- Directory discovery mode ---
    if args.ipv4_dir:
        files = discover_ipv4_files(args.ipv4_dir, target)
        vp_stats = analyze_files(files, args.ttl_tolerance)
        results.append(DatasetResult(
            label="IPv4 team-probing",
            directory=args.ipv4_dir,
            target_date=target_display,
            file_count=len(files),
            vp_stats=vp_stats,
        ))

    if args.ipv6_dir:
        files = discover_ipv6_files(args.ipv6_dir, target)
        vp_stats = analyze_files(files, args.ttl_tolerance)
        results.append(DatasetResult(
            label="IPv6 prefix-probing",
            directory=args.ipv6_dir,
            target_date=target_display,
            file_count=len(files),
            vp_stats=vp_stats,
        ))

    # --- Explicit files mode ---
    if has_files:
        warts_files = [
            f for f in args.files if is_warts_file(f)
        ]
        skipped = len(args.files) - len(warts_files)
        if skipped:
            print(
                f"NOTE: skipped {skipped} non-warts file(s).",
                file=sys.stderr,
            )
        vp_stats = analyze_files(
            warts_files, args.ttl_tolerance
        )
        results.append(DatasetResult(
            label="Explicit files",
            directory="(command line)",
            target_date="",
            file_count=len(warts_files),
            vp_stats=vp_stats,
        ))

    report, all_flagged = format_full_report(
        results,
        ttl_rewrite_threshold=args.ttl_rewrite_threshold,
        response_threshold=args.response_threshold,
        verbose=args.verbose,
        sort_by=args.sort_by,
    )

    # --- Print report to stdout (suppressed in --mailto mode) ---
    if report and not args.mailto:
        print(report)

    # --- Email if flagged VPs exist ---
    if args.mailto and all_flagged:
        # Dedupe by VP name: one VP flagged in both IPv4 and IPv6
        # is still one VP.  Per-dataset detail is in the body.
        n = len({s.vp for s in all_flagged})
        subject = (
            f"Daily Ark data scan: {n} VP(s) flagged -- {target_display}"
        )
        if not send_email(args.mailto, subject, report):
            return 1

    return 0


if __name__ == "__main__":
    sys.exit(main())
