diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..9b2892d --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "nanochat/nanochat-upstream"] + path = nanochat/nanochat-upstream + url = https://github.com/karpathy/nanochat.git diff --git a/image-resize/src/manifest.rs b/image-resize/src/manifest.rs index 8866885..5e1df87 100644 --- a/image-resize/src/manifest.rs +++ b/image-resize/src/manifest.rs @@ -41,7 +41,7 @@ mod tests { let parsed: serde_json::Value = serde_json::from_str(&json).unwrap(); assert!(parsed.is_object(), "Manifest must be valid JSON object"); assert_eq!(parsed["name"], "image-resize"); - assert_eq!(parsed["version"], "0.1.0"); + assert_eq!(parsed["version"], env!("CARGO_PKG_VERSION")); } #[test] diff --git a/nanochat/README.md b/nanochat/README.md new file mode 100644 index 0000000..a0f4ba0 --- /dev/null +++ b/nanochat/README.md @@ -0,0 +1,162 @@ +# nanochat worker + +A Python worker that brings [Karpathy's nanochat](https://github.com/karpathy/nanochat) onto the III engine. 20 functions covering the full LLM pipeline: tokenizer training, base pretraining, supervised fine-tuning, RL fine-tuning (GRPO), CORE/BPB/ChatCORE evaluation, inference with tool use, checkpoint management, and conversation persistence. + +nanochat trains a GPT-2 level model in ~2 hours on 8xH100 for ~$48. This worker wraps the entire pipeline as iii functions that any connected worker (Rust, TypeScript, Python) can call. Training runs the actual nanochat scripts as subprocesses via a pre-forked launcher, so you get 100% fidelity to the original implementation. Inference, evaluation, and tokenization run in-process for speed. + +## Prerequisites + +- Python 3.10+ +- PyTorch 2.0+ +- iii-sdk 0.10.0+ +- nanochat dependencies: tiktoken, tokenizers, rustbpe, pyarrow, wandb +- A running iii engine on `ws://localhost:49134` +- For training/inference: CUDA GPU recommended. CPU and MPS work but are slow. + +## Quick start + +```bash +git clone --recurse-submodules https://github.com/iii-hq/workers.git +cd workers/nanochat + +pip install iii-sdk torch tiktoken tokenizers rustbpe pyarrow wandb pydantic +cd nanochat-upstream && pip install -e . && cd .. + +# Start without loading a model +python worker.py --no-autoload + +# Start with a trained SFT model +python worker.py --source sft --device cuda +``` + +The nanochat source is included as a git submodule at `nanochat-upstream/`. Training functions run the actual nanochat scripts (`scripts/base_train.py`, `scripts/chat_sft.py`, etc.) as subprocesses from this directory. + +## Functions + +20 functions, 20 triggers (all HTTP). Every handler uses Pydantic type hints for automatic request/response schema extraction. + +**Chat** + +- `nanochat.chat.complete` POST - Generate a chat completion. Takes OpenAI-style messages, returns content + session_id. Conversation persisted to iii state. +- `nanochat.chat.stream` POST - Same as complete but generates token-by-token internally. +- `nanochat.chat.history` GET - Read conversation history from iii state by session_id. + +**Model** + +- `nanochat.model.load` POST - Load a checkpoint into memory. Accepts source (base/sft/rl), model_tag, step, device. +- `nanochat.model.status` GET - Current model config: loaded, source, device, n_layer, n_embd, vocab_size, parameters. +- `nanochat.model.sample` POST - Generate raw text samples with configurable prompt, temperature, top_k, num_samples. + +**Tokenizer** + +- `nanochat.tokenizer.encode` POST - Text to BPE token IDs. +- `nanochat.tokenizer.decode` POST - Token IDs to text. + +**Training** (runs actual nanochat scripts via pre-forked subprocess launcher) + +- `nanochat.train.tokenizer` POST - Train BPE tokenizer from dataset. Runs `scripts/tok_train.py`. +- `nanochat.train.base` POST - Pretrain base GPT model. Runs `scripts/base_train.py` with full Muon optimizer, gradient accumulation, LR scheduling, FP8, checkpoint saving. +- `nanochat.train.sft` POST - Supervised fine-tuning with real task mixture (SmolTalk, MMLU, GSM8K, SpellingBee). Runs `scripts/chat_sft.py`. +- `nanochat.train.rl` POST - GRPO reinforcement learning on GSM8K. Runs `scripts/chat_rl.py`. +- `nanochat.train.status` GET - Training run progress from iii state. + +**Evaluation** (imports and calls real nanochat eval functions) + +- `nanochat.eval.core` POST - CORE benchmark (DCLM). Calls `base_eval.evaluate_core()`. +- `nanochat.eval.loss` POST - Bits-per-byte on validation set. Calls `loss_eval.evaluate_bpb()`. +- `nanochat.eval.chat` POST - ChatCORE evaluation (GSM8K, MMLU, ARC-Easy, ARC-Challenge, HumanEval, SpellingBee). Calls `chat_eval.run_chat_eval()`. + +**Checkpoints** + +- `nanochat.checkpoint.save` POST - Save current model to disk. +- `nanochat.checkpoint.list` GET - List available checkpoints by source. + +**Health** + +- `nanochat.health` GET - Worker health, model loaded status, device. +- `nanochat.tools.execute` POST - Execute Python code in-process (not sandboxed). + +## State scopes + +All state goes through iii `state::get/set`. Five scopes: + +- **nanochat:sessions** - Conversation history keyed by session_id. +- **nanochat:models** - Model metadata. The `current` key reflects the loaded model. +- **nanochat:training** - Training run progress keyed by run_id. Updated with parsed metrics from subprocess stdout (step, loss, tok/sec, MFU, BPB, CORE scores). +- **nanochat:evals** - Evaluation results keyed by type and timestamp. +- **nanochat:checkpoints** - Checkpoint metadata. + +## How training works + +Training functions can't fork subprocesses from inside iii-sdk handlers (fork corrupts the WebSocket on macOS). The worker solves this with a pre-forked subprocess launcher: + +1. Before connecting to the iii engine, the worker forks a child process using `multiprocessing` with explicit fork context. +2. The child process waits for job requests on a Pipe. +3. When a training function is triggered, it sends the script name and arguments to the child via the Pipe. +4. The child runs `subprocess.Popen` (safe because it was forked before the WebSocket existed). +5. The child captures all stdout and sends it back. +6. The handler parses stdout for metrics (step, loss, BPB, CORE, ChatCORE, reward) and writes them to iii state. + +This gives 100% fidelity to nanochat's training scripts while keeping the iii worker alive. + +## E2E test results + +Tested on macOS (Apple Silicon, CPU) with iii engine v0.10.0 and Python 3.11. Trained a 2-layer, 1.9M parameter GPT model from scratch (5 steps on CPU), loaded the checkpoint, and ran inference through the worker. + +```text +1. Load model -> loaded=True, params=1,966,134, n_layer=2, n_embd=128 +2. Sample -> "<|bos|>Hello! if ifite Sther made Oite were are..." +3. Chat -> completion with session tracking (26 tokens) +4. History -> 1 session stored in iii state +5. Tokenizer -> encode: 5 tokens, decode roundtrip OK +6. Tools -> print(42) = 42 +7. Model status -> full config visible (device, layers, vocab, params) +8. Health -> worker alive after all operations + +8/8 passed +``` + +The generated text is gibberish because the model was only trained for 5 steps. With real GPU training (8xH100, ~2 hours), the model produces coherent chat responses, solves math problems with tool use, and scores competitively on CORE benchmarks. + +## Calling from other workers + +```python +from iii import register_worker +iii = register_worker("ws://localhost:49134") + +result = iii.trigger({ + "function_id": "nanochat.chat.complete", + "payload": { + "messages": [{"role": "user", "content": "What is the capital of France?"}], + "temperature": 0.8, + } +}) +print(result["content"]) +``` + +```typescript +import { registerWorker } from 'iii-sdk' +const iii = registerWorker('ws://localhost:49134') + +const result = await iii.trigger({ + function_id: 'nanochat.chat.complete', + payload: { + messages: [{ role: 'user', content: 'What is the capital of France?' }], + temperature: 0.8, + }, +}) +``` + +## Known issues + +**Null payloads time out.** iii-sdk v0.10.0 drops invocations with `payload: None`. Always pass `{}`. + +**Handler exceptions crash WebSocket.** Unhandled exceptions corrupt the SDK's connection. Every handler is wrapped with `safe()` which logs server-side and returns `{"error": "..."}`. + +**fork() from handler threads crashes WebSocket.** Both `subprocess.Popen` and `os.system` from inside `run_in_executor` or `asyncio.to_thread` corrupt the asyncio event loop on macOS. The pre-forked launcher solves this for training. `tools.execute` uses in-process `exec()`. + +**torch.compile hangs on CPU.** nanochat's `base_train.py` calls `torch.compile(model)` which takes extremely long on CPU. Use GPU for real training. + +## License + +Apache-2.0 diff --git a/nanochat/__pycache__/worker.cpython-311.pyc b/nanochat/__pycache__/worker.cpython-311.pyc new file mode 100644 index 0000000..31415ec Binary files /dev/null and b/nanochat/__pycache__/worker.cpython-311.pyc differ diff --git a/nanochat/nanochat-upstream b/nanochat/nanochat-upstream new file mode 160000 index 0000000..a445144 --- /dev/null +++ b/nanochat/nanochat-upstream @@ -0,0 +1 @@ +Subproject commit a445144d3905c6845fda2d3cab8e63248a70cd32 diff --git a/nanochat/pyproject.toml b/nanochat/pyproject.toml new file mode 100644 index 0000000..25a9840 --- /dev/null +++ b/nanochat/pyproject.toml @@ -0,0 +1,26 @@ +[build-system] +requires = ["setuptools>=64", "wheel"] +build-backend = "setuptools.backends._legacy:_Backend" + +[project] +name = "iii-nanochat" +version = "0.1.0" +description = "nanochat LLM worker for iii-engine" +license = "Apache-2.0" +requires-python = ">=3.10" +dependencies = [ + "iii-sdk>=0.10.0", + "torch>=2.0", + "pydantic>=2.0", + "tiktoken", + "tokenizers", + "datasets", + "pyarrow", + "psutil", +] + +[project.optional-dependencies] +train = ["wandb"] + +[project.scripts] +iii-nanochat = "worker:main" diff --git a/nanochat/worker.py b/nanochat/worker.py new file mode 100644 index 0000000..0ec1287 --- /dev/null +++ b/nanochat/worker.py @@ -0,0 +1,1016 @@ +""" +nanochat worker for iii-engine (v0.10.0 SDK). + +Covers the full nanochat pipeline: tokenizer training, base pretraining, +supervised fine-tuning, RL fine-tuning, CORE/BPB/ChatCORE evaluation, +inference with tool use, and checkpoint management. + +Every capability is a registered function + trigger. Pydantic type hints +on every handler for auto schema extraction. Async handlers for state I/O. +safe() wrapper on every handler for zero-crash guarantee. + +Usage: + python worker.py --no-autoload + python worker.py --source sft --device cuda +""" + +import argparse +import contextlib +import io +import os +import signal +import sys +import threading +import time +import traceback +import uuid +from pathlib import Path +from typing import Any + +from pydantic import BaseModel, Field + +from iii import InitOptions, Logger, TelemetryOptions, register_worker + +NANOCHAT_DIR = os.environ.get("NANOCHAT_DIR", str(Path(__file__).parent / "nanochat-upstream" / "nanochat")) + +logger = Logger(service_name="iii-nanochat") + +iii_client = None +_nanochat_imported = False + + +def _ensure_nanochat(): + global _nanochat_imported + if _nanochat_imported: + return + parent = str(Path(NANOCHAT_DIR).parent) + if parent not in sys.path: + sys.path.insert(0, parent) + import torch # noqa: F401 + _nanochat_imported = True + + +def safe(fn): + import asyncio as _aio + if _aio.iscoroutinefunction(fn): + async def wrapper(data): + try: + return await fn(data) + except Exception as e: + logger.error(f"Handler {fn.__name__} failed", {"error": str(e), "traceback": traceback.format_exc()}) + return {"error": str(e)} + else: + def wrapper(data): + try: + return fn(data) + except Exception as e: + logger.error(f"Handler {fn.__name__} failed", {"error": str(e), "traceback": traceback.format_exc()}) + return {"error": str(e)} + wrapper.__name__ = fn.__name__ + wrapper.__annotations__ = fn.__annotations__ + return wrapper + + +# --------------------------------------------------------------------------- +# Pydantic schemas +# --------------------------------------------------------------------------- + +class ChatMessage(BaseModel): + role: str + content: str + +class ChatCompleteInput(BaseModel): + messages: list[ChatMessage] + temperature: float = Field(0.6, ge=0.0, le=2.0) + top_k: int = Field(50, ge=0, le=200) + max_tokens: int = Field(2048, ge=1, le=4096) + session_id: str | None = None + +class ChatCompleteOutput(BaseModel): + content: str + tokens_generated: int + session_id: str + +class ChatHistoryInput(BaseModel): + session_id: str | None = None + +class ModelLoadInput(BaseModel): + source: str = "sft" + model_tag: str | None = None + step: int | None = None + device: str | None = None + +class ModelStatusOutput(BaseModel): + loaded: bool + source: str | None = None + model_tag: str | None = None + device: str | None = None + n_layer: int | None = None + n_embd: int | None = None + vocab_size: int | None = None + sequence_len: int | None = None + parameters: int | None = None + +class ModelSampleInput(BaseModel): + prompt: str = "" + max_tokens: int = 256 + temperature: float = 0.8 + top_k: int = 50 + num_samples: int = 1 + +class TokenizeInput(BaseModel): + text: str | list[str] + +class DecodeInput(BaseModel): + tokens: list[int] + +class ExecuteCodeInput(BaseModel): + code: str + timeout: float = 5.0 + +class TrainTokenizerInput(BaseModel): + max_chars: int = 2_000_000_000 + doc_cap: int = 10_000 + vocab_size: int = 32_768 + +class TrainBaseInput(BaseModel): + depth: int = 20 + aspect_ratio: int = 64 + head_dim: int = 128 + max_seq_len: int = 2048 + window_pattern: str = "SSSL" + target_param_data_ratio: float = 12.0 + num_iterations: int = -1 + device_batch_size: int = 32 + warmup_steps: int = 40 + warmdown_ratio: float = 0.65 + eval_every: int = 250 + save_every: int = -1 + device: str | None = None + run_name: str = "base" + model_tag: str | None = None + fp8: bool = False + +class TrainSFTInput(BaseModel): + source: str = "base" + model_tag: str | None = None + step: int | None = None + num_iterations: int = -1 + device_batch_size: int | None = None + mmlu_epochs: int = 3 + gsm8k_epochs: int = 4 + eval_every: int = 200 + save_every: int = -1 + warmdown_ratio: float = 0.5 + device: str | None = None + run_name: str = "sft" + +class TrainRLInput(BaseModel): + source: str = "sft" + model_tag: str | None = None + step: int | None = None + num_epochs: int = 1 + examples_per_step: int = 16 + num_samples: int = 16 + max_new_tokens: int = 256 + temperature: float = 1.0 + top_k: int = 50 + device_batch_size: int = 8 + eval_every: int = 60 + save_every: int = 60 + device: str | None = None + run_name: str = "rl" + +class TrainStatusInput(BaseModel): + run_id: str | None = None + +class EvalCoreInput(BaseModel): + max_per_task: int = -1 + +class EvalLossInput(BaseModel): + split: str = "val" + steps: int = 50 + device_batch_size: int = 4 + +class EvalChatInput(BaseModel): + task_name: str | None = None + temperature: float = 0.0 + max_new_tokens: int = 512 + num_samples: int = 1 + top_k: int = 50 + batch_size: int = 8 + max_problems: int | None = None + +class CheckpointSaveInput(BaseModel): + tag: str | None = None + step: int | None = None + +class CheckpointListInput(BaseModel): + source: str = "sft" + +class HealthOutput(BaseModel): + status: str + model_loaded: bool + device: str | None = None + source: str | None = None + worker: str = "iii-nanochat" + + +# --------------------------------------------------------------------------- +# GPU state +# --------------------------------------------------------------------------- + +class GPUState: + def __init__(self): + self.model = None + self.tokenizer = None + self.engine = None + self.meta: dict | None = None + self.source: str | None = None + self.model_tag: str | None = None + self.device: str | None = None + self._lock = threading.Lock() + + def load(self, source, device, model_tag=None, step=None): + _ensure_nanochat() + import torch + from nanochat.checkpoint_manager import load_model + from nanochat.engine import Engine + with self._lock: + phase = "eval" + dev = torch.device(device) + model, tokenizer, meta = load_model(source, dev, phase, model_tag=model_tag, step=step) + model.eval() + self.model = model + self.tokenizer = tokenizer + self.engine = Engine(model, tokenizer) + self.meta = meta + self.source = source + self.model_tag = model_tag + self.device = device + + def snapshot(self): + """Return a consistent snapshot under lock.""" + with self._lock: + return self.model, self.tokenizer, self.engine, self.meta, self.source, self.device, self.model_tag + + @property + def ready(self): + return self.engine is not None + +gpu = GPUState() + + +# --------------------------------------------------------------------------- +# Async state helpers +# --------------------------------------------------------------------------- + +async def state_get(scope, key): + return await iii_client.trigger_async({"function_id": "state::get", "payload": {"scope": scope, "key": key}}) + +async def state_set(scope, key, value): + return await iii_client.trigger_async({"function_id": "state::set", "payload": {"scope": scope, "key": key, "value": value}}) + +async def state_list(scope): + return await iii_client.trigger_async({"function_id": "state::list", "payload": {"scope": scope}}) + + +# --------------------------------------------------------------------------- +# Chat handlers +# --------------------------------------------------------------------------- + +async def fn_chat_complete(data: ChatCompleteInput) -> ChatCompleteOutput: + _ensure_nanochat() + import torch + if not gpu.ready: + raise RuntimeError("No model loaded. Trigger 'nanochat.model.load' first.") + + model, tokenizer, engine, _meta, source, _device, _tag = gpu.snapshot() + inp = ChatCompleteInput.model_validate(data) if isinstance(data, dict) else data + session_id = inp.session_id or str(uuid.uuid4()) + messages = [{"role": m["role"] if isinstance(m, dict) else m.role, "content": m["content"] if isinstance(m, dict) else m.content} for m in inp.messages] + conversation = {"messages": messages} + + if hasattr(tokenizer, "render_conversation"): + tokens, _mask = tokenizer.render_conversation(conversation, max_tokens=model.config.sequence_len) + else: + tokens = tokenizer.render_for_completion(conversation) + + with torch.no_grad(): + results, _masks = engine.generate_batch( + tokens, num_samples=1, + max_tokens=inp.max_tokens, temperature=inp.temperature, top_k=inp.top_k, + ) + + generated_ids = results[0] + text = tokenizer.decode(generated_ids) + if "<|assistant_end|>" in text: + text = text[:text.index("<|assistant_end|>")] + + messages.append({"role": "assistant", "content": text.strip()}) + await state_set("nanochat:sessions", session_id, { + "messages": messages, "model": source, "tokens_generated": len(generated_ids), + }) + logger.info("Chat completion", {"session_id": session_id, "tokens": len(generated_ids)}) + return ChatCompleteOutput(content=text.strip(), tokens_generated=len(generated_ids), session_id=session_id).model_dump() + + +async def fn_chat_stream(data: ChatCompleteInput) -> ChatCompleteOutput: + _ensure_nanochat() + import torch + if not gpu.ready: + raise RuntimeError("No model loaded. Trigger 'nanochat.model.load' first.") + + model, tokenizer, engine, _meta, source, _device, _tag = gpu.snapshot() + inp = ChatCompleteInput.model_validate(data) if isinstance(data, dict) else data + session_id = inp.session_id or str(uuid.uuid4()) + messages = [{"role": m["role"] if isinstance(m, dict) else m.role, "content": m["content"] if isinstance(m, dict) else m.content} for m in inp.messages] + conversation = {"messages": messages} + + if hasattr(tokenizer, "render_conversation"): + tokens, _mask = tokenizer.render_conversation(conversation, max_tokens=model.config.sequence_len) + else: + tokens = tokenizer.render_for_completion(conversation) + + chunks = [] + with torch.no_grad(): + for token_col, _token_masks in engine.generate( + [tokens], num_samples=1, + max_tokens=inp.max_tokens, temperature=inp.temperature, top_k=inp.top_k, + ): + token_id = token_col[0].item() + piece = tokenizer.decode([token_id]) + if "<|assistant_end|>" in piece: + break + chunks.append(piece) + + full_text = "".join(chunks) + messages.append({"role": "assistant", "content": full_text.strip()}) + await state_set("nanochat:sessions", session_id, { + "messages": messages, "model": source, "tokens_generated": len(chunks), + }) + return ChatCompleteOutput(content=full_text.strip(), tokens_generated=len(chunks), session_id=session_id).model_dump() + + +async def fn_chat_history(data: ChatHistoryInput) -> dict: + inp = ChatHistoryInput.model_validate(data) if isinstance(data, dict) else data + if not inp.session_id: + return {"sessions": await state_list("nanochat:sessions")} + return {"session_id": inp.session_id, "data": await state_get("nanochat:sessions", inp.session_id)} + + +# --------------------------------------------------------------------------- +# Model handlers +# --------------------------------------------------------------------------- + +async def fn_model_load(data: ModelLoadInput) -> ModelStatusOutput: + _ensure_nanochat() + from nanochat.common import autodetect_device_type + inp = ModelLoadInput.model_validate(data) if isinstance(data, dict) else data + device = inp.device or autodetect_device_type() + gpu.load(inp.source, device, model_tag=inp.model_tag, step=inp.step) + model, _tok, _eng, meta, source, dev, tag = gpu.snapshot() + await state_set("nanochat:models", "current", { + "source": source, "model_tag": tag, "device": dev, + "config": meta.get("model_config", {}) if meta else {}, + "parameters": sum(p.numel() for p in model.parameters()), + }) + logger.info("Model loaded", {"source": source, "device": dev}) + return await fn_model_status({}) + + +async def fn_model_status(data: dict) -> ModelStatusOutput: + if not gpu.ready: + return ModelStatusOutput(loaded=False).model_dump() + model, _tok, _eng, meta, source, device, model_tag = gpu.snapshot() + config = meta.get("model_config", {}) if meta else {} + return ModelStatusOutput( + loaded=True, source=source, model_tag=model_tag, device=device, + n_layer=config.get("n_layer"), n_embd=config.get("n_embd"), + vocab_size=config.get("vocab_size"), sequence_len=config.get("sequence_len"), + parameters=sum(p.numel() for p in model.parameters()) if model else None, + ).model_dump() + + +async def fn_model_sample(data: ModelSampleInput) -> dict: + _ensure_nanochat() + import torch + if not gpu.ready: + raise RuntimeError("No model loaded. Trigger 'nanochat.model.load' first.") + + _model, tokenizer, engine, _meta, _source, _device, _tag = gpu.snapshot() + inp = ModelSampleInput.model_validate(data) if isinstance(data, dict) else data + bos = tokenizer.get_bos_token_id() + if inp.prompt: + encoded = tokenizer.encode(inp.prompt) + tokens = [bos] + (list(encoded) if not isinstance(encoded, list) else encoded) + else: + tokens = [bos] + + samples = [] + with torch.no_grad(): + results, _masks = engine.generate_batch( + tokens, num_samples=inp.num_samples, + max_tokens=inp.max_tokens, temperature=inp.temperature, top_k=inp.top_k, + ) + for result_ids in results: + text = tokenizer.decode(result_ids) + if "<|assistant_end|>" in text: + text = text[:text.index("<|assistant_end|>")] + samples.append(text) + + return {"samples": samples, "num_samples": len(samples)} + + +# --------------------------------------------------------------------------- +# Tokenizer handlers +# --------------------------------------------------------------------------- + +async def fn_tokenizer_encode(data: TokenizeInput) -> dict: + _ensure_nanochat() + from nanochat.tokenizer import get_tokenizer + inp = TokenizeInput.model_validate(data) if isinstance(data, dict) else data + _model, tokenizer, _eng, _meta, _src, _dev, _tag = gpu.snapshot() + if tokenizer is None: + tokenizer = get_tokenizer() + bos = tokenizer.get_bos_token_id() + encoded = tokenizer.encode(inp.text, prepend=bos) + count = sum(len(t) for t in encoded) if isinstance(inp.text, list) else len(encoded) + return {"tokens": encoded, "count": count} + + +async def fn_tokenizer_decode(data: DecodeInput) -> dict: + _ensure_nanochat() + from nanochat.tokenizer import get_tokenizer + inp = DecodeInput.model_validate(data) if isinstance(data, dict) else data + _model, tokenizer, _eng, _meta, _src, _dev, _tag = gpu.snapshot() + if tokenizer is None: + tokenizer = get_tokenizer() + return {"text": tokenizer.decode(inp.tokens)} + + +# --------------------------------------------------------------------------- +# Tools handler +# --------------------------------------------------------------------------- + +async def fn_tools_execute(data: ExecuteCodeInput) -> dict: + inp = ExecuteCodeInput.model_validate(data) if isinstance(data, dict) else data + stdout_buf, stderr_buf = io.StringIO(), io.StringIO() + try: + with contextlib.redirect_stdout(stdout_buf), contextlib.redirect_stderr(stderr_buf): + exec(inp.code, {"__builtins__": __builtins__}, {}) + return {"success": True, "stdout": stdout_buf.getvalue(), "stderr": stderr_buf.getvalue(), "error": None} + except Exception as e: + return {"success": False, "stdout": stdout_buf.getvalue(), "stderr": stderr_buf.getvalue(), "error": str(e)} + + +# --------------------------------------------------------------------------- +# Subprocess runner with real-time stdout parsing -> iii state +# --------------------------------------------------------------------------- + +def _nanochat_repo_dir() -> str: + """Root of the nanochat repo (contains scripts/, tasks/, nanochat/).""" + return str(Path(NANOCHAT_DIR).parent) + + +def _parse_training_line(line: str) -> dict | None: + """Parse nanochat stdout into structured metrics. Returns None for non-metric lines.""" + import re + + # base_train / chat_sft step line: + # "step 00100/05000 (2.00%) | loss: 4.123456 | lrm: 0.50 | dt: 123.45ms | tok/sec: 123,456 | bf16_mfu: 0.45" + m = re.match(r"step\s+(\d+)(?:/(\d+))?\s+\((\d+\.\d+)%\)\s*\|(.+)", line) + if m: + metrics = {"step": int(m.group(1)), "pct": float(m.group(3))} + if m.group(2): + metrics["total_steps"] = int(m.group(2)) + for pair in m.group(4).split("|"): + pair = pair.strip() + kv = pair.split(":") + if len(kv) == 2: + key = kv[0].strip().replace(" ", "_") + val = kv[1].strip().replace(",", "").rstrip("ms").rstrip("m") + try: + metrics[key] = float(val) + except ValueError: + metrics[key] = val + return metrics + + # Validation BPB: "Step 00250 | Validation bpb: 1.234567" + m = re.match(r"Step\s+(\d+)\s+\|\s+Validation bpb:\s+(\S+)", line) + if m: + return {"step": int(m.group(1)), "val_bpb": float(m.group(2)), "event": "eval_bpb"} + + # CORE metric: "Step 00250 | CORE metric: 0.1234" + m = re.match(r"Step\s+(\d+)\s+\|\s+CORE metric:\s+(\S+)", line) + if m: + return {"step": int(m.group(1)), "core_metric": float(m.group(2)), "event": "eval_core"} + + # ChatCORE: "Step 00200 | ChatCORE: 0.1234 | ChatCORE_cat: 0.2345" + m = re.match(r"Step\s+(\d+)\s+\|\s+ChatCORE:\s+(\S+)\s+\|\s+ChatCORE_cat:\s+(\S+)", line) + if m: + return {"step": int(m.group(1)), "chatcore": float(m.group(2)), "chatcore_cat": float(m.group(3)), "event": "eval_chatcore"} + + # RL step: "Step 10/100 | Average reward: 0.5 | Average sequence length: 128.00" + m = re.match(r"Step\s+(\d+)/(\d+)\s+\|\s+Average reward:\s+(\S+)\s+\|\s+Average sequence length:\s+(\S+)", line) + if m: + return {"step": int(m.group(1)), "total_steps": int(m.group(2)), "avg_reward": float(m.group(3)), "avg_seq_len": float(m.group(4))} + + # RL pass@k: "Step 10 | pass@1: 0.25, pass@16: 0.75" + m = re.match(r"Step\s+(\d+)\s+\|\s+(pass@.+)", line) + if m: + metrics = {"step": int(m.group(1)), "event": "eval_passk"} + for pair in m.group(2).split(","): + kv = pair.strip().split(":") + if len(kv) == 2: + metrics[kv[0].strip()] = float(kv[1].strip()) + return metrics + + return None + + +# --------------------------------------------------------------------------- +# Pre-forked subprocess launcher (forked BEFORE iii connects, safe from WebSocket corruption) +# --------------------------------------------------------------------------- + +_launcher_conn = None + + +def _launcher_child(conn, python_exe: str, repo_dir: str): + """Child process: receives (module, args) over pipe, runs subprocess, sends back (returncode, lines).""" + import subprocess as sp + while True: + try: + msg = conn.recv() + except EOFError: + break + if msg is None: + break + + module, args = msg["module"], msg["args"] + cmd = [python_exe, "-m", module] + args + try: + proc = sp.Popen( + cmd, cwd=repo_dir, + stdout=sp.PIPE, stderr=sp.STDOUT, + text=True, bufsize=1, + ) + lines = [] + for line in proc.stdout: + lines.append(line.rstrip()) + proc.wait() + conn.send({"returncode": proc.returncode, "lines": lines}) + except Exception as e: + conn.send({"returncode": -1, "lines": [f"launcher error: {e}"]}) + + +def _start_launcher(): + """Fork a child process BEFORE iii connects. Uses fork (not spawn) since no iii state exists yet.""" + import multiprocessing as mp + ctx = mp.get_context("fork") + parent_conn, child_conn = ctx.Pipe() + child = ctx.Process(target=_launcher_child, args=(child_conn, sys.executable, _nanochat_repo_dir()), daemon=True) + child.start() + child_conn.close() + return parent_conn + + +async def _run_training(module: str, args: list[str], run_id: str, train_type: str, extra_state: dict | None = None) -> dict: + """Run a nanochat training script via the pre-forked launcher. + The launcher child does Popen (safe, forked before iii). Results come back over a Pipe.""" + import asyncio + + base_state = {"status": "running", "type": train_type, **(extra_state or {})} + await state_set("nanochat:training", run_id, base_state) + logger.info(f"Running: {module}", {"run_id": run_id, "type": train_type}) + + def _send_and_recv(): + _launcher_conn.send({"module": module, "args": args}) + return _launcher_conn.recv() + + result = await asyncio.to_thread(_send_and_recv) + + returncode = result["returncode"] + lines = result["lines"] + + last_metrics = {} + for line in lines: + metrics = _parse_training_line(line) + if metrics: + last_metrics.update(metrics) + event = metrics.get("event") + if event: + await state_set("nanochat:evals", f"{train_type}-{event}-{metrics.get('step', 0)}", { + "type": event, "run_id": run_id, **metrics, + }) + + status = "complete" if returncode == 0 else "failed" + final_state = { + **base_state, **last_metrics, + "status": status, "returncode": returncode, + "output_tail": "\n".join(lines[-50:]), + } + await state_set("nanochat:training", run_id, final_state) + logger.info(f"{train_type} training {status}", {"run_id": run_id, "returncode": returncode}) + + return {"status": status, "run_id": run_id, "returncode": returncode, **last_metrics} + + +# --------------------------------------------------------------------------- +# Training handlers (all queued, run actual nanochat scripts with live state) +# --------------------------------------------------------------------------- + +async def fn_train_tokenizer(data: TrainTokenizerInput) -> dict: + inp = TrainTokenizerInput.model_validate(data) if isinstance(data, dict) else data + run_id = str(uuid.uuid4())[:8] + + args = [ + "--max-chars", str(inp.max_chars), + "--doc-cap", str(inp.doc_cap), + "--vocab-size", str(inp.vocab_size), + ] + + return await _run_training("scripts.tok_train", args, run_id, "tokenizer", + {"vocab_size": inp.vocab_size}) + + +async def fn_train_base(data: TrainBaseInput) -> dict: + inp = TrainBaseInput.model_validate(data) if isinstance(data, dict) else data + run_id = str(uuid.uuid4())[:8] + + args = [ + "--run", inp.run_name, + "--depth", str(inp.depth), + "--aspect-ratio", str(inp.aspect_ratio), + "--head-dim", str(inp.head_dim), + "--max-seq-len", str(inp.max_seq_len), + "--window-pattern", inp.window_pattern, + "--target-param-data-ratio", str(inp.target_param_data_ratio), + "--device-batch-size", str(inp.device_batch_size), + "--warmup-steps", str(inp.warmup_steps), + "--warmdown-ratio", str(inp.warmdown_ratio), + "--eval-every", str(inp.eval_every), + ] + if inp.num_iterations > 0: + args += ["--num-iterations", str(inp.num_iterations)] + if inp.save_every > 0: + args += ["--save-every", str(inp.save_every)] + if inp.device: + args += ["--device-type", inp.device] + if inp.model_tag: + args += ["--model-tag", inp.model_tag] + if inp.fp8: + args += ["--fp8"] + + return await _run_training("scripts.base_train", args, run_id, "base", + {"depth": inp.depth, "model_tag": inp.model_tag or f"d{inp.depth}"}) + + +async def fn_train_sft(data: TrainSFTInput) -> dict: + inp = TrainSFTInput.model_validate(data) if isinstance(data, dict) else data + run_id = str(uuid.uuid4())[:8] + + args = [ + "--run", inp.run_name, + "--mmlu-epochs", str(inp.mmlu_epochs), + "--gsm8k-epochs", str(inp.gsm8k_epochs), + "--eval-every", str(inp.eval_every), + "--warmdown-ratio", str(inp.warmdown_ratio), + ] + if inp.num_iterations > 0: + args += ["--num-iterations", str(inp.num_iterations)] + if inp.device_batch_size: + args += ["--device-batch-size", str(inp.device_batch_size)] + if inp.save_every > 0: + args += ["--save-every", str(inp.save_every)] + if inp.device: + args += ["--device-type", inp.device] + if inp.model_tag: + args += ["--model-tag", inp.model_tag] + if inp.step: + args += ["--model-step", str(inp.step)] + + return await _run_training("scripts.chat_sft", args, run_id, "sft", + {"source": inp.source}) + + +async def fn_train_rl(data: TrainRLInput) -> dict: + inp = TrainRLInput.model_validate(data) if isinstance(data, dict) else data + run_id = str(uuid.uuid4())[:8] + + args = [ + "--run", inp.run_name, + "--num-epochs", str(inp.num_epochs), + "--examples-per-step", str(inp.examples_per_step), + "--num-samples", str(inp.num_samples), + "--max-new-tokens", str(inp.max_new_tokens), + "--temperature", str(inp.temperature), + "--top-k", str(inp.top_k), + "--device-batch-size", str(inp.device_batch_size), + "--eval-every", str(inp.eval_every), + "--save-every", str(inp.save_every), + ] + if inp.device: + args += ["--device-type", inp.device] + if inp.model_tag: + args += ["--model-tag", inp.model_tag] + if inp.step: + args += ["--model-step", str(inp.step)] + + return await _run_training("scripts.chat_rl", args, run_id, "rl") + + +async def fn_train_status(data: TrainStatusInput) -> dict: + inp = TrainStatusInput.model_validate(data) if isinstance(data, dict) else data + if inp.run_id: + return await state_get("nanochat:training", inp.run_id) or {"error": "run not found"} + return {"runs": await state_list("nanochat:training")} + + +# --------------------------------------------------------------------------- +# Evaluation handlers (import and call real nanochat functions) +# --------------------------------------------------------------------------- + +async def fn_eval_core(data: EvalCoreInput) -> dict: + if not gpu.ready: + raise RuntimeError("No model loaded. Trigger 'nanochat.model.load' first.") + _ensure_nanochat() + + model, tokenizer, _engine, _meta, source, device, _tag = gpu.snapshot() + inp = EvalCoreInput.model_validate(data) if isinstance(data, dict) else data + + scripts_dir = os.path.join(_nanochat_repo_dir(), "scripts") + if scripts_dir not in sys.path: + sys.path.insert(0, scripts_dir) + from base_eval import evaluate_core + + dev = model.get_device() if hasattr(model, "get_device") else device + result = evaluate_core(model, tokenizer, dev, max_per_task=inp.max_per_task) + + await state_set("nanochat:evals", f"core-{int(time.time())}", { + "type": "core", "core_metric": result["core_metric"], + "results": result["results"], "model": source, + }) + + return { + "core_metric": result["core_metric"], + "results": result.get("results", {}), + "centered_results": result.get("centered_results", {}), + } + + +async def fn_eval_loss(data: EvalLossInput) -> dict: + if not gpu.ready: + raise RuntimeError("No model loaded. Trigger 'nanochat.model.load' first.") + _ensure_nanochat() + from nanochat.loss_eval import evaluate_bpb + from nanochat.tokenizer import get_token_bytes + from nanochat.dataloader import tokenizing_distributed_data_loader_bos_bestfit + + model, tokenizer, _engine, _meta, source, device, _tag = gpu.snapshot() + inp = EvalLossInput.model_validate(data) if isinstance(data, dict) else data + token_bytes = get_token_bytes(device) + B, T = inp.device_batch_size, model.config.sequence_len + batches = tokenizing_distributed_data_loader_bos_bestfit(tokenizer, B, T, inp.split, device=device) + bpb = evaluate_bpb(model, batches, steps=inp.steps, token_bytes=token_bytes) + + await state_set("nanochat:evals", f"loss-{int(time.time())}", { + "type": "bpb", "bpb": bpb, "split": inp.split, "model": source, + }) + return {"bits_per_byte": bpb, "split": inp.split, "model": source} + + +async def fn_eval_chat(data: EvalChatInput) -> dict: + if not gpu.ready: + raise RuntimeError("No model loaded. Trigger 'nanochat.model.load' first.") + _ensure_nanochat() + + model, tokenizer, engine, _meta, source, _device, _tag = gpu.snapshot() + inp = EvalChatInput.model_validate(data) if isinstance(data, dict) else data + + scripts_dir = os.path.join(_nanochat_repo_dir(), "scripts") + tasks_dir = os.path.join(_nanochat_repo_dir(), "tasks") + if scripts_dir not in sys.path: + sys.path.insert(0, scripts_dir) + if tasks_dir not in sys.path: + sys.path.insert(0, tasks_dir) + + from chat_eval import run_chat_eval + + available_tasks = ["GSM8K", "MMLU", "ARC-Easy", "ARC-Challenge", "HumanEval", "SpellingBee"] + + if inp.task_name: + task_names = [inp.task_name] + else: + task_names = available_tasks + + results = {} + for task_name in task_names: + try: + acc = run_chat_eval( + task_name, model, tokenizer, engine, + batch_size=inp.batch_size, num_samples=inp.num_samples, + max_new_tokens=inp.max_new_tokens, temperature=inp.temperature, + top_k=inp.top_k, max_problems=inp.max_problems, + ) + results[task_name] = acc + except Exception as e: + results[task_name] = {"error": str(e)} + + await state_set("nanochat:evals", f"chat-{int(time.time())}", { + "type": "chat", "results": results, "model": source, + }) + return {"results": results, "model": source} + + +# --------------------------------------------------------------------------- +# Checkpoint handlers +# --------------------------------------------------------------------------- + +async def fn_checkpoint_save(data: CheckpointSaveInput) -> dict: + if not gpu.ready: + raise RuntimeError("No model loaded.") + _ensure_nanochat() + from nanochat.checkpoint_manager import save_checkpoint + from nanochat.common import get_base_dir + + model, _tok, _eng, meta, source, _dev, model_tag = gpu.snapshot() + inp = CheckpointSaveInput.model_validate(data) if isinstance(data, dict) else data + tag = inp.tag or model_tag or "manual" + step = inp.step or int(time.time()) + + base_dir = get_base_dir() + phase_dir = {"base": "checkpoints", "sft": "chatsft_checkpoints", "rl": "chatrl_checkpoints"}.get(source, "checkpoints") + checkpoint_dir = os.path.join(base_dir, phase_dir, tag) + + model_config = meta.get("model_config", {}) if meta else {} + save_checkpoint(checkpoint_dir, step, model.state_dict(), None, { + "step": step, "model_config": model_config, + }) + + await state_set("nanochat:checkpoints", f"{tag}-{step}", { + "tag": tag, "step": step, "source": source, "path": checkpoint_dir, + }) + logger.info("Checkpoint saved", {"tag": tag, "step": step}) + return {"tag": tag, "step": step, "path": checkpoint_dir} + + +async def fn_checkpoint_list(data: CheckpointListInput) -> dict: + _ensure_nanochat() + from nanochat.common import get_base_dir + + inp = CheckpointListInput.model_validate(data) if isinstance(data, dict) else data + base_dir = get_base_dir() + phase_dir = {"base": "checkpoints", "sft": "chatsft_checkpoints", "rl": "chatrl_checkpoints"}.get(inp.source, "checkpoints") + search_dir = os.path.join(base_dir, phase_dir) + + checkpoints = [] + if os.path.exists(search_dir): + for tag_dir in sorted(os.listdir(search_dir)): + tag_path = os.path.join(search_dir, tag_dir) + if os.path.isdir(tag_path): + steps = [] + for f in os.listdir(tag_path): + if f.startswith("model_") and f.endswith(".pt"): + try: + steps.append(int(f[6:-3])) + except ValueError: + continue + steps.sort() + checkpoints.append({"tag": tag_dir, "steps": steps, "path": tag_path}) + + return {"source": inp.source, "checkpoints": checkpoints} + + +# --------------------------------------------------------------------------- +# Health +# --------------------------------------------------------------------------- + +async def fn_health(data: dict) -> HealthOutput: + return HealthOutput( + status="ok", model_loaded=gpu.ready, device=gpu.device, source=gpu.source, + ).model_dump() + + +# --------------------------------------------------------------------------- +# Registration +# --------------------------------------------------------------------------- + +def register_all(iii): + iii.register_service({"id": "nanochat", "name": "nanochat", "description": "Full nanochat pipeline on iii-engine"}) + iii.register_service({"id": "nanochat.chat", "name": "Chat", "parent_service_id": "nanochat"}) + iii.register_service({"id": "nanochat.model", "name": "Model", "parent_service_id": "nanochat"}) + iii.register_service({"id": "nanochat.tokenizer", "name": "Tokenizer", "parent_service_id": "nanochat"}) + iii.register_service({"id": "nanochat.tools", "name": "Tools", "parent_service_id": "nanochat"}) + iii.register_service({"id": "nanochat.eval", "name": "Evaluation", "parent_service_id": "nanochat"}) + iii.register_service({"id": "nanochat.train", "name": "Training", "parent_service_id": "nanochat"}) + iii.register_service({"id": "nanochat.checkpoint", "name": "Checkpoints", "parent_service_id": "nanochat"}) + + functions = [ + # Chat + ("nanochat.chat.complete", fn_chat_complete, "Generate chat completion", "http", {"api_path": "/nanochat/chat/completions", "http_method": "POST"}), + ("nanochat.chat.stream", fn_chat_stream, "Generate chat completion token-by-token", "http", {"api_path": "/nanochat/chat/stream", "http_method": "POST"}), + ("nanochat.chat.history", fn_chat_history, "Get conversation history from state", "http", {"api_path": "/nanochat/chat/history", "http_method": "GET"}), + # Model + ("nanochat.model.load", fn_model_load, "Load checkpoint into GPU memory", "http", {"api_path": "/nanochat/model/load", "http_method": "POST"}), + ("nanochat.model.status", fn_model_status, "Get loaded model status and config", "http", {"api_path": "/nanochat/model/status", "http_method": "GET"}), + ("nanochat.model.sample", fn_model_sample, "Generate raw text samples from loaded model", "http", {"api_path": "/nanochat/model/sample", "http_method": "POST"}), + # Tokenizer + ("nanochat.tokenizer.encode", fn_tokenizer_encode, "Encode text to BPE token IDs", "http", {"api_path": "/nanochat/tokenizer/encode", "http_method": "POST"}), + ("nanochat.tokenizer.decode", fn_tokenizer_decode, "Decode token IDs to text", "http", {"api_path": "/nanochat/tokenizer/decode", "http_method": "POST"}), + # Tools + ("nanochat.tools.execute", fn_tools_execute, "Execute Python code (in-process, not sandboxed)", "http", {"api_path": "/nanochat/tools/execute", "http_method": "POST"}), + # Training (HTTP triggers, long-running - caller sets timeout) + ("nanochat.train.tokenizer", fn_train_tokenizer, "Train BPE tokenizer from dataset", "http", {"api_path": "/nanochat/train/tokenizer", "http_method": "POST"}), + ("nanochat.train.base", fn_train_base, "Pretrain base GPT model from scratch", "http", {"api_path": "/nanochat/train/base", "http_method": "POST"}), + ("nanochat.train.sft", fn_train_sft, "Supervised fine-tuning with task mixture", "http", {"api_path": "/nanochat/train/sft", "http_method": "POST"}), + ("nanochat.train.rl", fn_train_rl, "RL fine-tuning with GRPO on GSM8K", "http", {"api_path": "/nanochat/train/rl", "http_method": "POST"}), + ("nanochat.train.status", fn_train_status, "Check training run status", "http", {"api_path": "/nanochat/train/status", "http_method": "GET"}), + # Evaluation + ("nanochat.eval.core", fn_eval_core, "Run CORE benchmark (DCLM)", "http", {"api_path": "/nanochat/eval/core", "http_method": "POST"}), + ("nanochat.eval.loss", fn_eval_loss, "Evaluate bits-per-byte on validation set", "http", {"api_path": "/nanochat/eval/loss", "http_method": "POST"}), + ("nanochat.eval.chat", fn_eval_chat, "Run ChatCORE evaluation (GSM8K, MMLU, ARC)", "http", {"api_path": "/nanochat/eval/chat", "http_method": "POST"}), + # Checkpoints + ("nanochat.checkpoint.save", fn_checkpoint_save, "Save current model to disk", "http", {"api_path": "/nanochat/checkpoint/save", "http_method": "POST"}), + ("nanochat.checkpoint.list", fn_checkpoint_list, "List available checkpoints", "http", {"api_path": "/nanochat/checkpoint/list", "http_method": "GET"}), + # Health + ("nanochat.health", fn_health, "Worker health check", "http", {"api_path": "/nanochat/health", "http_method": "GET"}), + ] + + for func_id, handler, description, trigger_type, trigger_config in functions: + iii.register_function(func_id, safe(handler), description=description) + iii.register_trigger({"type": trigger_type, "function_id": func_id, "config": trigger_config}) + + logger.info("Registered all functions and triggers", {"count": len(functions)}) + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +def main(): + global iii_client + parser = argparse.ArgumentParser(description="nanochat iii-engine worker") + parser.add_argument("--engine-url", default=os.environ.get("III_ENGINE_URL", "ws://localhost:49134")) + parser.add_argument("--source", default="sft", choices=["base", "sft", "rl"]) + parser.add_argument("--model-tag", default=None) + parser.add_argument("--step", type=int, default=None) + parser.add_argument("--device", default=None) + parser.add_argument("--no-autoload", action="store_true") + parser.add_argument("--nanochat-dir", default=None) + args = parser.parse_args() + + if args.nanochat_dir: + global NANOCHAT_DIR + NANOCHAT_DIR = args.nanochat_dir + parent = str(Path(NANOCHAT_DIR).parent) + if parent not in sys.path: + sys.path.insert(0, parent) + + _ensure_nanochat() + + global _launcher_conn + _launcher_conn = _start_launcher() + print("[nanochat] subprocess launcher forked") + + iii_client = register_worker( + args.engine_url, + InitOptions( + worker_name="nanochat", + invocation_timeout_ms=600000, + telemetry=TelemetryOptions(language="python", project_name="nanochat"), + ), + ) + register_all(iii_client) + + if not args.no_autoload: + from nanochat.common import autodetect_device_type + device = args.device or autodetect_device_type() + try: + gpu.load(args.source, device, model_tag=args.model_tag, step=args.step) + iii_client.trigger({"function_id": "state::set", "payload": { + "scope": "nanochat:models", "key": "current", + "value": {"source": gpu.source, "device": gpu.device, + "config": gpu.meta.get("model_config", {}) if gpu.meta else {}}, + }}) + except Exception as e: + logger.warn("Auto-load failed, use nanochat.model.load", {"error": str(e)}) + + n_funcs = 20 + print(f"[nanochat] connected to {args.engine_url}") + print(f"[nanochat] model: {'loaded (' + gpu.source + ' on ' + gpu.device + ')' if gpu.ready else 'none'}") + print(f"[nanochat] {n_funcs} functions, {n_funcs} triggers (all HTTP)") + + try: + signal.pause() + except AttributeError: + while True: + time.sleep(1) + except KeyboardInterrupt: + pass + finally: + iii_client.shutdown() + + +if __name__ == "__main__": + main() diff --git a/proof/README.md b/proof/README.md new file mode 100644 index 0000000..6c85664 --- /dev/null +++ b/proof/README.md @@ -0,0 +1,255 @@ +# proof + +AI-powered browser testing for the [iii engine](https://github.com/iii-hq/iii). Scans your code changes, launches a real browser, and verifies everything works. + +proof registers browser tools as iii functions. Any agent connected to the engine — Claude Code, Codex, or the Anthropic API — can drive Chromium through snapshot-driven accessibility testing. No fragile CSS selectors. The AI reads the page structure, picks elements by ref, and acts. + +## Quick Start + +```bash +# Terminal 1: Start iii engine +iii --use-default-config + +# Terminal 2: Start proof worker +cd workers/proof +npm install +npm run dev +``` + +proof registers 25 functions with the engine. You're ready to test. + +## Usage + +### Interactive (Claude Code / Codex) + +With proof running, tell your agent: + +> "Test my changes at localhost:3000" + +The agent calls proof's browser functions through iii — no API key needed. + +Or call functions directly: + +```bash +# Scan for changes +iii trigger --function-id='proof::scan' \ + --payload='{"target":"unstaged","cwd":"/path/to/repo"}' + +# Launch browser +iii trigger --function-id='proof::browser::launch' \ + --payload='{"runId":"test-1","headed":true}' + +# Navigate +iii trigger --function-id='proof::browser::navigate' \ + --payload='{"url":"http://localhost:3000"}' + +# Snapshot — get accessibility tree with [ref=eN] markers +iii trigger --function-id='proof::browser::snapshot' --payload='{}' + +# Click by ref +iii trigger --function-id='proof::browser::click' --payload='{"ref":"e3"}' + +# Type into input +iii trigger --function-id='proof::browser::type' \ + --payload='{"ref":"e1","text":"user@example.com"}' + +# Screenshot +iii trigger --function-id='proof::browser::screenshot' --payload='{}' + +# Check console errors +iii trigger --function-id='proof::browser::console_logs' --payload='{}' + +# Check network requests +iii trigger --function-id='proof::browser::network' --payload='{}' + +# Performance metrics (FCP, TTFB, CLS) +iii trigger --function-id='proof::browser::performance' --payload='{}' + +# Raw Playwright execution +iii trigger --function-id='proof::browser::exec' \ + --payload='{"code":"return await page.title()"}' + +# Close browser +iii trigger --function-id='proof::browser::close' --payload='{"runId":"test-1"}' +``` + +### Automated (CI / API) + +For headless runs without an agent, proof drives Claude directly via the Anthropic API: + +```bash +ANTHROPIC_API_KEY=sk-... npm run dev +``` + +```bash +# Full pipeline: scan → plan → execute → report +curl -X POST localhost:3111/proof \ + -H 'Content-Type: application/json' \ + -d '{"target":"branch","base_url":"http://localhost:3000"}' + +# Queue-based run with auto-retry (uses iii Queue + DLQ) +curl -X POST localhost:3111/proof/enqueue \ + -d '{"target":"branch","base_url":"https://staging.myapp.com"}' +``` + +### Replay Saved Flows + +Successful runs save as replayable flows — no AI needed for reruns: + +```bash +# List saved flows +curl localhost:3111/proof/flows + +# Replay a flow +curl -X POST localhost:3111/proof/replay \ + -d '{"slug":"login-flow-m1abc","headed":true}' + +# Run history +curl localhost:3111/proof/history +``` + +## How It Works + +``` +proof::scan git diff → changed files, commits + ↓ +proof::coverage import graph → which files lack tests + ↓ +proof::execute agent loop with browser tools + ↓ ↕ proof::browser::navigate + ↓ ↕ proof::browser::snapshot + ↓ ↕ proof::browser::click + ↓ ↕ proof::browser::type + ↓ ↕ proof::browser::screenshot + ↓ ↕ proof::browser::assert + ↓ +proof::report results → iii State + Stream +``` + +The snapshot-driven approach: + +1. `proof::browser::snapshot` returns an ARIA accessibility tree with `[ref=eN]` markers on every interactive element +2. The agent reads the tree, identifies elements by ref — not CSS selectors +3. `proof::browser::click`, `proof::browser::type` etc. resolve refs to Playwright locators +4. After each action, a fresh snapshot is returned with updated refs + +This makes tests resilient to UI changes. Refs are structural, not visual. + +## Input Options + +```json +{ + "target": "unstaged | staged | branch | commit", + "base_url": "http://localhost:3000", + "instruction": "test the login flow", + "headed": true, + "cookies": true, + "cdp": "auto", + "cwd": "/path/to/repo", + "commit_hash": "abc123", + "main_branch": "main" +} +``` + +| Field | Default | Description | +|-------|---------|-------------| +| `target` | `unstaged` | What to scan: unstaged, staged, branch, or single commit | +| `base_url` | `http://localhost:3000` | URL of the app to test | +| `instruction` | — | Natural language instruction for what to test | +| `headed` | `false` | Show browser window | +| `cookies` | `false` | Extract and inject cookies from local Chrome/Firefox | +| `cdp` | — | CDP WebSocket URL or `"auto"` to discover running Chrome | +| `cwd` | worker cwd | Path to the git repository | +| `commit_hash` | `HEAD` | Specific commit hash (when target is `commit`) | + +## Functions + +### Browser Tools (12) + +| Function | Description | +|----------|-------------| +| `proof::browser::launch` | Launch Chromium (headed or headless, CDP optional) | +| `proof::browser::close` | Close browser session | +| `proof::browser::navigate` | Navigate to URL, return snapshot | +| `proof::browser::snapshot` | ARIA accessibility tree with `[ref=eN]` markers | +| `proof::browser::click` | Click element by ref | +| `proof::browser::type` | Type text into input by ref | +| `proof::browser::select` | Select dropdown option by ref | +| `proof::browser::press` | Press keyboard key on element | +| `proof::browser::screenshot` | Capture page as base64 PNG | +| `proof::browser::console_logs` | Read browser console messages | +| `proof::browser::network` | Read network request log | +| `proof::browser::performance` | Core Web Vitals (FCP, TTFB, CLS) | +| `proof::browser::exec` | Execute raw Playwright code | +| `proof::browser::assert` | Record a pass/fail assertion | + +### Pipeline (10) + +| Function | Description | +|----------|-------------| +| `proof::scan` | Git diff scanning (4 target modes) | +| `proof::coverage` | Import graph analysis → test coverage | +| `proof::execute` | Agent loop with Claude API | +| `proof::report` | Results → iii State + Stream | +| `proof::run` | Full pipeline orchestration | +| `proof::replay` | Replay a saved flow without AI | +| `proof::flows` | List saved flows | +| `proof::history` | Run history with trends | +| `proof::enqueue` | Queue-based run with retries + DLQ | +| `proof::cleanup` | Close all browser sessions | +| `proof::cookies::inject` | Extract local browser cookies | +| `proof::cdp::discover` | Find running Chrome CDP endpoint | + +### HTTP Endpoints (8) + +| Method | Path | Function | +|--------|------|----------| +| POST | `/proof` | `proof::run` | +| POST | `/proof/enqueue` | `proof::enqueue` | +| POST | `/proof/replay` | `proof::replay` | +| POST | `/proof/coverage` | `proof::coverage` | +| POST | `/proof/cleanup` | `proof::cleanup` | +| GET | `/proof/flows` | `proof::flows` | +| GET | `/proof/history` | `proof::history` | +| GET | `/proof/cdp` | `proof::cdp::discover` | + +## iii Primitives Used + +| Primitive | How proof uses it | +|-----------|------------------| +| **Functions** | 25 registered — browser tools, pipeline, queries | +| **Triggers** | 8 HTTP endpoints for REST access | +| **State** | Reports persisted to `proof:reports`, flows to `proof:flows` | +| **Streams** | Real-time test progress pushed to `proof` stream | +| **Queue** | `proof::enqueue` for CI runs with auto-retry | +| **DLQ** | Failed test runs land in DLQ for inspection | +| **Logger** | Every action traced with OTel | + +## Architecture + +``` +┌──────────────────────────────────────────┐ +│ iii Engine │ +│ (ports 49134, 3111) │ +└──────────────────┬───────────────────────┘ + │ + ┌────────┴────────┐ + │ proof worker │ + │ │ + │ 25 functions │ + │ 8 HTTP routes │ + │ Playwright │ + │ simple-git │ + └─────────────────┘ + │ + ┌─────────────┼─────────────┐ + │ │ │ + Claude Code Codex Anthropic API + (interactive) (interactive) (CI/automated) +``` + +Any agent on the engine can call proof's functions. The worker handles browser lifecycle, snapshot generation, and session management. The agent handles test logic. + +## License + +Apache-2.0 diff --git a/proof/package.json b/proof/package.json new file mode 100644 index 0000000..3de0d23 --- /dev/null +++ b/proof/package.json @@ -0,0 +1,27 @@ +{ + "name": "proof", + "version": "0.1.0", + "type": "module", + "description": "AI-powered browser testing worker for iii — scans code changes, generates test plans, runs them in a real browser", + "scripts": { + "dev": "npx tsx --watch src/worker.ts", + "build": "tsc", + "test": "vitest run", + "postinstall": "playwright install chromium" + }, + "dependencies": { + "iii-sdk": "^0.10.0", + "playwright": "^1.52.0", + "simple-git": "^3.27.0" + }, + "optionalDependencies": { + "@anthropic-ai/sdk": "^0.52.0" + }, + "devDependencies": { + "@types/node": "^22.0.0", + "tsx": "^4.0.0", + "typescript": "^5.0.0", + "vitest": "^2.1.0" + }, + "license": "Apache-2.0" +} diff --git a/proof/src/agent.ts b/proof/src/agent.ts new file mode 100644 index 0000000..8e578dd --- /dev/null +++ b/proof/src/agent.ts @@ -0,0 +1,191 @@ +import { SYSTEM_PROMPT, buildUserPrompt } from "./prompt.js"; +import { getAnthropicTools, toolNameToFunctionId } from "./tools.js"; +import type { StepResult, RunReport } from "./types.js"; +import type { CoverageReport } from "./context.js"; + +const MAX_ITERATIONS = 50; + +const STEP_MARKER_RE = + /^(STEP_START|STEP_DONE|ASSERTION_PASSED|ASSERTION_FAILED|RUN_COMPLETED)\|([^|]+)\|(.+)$/gm; + +type IIITrigger = (req: { function_id: string; payload: unknown }) => Promise; + +export async function runAgent( + trigger: IIITrigger, + diff: string, + files: string[], + baseUrl: string, + runId: string, + instruction?: string, + commits?: Array<{ hash: string; subject: string }>, + coverage?: CoverageReport, +): Promise { + if (!process.env.ANTHROPIC_API_KEY) { + throw new Error( + "ANTHROPIC_API_KEY required for automated runs. " + + "For interactive testing, use Claude Code or Codex directly — " + + "browser tools are registered as iii functions (proof::browser::*)." + ); + } + const { default: Anthropic } = await import("@anthropic-ai/sdk"); + const anthropic = new Anthropic(); + const startedAt = Date.now(); + const steps: StepResult[] = []; + let runTitle = "Proof run"; + let runStatus: "pass" | "fail" | "error" = "pass"; + const recordedActions: Array<{ tool: string; input: Record }> = []; + + const messages: any[] = [ + { + role: "user", + content: buildUserPrompt(diff, files, baseUrl, instruction, commits, coverage), + }, + ]; + + for (let iteration = 0; iteration < MAX_ITERATIONS; iteration++) { + const response = await anthropic.messages.create({ + model: "claude-sonnet-4-20250514", + max_tokens: 4096, + system: SYSTEM_PROMPT, + tools: getAnthropicTools() as any[], + messages, + }); + + const toolResults: any[] = []; + + for (const block of response.content) { + if (block.type === "text") { + parseStepMarkers(block.text, steps); + + const runMatch = block.text.match(/RUN_COMPLETED\|(passed|failed)\|(.+)/); + if (runMatch) { + runStatus = runMatch[1] === "passed" ? "pass" : "fail"; + runTitle = runMatch[2].trim(); + } + } + + if (block.type === "tool_use") { + const fnId = toolNameToFunctionId(block.name); + recordedActions.push({ + tool: block.name, + input: block.input as Record, + }); + + try { + const result = await trigger({ + function_id: fnId, + payload: block.input, + }); + + const isScreenshot = block.name === "browser_screenshot"; + if (isScreenshot && typeof result !== "string") { + throw new Error("Screenshot returned invalid data"); + } + toolResults.push({ + type: "tool_result", + tool_use_id: block.id, + content: isScreenshot + ? [{ type: "image", source: { type: "base64", media_type: "image/png", data: result as string } }] + : [{ type: "text", text: typeof result === "string" ? result : JSON.stringify(result) }], + } as any); + } catch (err: unknown) { + const errMsg = err instanceof Error ? err.message : String(err); + toolResults.push({ + type: "tool_result", + tool_use_id: block.id, + content: [{ type: "text", text: `Error: ${errMsg}` }], + is_error: true, + } as any); + } + } + } + + await pushStepProgress(trigger, runId, steps); + + if (response.stop_reason === "end_turn") break; + + if (toolResults.length > 0) { + messages.push({ role: "assistant", content: response.content as any[] }); + messages.push({ role: "user", content: toolResults }); + } else { + break; + } + } + + if (steps.length === 0 && recordedActions.length > 0) { + steps.push({ + id: "step-01", + description: "Browser test execution", + status: runStatus === "pass" ? "passed" : "failed", + assertions: [], + startedAt, + completedAt: Date.now(), + }); + } + + const passed = steps.filter((s) => s.status === "passed").length; + const total = steps.length; + + return { + runId, + title: runTitle, + steps, + status: runStatus, + passRate: total > 0 ? Math.round((passed / total) * 100) : 0, + files, + startedAt, + completedAt: Date.now(), + recordedActions, + }; +} + +function parseStepMarkers(text: string, steps: StepResult[]): void { + STEP_MARKER_RE.lastIndex = 0; + let match: RegExpExecArray | null; + while ((match = STEP_MARKER_RE.exec(text)) !== null) { + const [, marker, id, detail] = match; + + switch (marker) { + case "STEP_START": + steps.push({ id, description: detail, status: "running", assertions: [], startedAt: Date.now() }); + break; + case "STEP_DONE": { + const step = steps.find((s) => s.id === id); + if (step) { + if (step.status !== "failed") step.status = "passed"; + step.completedAt = Date.now(); + } + break; + } + case "ASSERTION_PASSED": + steps.find((s) => s.id === id)?.assertions.push({ text: detail, passed: true }); + break; + case "ASSERTION_FAILED": { + const step = steps.find((s) => s.id === id); + if (step) { step.status = "failed"; step.assertions.push({ text: detail, passed: false }); step.completedAt = Date.now(); } + break; + } + } + } +} + +async function pushStepProgress( + trigger: IIITrigger, + runId: string, + steps: StepResult[], +): Promise { + if (steps.length === 0) return; + try { + await trigger({ + function_id: "stream::set", + payload: { + stream_name: "proof", + group_id: runId, + item_id: `progress`, + data: { steps, updatedAt: Date.now() }, + }, + }); + } catch { + // stream push is best-effort + } +} diff --git a/proof/src/browser.ts b/proof/src/browser.ts new file mode 100644 index 0000000..c927d53 --- /dev/null +++ b/proof/src/browser.ts @@ -0,0 +1,273 @@ +import { chromium, type Browser, type Page } from "playwright"; +import type { BrowserSession, ConsoleEntry, NetworkEntry, RefEntry } from "./types.js"; + +const INTERACTIVE_ROLES = new Set([ + "button", "link", "textbox", "checkbox", "radio", "combobox", + "menuitem", "tab", "switch", "slider", "spinbutton", "searchbox", +]); + +const CONTENT_ROLES = new Set([ + "heading", "img", "cell", "row", "alert", "status", "banner", +]); + +const sessions = new Map(); +let sharedBrowser: Browser | null = null; + +async function getOrCreateBrowser(): Promise { + if (!sharedBrowser || !sharedBrowser.isConnected()) { + sharedBrowser = await chromium.launch({ headless: true }); + } + return sharedBrowser; +} + +export async function autoDiscoverCdp(): Promise { + const endpoints = [ + "http://localhost:9222/json/version", + "http://127.0.0.1:9222/json/version", + ]; + for (const url of endpoints) { + try { + const res = await fetch(url, { signal: AbortSignal.timeout(2000) }); + const data = await res.json() as { webSocketDebuggerUrl?: string }; + if (data.webSocketDebuggerUrl) return data.webSocketDebuggerUrl; + } catch { /* not running */ } + } + return null; +} + +function setupPageTracking(page: Page, session: BrowserSession): void { + page.on("console", (msg) => { + session.consoleMessages.push({ + type: msg.type(), + text: msg.text(), + timestamp: Date.now(), + }); + }); + + page.on("response", (response) => { + session.networkRequests.push({ + method: response.request().method(), + url: response.url(), + status: response.status(), + resourceType: response.request().resourceType(), + timestamp: Date.now(), + }); + }); +} + +export async function launchBrowser( + runId: string, + headed = false, + cdpUrl?: string, +): Promise { + const existing = sessions.get(runId); + if (existing) return existing; + + let browser: Browser; + if (cdpUrl) { + browser = await chromium.connectOverCDP(cdpUrl); + } else if (headed) { + browser = await chromium.launch({ headless: false }); + } else { + browser = await getOrCreateBrowser(); + } + + const context = await browser.newContext({ + viewport: { width: 1280, height: 720 }, + }); + const page = await context.newPage(); + + const session: BrowserSession = { + browser, + context, + page, + refMap: new Map(), + headed, + consoleMessages: [], + networkRequests: [], + replayEvents: [], + cdpUrl, + }; + + setupPageTracking(page, session); + sessions.set(runId, session); + return session; +} + +export function getSession(runId: string): BrowserSession | undefined { + return sessions.get(runId); +} + +const ARIA_LINE_RE = /^(\s*)- (\w+)(?: "([^"]*)")?(.*)$/; + +export async function buildSnapshot( + page: Page, + refMap: Map, +): Promise { + const ariaSnapshot = await page.locator("body").ariaSnapshot(); + if (!ariaSnapshot) return "(empty page)"; + + refMap.clear(); + let refCounter = 0; + const outputLines: string[] = []; + + for (const line of ariaSnapshot.split("\n")) { + const match = ARIA_LINE_RE.exec(line); + if (!match) { + outputLines.push(line); + continue; + } + + const [, indent, role, name, rest] = match; + const isInteractive = INTERACTIVE_ROLES.has(role); + const isContent = CONTENT_ROLES.has(role) && (name?.length ?? 0) > 0; + + let outputLine = `${indent}- ${role}`; + if (name) outputLine += ` "${name}"`; + if (rest) outputLine += rest; + + if (isInteractive || isContent) { + refCounter++; + const ref = `e${refCounter}`; + outputLine += ` [ref=${ref}]`; + refMap.set(ref, { role, name: name ?? "" }); + } + + outputLines.push(outputLine); + } + + return outputLines.join("\n"); +} + +export function resolveRef( + ref: string, + refMap: Map, + page: Page, +) { + const entry = refMap.get(ref); + if (!entry) throw new Error(`Ref "${ref}" not found in current snapshot. Take a new snapshot.`); + return page.getByRole(entry.role as any, { name: entry.name }).first(); +} + +export async function handleNavigate(url: string, session: BrowserSession): Promise { + await session.page.goto(url, { waitUntil: "domcontentloaded", timeout: 15_000 }); + return buildSnapshot(session.page, session.refMap); +} + +export async function handleClick(ref: string, session: BrowserSession): Promise { + const locator = resolveRef(ref, session.refMap, session.page); + await locator.click({ timeout: 10_000 }); + await session.page.waitForTimeout(300); + return buildSnapshot(session.page, session.refMap); +} + +export async function handleType(ref: string, text: string, session: BrowserSession): Promise { + const locator = resolveRef(ref, session.refMap, session.page); + await locator.fill(text, { timeout: 10_000 }); + return buildSnapshot(session.page, session.refMap); +} + +export async function handleSelect(ref: string, value: string, session: BrowserSession): Promise { + const locator = resolveRef(ref, session.refMap, session.page); + await locator.selectOption(value, { timeout: 10_000 }); + return buildSnapshot(session.page, session.refMap); +} + +export async function handlePress(ref: string, key: string, session: BrowserSession): Promise { + const locator = resolveRef(ref, session.refMap, session.page); + await locator.press(key, { timeout: 10_000 }); + await session.page.waitForTimeout(300); + return buildSnapshot(session.page, session.refMap); +} + +export async function handleScreenshot(session: BrowserSession): Promise { + const buffer = await session.page.screenshot({ type: "png" }); + return buffer.toString("base64"); +} + +export async function handleConsoleLogs( + session: BrowserSession, + filter?: { type?: string; clear?: boolean }, +): Promise { + let logs = session.consoleMessages; + if (filter?.type) { + logs = logs.filter((l) => l.type === filter.type); + } + if (filter?.clear) { + session.consoleMessages = []; + } + return logs; +} + +export async function handleNetworkRequests( + session: BrowserSession, + filter?: { method?: string; urlContains?: string; resourceType?: string; clear?: boolean }, +): Promise { + let reqs = session.networkRequests; + if (filter?.method) reqs = reqs.filter((r) => r.method === filter.method); + if (filter?.urlContains) reqs = reqs.filter((r) => r.url.includes(filter.urlContains!)); + if (filter?.resourceType) reqs = reqs.filter((r) => r.resourceType === filter.resourceType); + if (filter?.clear) { + session.networkRequests = []; + } + return reqs; +} + +export async function handlePerformanceMetrics(session: BrowserSession) { + return session.page.evaluate(() => { + const perf = performance.getEntriesByType("navigation")[0] as PerformanceNavigationTiming | undefined; + const paint = performance.getEntriesByType("paint"); + const fcp = paint.find((e) => e.name === "first-contentful-paint"); + + const cls = (performance as any).getEntriesByType?.("layout-shift") ?? []; + const clsValue = cls.reduce((sum: number, e: any) => sum + (e.hadRecentInput ? 0 : e.value), 0); + + return { + url: location.href, + fcp: fcp ? Math.round(fcp.startTime) : null, + domContentLoaded: perf ? Math.round(perf.domContentLoadedEventEnd - perf.startTime) : null, + load: perf ? Math.round(perf.loadEventEnd - perf.startTime) : null, + ttfb: perf ? Math.round(perf.responseStart - perf.requestStart) : null, + cls: Math.round(clsValue * 1000) / 1000, + transferSize: perf?.transferSize ?? null, + }; + }); +} + +export async function handlePlaywrightExec( + code: string, + session: BrowserSession, +): Promise { + const { page, context, browser } = session; + const ref = (id: string) => { + const entry = session.refMap.get(id); + if (!entry) throw new Error(`Ref "${id}" not found`); + return page.getByRole(entry.role as any, { name: entry.name }).first(); + }; + const AsyncFunction = Object.getPrototypeOf(async () => {}).constructor; + const fn = new AsyncFunction("page", "context", "browser", "ref", code); + return fn(page, context, browser, ref); +} + +export async function closeBrowser(runId: string): Promise<{ replayEvents: unknown[] }> { + const session = sessions.get(runId); + if (!session) return { replayEvents: [] }; + + const events = session.replayEvents; + await session.context.close(); + if (session.headed && session.browser !== sharedBrowser && !session.cdpUrl) { + await session.browser.close(); + } + sessions.delete(runId); + return { replayEvents: events }; +} + +export async function closeAll(): Promise { + for (const [runId] of sessions) { + await closeBrowser(runId); + } + if (sharedBrowser) { + await sharedBrowser.close(); + sharedBrowser = null; + } +} diff --git a/proof/src/context.ts b/proof/src/context.ts new file mode 100644 index 0000000..bf3d9a7 --- /dev/null +++ b/proof/src/context.ts @@ -0,0 +1,189 @@ +import { simpleGit, type SimpleGit } from "simple-git"; +import type { ScanResult } from "./types.js"; +import * as fs from "node:fs"; +import * as path from "node:path"; + +const MAX_DIFF_CHARS = 50_000; +const MAX_FILES = 12; +const MAX_COMMITS = 5; + +const SOURCE_EXTENSIONS = new Set([".ts", ".tsx", ".js", ".jsx", ".mts", ".mjs", ".cjs"]); +const SKIP_DIRS = new Set(["node_modules", "dist", "build", ".git", ".next", "coverage", "__pycache__", ".cache"]); +const TEST_PATTERN = /\.(test|spec|e2e)\.[tj]sx?$|__tests__/; + +export async function scanChanges( + target: "unstaged" | "staged" | "branch" | "commit" = "unstaged", + cwd?: string, + mainBranch?: string, + commitHash?: string, +): Promise { + const git: SimpleGit = simpleGit(cwd ?? process.cwd()); + + let diff: string; + let files: string[]; + let commits: Array<{ hash: string; subject: string }> = []; + + switch (target) { + case "branch": { + const main = mainBranch ?? (await detectMainBranch(git)); + diff = await git.diff([`${main}...HEAD`]); + const summary = await git.diffSummary([`${main}...HEAD`]); + files = summary.files.map((f) => f.file).slice(0, MAX_FILES); + const log = await git.log({ from: main, to: "HEAD", maxCount: MAX_COMMITS }); + commits = log.all.map((c) => ({ hash: c.hash, subject: c.message.split("\n")[0] })); + break; + } + case "commit": { + const hash = commitHash ?? "HEAD"; + diff = await git.diff([`${hash}^..${hash}`]); + const summary = await git.diffSummary([`${hash}^..${hash}`]); + files = summary.files.map((f) => f.file).slice(0, MAX_FILES); + const log = await git.log({ from: `${hash}^`, to: hash, maxCount: 1 }); + commits = log.all.map((c) => ({ hash: c.hash, subject: c.message.split("\n")[0] })); + break; + } + case "staged": { + diff = await git.diff(["--cached"]); + const summary = await git.diffSummary(["--cached"]); + files = summary.files.map((f) => f.file).slice(0, MAX_FILES); + break; + } + default: { + diff = await git.diff(); + const summary = await git.diffSummary(); + files = summary.files.map((f) => f.file).slice(0, MAX_FILES); + break; + } + } + + if (!diff.trim()) { + return { diff: "", files: [], commits: [], empty: true }; + } + + const truncatedDiff = + diff.length > MAX_DIFF_CHARS + ? diff.slice(0, MAX_DIFF_CHARS) + "\n... (truncated)" + : diff; + + return { diff: truncatedDiff, files, commits, empty: false }; +} + +export type CoverageEntry = { + path: string; + testFiles: string[]; + covered: boolean; +}; + +export type CoverageReport = { + entries: CoverageEntry[]; + coveredCount: number; + totalCount: number; + percent: number; +}; + +export async function analyzeTestCoverage( + changedFiles: string[], + cwd?: string, +): Promise { + const root = cwd ?? process.cwd(); + const sourceFiles = changedFiles.filter( + (f) => SOURCE_EXTENSIONS.has(path.extname(f)) && !TEST_PATTERN.test(f), + ); + + if (sourceFiles.length === 0) { + return { entries: [], coveredCount: 0, totalCount: 0, percent: 100 }; + } + + const testFiles = await findTestFiles(root); + const testImports = new Map>(); + + for (const testFile of testFiles) { + const imports = await extractImports(path.join(root, testFile)); + for (const imp of imports) { + const resolved = resolveImportPath(imp, testFile, root); + if (resolved) { + if (!testImports.has(resolved)) testImports.set(resolved, new Set()); + testImports.get(resolved)!.add(testFile); + } + } + } + + const entries: CoverageEntry[] = sourceFiles.map((f) => { + const tests = testImports.get(f); + return { + path: f, + testFiles: tests ? [...tests] : [], + covered: !!tests && tests.size > 0, + }; + }); + + const coveredCount = entries.filter((e) => e.covered).length; + return { + entries, + coveredCount, + totalCount: entries.length, + percent: entries.length > 0 ? Math.round((coveredCount / entries.length) * 100) : 100, + }; +} + +async function findTestFiles(root: string, dir = "", results: string[] = []): Promise { + const fullDir = path.join(root, dir); + let entries: fs.Dirent[]; + try { + entries = fs.readdirSync(fullDir, { withFileTypes: true }); + } catch { + return results; + } + + for (const entry of entries) { + if (SKIP_DIRS.has(entry.name)) continue; + const rel = path.join(dir, entry.name); + if (entry.isDirectory()) { + if (results.length < 200) await findTestFiles(root, rel, results); + } else if (TEST_PATTERN.test(entry.name)) { + results.push(rel); + } + } + return results; +} + +async function extractImports(filePath: string): Promise { + let content: string; + try { + content = fs.readFileSync(filePath, "utf-8"); + } catch { + return []; + } + + const imports: string[] = []; + const importRe = /from\s+['"]([^'"]+)['"]/g; + const requireRe = /require\s*\(\s*['"]([^'"]+)['"]\s*\)/g; + + let match: RegExpExecArray | null; + while ((match = importRe.exec(content)) !== null) imports.push(match[1]); + while ((match = requireRe.exec(content)) !== null) imports.push(match[1]); + + return imports.filter((i) => i.startsWith(".")); +} + +function resolveImportPath(importPath: string, fromFile: string, root: string): string | null { + const fromDir = path.dirname(fromFile); + const resolved = path.normalize(path.join(fromDir, importPath)); + + for (const ext of ["", ".ts", ".tsx", ".js", ".jsx", "/index.ts", "/index.js"]) { + const full = path.join(root, resolved + ext); + try { + if (fs.statSync(full).isFile()) return resolved + ext; + } catch { /* not found */ } + } + return resolved; +} + +async function detectMainBranch(git: SimpleGit): Promise { + try { + const ref = await git.raw(["symbolic-ref", "refs/remotes/origin/HEAD"]); + return ref.trim().replace("refs/remotes/origin/", ""); + } catch { + return "main"; + } +} diff --git a/proof/src/cookies.ts b/proof/src/cookies.ts new file mode 100644 index 0000000..53fa7a0 --- /dev/null +++ b/proof/src/cookies.ts @@ -0,0 +1,162 @@ +import * as fs from "node:fs"; +import * as path from "node:path"; +import * as os from "node:os"; +import { execFile } from "node:child_process"; +import { promisify } from "node:util"; +import type { BrowserSession } from "./types.js"; + +const execFileAsync = promisify(execFile); + +type ExtractedCookie = { + name: string; + value: string; + domain: string; + path: string; + expires?: number; + secure: boolean; + httpOnly: boolean; + sameSite?: "Strict" | "Lax" | "None"; +}; + +export async function extractAndInjectCookies( + session: BrowserSession, + targetUrl: string, +): Promise { + const hostname = new URL(targetUrl).hostname; + const cookies = await extractCookiesForDomain(hostname); + if (cookies.length === 0) return 0; + + const pwCookies = cookies.map((c) => ({ + name: c.name, + value: c.value, + domain: c.domain, + path: c.path, + expires: c.expires ?? -1, + secure: c.secure, + httpOnly: c.httpOnly, + sameSite: (c.sameSite ?? "Lax") as "Strict" | "Lax" | "None", + })); + + await session.context.addCookies(pwCookies); + return pwCookies.length; +} + +async function extractCookiesForDomain(domain: string): Promise { + const cookies = await extractChromeCookies(domain); + if (cookies.length > 0) return cookies; + return extractFirefoxCookies(domain); +} + +async function extractChromeCookies(domain: string): Promise { + const platform = os.platform(); + let cookieDbPath: string; + + if (platform === "darwin") { + cookieDbPath = path.join(os.homedir(), "Library/Application Support/Google/Chrome/Default/Cookies"); + } else if (platform === "linux") { + cookieDbPath = path.join(os.homedir(), ".config/google-chrome/Default/Cookies"); + } else { + return []; + } + + if (!fs.existsSync(cookieDbPath)) return []; + + try { + const { stdout } = await execFileAsync("sqlite3", [ + "-json", + cookieDbPath, + `SELECT name, value, host_key as domain, path, expires_utc, is_secure, is_httponly, samesite FROM cookies WHERE host_key LIKE '%${domain.replace(/'/g, "''")}'`, + ]); + + if (!stdout.trim()) return []; + + const rows = JSON.parse(stdout) as Array<{ + name: string; + value: string; + domain: string; + path: string; + expires_utc: number; + is_secure: number; + is_httponly: number; + samesite: number; + }>; + + return rows + .filter((r) => r.value) + .map((r) => ({ + name: r.name, + value: r.value, + domain: r.domain, + path: r.path, + expires: r.expires_utc > 0 ? Math.floor((r.expires_utc / 1_000_000) - 11644473600) : undefined, + secure: r.is_secure === 1, + httpOnly: r.is_httponly === 1, + sameSite: ([undefined, "Lax", "Strict", "None"] as const)[r.samesite] ?? undefined, + })); + } catch { + return []; + } +} + +async function extractFirefoxCookies(domain: string): Promise { + const platform = os.platform(); + let profilesDir: string; + + if (platform === "darwin") { + profilesDir = path.join(os.homedir(), "Library/Application Support/Firefox/Profiles"); + } else if (platform === "linux") { + profilesDir = path.join(os.homedir(), ".mozilla/firefox"); + } else { + return []; + } + + if (!fs.existsSync(profilesDir)) return []; + + let cookieDb: string | null = null; + try { + const profiles = fs.readdirSync(profilesDir); + const defaultProfile = profiles.find((p) => p.endsWith(".default-release") || p.endsWith(".default")); + if (defaultProfile) { + const dbPath = path.join(profilesDir, defaultProfile, "cookies.sqlite"); + if (fs.existsSync(dbPath)) cookieDb = dbPath; + } + } catch { + return []; + } + + if (!cookieDb) return []; + + try { + const { stdout } = await execFileAsync("sqlite3", [ + "-json", + cookieDb, + `SELECT name, value, host as domain, path, expiry, isSecure, isHttpOnly, sameSite FROM moz_cookies WHERE host LIKE '%${domain.replace(/'/g, "''")}'`, + ]); + + if (!stdout.trim()) return []; + + const rows = JSON.parse(stdout) as Array<{ + name: string; + value: string; + domain: string; + path: string; + expiry: number; + isSecure: number; + isHttpOnly: number; + sameSite: number; + }>; + + return rows.filter((r) => r.value).map((r) => ({ + name: r.name, + value: r.value, + domain: r.domain, + path: r.path, + expires: r.expiry > 0 ? r.expiry : undefined, + secure: r.isSecure === 1, + httpOnly: r.isHttpOnly === 1, + sameSite: (["None", "Lax", "Strict"] as const)[r.sameSite] ?? undefined, + })); + } catch { + return []; + } +} diff --git a/proof/src/prompt.ts b/proof/src/prompt.ts new file mode 100644 index 0000000..c13608b --- /dev/null +++ b/proof/src/prompt.ts @@ -0,0 +1,113 @@ +import type { CoverageReport } from "./context.js"; + +export const SYSTEM_PROMPT = `You are a QA engineer testing a web application in a real browser. You verify that code changes work correctly by interacting with the live app. + +## Workflow +1. Read the code diff to understand what changed. +2. Navigate to the base URL with browser_navigate. +3. Take a snapshot with browser_snapshot to see the page structure. +4. Execute test flows that verify the changes work. +5. Emit step markers to track progress. + +## Snapshot-First Pattern +- ALWAYS call browser_snapshot before interacting with elements. +- The snapshot shows an accessibility tree where interactive elements have [ref=eN] markers. +- Use ref IDs in browser_click, browser_type, browser_select, browser_press — never guess CSS selectors. +- After navigation or page changes, take a new snapshot to get fresh refs. +- For complex interactions, use browser_exec with ref() function for direct Playwright access. + +Example snapshot: + - heading "Login" [level=1] + - textbox "Email" [ref=e1] + - textbox "Password" [ref=e2] + - button "Sign In" [ref=e3] + - link "Forgot password?" [ref=e4] + +To click Sign In: use browser_click with ref "e3". + +## Available Tools +- browser_navigate: Go to a URL +- browser_snapshot: Get accessibility tree with refs +- browser_click, browser_type, browser_select, browser_press: Interact by ref +- browser_screenshot: Visual capture (use to verify visual state) +- browser_assert: Record pass/fail assertions +- browser_console_logs: Read browser console output (errors, warnings, logs) +- browser_network: Inspect network requests (API calls, resources) +- browser_performance: Get Core Web Vitals (FCP, TTFB, CLS) +- browser_exec: Run raw Playwright code with page, context, ref() available + +## Step Markers +Emit these markers in your text to track test progress: +- STEP_START|step-NN|Description of what is being tested +- STEP_DONE|step-NN|What was verified +- ASSERTION_PASSED|step-NN|What passed +- ASSERTION_FAILED|step-NN|What failed and why +- RUN_COMPLETED|passed|Summary of all tests +- RUN_COMPLETED|failed|What failed + +## Scope +- For unstaged changes: test 1-3 focused flows on the exact change. +- For staged changes: test 2-4 flows including related functionality. +- For branch changes: test 3-5 flows covering all modified features. +- For commit changes: test 2-4 flows covering the commit's intent. + +## Debugging +- Use browser_console_logs to check for JavaScript errors after interactions. +- Use browser_network to verify API calls are being made correctly. +- Use browser_performance to check page load performance. +- Use browser_screenshot when you need to see the visual layout. + +## Recovery +If something fails: +- Take a screenshot to see the visual state. +- Check console logs for errors. +- Categorize: app-bug (real issue), env-issue (server down), auth-blocked (needs login), selector-drift (ref not found). +- For app-bug: record as ASSERTION_FAILED — this is a real finding. +- For env-issue or auth-blocked: note it and skip the flow. +- For selector-drift: retake snapshot and retry with updated refs. + +## Rules +- Verify results with browser_assert after each meaningful action. +- Check browser_console_logs for errors after page loads and form submissions. +- If a page requires authentication you cannot provide, skip with STEP_DONE noting auth-blocked. +- Always finish with RUN_COMPLETED. +- Keep tests focused on what the diff actually changed.`; + +export function buildUserPrompt( + diff: string, + files: string[], + baseUrl: string, + instruction?: string, + commits?: Array<{ hash: string; subject: string }>, + coverage?: CoverageReport, +): string { + const parts: string[] = []; + + if (instruction) { + parts.push(`## Instruction\n${instruction}`); + } + + parts.push(`## Base URL\n${baseUrl}`); + parts.push(`## Changed Files (${files.length})\n${files.map((f) => `- ${f}`).join("\n")}`); + + if (commits?.length) { + parts.push( + `## Recent Commits\n${commits.map((c) => `- ${c.hash.slice(0, 7)} ${c.subject}`).join("\n")}`, + ); + } + + if (coverage && coverage.totalCount > 0) { + const lines = coverage.entries.map((e) => + e.covered + ? ` [covered] ${e.path}${e.testFiles.length ? ` (tested by: ${e.testFiles.join(", ")})` : ""}` + : ` [no test] ${e.path}`, + ); + parts.push( + `## Test Coverage (${coverage.percent}% — ${coverage.coveredCount}/${coverage.totalCount} files)\n${lines.join("\n")}\nPrioritize browser-testing files WITHOUT existing test coverage.`, + ); + } + + parts.push(`## Diff\n\`\`\`diff\n${diff}\n\`\`\``); + + return parts.join("\n\n"); +} diff --git a/proof/src/tools.ts b/proof/src/tools.ts new file mode 100644 index 0000000..c1d97ab --- /dev/null +++ b/proof/src/tools.ts @@ -0,0 +1,158 @@ +export type ToolDef = { + name: string; + function_id: string; + description: string; + input_schema: Record; +}; + +export const TOOLS: ToolDef[] = [ + { + name: "browser_navigate", + function_id: "proof::browser::navigate", + description: "Navigate to a URL. Returns the page accessibility snapshot after navigation.", + input_schema: { + type: "object", + properties: { url: { type: "string", description: "URL to navigate to" } }, + required: ["url"], + }, + }, + { + name: "browser_snapshot", + function_id: "proof::browser::snapshot", + description: "Get the current page accessibility tree. Interactive elements have [ref=eN] markers. Use these refs in click, type, select, and press tools.", + input_schema: { type: "object", properties: {} }, + }, + { + name: "browser_click", + function_id: "proof::browser::click", + description: "Click an element by ref ID from the snapshot. Returns updated snapshot.", + input_schema: { + type: "object", + properties: { ref: { type: "string", description: "Ref ID from snapshot (e.g. 'e3')" } }, + required: ["ref"], + }, + }, + { + name: "browser_type", + function_id: "proof::browser::type", + description: "Type text into an input by ref ID. Clears existing text first. Returns updated snapshot.", + input_schema: { + type: "object", + properties: { + ref: { type: "string", description: "Ref ID from snapshot" }, + text: { type: "string", description: "Text to type" }, + }, + required: ["ref", "text"], + }, + }, + { + name: "browser_select", + function_id: "proof::browser::select", + description: "Select an option in a dropdown by ref ID. Returns updated snapshot.", + input_schema: { + type: "object", + properties: { + ref: { type: "string", description: "Ref ID from snapshot" }, + value: { type: "string", description: "Option value to select" }, + }, + required: ["ref", "value"], + }, + }, + { + name: "browser_press", + function_id: "proof::browser::press", + description: "Press a keyboard key on an element. Returns updated snapshot.", + input_schema: { + type: "object", + properties: { + ref: { type: "string", description: "Ref ID from snapshot" }, + key: { type: "string", description: "Key to press (Enter, Tab, Escape, etc.)" }, + }, + required: ["ref", "key"], + }, + }, + { + name: "browser_screenshot", + function_id: "proof::browser::screenshot", + description: "Take a screenshot of the current page. Returns base64 PNG image.", + input_schema: { + type: "object", + properties: { + description: { type: "string", description: "What you expect to see" }, + }, + }, + }, + { + name: "browser_assert", + function_id: "proof::browser::assert", + description: "Record an assertion about the current page state.", + input_schema: { + type: "object", + properties: { + assertion: { type: "string", description: "What you are asserting" }, + passed: { type: "boolean", description: "Whether the assertion passed" }, + }, + required: ["assertion", "passed"], + }, + }, + { + name: "browser_console_logs", + function_id: "proof::browser::console_logs", + description: "Get console log messages from the page. Optionally filter by type and clear after reading.", + input_schema: { + type: "object", + properties: { + type: { type: "string", description: "Filter by type: log, error, warning, info" }, + clear: { type: "boolean", description: "Clear logs after reading" }, + }, + }, + }, + { + name: "browser_network", + function_id: "proof::browser::network", + description: "Get network requests made by the page. Filter by method, URL substring, or resource type.", + input_schema: { + type: "object", + properties: { + method: { type: "string", description: "Filter by HTTP method (GET, POST, etc.)" }, + url_contains: { type: "string", description: "Filter by URL substring" }, + resource_type: { type: "string", description: "Filter by type: xhr, fetch, document, script, stylesheet, image" }, + clear: { type: "boolean", description: "Clear request log after reading" }, + }, + }, + }, + { + name: "browser_performance", + function_id: "proof::browser::performance", + description: "Get performance metrics: FCP, DOM content loaded, TTFB, CLS, transfer size.", + input_schema: { type: "object", properties: {} }, + }, + { + name: "browser_exec", + function_id: "proof::browser::exec", + description: "Execute raw Playwright code. Has access to page, context, browser, and ref() function. Returns the result as JSON.", + input_schema: { + type: "object", + properties: { + code: { type: "string", description: "Playwright code to execute. Use ref('e3') to get locators from snapshot refs. Must return a value." }, + }, + required: ["code"], + }, + }, +]; + +const nameToFnId = new Map(TOOLS.map((t) => [t.name, t.function_id])); + +export function toolNameToFunctionId(name: string): string { + const fnId = nameToFnId.get(name); + if (!fnId) throw new Error(`Unknown tool: ${name}`); + return fnId; +} + +export function getAnthropicTools() { + return TOOLS.map((t) => ({ + name: t.name, + description: t.description, + input_schema: t.input_schema, + })); +} diff --git a/proof/src/types.ts b/proof/src/types.ts new file mode 100644 index 0000000..bf0d5db --- /dev/null +++ b/proof/src/types.ts @@ -0,0 +1,81 @@ +import type { Browser, BrowserContext, Page } from "playwright"; + +export type StepResult = { + id: string; + description: string; + status: "running" | "passed" | "failed"; + assertions: Array<{ text: string; passed: boolean }>; + startedAt: number; + completedAt?: number; +}; + +export type RunReport = { + runId: string; + title: string; + steps: StepResult[]; + status: "pass" | "fail" | "error"; + passRate: number; + files: string[]; + startedAt: number; + completedAt: number; + recordedActions: Array<{ tool: string; input: Record }>; +}; + +export type SavedFlow = { + slug: string; + title: string; + baseUrl: string; + actions: Array<{ tool: string; input: Record }>; + savedAt: number; +}; + +export type ScanResult = { + diff: string; + files: string[]; + commits: Array<{ hash: string; subject: string }>; + empty: boolean; +}; + +export type RefEntry = { + role: string; + name: string; + level?: number; +}; + +export type ConsoleEntry = { + type: string; + text: string; + timestamp: number; +}; + +export type NetworkEntry = { + method: string; + url: string; + status?: number; + resourceType: string; + timestamp: number; +}; + +export type BrowserSession = { + browser: Browser; + context: BrowserContext; + page: Page; + refMap: Map; + headed: boolean; + consoleMessages: ConsoleEntry[]; + networkRequests: NetworkEntry[]; + replayEvents: unknown[]; + cdpUrl?: string; +}; + +export type RunInput = { + target?: "unstaged" | "staged" | "branch" | "commit"; + main_branch?: string; + commit_hash?: string; + base_url?: string; + instruction?: string; + headed?: boolean; + cwd?: string; + cdp?: string; + cookies?: boolean; +}; diff --git a/proof/src/worker.ts b/proof/src/worker.ts new file mode 100644 index 0000000..76019dd --- /dev/null +++ b/proof/src/worker.ts @@ -0,0 +1,345 @@ +import { registerWorker, Logger, TriggerAction } from "iii-sdk"; +import { scanChanges, analyzeTestCoverage } from "./context.js"; +import { runAgent } from "./agent.js"; +import { + launchBrowser, getSession, buildSnapshot, autoDiscoverCdp, + handleNavigate, handleClick, handleType, handleSelect, + handlePress, handleScreenshot, handleConsoleLogs, + handleNetworkRequests, handlePerformanceMetrics, + handlePlaywrightExec, closeBrowser, closeAll, +} from "./browser.js"; +import { extractAndInjectCookies } from "./cookies.js"; +import type { BrowserSession, RunInput, SavedFlow } from "./types.js"; + +const iii = registerWorker(process.env.III_URL ?? "ws://localhost:49134"); +const logger = new Logger(); + +let activeRunId: string | null = null; + +function acquireRun(runId: string): void { + if (activeRunId) throw new Error("Another run is in progress. Wait or call proof::cleanup."); + activeRunId = runId; +} + +function releaseRun(): void { + activeRunId = null; +} + +function requireSession(): BrowserSession { + if (!activeRunId) throw new Error("No active browser session. Call proof::run first."); + const session = getSession(activeRunId); + if (!session) throw new Error("No browser session"); + return session; +} + +// --------------------------------------------------------------------------- +// Browser lifecycle — registered as iii functions +// --------------------------------------------------------------------------- + +iii.registerFunction({ id: "proof::browser::launch" }, async (input) => { + const { runId, headed, cdp } = input; + acquireRun(runId); + let cdpUrl: string | undefined; + if (cdp === "auto") { + cdpUrl = (await autoDiscoverCdp()) ?? undefined; + } else if (cdp) { + cdpUrl = cdp; + } + await launchBrowser(runId, headed, cdpUrl); + logger.info("Browser launched", { runId, headed, cdp: cdpUrl ?? "none" }); + return { runId, launched: true }; +}); + +iii.registerFunction({ id: "proof::browser::close" }, async (input) => { + const result = await closeBrowser(input.runId); + releaseRun(); + logger.info("Browser closed", { runId: input.runId }); + return result; +}); + +// --------------------------------------------------------------------------- +// Browser tools — 12 functions called by the agent via iii.trigger() +// --------------------------------------------------------------------------- + +iii.registerFunction({ id: "proof::browser::navigate" }, async (input) => + handleNavigate(input.url, requireSession())); + +iii.registerFunction({ id: "proof::browser::snapshot" }, async () => { + const s = requireSession(); + return buildSnapshot(s.page, s.refMap); +}); + +iii.registerFunction({ id: "proof::browser::click" }, async (input) => + handleClick(input.ref, requireSession())); + +iii.registerFunction({ id: "proof::browser::type" }, async (input) => + handleType(input.ref, input.text, requireSession())); + +iii.registerFunction({ id: "proof::browser::select" }, async (input) => + handleSelect(input.ref, input.value, requireSession())); + +iii.registerFunction({ id: "proof::browser::press" }, async (input) => + handlePress(input.ref, input.key, requireSession())); + +iii.registerFunction({ id: "proof::browser::screenshot" }, async () => + handleScreenshot(requireSession())); + +iii.registerFunction({ id: "proof::browser::assert" }, async (input) => { + logger.info("Assertion", { assertion: input.assertion, passed: input.passed }); + return { assertion: input.assertion, passed: input.passed }; +}); + +iii.registerFunction({ id: "proof::browser::console_logs" }, async (input) => + handleConsoleLogs(requireSession(), input)); + +iii.registerFunction({ id: "proof::browser::network" }, async (input) => + handleNetworkRequests(requireSession(), { + method: input.method, + urlContains: input.url_contains, + resourceType: input.resource_type, + clear: input.clear, + })); + +iii.registerFunction({ id: "proof::browser::performance" }, async () => + handlePerformanceMetrics(requireSession())); + +iii.registerFunction({ id: "proof::browser::exec" }, async (input) => + handlePlaywrightExec(input.code, requireSession())); + +iii.registerFunction({ id: "proof::cookies::inject" }, async (input) => { + const session = requireSession(); + const count = await extractAndInjectCookies(session, input.url); + logger.info("Cookies injected", { url: input.url, count }); + return { injected: count }; +}); + +iii.registerFunction({ id: "proof::cdp::discover" }, async () => { + const url = await autoDiscoverCdp(); + return { found: !!url, url }; +}); + +// --------------------------------------------------------------------------- +// Pipeline functions — all inter-function calls go through iii.trigger() +// --------------------------------------------------------------------------- + +iii.registerFunction({ id: "proof::scan" }, async (input) => { + logger.info("Scanning changes", { target: input.target ?? "unstaged" }); + return scanChanges(input.target, input.cwd, input.main_branch, input.commit_hash); +}); + +iii.registerFunction({ id: "proof::coverage" }, async (input) => { + logger.info("Analyzing test coverage", { files: input.files?.length }); + return analyzeTestCoverage(input.files ?? [], input.cwd); +}); + +iii.registerFunction({ id: "proof::execute" }, async (input) => { + const { diff, files, base_url, instruction, runId, headed, commits, coverage, cdp, cookies } = input; + logger.info("Executing agent loop", { runId, file_count: files?.length }); + + await iii.trigger({ + function_id: "proof::browser::launch", + payload: { runId, headed, cdp }, + }); + + if (cookies) { + await iii.trigger({ + function_id: "proof::cookies::inject", + payload: { url: base_url }, + }); + } + + try { + const trigger = iii.trigger.bind(iii); + return await runAgent(trigger, diff, files, base_url, runId, instruction, commits, coverage); + } finally { + await iii.trigger({ + function_id: "proof::browser::close", + payload: { runId }, + }); + } +}); + +iii.registerFunction({ id: "proof::report" }, async (input) => { + const { report, scan } = input; + logger.info("Test report", { + status: report.status, + pass_rate: `${report.passRate}%`, + steps: report.steps.length, + }); + + await iii.trigger({ + function_id: "state::set", + payload: { scope: "proof:reports", key: `report:${report.runId}`, data: report }, + }); + + if (report.status === "pass" && report.steps.length > 0) { + const base = report.title + .toLowerCase() + .replace(/[^a-z0-9]+/g, "-") + .replace(/^-|-$/g, "") + .slice(0, 50); + const slug = `${base}-${Date.now().toString(36)}`; + + const flow: SavedFlow = { + slug, + title: report.title, + baseUrl: scan?.base_url ?? "", + actions: report.recordedActions ?? [], + savedAt: Date.now(), + }; + + await iii.trigger({ + function_id: "state::set", + payload: { scope: "proof:flows", key: slug, data: flow }, + }); + logger.info("Flow saved", { slug }); + } + + await iii.trigger({ + function_id: "stream::set", + payload: { + stream_name: "proof", + group_id: "results", + item_id: report.runId, + data: { status: report.status, title: report.title, passRate: report.passRate, completedAt: report.completedAt }, + }, + }).catch(() => {}); + + return report; +}); + +iii.registerFunction({ id: "proof::run" }, async (input: RunInput) => { + const runId = `run-${Date.now()}-${Math.random().toString(36).slice(2, 8)}`; + const baseUrl = input.base_url ?? "http://localhost:3000"; + logger.info("Starting proof run", { runId, target: input.target ?? "unstaged" }); + + const scan = await iii.trigger({ + function_id: "proof::scan", + payload: { target: input.target, cwd: input.cwd, main_branch: input.main_branch, commit_hash: input.commit_hash }, + }) as Awaited>; + + if (scan.empty) { + logger.info("No changes detected"); + return { status: "skip", reason: "No changes detected" }; + } + + const coverage = await iii.trigger({ + function_id: "proof::coverage", + payload: { files: scan.files, cwd: input.cwd }, + }); + + const report = await iii.trigger({ + function_id: "proof::execute", + payload: { + diff: scan.diff, files: scan.files, base_url: baseUrl, + instruction: input.instruction, runId, headed: input.headed, + commits: scan.commits, coverage, cdp: input.cdp, cookies: input.cookies, + }, + }); + + return iii.trigger({ + function_id: "proof::report", + payload: { report, scan: { ...scan, base_url: baseUrl } }, + }); +}); + +// --------------------------------------------------------------------------- +// Flow replay — all browser calls through iii.trigger() +// --------------------------------------------------------------------------- + +iii.registerFunction({ id: "proof::replay" }, async (input) => { + const { slug } = input; + const flow = (await iii.trigger({ + function_id: "state::get", + payload: { scope: "proof:flows", key: slug }, + })) as SavedFlow | null; + + if (!flow) return { status: "error", reason: `Flow "${slug}" not found` }; + + logger.info("Replaying flow", { slug, actions: flow.actions.length }); + const runId = `replay-${Date.now()}`; + + await iii.trigger({ + function_id: "proof::browser::launch", + payload: { runId, headed: input.headed ?? false }, + }); + + const results: Array<{ tool: string; status: string; error?: string }> = []; + + try { + for (const action of flow.actions) { + try { + await iii.trigger({ + function_id: `proof::browser::${action.tool.replace("browser_", "")}`, + payload: action.input, + }); + results.push({ tool: action.tool, status: "pass" }); + } catch (err: unknown) { + const msg = err instanceof Error ? err.message : String(err); + results.push({ tool: action.tool, status: "fail", error: msg }); + } + } + } finally { + await iii.trigger({ + function_id: "proof::browser::close", + payload: { runId }, + }); + } + + const failed = results.filter((r) => r.status === "fail").length; + return { slug, status: failed === 0 ? "pass" : "fail", total: results.length, failed, results }; +}); + +// --------------------------------------------------------------------------- +// State queries — all through iii.trigger() +// --------------------------------------------------------------------------- + +iii.registerFunction({ id: "proof::flows" }, async () => { + return iii.trigger({ function_id: "state::list", payload: { scope: "proof:flows" } }); +}); + +iii.registerFunction({ id: "proof::history" }, async (input) => { + const reports = await iii.trigger({ function_id: "state::list", payload: { scope: "proof:reports" } }) as any[]; + if (!Array.isArray(reports)) return []; + return reports + .sort((a: any, b: any) => (b.completedAt ?? 0) - (a.completedAt ?? 0)) + .slice(0, input?.limit ?? 20) + .map((r: any) => ({ + runId: r.runId, title: r.title, status: r.status, + passRate: r.passRate, steps: r.steps?.length ?? 0, completedAt: r.completedAt, + })); +}); + +iii.registerFunction({ id: "proof::cleanup" }, async () => { + await closeAll(); + releaseRun(); + logger.info("All browsers closed"); + return { cleaned: true }; +}); + +// --------------------------------------------------------------------------- +// Queue-based runs — iii primitive, Expect can't do this +// --------------------------------------------------------------------------- + +iii.registerFunction({ id: "proof::enqueue" }, async (input: RunInput) => { + return iii.trigger({ + function_id: "proof::run", + payload: input, + action: TriggerAction.Enqueue({ queue: "proof" }), + }); +}); + +// --------------------------------------------------------------------------- +// HTTP triggers — every function accessible via REST +// --------------------------------------------------------------------------- + +iii.registerTrigger({ type: "http", function_id: "proof::run", config: { api_path: "/proof", http_method: "POST" } }); +iii.registerTrigger({ type: "http", function_id: "proof::replay", config: { api_path: "/proof/replay", http_method: "POST" } }); +iii.registerTrigger({ type: "http", function_id: "proof::flows", config: { api_path: "/proof/flows", http_method: "GET" } }); +iii.registerTrigger({ type: "http", function_id: "proof::history", config: { api_path: "/proof/history", http_method: "GET" } }); +iii.registerTrigger({ type: "http", function_id: "proof::cleanup", config: { api_path: "/proof/cleanup", http_method: "POST" } }); +iii.registerTrigger({ type: "http", function_id: "proof::coverage", config: { api_path: "/proof/coverage", http_method: "POST" } }); +iii.registerTrigger({ type: "http", function_id: "proof::enqueue", config: { api_path: "/proof/enqueue", http_method: "POST" } }); +iii.registerTrigger({ type: "http", function_id: "proof::cdp::discover", config: { api_path: "/proof/cdp", http_method: "GET" } }); + +console.log("proof worker started — listening for calls"); diff --git a/proof/tsconfig.json b/proof/tsconfig.json new file mode 100644 index 0000000..1d6524b --- /dev/null +++ b/proof/tsconfig.json @@ -0,0 +1,13 @@ +{ + "compilerOptions": { + "target": "ES2022", + "module": "ESNext", + "moduleResolution": "bundler", + "esModuleInterop": true, + "strict": true, + "outDir": "dist", + "rootDir": "src", + "skipLibCheck": true + }, + "include": ["src"] +} diff --git a/registry/index.json b/registry/index.json index e152a66..62638e7 100644 --- a/registry/index.json +++ b/registry/index.json @@ -20,6 +20,20 @@ } }, "version": "0.1.2" + }, + "nanochat": { + "description": "Karpathy's nanochat LLM worker — train, fine-tune, evaluate, and chat with GPT models", + "repo": "iii-hq/workers", + "tag_prefix": "nanochat", + "language": "python", + "supported_targets": ["any"], + "has_checksum": false, + "default_config": { + "source": "sft", + "device": null, + "engine_url": "ws://localhost:49134" + }, + "version": "0.1.0" } } }