#!/usr/bin/env python3
"""Benchmark local coding/chat models behind ds4-server, Ollama, and llama.cpp.

The goal is not a synthetic tokens/sec beauty contest.  It records latency,
reported token rates when an API exposes them, and whether the model actually
followed the instruction for a few agent-shaped tasks.
"""

from __future__ import annotations

import argparse
import datetime as dt
import json
import os
import pathlib
import re
import shutil
import statistics
import subprocess
import sys
import time
import socket
import urllib.error
import urllib.parse
import urllib.request
from typing import Any


DEFAULT_DS4_URL = os.environ.get("DS4_URL", "http://127.0.0.1:8000/v1/chat/completions")
DEFAULT_DS4_MODEL = os.environ.get("DS4_MODEL", "deepseek-v4-flash")
DEFAULT_SPARK_URL = os.environ.get("SPARK_OLLAMA_URL", "http://127.0.0.1:11435")
DEFAULT_SPARK_MODEL = os.environ.get("SPARK_MODEL", "qwen3-coder:30b")
DEFAULT_LLAMA_URL = os.environ.get("LLAMA_URL", "http://127.0.0.1:18080/v1/chat/completions")
DEFAULT_LLAMA_MODEL = os.environ.get("LLAMA_MODEL", "llama-cpp")
SCRIPT_DIR = pathlib.Path(__file__).resolve().parent
PRIVATE_TEXT_PATTERNS = [
    re.compile(r"\b10\.(?:[0-9]{1,3}\.){2}[0-9]{1,3}\b"),
    re.compile(r"\b192\.168\.[0-9]{1,3}\.[0-9]{1,3}\b"),
    re.compile(r"\b172\.(?:1[6-9]|2[0-9]|3[0-1])\.[0-9]{1,3}\.[0-9]{1,3}\b"),
    re.compile(r"\b(?:[0-9A-Fa-f]{2}[:-]){5}[0-9A-Fa-f]{2}\b"),
    re.compile(r"/Users/[A-Za-z0-9._-]+"),
    re.compile(r"\b[A-Za-z0-9._-]+\.local\b"),
    re.compile(r"SHA256:[A-Za-z0-9+/=]+"),
]

SYSTEM_PROMPT = (
    "You are being benchmarked for local coding-agent use. Follow the user's "
    "instruction exactly. Be concise. Do not add analysis unless asked."
)

LONG_CONTEXT_BLOCK = "\n".join(
    [
        "Redis streams store append-only event entries addressed by stream IDs.",
        "Consumer groups let many workers divide pending entries while preserving acknowledgements.",
        "A worker should claim old pending entries when peers crash, then acknowledge completed work.",
        "Backpressure can be managed with bounded reads, explicit retries, and dead-letter streams.",
        "Operational dashboards should track lag, pending entry age, retry count, and processing rate.",
    ]
    * 120
)

TASKS: list[dict[str, Any]] = [
    {
        "name": "exact-ok",
        "suite": "smoke",
        "kind": "obedience",
        "prompt": "Reply with exactly OK.",
        "max_tokens": 32,
        "pass": "exact_ok",
    },
    {
        "name": "short-technical",
        "suite": "smoke",
        "kind": "chat",
        "prompt": "Write a short technical explanation of Redis streams and consumer groups.",
        "max_tokens": 256,
        "pass": "min_words",
        "min_words": 60,
    },
    {
        "name": "json-plan",
        "suite": "smoke",
        "kind": "agent",
        "prompt": (
            "Return only compact JSON with keys summary, files, and commands. "
            "Plan a small docs site for setting up a Coldcard Mk5."
        ),
        "max_tokens": 256,
        "pass": "json_keys",
        "keys": ["summary", "files", "commands"],
    },
    {
        "name": "code-debug",
        "suite": "smoke",
        "kind": "coding",
        "prompt": (
            "Find the bug and provide a corrected Python function only.\n\n"
            "def median(xs):\n"
            "    xs = xs.sort()\n"
            "    n = len(xs)\n"
            "    mid = n // 2\n"
            "    if n % 2:\n"
            "        return xs[mid]\n"
            "    return (xs[mid - 1] + xs[mid]) / 2\n"
        ),
        "max_tokens": 256,
        "pass": "contains",
        "contains": ["sorted", "def median"],
    },
    {
        "name": "long-prefill-summary",
        "suite": "smoke",
        "kind": "long-context",
        "prompt": (
            "Summarize the operational risks in five bullets. Keep it under 120 words.\n\n"
            + LONG_CONTEXT_BLOCK
        ),
        "max_tokens": 192,
        "pass": "min_words",
        "min_words": 40,
    },
    {
        "name": "code-surgical-edit",
        "suite": "code",
        "kind": "coding",
        "prompt": (
            "You are reviewing this HTML row:\n\n"
            "<tr><td>YYY XX<br/>cocktail venue, food and beverages, entertainment, furniture, and insurance</td>"
            "<td>1</td><td>$45,000.00</td><td>$45,000.00</td></tr>\n\n"
            "Return only the corrected row. Change only the first description line back to "
            "\"YYY XX Side Event - BX Forum 2026\". Keep the second description line unchanged."
        ),
        "max_tokens": 256,
        "pass": "contains",
        "contains": [
            "YYY XX Side Event - BX Forum 2026",
            "cocktail venue, food and beverages, entertainment, furniture, and insurance",
            "<tr>",
        ],
        "not_contains": ["YYY XX<br/>"],
    },
    {
        "name": "code-review-seeded-bug",
        "suite": "code",
        "kind": "coding",
        "prompt": (
            "Review this JavaScript function. Return compact JSON with keys bug, risk, and fix.\n\n"
            "function isAdmin(user) {\n"
            "  if (user.role = 'admin') return true;\n"
            "  return false;\n"
            "}\n"
        ),
        "max_tokens": 256,
        "pass": "json_keys",
        "keys": ["bug", "risk", "fix"],
        "contains": ["=", "==="],
    },
    {
        "name": "code-repo-location",
        "suite": "code",
        "kind": "coding",
        "prompt": (
            "A repo has these files:\n"
            "app/router.py defines route_request(path)\n"
            "app/auth.py defines require_admin(user)\n"
            "tests/test_router.py checks redirects\n\n"
            "Question: which file should be edited to change admin authorization behavior? "
            "Reply with only the file path."
        ),
        "max_tokens": 64,
        "pass": "contains",
        "contains": ["app/auth.py"],
    },
    {
        "name": "question-local-benchmark",
        "suite": "question",
        "kind": "qa",
        "prompt": (
            "Using these measured local benchmark facts only:\n"
            "- Spark Ollama qwen3-coder:30b json-plan median: 13162 ms\n"
            "- Spark llama.cpp qwen3-coder-30b.gguf json-plan median: 2632 ms\n"
            "- Spark llama.cpp long-prefill-summary: 5596 ms\n\n"
            "Which Spark runtime was faster for json-plan, and what was its median time? "
            "Reply in one sentence."
        ),
        "max_tokens": 96,
        "pass": "contains",
        "contains": ["llama.cpp", "2632"],
    },
    {
        "name": "question-abstain-missing-context",
        "suite": "question",
        "kind": "qa",
        "prompt": (
            "Answer using only this context:\n"
            "The benchmark compared Spark Ollama and Spark llama.cpp on qwen3-coder:30b.\n\n"
            "Question: What is the Spark device serial number?"
        ),
        "max_tokens": 96,
        "pass": "contains_any",
        "contains_any": ["not provided", "not enough", "cannot determine", "not in the context"],
    },
    {
        "name": "question-cited-answer",
        "suite": "question",
        "kind": "qa",
        "prompt": (
            "Answer from the source snippets and cite the source id in brackets.\n\n"
            "[A] Spark llama.cpp qwen3-coder-30b.gguf short-technical median wall time was 1839 ms.\n"
            "[B] Spark Ollama qwen3-coder:30b short-technical median wall time was 8894 ms.\n\n"
            "Which runtime was faster for the short-technical task?"
        ),
        "max_tokens": 128,
        "pass": "contains",
        "contains": ["llama.cpp", "[A]"],
    },
    {
        "name": "wiki-query-citation",
        "suite": "wiki",
        "kind": "wiki",
        "prompt": (
            "You are answering from a tiny llm-wiki fixture. Use only these articles:\n\n"
            "Article: local-ai-benchmark-registry.md\n"
            "Claim: Spark llama.cpp with qwen3-coder-30b.gguf is the Spark performance default.\n\n"
            "Article: profiling-local-agent-setups.md\n"
            "Claim: Profile promotion should prioritize correctness, tool-call validity, stall recovery, then wall time.\n\n"
            "Question: What should be the Spark performance default, and what should be optimized before wall time? "
            "Cite both article filenames."
        ),
        "max_tokens": 192,
        "pass": "contains",
        "contains": [
            "qwen3-coder-30b.gguf",
            "local-ai-benchmark-registry.md",
            "profiling-local-agent-setups.md",
        ],
    },
    {
        "name": "wiki-ingest-frontmatter",
        "suite": "wiki",
        "kind": "wiki",
        "prompt": (
            "Return only YAML frontmatter for a raw llm-wiki note about today's Spark llama.cpp benchmark. "
            "It must include title, source_type, created, updated, tags, summary, and sources. "
            "Use created and updated date 2026-05-10."
        ),
        "max_tokens": 256,
        "pass": "contains",
        "contains": ["title:", "source_type:", "created:", "updated:", "tags:", "summary:", "sources:"],
    },
    {
        "name": "wiki-audit-unsupported-claim",
        "suite": "wiki",
        "kind": "wiki",
        "prompt": (
            "Audit this wiki claim against the source. Return compact JSON with keys supported and issue.\n\n"
            "Source: Spark llama.cpp completed long-prefill-summary in 5596 ms.\n"
            "Claim: Spark Ollama completed long-prefill-summary in 5596 ms."
        ),
        "max_tokens": 160,
        "pass": "json_keys",
        "keys": ["supported", "issue"],
        "contains_any": ["false", "unsupported", "Ollama", "llama.cpp"],
    },
]


def now_slug() -> str:
    return dt.datetime.now().strftime("%Y%m%d-%H%M%S")


def post_json(url: str, payload: dict[str, Any], timeout: float) -> tuple[int, dict[str, Any]]:
    req = urllib.request.Request(
        url,
        data=json.dumps(payload).encode("utf-8"),
        headers={"Content-Type": "application/json", "Authorization": "Bearer local"},
        method="POST",
    )
    try:
        with urllib.request.urlopen(req, timeout=timeout) as response:
            body = response.read().decode("utf-8", "replace")
            return response.status, json.loads(body)
    except urllib.error.HTTPError as exc:
        body = exc.read().decode("utf-8", "replace")
        try:
            data: dict[str, Any] = json.loads(body)
        except json.JSONDecodeError:
            data = {"error": body}
        return exc.code, data
    except (TimeoutError, socket.timeout, urllib.error.URLError) as exc:
        return 599, {"error": {"message": str(exc), "type": type(exc).__name__}}


def get_json(url: str, timeout: float) -> tuple[int, dict[str, Any]]:
    try:
        with urllib.request.urlopen(url, timeout=timeout) as response:
            body = response.read().decode("utf-8", "replace")
            return response.status, json.loads(body)
    except urllib.error.HTTPError as exc:
        body = exc.read().decode("utf-8", "replace")
        try:
            data: dict[str, Any] = json.loads(body)
        except json.JSONDecodeError:
            data = {"error": body}
        return exc.code, data
    except (TimeoutError, socket.timeout, urllib.error.URLError) as exc:
        return 599, {"error": {"message": str(exc), "type": type(exc).__name__}}


def post_json_curl(url: str, payload: dict[str, Any], timeout: float) -> tuple[int, dict[str, Any]]:
    command = [
        "curl",
        "-sS",
        "--max-time",
        str(timeout),
        "-H",
        "Content-Type: application/json",
        "-H",
        "Authorization: Bearer local",
        "-X",
        "POST",
        "--data-binary",
        "@-",
        "-w",
        "\n%{http_code}",
        url,
    ]
    try:
        proc = subprocess.run(
            command,
            input=json.dumps(payload),
            text=True,
            capture_output=True,
            timeout=timeout + 5,
            check=False,
        )
    except (TimeoutError, subprocess.TimeoutExpired) as exc:
        return 599, {"error": {"message": str(exc), "type": type(exc).__name__}}
    if proc.returncode != 0:
        message = redact_private_text(proc.stderr.strip() or proc.stdout.strip())
        return 599, {"error": {"message": message, "type": "curl"}}
    body, _, status_text = proc.stdout.rpartition("\n")
    try:
        status = int(status_text)
    except ValueError:
        return 599, {"error": {"message": redact_private_text(proc.stdout[-500:]), "type": "curl"}}
    try:
        data: dict[str, Any] = json.loads(body)
    except json.JSONDecodeError:
        data = {"error": body}
    return status, data


def get_json_curl(url: str, timeout: float) -> tuple[int, dict[str, Any]]:
    command = ["curl", "-sS", "--max-time", str(timeout), "-w", "\n%{http_code}", url]
    try:
        proc = subprocess.run(command, text=True, capture_output=True, timeout=timeout + 5, check=False)
    except (TimeoutError, subprocess.TimeoutExpired) as exc:
        return 599, {"error": {"message": str(exc), "type": type(exc).__name__}}
    if proc.returncode != 0:
        message = redact_private_text(proc.stderr.strip() or proc.stdout.strip())
        return 599, {"error": {"message": message, "type": "curl"}}
    body, _, status_text = proc.stdout.rpartition("\n")
    try:
        status = int(status_text)
    except ValueError:
        return 599, {"error": {"message": redact_private_text(proc.stdout[-500:]), "type": "curl"}}
    try:
        data: dict[str, Any] = json.loads(body)
    except json.JSONDecodeError:
        data = {"error": body}
    return status, data


def redact_private_text(text: str) -> str:
    redacted = text
    for pattern in PRIVATE_TEXT_PATTERNS:
        redacted = pattern.sub("<redacted>", redacted)
    return redacted


def public_endpoint(url: str, show_endpoints: bool) -> str:
    if show_endpoints:
        return url
    parsed = urllib.parse.urlsplit(url)
    host = parsed.hostname or ""
    if host in {"127.0.0.1", "localhost", "::1"}:
        return url
    netloc = "<redacted>"
    if parsed.port:
        netloc += f":{parsed.port}"
    return urllib.parse.urlunsplit((parsed.scheme, netloc, parsed.path, "", ""))


def public_argv(argv: list[str], args: argparse.Namespace) -> list[str]:
    if args.show_endpoints:
        return argv
    replacements = {
        args.ds4_url: public_endpoint(args.ds4_url, False),
        args.spark_url: public_endpoint(args.spark_url, False),
        args.llama_url: public_endpoint(args.llama_url, False),
    }
    cleaned: list[str] = []
    for item in argv:
        value = item
        for raw, public in replacements.items():
            value = value.replace(raw, public)
        cleaned.append(redact_private_text(value))
    return cleaned


def extract_openai_text(data: dict[str, Any]) -> str:
    choices = data.get("choices") or []
    if not choices:
        return ""
    message = choices[0].get("message") or {}
    content = message.get("content")
    if isinstance(content, str):
        return content
    if isinstance(content, list):
        parts = []
        for item in content:
            if isinstance(item, dict) and isinstance(item.get("text"), str):
                parts.append(item["text"])
        return "".join(parts)
    return ""


def task_messages(task: dict[str, Any]) -> list[dict[str, Any]]:
    messages: list[dict[str, Any]] = []
    system = task.get("system", SYSTEM_PROMPT)
    if system:
        messages.append({"role": "system", "content": str(system)})
    if isinstance(task.get("messages"), list):
        for message in task["messages"]:
            if isinstance(message, dict) and "role" in message and "content" in message:
                messages.append({"role": str(message["role"]), "content": str(message["content"])})
    else:
        messages.append({"role": "user", "content": str(task["prompt"])})
    return messages


def task_max_tokens(task: dict[str, Any]) -> int:
    return int(task.get("max_tokens") or task.get("num_predict") or 256)


def needles(value: Any) -> list[str]:
    if value is None:
        return []
    if isinstance(value, str):
        return [value]
    if isinstance(value, list):
        return [str(item) for item in value]
    return [str(value)]


def pass_check(task: dict[str, Any], text: str) -> tuple[bool, str]:
    stripped = text.strip()
    rule = task.get("pass", "none")
    if rule == "exact_ok":
        return stripped == "OK", "text == OK"
    if rule == "min_words":
        words = stripped.split()
        return len(words) >= int(task["min_words"]), f"words >= {task['min_words']}"
    if rule == "json_keys":
        try:
            data = json.loads(stripped)
        except json.JSONDecodeError:
            return False, "valid JSON"
        missing = [key for key in task["keys"] if key not in data]
        if missing:
            return False, "keys " + ",".join(task["keys"])
    if rule == "contains":
        lowered = stripped.lower()
        missing = [needle for needle in needles(task.get("contains")) if needle.lower() not in lowered]
        if missing:
            return False, "contains " + ",".join(needles(task.get("contains")))
    if rule == "contains_any":
        lowered = stripped.lower()
        options = needles(task.get("contains_any"))
        if not any(option.lower() in lowered for option in options):
            return False, "contains any " + ",".join(options)
    if rule == "regex":
        pattern = str(task.get("regex", ""))
        if not re.search(pattern, stripped, flags=re.MULTILINE | re.DOTALL):
            return False, "regex " + pattern

    lowered = stripped.lower()
    forbidden = [needle for needle in needles(task.get("not_contains")) if needle.lower() in lowered]
    if forbidden:
        return False, "not contains " + ",".join(needles(task.get("not_contains")))

    contains_any = needles(task.get("contains_any")) if rule != "contains_any" else []
    if contains_any and not any(needle.lower() in lowered for needle in contains_any):
        return False, "contains any " + ",".join(contains_any)

    return True, "no check" if rule == "none" else str(rule)


def call_ds4(args: argparse.Namespace, task: dict[str, Any]) -> dict[str, Any]:
    payload: dict[str, Any] = {
        "model": args.ds4_model,
        "messages": task_messages(task),
        "max_tokens": task_max_tokens(task),
        "temperature": args.temperature,
        "stream": False,
    }
    if args.ds4_reasoning_effort:
        payload["reasoning_effort"] = args.ds4_reasoning_effort
    start = time.perf_counter()
    status, data = post_json(args.ds4_url, payload, args.timeout)
    wall_ms = (time.perf_counter() - start) * 1000
    text = extract_openai_text(data) if status < 400 else ""
    passed, pass_rule = pass_check(task, text)
    usage = data.get("usage") if isinstance(data, dict) else None
    return {
        "stack": "ds4",
        "model": args.ds4_model,
        "endpoint": public_endpoint(args.ds4_url, args.show_endpoints),
        "task": task["name"],
        "suite": task.get("suite", "smoke"),
        "kind": task["kind"],
        "status": status,
        "wall_ms": round(wall_ms, 1),
        "pass": bool(passed and status < 400),
        "pass_rule": pass_rule,
        "output_chars": len(text),
        "usage": usage,
        "text": text,
        "text_preview": text[:500],
        "error": data.get("error") if status >= 400 and isinstance(data, dict) else None,
        "settings": {
            "temperature": args.temperature,
            "max_tokens": task_max_tokens(task),
            "reasoning_effort": args.ds4_reasoning_effort,
        },
    }


def call_llama_cpp(args: argparse.Namespace, task: dict[str, Any]) -> dict[str, Any]:
    payload: dict[str, Any] = {
        "model": args.llama_model,
        "messages": task_messages(task),
        "max_tokens": task_max_tokens(task),
        "temperature": args.temperature,
        "stream": False,
    }
    start = time.perf_counter()
    status, data = post_json_curl(args.llama_url, payload, args.timeout)
    wall_ms = (time.perf_counter() - start) * 1000
    text = extract_openai_text(data) if status < 400 else ""
    passed, pass_rule = pass_check(task, text)
    usage = data.get("usage") if isinstance(data, dict) else None

    metrics: dict[str, Any] = {}
    timings = data.get("timings") if isinstance(data, dict) else None
    if isinstance(timings, dict):
        prompt_tps = timings.get("prompt_per_second")
        predicted_tps = timings.get("predicted_per_second")
        if isinstance(prompt_tps, (int, float)):
            metrics["prefill_tps"] = round(prompt_tps, 2)
        if isinstance(predicted_tps, (int, float)):
            metrics["generation_tps"] = round(predicted_tps, 2)
    if isinstance(usage, dict) and usage.get("prompt_tokens"):
        metrics["prompt_tokens"] = usage["prompt_tokens"]

    return {
        "stack": "spark-llama.cpp",
        "model": args.llama_model,
        "endpoint": public_endpoint(args.llama_url, args.show_endpoints),
        "task": task["name"],
        "suite": task.get("suite", "smoke"),
        "kind": task["kind"],
        "status": status,
        "wall_ms": round(wall_ms, 1),
        "pass": bool(passed and status < 400),
        "pass_rule": pass_rule,
        "output_chars": len(text),
        "usage": usage,
        "metrics": metrics,
        "text": text,
        "text_preview": text[:500],
        "error": data.get("error") if status >= 400 and isinstance(data, dict) else None,
        "settings": {
            "temperature": args.temperature,
            "max_tokens": task_max_tokens(task),
        },
    }


def call_ollama(args: argparse.Namespace, task: dict[str, Any]) -> dict[str, Any]:
    payload: dict[str, Any] = {
        "model": args.spark_model,
        "messages": task_messages(task),
        "stream": False,
        "think": args.ollama_think,
        "keep_alive": args.ollama_keepalive,
        "options": {
            "temperature": args.temperature,
            "num_predict": task_max_tokens(task),
            "num_ctx": args.ollama_num_ctx,
        },
    }
    start = time.perf_counter()
    status, data = post_json(args.spark_url.rstrip("/") + "/api/chat", payload, args.timeout)
    wall_ms = (time.perf_counter() - start) * 1000
    message = data.get("message") if isinstance(data, dict) else None
    text = message.get("content", "") if isinstance(message, dict) else ""
    passed, pass_rule = pass_check(task, text)

    eval_count = data.get("eval_count") if isinstance(data, dict) else None
    eval_duration = data.get("eval_duration") if isinstance(data, dict) else None
    prompt_eval_count = data.get("prompt_eval_count") if isinstance(data, dict) else None
    prompt_eval_duration = data.get("prompt_eval_duration") if isinstance(data, dict) else None
    metrics: dict[str, Any] = {}
    if eval_count and eval_duration:
        metrics["generation_tps"] = round(eval_count / (eval_duration / 1_000_000_000), 2)
    if prompt_eval_count and prompt_eval_duration:
        metrics["prefill_tps"] = round(prompt_eval_count / (prompt_eval_duration / 1_000_000_000), 2)

    return {
        "stack": "spark-ollama",
        "model": args.spark_model,
        "endpoint": public_endpoint(args.spark_url, args.show_endpoints),
        "task": task["name"],
        "suite": task.get("suite", "smoke"),
        "kind": task["kind"],
        "status": status,
        "wall_ms": round(wall_ms, 1),
        "pass": bool(passed and status < 400),
        "pass_rule": pass_rule,
        "output_chars": len(text),
        "usage": {
            key: data.get(key)
            for key in [
                "total_duration",
                "load_duration",
                "prompt_eval_count",
                "prompt_eval_duration",
                "eval_count",
                "eval_duration",
            ]
            if isinstance(data, dict) and key in data
        },
        "metrics": metrics,
        "text": text,
        "text_preview": text[:500],
        "error": data.get("error") if status >= 400 and isinstance(data, dict) else None,
        "settings": {
            "temperature": args.temperature,
            "num_predict": task_max_tokens(task),
            "num_ctx": args.ollama_num_ctx,
            "think": args.ollama_think,
            "keep_alive": args.ollama_keepalive,
        },
    }


def percentile(values: list[float], pct: float) -> float:
    if not values:
        return 0.0
    ordered = sorted(values)
    index = (len(ordered) - 1) * pct
    lower = int(index)
    upper = min(lower + 1, len(ordered) - 1)
    weight = index - lower
    return ordered[lower] * (1 - weight) + ordered[upper] * weight


def summarize(records: list[dict[str, Any]]) -> str:
    lines = [
        "# Local AI Benchmark Summary",
        "",
        f"Generated: {dt.datetime.now().isoformat(timespec='seconds')}",
        "",
        "| Stack | Model | Suite | Task | Runs | Pass | Median wall | P95 wall | Median out | Median prefill | Median gen |",
        "| --- | --- | --- | --- | ---: | ---: | ---: | ---: | ---: | ---: | ---: |",
    ]
    grouped: dict[tuple[str, str, str, str], list[dict[str, Any]]] = {}
    for record in records:
        grouped.setdefault(
            (record["stack"], record["model"], record.get("suite", "smoke"), record["task"]),
            [],
        ).append(record)

    for (stack, model, suite, task), items in sorted(grouped.items()):
        walls = [float(item["wall_ms"]) for item in items if item.get("wall_ms") is not None]
        outputs = [float(item["output_chars"]) for item in items if item.get("output_chars") is not None]
        prefill = [
            float((item.get("metrics") or {}).get("prefill_tps"))
            for item in items
            if (item.get("metrics") or {}).get("prefill_tps") is not None
        ]
        gen = [
            float((item.get("metrics") or {}).get("generation_tps"))
            for item in items
            if (item.get("metrics") or {}).get("generation_tps") is not None
        ]
        pass_count = sum(1 for item in items if item.get("pass"))
        lines.append(
            "| {stack} | `{model}` | {suite} | {task} | {runs} | {passes}/{runs} | {median_wall:.0f} ms | "
            "{p95_wall:.0f} ms | {median_out:.0f} chars | {median_prefill} | {median_gen} |".format(
                stack=stack,
                model=model,
                suite=suite,
                task=task,
                runs=len(items),
                passes=pass_count,
                median_wall=statistics.median(walls) if walls else 0,
                p95_wall=percentile(walls, 0.95) if walls else 0,
                median_out=statistics.median(outputs) if outputs else 0,
                median_prefill=f"{statistics.median(prefill):.1f} t/s" if prefill else "-",
                median_gen=f"{statistics.median(gen):.1f} t/s" if gen else "-",
            )
        )

    errors = [item for item in records if item.get("status", 0) >= 400 or item.get("error")]
    if errors:
        lines.extend(["", "## Errors", ""])
        for item in errors:
            lines.append(f"- {item['stack']} {item['task']}: status={item.get('status')} error={item.get('error')}")
    return "\n".join(lines) + "\n"


def check_endpoints(args: argparse.Namespace) -> None:
    stacks = selected_stacks(args.stack)
    if "ds4" in stacks:
        models_url = args.ds4_url.split("/chat/completions", 1)[0] + "/models"
        status, _ = get_json(models_url, min(args.timeout, 5))
        if status >= 400:
            endpoint = public_endpoint(args.ds4_url, args.show_endpoints)
            raise SystemExit(f"ds4 model endpoint failed with HTTP {status}: {endpoint}")
    if "spark" in stacks:
        try:
            status, data = get_json(args.spark_url.rstrip("/") + "/api/tags", min(args.timeout, 5))
        except urllib.error.URLError as exc:
            endpoint = public_endpoint(args.spark_url, args.show_endpoints)
            raise SystemExit(f"Spark Ollama endpoint unavailable at {endpoint}: {redact_private_text(str(exc))}") from exc
        if status >= 400:
            raise SystemExit(f"Spark Ollama endpoint failed with HTTP {status}: {data}")
    if "llama" in stacks:
        models_url = args.llama_url.split("/chat/completions", 1)[0] + "/models"
        status, data = get_json_curl(models_url, min(args.timeout, 5))
        if status >= 400:
            endpoint = public_endpoint(args.llama_url, args.show_endpoints)
            raise SystemExit(f"llama.cpp endpoint failed with HTTP {status}: {endpoint}: {redact_private_text(str(data))}")


def selected_stacks(stack: str) -> list[str]:
    if stack == "both":
        return ["ds4", "spark"]
    if stack == "all":
        return ["ds4", "spark", "llama"]
    return [stack]


def load_tasks(task_files: list[str] | None) -> list[dict[str, Any]]:
    tasks = [dict(task) for task in TASKS]
    for task_file in task_files or []:
        path = pathlib.Path(task_file)
        data = json.loads(path.read_text(encoding="utf-8"))
        if isinstance(data, dict):
            loaded = data.get("tasks")
        else:
            loaded = data
        if not isinstance(loaded, list):
            raise SystemExit(f"Task file must contain a list or {{'tasks': [...]}}: {path}")
        for task in loaded:
            if not isinstance(task, dict):
                raise SystemExit(f"Task file contains a non-object task: {path}")
            tasks.append(task)
    return tasks


def selected_tasks(args: argparse.Namespace) -> list[dict[str, Any]]:
    tasks = load_tasks(args.tasks_file)
    suites = set(args.suite or ["smoke"])
    if "all" not in suites:
        tasks = [task for task in tasks if str(task.get("suite", "smoke")) in suites]
    if args.include_kind:
        include = set(args.include_kind)
        tasks = [task for task in tasks if task["kind"] in include]
    if args.exclude_kind:
        exclude = set(args.exclude_kind)
        tasks = [task for task in tasks if task["kind"] not in exclude]
    if not tasks:
        raise SystemExit("No tasks selected.")
    return tasks


def safe_slug(value: str) -> str:
    slug = re.sub(r"[^A-Za-z0-9._-]+", "-", value).strip("-")
    return slug or "item"


def prepare_workspace(
    args: argparse.Namespace,
    task: dict[str, Any],
    out_dir: pathlib.Path,
    record: dict[str, Any],
) -> pathlib.Path | None:
    verifier = task.get("verifier")
    fixture = task.get("fixture")
    if not verifier and not fixture and not task.get("workspace"):
        return None

    workspace_root = pathlib.Path(args.workspace_root) if args.workspace_root else out_dir / "workspaces"
    workspace = workspace_root / (
        f"run-{record['run']}-{safe_slug(record['stack'])}-{safe_slug(record['model'])}-{safe_slug(task['name'])}"
    )
    if workspace.exists():
        shutil.rmtree(workspace)
    workspace.mkdir(parents=True, exist_ok=True)

    if fixture:
        fixture_path = pathlib.Path(str(fixture))
        if not fixture_path.is_absolute():
            fixture_path = pathlib.Path(args.fixtures_root) / fixture_path
        if fixture_path.is_dir():
            shutil.copytree(fixture_path, workspace, dirs_exist_ok=True)
        elif fixture_path.is_file():
            shutil.copy2(fixture_path, workspace / fixture_path.name)
        else:
            raise SystemExit(f"Fixture not found: {fixture_path}")
    return workspace


def run_verifier(
    args: argparse.Namespace,
    task: dict[str, Any],
    record: dict[str, Any],
    out_dir: pathlib.Path,
) -> dict[str, Any] | None:
    verifier = task.get("verifier")
    if not verifier:
        return None

    workspace = prepare_workspace(args, task, out_dir, record)
    if workspace is None:
        workspace = out_dir

    output_dir = out_dir / "outputs"
    output_dir.mkdir(parents=True, exist_ok=True)
    output_path = output_dir / (
        f"run-{record['run']}-{safe_slug(record['stack'])}-{safe_slug(record['model'])}-{safe_slug(task['name'])}.txt"
    )
    output_path.write_text(str(record.get("text") or ""), encoding="utf-8")

    env = os.environ.copy()
    env.update(
        {
            "BENCH_OUTPUT": str(record.get("text") or ""),
            "BENCH_OUTPUT_FILE": str(output_path),
            "BENCH_TASK": task["name"],
            "BENCH_SUITE": str(task.get("suite", "smoke")),
            "BENCH_STACK": record["stack"],
            "BENCH_MODEL": record["model"],
            "BENCH_WORKSPACE": str(workspace),
        }
    )

    shell = isinstance(verifier, str)
    command = verifier if shell else [str(part) for part in verifier]
    try:
        proc = subprocess.run(
            command,
            cwd=workspace,
            env=env,
            shell=shell,
            text=True,
            capture_output=True,
            timeout=args.verifier_timeout,
            check=False,
        )
        return {
            "command": verifier,
            "status": proc.returncode,
            "pass": proc.returncode == 0,
            "stdout": redact_private_text(proc.stdout[-2000:]),
            "stderr": redact_private_text(proc.stderr[-2000:]),
            "workspace": redact_private_text(str(workspace)),
            "output_file": redact_private_text(str(output_path)),
        }
    except (TimeoutError, subprocess.TimeoutExpired) as exc:
        return {
            "command": verifier,
            "status": 124,
            "pass": False,
            "stdout": "",
            "stderr": str(exc),
            "workspace": redact_private_text(str(workspace)),
            "output_file": redact_private_text(str(output_path)),
        }


def write_manifest(args: argparse.Namespace, tasks: list[dict[str, Any]], out_dir: pathlib.Path) -> None:
    manifest = {
        "created_at": dt.datetime.now().isoformat(timespec="seconds"),
        "command": public_argv(sys.argv, args),
        "stack": args.stack,
        "stacks": selected_stacks(args.stack),
        "suites": args.suite or ["smoke"],
        "runs": args.runs,
        "temperature": args.temperature,
        "timeout": args.timeout,
        "tasks": [
            {
                "name": task["name"],
                "suite": task.get("suite", "smoke"),
                "kind": task["kind"],
                "max_tokens": task_max_tokens(task),
                "pass": task.get("pass", "none"),
            }
            for task in tasks
        ],
        "endpoints": {
            "ds4": public_endpoint(args.ds4_url, args.show_endpoints),
            "spark_ollama": public_endpoint(args.spark_url, args.show_endpoints),
            "llama": public_endpoint(args.llama_url, args.show_endpoints),
        },
        "models": {
            "ds4": args.ds4_model,
            "spark": args.spark_model,
            "llama": args.llama_model,
        },
    }
    (out_dir / "manifest.json").write_text(json.dumps(manifest, indent=2), encoding="utf-8")


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument("--stack", choices=["ds4", "spark", "llama", "both", "all"], default="both")
    parser.add_argument("--runs", type=int, default=3)
    parser.add_argument("--timeout", type=float, default=180)
    parser.add_argument("--temperature", type=float, default=0)
    parser.add_argument("--out", default="")
    parser.add_argument(
        "--suite",
        action="append",
        choices=["smoke", "code", "question", "wiki", "all"],
        help="Benchmark suite to run. Defaults to smoke. Can be repeated.",
    )
    parser.add_argument("--tasks-file", action="append", help="Load additional JSON task definitions.")
    parser.add_argument("--fixtures-root", default="bench-local-ai/fixtures")
    parser.add_argument("--workspace-root", default="")
    parser.add_argument("--verifier-timeout", type=float, default=30)
    parser.add_argument("--ds4-url", default=DEFAULT_DS4_URL)
    parser.add_argument("--ds4-model", default=DEFAULT_DS4_MODEL)
    parser.add_argument("--ds4-reasoning-effort", default=os.environ.get("DS4_REASONING_EFFORT", ""))
    parser.add_argument("--spark-url", default=DEFAULT_SPARK_URL)
    parser.add_argument("--spark-model", default=DEFAULT_SPARK_MODEL)
    parser.add_argument("--llama-url", default=DEFAULT_LLAMA_URL)
    parser.add_argument("--llama-model", default=DEFAULT_LLAMA_MODEL)
    parser.add_argument("--ollama-keepalive", default=os.environ.get("OLLAMA_KEEPALIVE", "1h"))
    parser.add_argument("--ollama-num-ctx", type=int, default=int(os.environ.get("OLLAMA_NUM_CTX", "32768")))
    parser.add_argument(
        "--ollama-think",
        default=os.environ.get("OLLAMA_THINK", "false"),
        choices=["true", "false", "high", "medium", "low"],
    )
    parser.add_argument(
        "--include-kind",
        action="append",
        help="Only run task kinds listed here. Can be repeated.",
    )
    parser.add_argument(
        "--exclude-kind",
        action="append",
        help="Skip task kinds listed here. Can be repeated.",
    )
    parser.add_argument("--no-check", action="store_true", help="Skip endpoint preflight checks.")
    parser.add_argument(
        "--show-endpoints",
        action="store_true",
        help="Store full endpoint URLs in runs.jsonl and manifest.json. Default redacts non-local hosts.",
    )
    return parser.parse_args()


def main() -> int:
    args = parse_args()
    if args.runs < 1:
        raise SystemExit("--runs must be >= 1")
    if args.ollama_think == "true":
        args.ollama_think = True
    elif args.ollama_think == "false":
        args.ollama_think = False
    if not args.no_check:
        check_endpoints(args)

    tasks = selected_tasks(args)

    out_dir = pathlib.Path(args.out).expanduser() if args.out else SCRIPT_DIR / "results" / now_slug()
    out_dir.mkdir(parents=True, exist_ok=True)
    write_manifest(args, tasks, out_dir)
    jsonl_path = out_dir / "runs.jsonl"
    records: list[dict[str, Any]] = []

    stacks = selected_stacks(args.stack)
    with jsonl_path.open("w", encoding="utf-8") as fp:
        for run in range(1, args.runs + 1):
            for task in tasks:
                for stack in stacks:
                    if stack == "ds4":
                        record = call_ds4(args, task)
                    elif stack == "spark":
                        record = call_ollama(args, task)
                    elif stack == "llama":
                        record = call_llama_cpp(args, task)
                    else:
                        raise AssertionError(f"unknown stack: {stack}")
                    record["run"] = run
                    record["created_at"] = dt.datetime.now().isoformat(timespec="seconds")
                    verifier_result = run_verifier(args, task, record, out_dir)
                    if verifier_result is not None:
                        record["verifier"] = verifier_result
                        record["pass"] = bool(record["pass"] and verifier_result["pass"])
                    records.append(record)
                    fp.write(json.dumps(record, ensure_ascii=False) + "\n")
                    fp.flush()
                    status = "PASS" if record["pass"] else "FAIL"
                    print(
                        f"{status} {record['stack']} {record['task']} "
                        f"{record['wall_ms']:.0f}ms {record['output_chars']} chars",
                        flush=True,
                    )

    summary = summarize(records)
    (out_dir / "summary.md").write_text(summary, encoding="utf-8")
    print(f"\nWrote {jsonl_path}")
    print(f"Wrote {out_dir / 'summary.md'}")
    print(summary)
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
