From e0c8ef0aa167ad2ff513d9500d7e76a6a3ccefe9 Mon Sep 17 00:00:00 2001 From: Hong-Rong Hsu Date: Mon, 5 Aug 2024 09:05:29 +0000 Subject: [PATCH] Fix llama example split failed Missing `split_spec` in `pipeline` will cause error like: "RuntimeError: Pipeline group size 4 cannot be larger than number of stages 1" --- examples/llama/pippy_llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/llama/pippy_llama.py b/examples/llama/pippy_llama.py index 168d47045..e06a30833 100644 --- a/examples/llama/pippy_llama.py +++ b/examples/llama/pippy_llama.py @@ -33,7 +33,7 @@ # Create a pipeline representation from the model mb_inputs = tokenizer(mb_prompts, return_tensors="pt", padding=True).to(device) -pipe = pipeline(llama, mb_args=(mb_inputs["input_ids"],)) +pipe = pipeline(llama, mb_args=(mb_inputs["input_ids"],), split_spec=split_spec) # Create pipeline stage for each rank stage = pipe.build_stage(rank, device=device)