Skip to content

Commit 63bdf94

Browse files
authored
Fix examples and add them to CI test (#28)
1 parent 6ee7cdf commit 63bdf94

File tree

3 files changed

+8
-9
lines changed

3 files changed

+8
-9
lines changed

.github/workflows/test_cuda.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,5 @@ jobs:
3737
pip install --no-input --quiet --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126
3838
pip install --quiet .
3939
pytest tests
40+
python examples/example_autoparallel.py
41+
python examples/example_llama3.py

examples/example_autoparallel.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -77,17 +77,13 @@ def forward(self, x):
7777
dim2 = dim1 * 4
7878

7979

80-
def model_fn():
81-
return Block(nheads, dim1, dim2)
82-
83-
8480
def input_fn():
8581
return torch.rand(bs, seq_len, dim1, device="cuda")
8682

8783

8884
# parallelize the model
8985
with torch.device("meta"):
90-
model = model_fn()
86+
model = Block(nheads, dim1, dim2)
9187
autop = AutoParallel(model, input_fn, mesh)
9288
autop.add_parameter_memory_constraint(low=None, high=None)
9389

@@ -96,11 +92,11 @@ def input_fn():
9692
autop.add_input_constraints([x_sharding])
9793
autop.add_output_constraints([x_sharding])
9894

99-
10095
sharding_placement = autop.optimize_placement()
101-
parallel_mod = autop.apply_placement(sharding_placement)
10296

103-
# run weight init on our sharded DTensor params
97+
# AutoParallel produces a module with meta-DTensor parameters that need to be initialized
98+
parallel_mod = autop.apply_placement(sharding_placement)
99+
parallel_mod.to_empty(device="cuda")
104100
parallel_mod.init_weights()
105101

106102
# now let's run it

examples/example_llama3.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -586,7 +586,7 @@ def forward(self, tokens: torch.Tensor, input_batch: torch.Tensor | None = None)
586586

587587
def model_fn():
588588
model_args = TransformerModelArgs(
589-
n_layers=32, vocab_size=vocab_size, max_seq_len=seqlen
589+
n_layers=8, vocab_size=vocab_size, max_seq_len=seqlen
590590
)
591591
m = Transformer(model_args)
592592
return m
@@ -628,6 +628,7 @@ def input_fn():
628628
parallel_mod = autop.apply_placement(sharding_placement)
629629

630630
# run weight init on our sharded DTensor params
631+
parallel_mod.to_empty(device="cuda")
631632
parallel_mod.init_weights()
632633

633634
# now let's run it

0 commit comments

Comments
 (0)