|
| 1 | +# Workflow Design: 1 - Grok-1 Inference and Sampling |
| 2 | + |
| 3 | +## Overview |
| 4 | + |
| 5 | +The \"Grok-1 Inference and Sampling\" workflow provides the machinery to load the Grok-1 model's 314 billion parameters from a checkpoint, initialize the decoder-only transformer architecture with Mixture-of-Experts (MoE) layers and Grouped Query Attention (GQA), set up distributed sharding across GPUs using JAX meshes and PJIT, tokenize prompts with SentencePiece, and generate text autoregressively. Sampling incorporates temperature-controlled softmax, nucleus (top-p) filtering for diversity control, and top-k logging. The design emphasizes correctness for validation, supporting batched multi-request handling via a generator that manages KV caches per request slot, padding for variable lengths, and efficient decode steps post-prefill. |
| 6 | + |
| 7 | +Key inputs: Checkpoint in `./checkpoints/ckpt-0/`, `tokenizer.model`, GPU cluster, prompts as `Request` objects (prompt str, temperature float, nucleus_p float, rng_seed int, max_len int). |
| 8 | +Outputs: Generated text strings. |
| 9 | +Entry points: `run.py` for test run, or `InferenceRunner().run()` generator for streaming requests. |
| 10 | +Relevant files: `run.py`, `runners.py`, `model.py`, `checkpoint.py`, `tokenizer.model`. |
| 11 | + |
| 12 | +The workflow orchestrates model loading, compilation of sharded compute functions, prompt processing (prefill KV cache while sampling first token), and iterative single-token generation using cached attention keys/values, until max length or EOS. |
| 13 | + |
| 14 | +## Components |
| 15 | + |
| 16 | +### run.py |
| 17 | +- Defines Grok-1 hyperparameters via `LanguageModelConfig` and `TransformerConfig` (e.g., vocab_size=131072, sequence_len=8192, emb_size=6144, num_layers=64, num_q_heads=48, num_kv_heads=8, num_experts=8, num_selected_experts=2, widening_factor=8, key_size=128, shard_activations=True). |
| 18 | +- Instantiates `InferenceRunner` with `ModelRunner`, checkpoint_path=\"./checkpoints/\", tokenizer_path=\"./tokenizer.model\", mesh configs (local=(1,8) for 8 GPUs, between_hosts=(1,1)). |
| 19 | +- Calls `initialize()` and `run()` generator, demonstrates sampling via `sample_from_model(gen, prompt, max_len=100, temperature=0.01)` on example prompt. |
| 20 | + |
| 21 | +### runners.py |
| 22 | +- **ModelRunner**: Configures model dtype (bfloat16), computes batch sizes from bs_per_device * devices * replicas, creates hybrid JAX mesh, defines/transforms Haiku forward functions for full pass and logits-only, applies partition rules for sharding, loads or inits params via checkpoint restore. Supports quantization and activation sharding. |
| 23 | +- **InferenceRunner**: Loads tokenizer, computes param sharding from shapes, compiles PJIT functions: |
| 24 | + - `new_memory`: Initialize KV cache for batch/seq_len. |
| 25 | + - `prefill_memory`: For a request slot, encode prompt, pad, process full prompt forward (with new memory, length tracking), sample first gen token, update global batch memory/settings/rngs/last_output. |
| 26 | + - `sample_step`: For all active slots, forward on last token using shared memory, sample next token, update memory (donate for efficiency). |
| 27 | +- `run()` generator: Precompiles with dummy prompts for pad buckets, manages fixed batch slots (some free), yields for requests, fills slots via prefill, loops stepping all active, appends tokens per slot on host, yields decoded text when done, deactivates slot. Handles concurrency via free_slots list. |
| 28 | + |
| 29 | +### model.py |
| 30 | +- Architecture: Embeddings → 64 Transformer layers (RMSNorm → GQA MultiHeadAttention with RoPE & KV cache → residual → RMSNorm → MoELayer → residual) → output linear to logits. |
| 31 | +- **MultiHeadAttention**: GQA (48 query / 8 KV heads, head_dim=128), supports caching via `Memory` (list of `KVMemory` per layer with k,v,step). |
| 32 | +- **MoELayer**: Router selects top-2 of 8 experts per token, each expert is SwiGLU FFN; uses shard_map/vmap for dispatch (validation-focused, not optimized). |
| 33 | +- Other: `RotaryEmbedding`, custom `Linear` with quantization, `apply_rules` for sharding specs (P('data'), P('model'), etc.). |
| 34 | +- Forward callable via `make(mesh)` integrates sharding, returns `LanguageModelOutput` (logits, model_state=Memory). |
| 35 | + |
| 36 | +### checkpoint.py |
| 37 | +- `restore()`: Computes shapes, loads pickled sharded checkpoint files (handles `QuantizedWeight8bit`), copies to shared memory (/dev/shm) for fast access, syncs across hosts via broadcast, shards into JAX arrays matching specified sharding/mesh. Supports params_only, init_state fallback, rename/exclude rules. |
| 38 | + |
| 39 | +### tokenizer.model & Others |
| 40 | +- SentencePiece for subword tokenization (pad_token=0, eos_token=2). |
| 41 | +- Dependencies: JAX (distributed arrays, pjit, shard_map), Haiku (modules/transform), NumPy/Jax.numpy, sentencepiece. |
| 42 | +- checkpoints/: Directory for downloaded weights (torrent or HF). |
| 43 | + |
| 44 | +## Initialization Sequence |
| 45 | + |
| 46 | +```mermaid |
| 47 | +sequenceDiagram |
| 48 | + participant User |
| 49 | + participant RunPy as run.py |
| 50 | + participant IR as InferenceRunner |
| 51 | + participant MR as ModelRunner |
| 52 | + participant Model as model.py |
| 53 | + participant Checkpoint as checkpoint.py |
| 54 | + participant JAX as JAX Runtime |
| 55 | + User->>RunPy: Execute main() |
| 56 | + RunPy->>IR: Create with config, MR, paths, meshes |
| 57 | + IR->>MR: initialize(dummy_data, meshes) |
| 58 | + MR->>Model: model.initialize(), fprop_dtype=bf16 |
| 59 | + Note over MR,JAX: Calculate batch sizes, create mesh (data, model axes) |
| 60 | + MR->>MR: hk.transform forward/logits_fn with pjit sharding |
| 61 | + MR->>Checkpoint: load_or_init -> restore(shapes, mesh, sharding) |
| 62 | + Checkpoint->>MR: Sharded params (TrainingState) |
| 63 | + IR->>IR: Load tokenizer, compile pjit funcs (sample_step, prefill_memory, new_memory) with shardings |
| 64 | + IR->>IR: Precompile with dummy prompts for pad_sizes |
| 65 | + RunPy->>IR: gen = run() // generator setup with initial memory, settings, etc. |
| 66 | +``` |
| 67 | + |
| 68 | +## Inference and Sampling Sequence |
| 69 | + |
| 70 | +```mermaid |
| 71 | +sequenceDiagram |
| 72 | + participant Gen as Generator (run()) |
| 73 | + participant Req as Request |
| 74 | + participant Tok as Tokenizer |
| 75 | + participant Prefill as prefill_memory |
| 76 | + participant Step as sample_step |
| 77 | + participant LM as LM forward |
| 78 | + participant Samp as sample_token |
| 79 | + participant Mem as KV Memory |
| 80 | + participant Out as Output |
| 81 | +
|
| 82 | + Note over Gen: Initial setup: memory, rngs, settings, last_output |
| 83 | +
|
| 84 | + Gen->>Req: yield (wait for input) |
| 85 | + Req->>Gen: send Request(prompt, temp, p, seed, max_len) |
| 86 | + Gen->>Tok: encode(prompt) -> tokens |
| 87 | + Gen->>Gen: pad tokens, create settings, active=1 |
| 88 | + Gen->>Prefill: call prefill_memory(tokens, len, new_settings, slot) |
| 89 | + Prefill->>LM: hk_forward(tokens, new_mem, length, active) // process prompt |
| 90 | + LM->>Samp: sample_token from logits // sample first token? |
| 91 | + Prefill->>Mem: update KV cache with prompt tokens + first? |
| 92 | + Prefill->>Gen: updated rngs, last_output, memory, settings |
| 93 | + loop Autoregressive Sampling (while active and < max_len) |
| 94 | + Gen->>Step: sample_step(params, rngs, last_output, memory, settings) |
| 95 | + Step->>LM: hk_forward(last_token, memory) // decode step |
| 96 | + LM->>Samp: sample_token(logits, settings) |
| 97 | + Step->>Mem: update memory with new KV (donate old) |
| 98 | + Step->>Gen: new rngs, sample_output, memory |
| 99 | + Gen->>Gen: append token to sequence, copy to host |
| 100 | + alt Reached max_len or EOS? |
| 101 | + Gen->>Out: decode all tokens -> yield text |
| 102 | + Gen->>Gen: deactivate slot, free for new req |
| 103 | + end |
| 104 | + end |
| 105 | +``` |
| 106 | + |
| 107 | +## Sharding and Distributed Execution |
| 108 | + |
| 109 | +- **Mesh Configuration**: `make_mesh(local=(data_replicas, model_par), between_hosts=(data_hosts, model_hosts))` creates hybrid mesh for SPMD parallelism. E.g., local 1x8 shards model across 8 GPUs. |
| 110 | +- **Sharding Specs**: Model params sharded per rules (e.g., embedding over model, attention QKV over data/model/head). Activations optionally sharded. KV Memory sharded over data axis. |
| 111 | +- **PJIT & Compilation**: Functions wrapped in `hk.transform` then `pjit` with explicit in/out shardings, static args, donation for memory efficiency. Precompilation with dummies reduces first-run latency. |
| 112 | +- **Multi-Host**: Checkpoint loading syncs via `multihost_utils`, assumes launched with `jax process_count()` matching topology. |
| 113 | +- **Memory Optimizations**: bfloat16 compute, 8-bit weight quantization (dequant on fly), KV cache management, activation checkpointing/sharding, padding truncation. |
| 114 | + |
| 115 | +## Sampling Mechanism |
| 116 | + |
| 117 | +- **sample_token**: Scales logits by 1/temp, applies mask (-inf to disallowed), nucleus filter (sort probs, threshold at cumsum >=1-p, mask others to -inf), categorical sample from softmax. Returns token, prob, top-k tokens/probs. |
| 118 | +- **top_p_filter**: Sorts logits descending, soft max to probs, finds minimal set summing to p mass. |
| 119 | +- **Batch Integration**: Settings (temp, p, mask, active) broadcasted/vmap'ed across batch. Active flag skips inactive computations by resetting cache steps. |
| 120 | +- **RNG**: Per-slot PRNG keys split and updated each step. |
| 121 | +- **Defaults**: nucleus_p=1.0 (full dist), temp=0.01 for low randomness in tests, top_k=8 for auxiliary. |
| 122 | + |
| 123 | +## Other Design Aspects |
| 124 | + |
| 125 | +- **KV Caching**: `Memory` dataclass with layers of `KVMemory(k: [batch,heads,seqlen,head_dim], v:..., step:scalar)`. Updated via dynamic_update and pad_to_max_len. Supports variable lengths per slot via length param in forward. |
| 126 | +- **Batching Strategy**: Fixed global batch_size, slots filled on-demand. Pad buckets (e.g., 1024) group similar lengths? Code uses bisect for bucket but pads to bucket size in prefill. |
| 127 | +- **Error/Edge Cases**: Assumes sufficient memory/GPUs; handles long contexts by left-truncation/padding. No built-in EOS handling (relies on max_len or app logic). Quantized weights require custom unpickling. |
| 128 | +- **Performance Notes**: MoE router/experts use JAX vmap/shard_map (serial per-token, inefficient for prod). Focus on correctness/single-host validation. |
| 129 | +- **Extensibility**: Modular Haiku design allows custom configs/modules. Generator interface suits serving multiple prompts concurrently. |
| 130 | +- **Dependencies & Setup**: `requirements.txt` (jax[cuda12_pip], haiku, etc.). Download ckpt via torrent/HF, place in checkpoints/. |
| 131 | + |
| 132 | +This document captures the high-level design, derived from code analysis. |
0 commit comments