Skip to content

Commit f27a184

Browse files
authored
[WIP][DeepSeek] integrate torchao _scaled_group_gemm, refactor all group_gemms into base class and class impl for modularity (#1142)
This PR: 1 - integrates the new FP8 rowwise _scaled_group_gemm from torchao. This is fully integrated and working but is not perf optimized yet. 2 - refactors the group gemm impls via group gemm base class that requires: a - init (set activation function) b - arrange_expert_weights c - execute (run the group gemm) d - is_available (confirm requirements available for this group gemm). There are now four implementations implemented using this base class: "torch" - native PyTorch bf16 (cutlass) "torchao" - triton based FBGemm derivative (bf16) "torchfp8" - torchao fp8 rowwise "dsgemm" - DeepSeek fp8 blockwise/groupwise all group gemm code has moved to "group_gemms.py". Testing: a - ran all four group gemms using inference. b - ran 'torch' with training.
1 parent 1bff6a0 commit f27a184

File tree

3 files changed

+462
-181
lines changed

3 files changed

+462
-181
lines changed

torchtitan/experiments/deepseek_v3/generate.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def create_model(dist_config: DistConfig):
127127
model_args.ep_size = dist_config.ep_size
128128
model_args.num_stages = dist_config.pp_size
129129
model_args.stage_idx = dist_config.pp_rank
130-
model_args.max_seq_len = 16384
130+
model_args.max_seq_len = 4096 # 16384
131131

132132
with dist_config.device, dist_config.mesh:
133133
model = DeepseekForCausalLM(model_args)
@@ -291,7 +291,7 @@ def generate(
291291

292292
# Print progress indicator every 20 tokens
293293
if rank == 0 and tokens_generated % 20 == 0:
294-
print(".", end="", flush=True)
294+
print(f"{color.yellow}:{color.reset}", end="", flush=True)
295295

296296
# Print newline after progress indicator
297297
if rank == 0:

0 commit comments

Comments
 (0)