Skip to content

Commit 4f45f05

Browse files
add initial design files
1 parent 7050ed2 commit 4f45f05

6 files changed

+550
-0
lines changed
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
# Workflow Design: 1 - Grok-1 Inference and Sampling
2+
3+
## Overview
4+
5+
The \"Grok-1 Inference and Sampling\" workflow provides the machinery to load the Grok-1 model's 314 billion parameters from a checkpoint, initialize the decoder-only transformer architecture with Mixture-of-Experts (MoE) layers and Grouped Query Attention (GQA), set up distributed sharding across GPUs using JAX meshes and PJIT, tokenize prompts with SentencePiece, and generate text autoregressively. Sampling incorporates temperature-controlled softmax, nucleus (top-p) filtering for diversity control, and top-k logging. The design emphasizes correctness for validation, supporting batched multi-request handling via a generator that manages KV caches per request slot, padding for variable lengths, and efficient decode steps post-prefill.
6+
7+
Key inputs: Checkpoint in `./checkpoints/ckpt-0/`, `tokenizer.model`, GPU cluster, prompts as `Request` objects (prompt str, temperature float, nucleus_p float, rng_seed int, max_len int).
8+
Outputs: Generated text strings.
9+
Entry points: `run.py` for test run, or `InferenceRunner().run()` generator for streaming requests.
10+
Relevant files: `run.py`, `runners.py`, `model.py`, `checkpoint.py`, `tokenizer.model`.
11+
12+
The workflow orchestrates model loading, compilation of sharded compute functions, prompt processing (prefill KV cache while sampling first token), and iterative single-token generation using cached attention keys/values, until max length or EOS.
13+
14+
## Components
15+
16+
### run.py
17+
- Defines Grok-1 hyperparameters via `LanguageModelConfig` and `TransformerConfig` (e.g., vocab_size=131072, sequence_len=8192, emb_size=6144, num_layers=64, num_q_heads=48, num_kv_heads=8, num_experts=8, num_selected_experts=2, widening_factor=8, key_size=128, shard_activations=True).
18+
- Instantiates `InferenceRunner` with `ModelRunner`, checkpoint_path=\"./checkpoints/\", tokenizer_path=\"./tokenizer.model\", mesh configs (local=(1,8) for 8 GPUs, between_hosts=(1,1)).
19+
- Calls `initialize()` and `run()` generator, demonstrates sampling via `sample_from_model(gen, prompt, max_len=100, temperature=0.01)` on example prompt.
20+
21+
### runners.py
22+
- **ModelRunner**: Configures model dtype (bfloat16), computes batch sizes from bs_per_device * devices * replicas, creates hybrid JAX mesh, defines/transforms Haiku forward functions for full pass and logits-only, applies partition rules for sharding, loads or inits params via checkpoint restore. Supports quantization and activation sharding.
23+
- **InferenceRunner**: Loads tokenizer, computes param sharding from shapes, compiles PJIT functions:
24+
- `new_memory`: Initialize KV cache for batch/seq_len.
25+
- `prefill_memory`: For a request slot, encode prompt, pad, process full prompt forward (with new memory, length tracking), sample first gen token, update global batch memory/settings/rngs/last_output.
26+
- `sample_step`: For all active slots, forward on last token using shared memory, sample next token, update memory (donate for efficiency).
27+
- `run()` generator: Precompiles with dummy prompts for pad buckets, manages fixed batch slots (some free), yields for requests, fills slots via prefill, loops stepping all active, appends tokens per slot on host, yields decoded text when done, deactivates slot. Handles concurrency via free_slots list.
28+
29+
### model.py
30+
- Architecture: Embeddings → 64 Transformer layers (RMSNorm → GQA MultiHeadAttention with RoPE & KV cache → residual → RMSNorm → MoELayer → residual) → output linear to logits.
31+
- **MultiHeadAttention**: GQA (48 query / 8 KV heads, head_dim=128), supports caching via `Memory` (list of `KVMemory` per layer with k,v,step).
32+
- **MoELayer**: Router selects top-2 of 8 experts per token, each expert is SwiGLU FFN; uses shard_map/vmap for dispatch (validation-focused, not optimized).
33+
- Other: `RotaryEmbedding`, custom `Linear` with quantization, `apply_rules` for sharding specs (P('data'), P('model'), etc.).
34+
- Forward callable via `make(mesh)` integrates sharding, returns `LanguageModelOutput` (logits, model_state=Memory).
35+
36+
### checkpoint.py
37+
- `restore()`: Computes shapes, loads pickled sharded checkpoint files (handles `QuantizedWeight8bit`), copies to shared memory (/dev/shm) for fast access, syncs across hosts via broadcast, shards into JAX arrays matching specified sharding/mesh. Supports params_only, init_state fallback, rename/exclude rules.
38+
39+
### tokenizer.model & Others
40+
- SentencePiece for subword tokenization (pad_token=0, eos_token=2).
41+
- Dependencies: JAX (distributed arrays, pjit, shard_map), Haiku (modules/transform), NumPy/Jax.numpy, sentencepiece.
42+
- checkpoints/: Directory for downloaded weights (torrent or HF).
43+
44+
## Initialization Sequence
45+
46+
```mermaid
47+
sequenceDiagram
48+
participant User
49+
participant RunPy as run.py
50+
participant IR as InferenceRunner
51+
participant MR as ModelRunner
52+
participant Model as model.py
53+
participant Checkpoint as checkpoint.py
54+
participant JAX as JAX Runtime
55+
User->>RunPy: Execute main()
56+
RunPy->>IR: Create with config, MR, paths, meshes
57+
IR->>MR: initialize(dummy_data, meshes)
58+
MR->>Model: model.initialize(), fprop_dtype=bf16
59+
Note over MR,JAX: Calculate batch sizes, create mesh (data, model axes)
60+
MR->>MR: hk.transform forward/logits_fn with pjit sharding
61+
MR->>Checkpoint: load_or_init -> restore(shapes, mesh, sharding)
62+
Checkpoint->>MR: Sharded params (TrainingState)
63+
IR->>IR: Load tokenizer, compile pjit funcs (sample_step, prefill_memory, new_memory) with shardings
64+
IR->>IR: Precompile with dummy prompts for pad_sizes
65+
RunPy->>IR: gen = run() // generator setup with initial memory, settings, etc.
66+
```
67+
68+
## Inference and Sampling Sequence
69+
70+
```mermaid
71+
sequenceDiagram
72+
participant Gen as Generator (run())
73+
participant Req as Request
74+
participant Tok as Tokenizer
75+
participant Prefill as prefill_memory
76+
participant Step as sample_step
77+
participant LM as LM forward
78+
participant Samp as sample_token
79+
participant Mem as KV Memory
80+
participant Out as Output
81+
82+
Note over Gen: Initial setup: memory, rngs, settings, last_output
83+
84+
Gen->>Req: yield (wait for input)
85+
Req->>Gen: send Request(prompt, temp, p, seed, max_len)
86+
Gen->>Tok: encode(prompt) -> tokens
87+
Gen->>Gen: pad tokens, create settings, active=1
88+
Gen->>Prefill: call prefill_memory(tokens, len, new_settings, slot)
89+
Prefill->>LM: hk_forward(tokens, new_mem, length, active) // process prompt
90+
LM->>Samp: sample_token from logits // sample first token?
91+
Prefill->>Mem: update KV cache with prompt tokens + first?
92+
Prefill->>Gen: updated rngs, last_output, memory, settings
93+
loop Autoregressive Sampling (while active and < max_len)
94+
Gen->>Step: sample_step(params, rngs, last_output, memory, settings)
95+
Step->>LM: hk_forward(last_token, memory) // decode step
96+
LM->>Samp: sample_token(logits, settings)
97+
Step->>Mem: update memory with new KV (donate old)
98+
Step->>Gen: new rngs, sample_output, memory
99+
Gen->>Gen: append token to sequence, copy to host
100+
alt Reached max_len or EOS?
101+
Gen->>Out: decode all tokens -> yield text
102+
Gen->>Gen: deactivate slot, free for new req
103+
end
104+
end
105+
```
106+
107+
## Sharding and Distributed Execution
108+
109+
- **Mesh Configuration**: `make_mesh(local=(data_replicas, model_par), between_hosts=(data_hosts, model_hosts))` creates hybrid mesh for SPMD parallelism. E.g., local 1x8 shards model across 8 GPUs.
110+
- **Sharding Specs**: Model params sharded per rules (e.g., embedding over model, attention QKV over data/model/head). Activations optionally sharded. KV Memory sharded over data axis.
111+
- **PJIT & Compilation**: Functions wrapped in `hk.transform` then `pjit` with explicit in/out shardings, static args, donation for memory efficiency. Precompilation with dummies reduces first-run latency.
112+
- **Multi-Host**: Checkpoint loading syncs via `multihost_utils`, assumes launched with `jax process_count()` matching topology.
113+
- **Memory Optimizations**: bfloat16 compute, 8-bit weight quantization (dequant on fly), KV cache management, activation checkpointing/sharding, padding truncation.
114+
115+
## Sampling Mechanism
116+
117+
- **sample_token**: Scales logits by 1/temp, applies mask (-inf to disallowed), nucleus filter (sort probs, threshold at cumsum >=1-p, mask others to -inf), categorical sample from softmax. Returns token, prob, top-k tokens/probs.
118+
- **top_p_filter**: Sorts logits descending, soft max to probs, finds minimal set summing to p mass.
119+
- **Batch Integration**: Settings (temp, p, mask, active) broadcasted/vmap'ed across batch. Active flag skips inactive computations by resetting cache steps.
120+
- **RNG**: Per-slot PRNG keys split and updated each step.
121+
- **Defaults**: nucleus_p=1.0 (full dist), temp=0.01 for low randomness in tests, top_k=8 for auxiliary.
122+
123+
## Other Design Aspects
124+
125+
- **KV Caching**: `Memory` dataclass with layers of `KVMemory(k: [batch,heads,seqlen,head_dim], v:..., step:scalar)`. Updated via dynamic_update and pad_to_max_len. Supports variable lengths per slot via length param in forward.
126+
- **Batching Strategy**: Fixed global batch_size, slots filled on-demand. Pad buckets (e.g., 1024) group similar lengths? Code uses bisect for bucket but pads to bucket size in prefill.
127+
- **Error/Edge Cases**: Assumes sufficient memory/GPUs; handles long contexts by left-truncation/padding. No built-in EOS handling (relies on max_len or app logic). Quantized weights require custom unpickling.
128+
- **Performance Notes**: MoE router/experts use JAX vmap/shard_map (serial per-token, inefficient for prod). Focus on correctness/single-host validation.
129+
- **Extensibility**: Modular Haiku design allows custom configs/modules. Generator interface suits serving multiple prompts concurrently.
130+
- **Dependencies & Setup**: `requirements.txt` (jax[cuda12_pip], haiku, etc.). Download ckpt via torrent/HF, place in checkpoints/.
131+
132+
This document captures the high-level design, derived from code analysis.
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
# Design: Workflow #2 - Model Loading and Initialization
2+
3+
## Overview
4+
5+
This workflow defines the Grok-1 model architecture using JAX and Haiku, and handles loading model parameters from a sharded checkpoint or initializing them randomly. It supports advanced features like 8-bit weight quantization, activation sharding for memory efficiency, and distributed sharding across multiple GPUs and hosts via JAX's SPMD parallelism.
6+
7+
**Inputs:**
8+
- Model configurations (`LanguageModelConfig`, `TransformerConfig`) specifying Grok-1 hyperparameters (e.g., 64 layers, 6144 embed dim, MoE with 8 experts/2 selected, GQA with 48/8 heads).
9+
- Checkpoint path (e.g., `./checkpoints/ckpt-0/` containing sharded tensor files).
10+
- Mesh configurations: `local_mesh_config` (GPUs per host, e.g., (1, 8)), `between_hosts_config` (replicas/hosts).
11+
- Dummy init data for shape inference and initialization.
12+
13+
**Outputs:**
14+
- `TrainingState` with sharded parameters (`params`), ready for use in forward passes or inference/training loops.
15+
16+
The process ensures efficient loading of 314B parameters, correct mapping between checkpoint structure and model params (via rename/exclude rules), and proper distribution to devices.
17+
18+
**Entry Point:** `runners.ModelRunner.load_or_init()` or `checkpoint.restore()`.
19+
20+
## Components
21+
22+
### Model Definition (`model.py`)
23+
- **Configurations:**
24+
- `TransformerConfig`: Core params including `emb_size=6144`, `key_size=128`, `num_layers=64`, `num_q_heads=48`, `num_kv_heads=8`, MoE settings (`num_experts=8`, `num_selected_experts=2`, `widening_factor=8`), sharding axes (`data_axis`, `model_axis`), activation sharding flag.
25+
- `LanguageModelConfig`: Extends with `vocab_size=131072`, `sequence_len=8192`, embedding/output scales, `make()` method to instantiate `LanguageModel` Haiku module (embeddings → transformer → output logits).
26+
- **Architecture Modules:** Haiku-based decoder-only transformer with RMSNorm, Multi-Head Attention (GQA, RoPE, KV caching), MoE FFN (SwiGLU), custom Linear with quantization support.
27+
- **Sharding:** `partition_rules()` returns specs like `P('model', None)` for weights, enabling data/model parallelism.
28+
- **Initialization:** Uses Haiku initializers with config scales; supports `fprop_dtype=jnp.bfloat16`.
29+
30+
### Orchestration (`runners.py`)
31+
- **`ModelRunner` dataclass:** Central coordinator.
32+
- `initialize(init_data, local_mesh_config, between_hosts_config)`: Computes batch sizes, creates JAX mesh, defines `forward` and `logits_fn` via `hk.transform` and `pjit` for sharded execution, derives `state_sharding` using `eval_shape` and partition rules.
33+
- `load_or_init(init_data, from_checkpoint=True)`: Branches to checkpoint loading or random init; wraps in mesh context for sharding.
34+
- Supports custom `init_fn`, RNG seeding, transform flags for full state (params/optimizers in future).
35+
36+
### Checkpoint Handling (`checkpoint.py`)
37+
- **`restore(checkpoint_path, state_shapes, mesh, between_hosts_config, state_sharding, params_only, init_state)`:** Loads and shards params.
38+
- `load_tensors()`: Multithreaded (32 workers) parallel unpickling of sharded files (`tensor{i:05d}_{idx:03d}`) based on process index.
39+
- `replace_with_load_state()`: Maps checkpoint keys to model structure using regex rename/exclude rules, fills missing with zeros or init.
40+
- Assembly: Flattens/unflattens trees, sanity checks param keys.
41+
- Distribution: `multihost_utils.host_local_array_to_global_array` to create sharded global arrays.
42+
- **Optimizations:** `fast_unpickle`/`fast_pickle` using `/dev/shm` temp files for I/O speed; handles `QuantizedWeight8bit`.
43+
- Logging per rank for debugging.
44+
45+
## Sequence Diagram
46+
47+
```mermaid
48+
sequenceDiagram
49+
participant S as Script/User
50+
participant MR as ModelRunner
51+
participant MD as Model (model.py)
52+
participant CL as Checkpoint (checkpoint.py)
53+
participant JM as JAX Mesh
54+
participant D as Devices
55+
56+
S->>+MR: new ModelRunner(config)
57+
MR->>+MD: model.make(mesh) [in init]
58+
Note right of MR: initialize(local_mesh_config, between_hosts_config, init_data)
59+
MR->>+JM: make_mesh(configs)
60+
JM-->>-MR: mesh
61+
MR->>+MR: hk.transform(forward) & pjit
62+
MR->>+MR: compute state_sharding via eval_shape & partition_rules
63+
64+
alt Load from Checkpoint
65+
MR->>+MR: load_or_init(init_data, from_checkpoint=True)
66+
MR->>+MR: eval_shape(init_fn) -> shapes
67+
MR->>+CL: restore(path, shapes, mesh, sharding, params_only=True)
68+
Note right of CL: load_tensors(): parallel unpickle sharded tensors<br/>from ckpt-0/tensorXXXX_YYY
69+
CL->>+JM: host_local_to_global_array(state, mesh, sharding)
70+
JM->>+D: Shard params across devices/hosts
71+
D-->>-JM:
72+
JM-->>-CL: Sharded state
73+
CL-->>-MR: params
74+
else Random Init
75+
MR->>+MR: load_or_init(init_data, from_checkpoint=False)
76+
MR->>+MR: init_fn(rng, init_data) -> forward.init(rng, inputs)
77+
Note right of MR: Generates random params matching shapes
78+
MR->>+JM: Shard new params
79+
JM-->>-MR: Sharded params
80+
end
81+
82+
MR-->>-S: Sharded TrainingState(params)
83+
```
84+
85+
## Additional Design Aspects
86+
87+
### Sharding Strategy
88+
- **Mesh Axes:** Data (batch parallelism/replicas), Model (parameter sharding).
89+
- **Rules:** Explicit `PartitionSpec` for components (e.g., QKV projections sharded over heads/model axis, MoE experts replicated or sharded).
90+
- **Activation Sharding:** Configurable to shard intermediates along data axis, reducing per-device memory.
91+
- **KV Memory:** Sharded for caching in autoregressive generation.
92+
93+
### Quantization and Precision
94+
- **8-bit Weights:** Checkpoint may contain `QuantizedWeight8bit`; dequantized on-the-fly in Linear layers.
95+
- **Compute:** bfloat16 for forward pass to balance precision/speed.
96+
- **Memory Management:** Sharding + quantization enable loading on limited hardware (e.g., 8x H100s).
97+
98+
### Error Handling and Validation
99+
- Param key mismatch raises ValueError with details.
100+
- Exclusion/rename rules for flexibility (e.g., adapting external checkpoints).
101+
- Per-rank logging for distributed debugging.
102+
- Shape consistency via `eval_shape` before loading.
103+
104+
### Trade-offs
105+
- **Performance vs. Simplicity:** Uses JAX standard ops; MoE inefficient (no fused kernels/expert parallelism) for validation focus.
106+
- **Resource Intensive:** Requires fast storage/network for multi-host loading; assumes high-end GPUs.
107+
- **Extensibility:** Modular configs allow variants; easy integration with custom init_fns.
108+
109+
### Relevant Files
110+
- `model.py`: Architecture, configs, partition rules.
111+
- `runners.py`: ModelRunner, mesh setup, load_or_init.
112+
- `checkpoint.py`: restore, tensor loading, sharding utils.
113+
- `run.py`: Example config instantiation and runner usage.
114+
115+
This design prioritizes correctness and distributed scalability for the massive Grok-1 model.

0 commit comments

Comments
 (0)