Skip to content

Commit 3141beb

Browse files
committed
tt.sync / tt.shared_store / tt.shared_load
1 parent b1cdb0e commit 3141beb

13 files changed

Lines changed: 617 additions & 15 deletions

File tree

README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,8 +132,9 @@ Expected: ~3x fewer kernel launches, ~3x speedup.
132132

133133
#### Layer 4 — Tiled matmul with shared memory (C++ MLIR change)
134134

135-
- [ ] Tiled matmul with shared memory + barriers — reuse each weight row across multiple threads via shared memory; impactful at n_embd >= 64
136-
- [ ] Requires adding `gpu.barrier` and shared memory ops to the MLIR pipeline
135+
- [x] Shared memory primitives — `tt.sync()`, `tt.shared_store(idx, val)`, `tt.shared_load(idx)`[`examples/tiled_matvec_test.py`](examples/tiled_matvec_test.py), [`docs/15-shared-memory.md`](docs/15-shared-memory.md)
136+
- [x] Tiled 2-row-per-block matvec demo using shared memory for x vector reuse
137+
- [ ] Full tiled GEMM (requires `tt.for_range` loop support → `scf.for` in MLIR)
137138

138139
#### Layer 5 — Flash Attention (algorithmic, longer sequences)
139140

bindings/python_bindings.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,20 @@ PYBIND11_MODULE(_tiny_ton_core, m) {
141141
[](tinyton::IRBuilder &self, PyValue cond, int64_t skip) {
142142
self.emitBranchZero(cond.val, skip);
143143
})
144+
.def("emit_sync", [](tinyton::IRBuilder &self) { self.emitSync(); })
145+
.def("emit_shared_store",
146+
[](tinyton::IRBuilder &self, PyValue idx, PyValue val,
147+
int64_t bufferSize) {
148+
self.emitSharedStore(idx.val, val.val, bufferSize);
149+
},
150+
py::arg("idx"), py::arg("val"), py::arg("buffer_size"))
151+
.def("emit_shared_load",
152+
[](tinyton::IRBuilder &self, PyValue idx, int64_t bufferSize,
153+
const std::string &dtype) {
154+
auto et = tinyton::elementTypeFromString(dtype);
155+
return PyValue{self.emitSharedLoad(idx.val, bufferSize, et)};
156+
},
157+
py::arg("idx"), py::arg("buffer_size"), py::arg("dtype") = "f32")
144158
.def("emit_ret", [](tinyton::IRBuilder &self) { self.emitRet(); })
145159
.def("dump_mlir",
146160
[](tinyton::IRBuilder &self) {

docs/15-shared-memory.md

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
# Shared Memory: tt.sync / tt.shared_store / tt.shared_load
2+
3+
## The problem
4+
5+
In the current `linear_kernel`, each block computes one output row of `y = W @ x`.
6+
Every block loads the full `x` vector from global memory independently:
7+
8+
```
9+
Block 0: load x[0..N-1] from global → dot with W[0,:]
10+
Block 1: load x[0..N-1] from global → dot with W[1,:]
11+
Block 2: load x[0..N-1] from global → dot with W[2,:]
12+
...
13+
```
14+
15+
The same `x` data is read `out_features` times from global memory. On real
16+
GPUs the L2 cache usually handles this for small vectors, but for larger data
17+
the redundant reads become a bottleneck.
18+
19+
Shared memory solves this by loading `x` once and letting all threads in a
20+
block reuse it from fast on-chip storage.
21+
22+
## New primitives
23+
24+
```python
25+
tt.sync() # barrier — all threads in the block wait here
26+
tt.shared_store(idx, val) # write val to shared memory at position idx
27+
val = tt.shared_load(idx) # read from shared memory at position idx
28+
```
29+
30+
Shared memory is **per-block**: each block has its own buffer, sized
31+
automatically to `BLOCK` (from `tt.arange(0, BLOCK)`).
32+
33+
## Execution model
34+
35+
```
36+
Thread 0: load x[0] from global → shared_store(0, x[0])
37+
Thread 1: load x[1] from global → shared_store(1, x[1])
38+
...
39+
Thread N-1: load x[N-1] from global → shared_store(N-1, x[N-1])
40+
41+
╔═══════════╗
42+
║ tt.sync() ║ ← all threads wait here
43+
╚═══════════╝
44+
45+
Thread 0: x_sh = shared_load(0..N-1) → compute dot with W row
46+
Thread 1: x_sh = shared_load(0..N-1) → compute dot with W row
47+
```
48+
49+
## Example: tiled 2-row-per-block matvec
50+
51+
Instead of 1 output row per block, each block computes 2 rows. The `x` vector
52+
is loaded into shared memory once and reused for both dot products:
53+
54+
```python
55+
@tt.jit
56+
def tiled_linear_kernel(W_ptr, x_ptr, y_ptr, in_features, BLOCK: tt.constexpr):
57+
pid = tt.program_id(0)
58+
tid = tt.arange(0, BLOCK)
59+
mask = tid < in_features
60+
61+
x_val = tt.load(x_ptr + tid, mask=mask)
62+
tt.shared_store(tid, x_val)
63+
tt.sync()
64+
x_sh = tt.shared_load(tid)
65+
66+
w0 = tt.load(W_ptr + (pid * 2) * in_features + tid, mask=mask)
67+
w1 = tt.load(W_ptr + (pid * 2 + 1) * in_features + tid, mask=mask)
68+
dot0 = tt.reduce_sum(w0 * x_sh)
69+
dot1 = tt.reduce_sum(w1 * x_sh)
70+
tt.store(y_ptr + pid * 2, dot0)
71+
tt.store(y_ptr + pid * 2 + 1, dot1)
72+
```
73+
74+
Launch with `grid = (out_features // 2,)` — half the blocks, each doing 2x work.
75+
76+
## MLIR lowering
77+
78+
| Python | TinyTon IR | GPU dialect |
79+
|---|---|---|
80+
| `tt.sync()` | `tinyton.sync` | `gpu.barrier` |
81+
| `tt.shared_store(idx, val)` | `tinyton.shared_store %idx, %val size 64` | `memref.store` to workgroup memref |
82+
| `tt.shared_load(idx)` | `tinyton.shared_load %idx size 64` | `memref.load` from workgroup memref |
83+
84+
The `size` attribute is the buffer size, baked in at compile time from
85+
`block_size` (captured via `tt.arange`). The GPU lowering allocates a
86+
`memref<size x f32, #gpu.address_space<workgroup>>` as a second workgroup
87+
attribution (separate from the 32-element buffer used by `reduce_sum`/
88+
`reduce_max`).
89+
90+
## Simulator
91+
92+
The simulator maintains a 256-element `sharedMem` vector per block. Instructions
93+
are distinguished from regular `LDR`/`STR` by flag bits in the encoding:
94+
95+
- `SHMEM_STR`: opcode 0x8 with rd=1 (global STR has rd=0)
96+
- `SHMEM_LDR`: opcode 0x7 with rt=1 (global LDR has rt=0)
97+
- `SYNC`: opcode 0xF with imm=1 (RET has imm=0); pauses all threads and
98+
resumes them in the next phase, matching GPU barrier semantics.
99+
100+
## Files changed
101+
102+
| File | Change |
103+
|---|---|
104+
| `include/tiny-ton/Dialect/TinyTon/TinyTonOps.td` | `SyncOp`, `SharedStoreOp`, `SharedLoadOp` |
105+
| `include/tiny-ton/IR/Builder.h` | `emitSync`, `emitSharedStore`, `emitSharedLoad` |
106+
| `lib/IR/Builder.cpp` | Implementation |
107+
| `lib/Conversion/TinyTonToGPU.cpp` | Pre-scan for buffer size, second workgroup memref, lowering |
108+
| `lib/Compiler/CodeGen.cpp` | Simulator instruction encoding |
109+
| `lib/Runtime/Simulator.cpp` | `sharedMem` buffer, `StepResult::Sync`, flag-based dispatch |
110+
| `bindings/python_bindings.cpp` | Python bindings |
111+
| `python/tiny_ton/jit.py` | `_BUILTINS`, `_eval_call` handlers |
112+
| `python/tiny_ton/__init__.py` | Stubs |
113+
| `examples/tiled_matvec_test.py` | Round-trip + tiled matvec tests |
114+
| `docs/15-shared-memory.md` | This design doc |
115+
116+
## What this does NOT include
117+
118+
Full tiled GEMM requires iterating over K tiles inside the kernel via
119+
`tt.for_range`, which generates `scf.for` in MLIR. That is a separate
120+
future addition. This plan provides all the shared memory building blocks
121+
that tiled GEMM needs.

examples/tiled_matvec_test.py

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
"""Test shared memory ops: tt.sync, tt.shared_store, tt.shared_load.
2+
3+
Verifies:
4+
- Basic shared memory round-trip: store then load
5+
- Tiled 2-row-per-block matvec using shared memory for x vector reuse
6+
- Correctness against NumPy for various matrix sizes
7+
"""
8+
9+
import sys
10+
import os
11+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'python'))
12+
13+
import numpy as np
14+
import tiny_ton as tt
15+
16+
17+
def compare(name, got, expected, atol=1e-4):
18+
ok = np.allclose(got, expected, atol=atol)
19+
print(f' {name}: {"PASS" if ok else "FAIL"}')
20+
if not ok:
21+
print(f' got: {got}')
22+
print(f' expected: {expected}')
23+
return ok
24+
25+
26+
# ---------------------------------------------------------------------------
27+
# Kernel 1: shared memory round-trip (store → sync → load)
28+
# ---------------------------------------------------------------------------
29+
30+
@tt.jit
31+
def shmem_roundtrip(src, dst, N, BLOCK: tt.constexpr):
32+
tid = tt.arange(0, BLOCK)
33+
mask = tid < N
34+
val = tt.load(src + tid, mask=mask)
35+
tt.shared_store(tid, val)
36+
tt.sync()
37+
out = tt.shared_load(tid)
38+
tt.store(dst + tid, out, mask=mask)
39+
40+
41+
# ---------------------------------------------------------------------------
42+
# Kernel 2: single-row linear (baseline, no shared memory)
43+
# ---------------------------------------------------------------------------
44+
45+
@tt.jit
46+
def linear_kernel(W_ptr, x_ptr, y_ptr, in_features, BLOCK: tt.constexpr):
47+
pid = tt.program_id(0)
48+
tid = tt.arange(0, BLOCK)
49+
mask = tid < in_features
50+
w = tt.load(W_ptr + pid * in_features + tid, mask=mask)
51+
x = tt.load(x_ptr + tid, mask=mask)
52+
dot = tt.reduce_sum(w * x)
53+
tt.store(y_ptr + pid, dot)
54+
55+
56+
# ---------------------------------------------------------------------------
57+
# Kernel 3: tiled 2-row-per-block linear (shared memory for x reuse)
58+
# ---------------------------------------------------------------------------
59+
60+
@tt.jit
61+
def tiled_linear_kernel(W_ptr, x_ptr, y_ptr, in_features, BLOCK: tt.constexpr):
62+
pid = tt.program_id(0)
63+
tid = tt.arange(0, BLOCK)
64+
mask = tid < in_features
65+
66+
x_val = tt.load(x_ptr + tid, mask=mask)
67+
tt.shared_store(tid, x_val)
68+
tt.sync()
69+
x_sh = tt.shared_load(tid)
70+
71+
w0 = tt.load(W_ptr + (pid * 2) * in_features + tid, mask=mask)
72+
w1 = tt.load(W_ptr + (pid * 2 + 1) * in_features + tid, mask=mask)
73+
dot0 = tt.reduce_sum(w0 * x_sh)
74+
dot1 = tt.reduce_sum(w1 * x_sh)
75+
tt.store(y_ptr + pid * 2, dot0)
76+
tt.store(y_ptr + pid * 2 + 1, dot1)
77+
78+
79+
# ---------------------------------------------------------------------------
80+
# Tests
81+
# ---------------------------------------------------------------------------
82+
83+
def test_shmem_roundtrip():
84+
print('--- shared memory round-trip ---')
85+
all_ok = True
86+
for N in [4, 16, 27]:
87+
x = np.random.randn(N).astype(np.float32)
88+
out = np.zeros(N, dtype=np.float32)
89+
shmem_roundtrip[(1,)](x.copy(), out, N, N)
90+
ok = compare(f'N={N}', out, x)
91+
all_ok = all_ok and ok
92+
return all_ok
93+
94+
95+
def test_single_row_linear():
96+
print('--- single-row linear (baseline) ---')
97+
all_ok = True
98+
for out_features, in_features in [(4, 4), (8, 16), (6, 27)]:
99+
W = np.random.randn(out_features, in_features).astype(np.float32)
100+
x = np.random.randn(in_features).astype(np.float32)
101+
expected = W @ x
102+
y = np.zeros(out_features, dtype=np.float32)
103+
BLOCK = max(in_features, 4)
104+
linear_kernel[(out_features,)](
105+
W.flatten().copy(), x.copy(), y, in_features, BLOCK)
106+
ok = compare(f'{out_features}x{in_features}', y, expected)
107+
all_ok = all_ok and ok
108+
return all_ok
109+
110+
111+
def test_tiled_linear():
112+
print('--- tiled 2-row-per-block linear (shared memory) ---')
113+
all_ok = True
114+
for out_features, in_features in [(4, 4), (8, 16), (6, 27)]:
115+
W = np.random.randn(out_features, in_features).astype(np.float32)
116+
x = np.random.randn(in_features).astype(np.float32)
117+
expected = W @ x
118+
y = np.zeros(out_features, dtype=np.float32)
119+
BLOCK = max(in_features, 4)
120+
n_blocks = out_features // 2
121+
tiled_linear_kernel[(n_blocks,)](
122+
W.flatten().copy(), x.copy(), y, in_features, BLOCK)
123+
ok = compare(f'{out_features}x{in_features} tiled', y, expected)
124+
all_ok = all_ok and ok
125+
return all_ok
126+
127+
128+
def test_tiled_vs_baseline():
129+
print('--- tiled vs baseline match ---')
130+
W = np.random.randn(8, 16).astype(np.float32)
131+
x = np.random.randn(16).astype(np.float32)
132+
133+
y_baseline = np.zeros(8, dtype=np.float32)
134+
linear_kernel[(8,)](W.flatten().copy(), x.copy(), y_baseline, 16, 16)
135+
136+
y_tiled = np.zeros(8, dtype=np.float32)
137+
tiled_linear_kernel[(4,)](W.flatten().copy(), x.copy(), y_tiled, 16, 16)
138+
139+
return compare('8x16 baseline==tiled', y_tiled, y_baseline)
140+
141+
142+
if __name__ == '__main__':
143+
np.random.seed(42)
144+
results = [
145+
test_shmem_roundtrip(),
146+
test_single_row_linear(),
147+
test_tiled_linear(),
148+
test_tiled_vs_baseline(),
149+
]
150+
print()
151+
if all(results):
152+
print('All tests PASSED')
153+
else:
154+
print('SOME TESTS FAILED')
155+
sys.exit(1)

0 commit comments

Comments
 (0)