#!/usr/bin/env python3
"""Pull an Ollama model through the configured Spark tunnel with quiet progress."""

from __future__ import annotations

import argparse
import json
import time
import urllib.request


def main() -> int:
    parser = argparse.ArgumentParser()
    parser.add_argument("model")
    parser.add_argument("--url", default="http://127.0.0.1:11435")
    parser.add_argument("--timeout", type=float, default=7200)
    args = parser.parse_args()

    req = urllib.request.Request(
        args.url.rstrip("/") + "/api/pull",
        data=json.dumps({"model": args.model, "stream": True}).encode(),
        headers={"Content-Type": "application/json"},
    )
    last_pct = -1
    last_status = ""
    start = time.time()
    with urllib.request.urlopen(req, timeout=args.timeout) as response:
        for raw in response:
            if not raw.strip():
                continue
            msg = json.loads(raw)
            status = msg.get("status", "")
            total = msg.get("total") or 0
            completed = msg.get("completed") or 0
            if total:
                pct = int(completed / total * 100)
                if pct != last_pct:
                    print(
                        f"{status}: {completed / 1024**3:.2f}/{total / 1024**3:.2f} GiB ({pct}%)",
                        flush=True,
                    )
                    last_pct = pct
            elif status != last_status:
                print(status, flush=True)
                last_status = status

    print(f"done in {(time.time() - start) / 60:.1f} min", flush=True)
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
