Skip to content

Commit ea7d1d6

Browse files
authoredDec 20, 2023
Simplify llama example (#921)
Add README. Remove unnecessary logics.
1 parent c81a65d commit ea7d1d6

File tree

3 files changed

+65
-108
lines changed

3 files changed

+65
-108
lines changed
 

‎examples/llama/README.md

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
```
2+
$ torchrun --nproc-per-node 2 pippy_llama.py
3+
```
4+
```
5+
$ torchrun --nproc-per-node 4 pippy_llama.py
6+
```
7+
```
8+
$ torchrun --nproc-per-node 8 pippy_llama.py
9+
```
10+
```
11+
prompts = (
12+
"How do you", "I like to", "Can I help", "You need to",
13+
"The weather is", "I found a", "What is your", "You are so",
14+
)
15+
Outputs:
16+
['make', 'think', 'you', 'be', 'getting', 'great', 'favorite', 'right']
17+
```

‎examples/llama/pippy_llama.py

+46-108
Original file line numberDiff line numberDiff line change
@@ -1,111 +1,49 @@
1-
# Minimum effort to run this example:
2-
# $ pip install transformers
3-
# $ torchrun --nproc-per-node 2 pippy_llama.py
4-
5-
import argparse
1+
# $ torchrun --nproc-per-node 4 pippy_llama.py
62
import os
7-
83
import torch
9-
import torch.distributed as dist
10-
114
from transformers import AutoModelForCausalLM, AutoTokenizer
12-
13-
from pippy.IR import Pipe, PipeSplitWrapper, annotate_split_points
14-
from pippy.PipelineStage import PipelineStage
15-
16-
17-
def add_split_points(llama, nranks):
18-
# Cut model by equal number of layers per rank
19-
layers_per_rank = (llama.config.num_hidden_layers + nranks - 1) // nranks
20-
print(f"layers_per_rank = {layers_per_rank}")
21-
for i in range(1, nranks):
22-
annotate_split_points(
23-
llama,
24-
{f'model.layers.{i * layers_per_rank}': PipeSplitWrapper.SplitPoint.BEGINNING},
25-
)
26-
27-
28-
def get_number_of_params(model):
29-
return sum(p.numel() for p in model.parameters() if p.requires_grad)
30-
31-
32-
def run(args):
33-
# Create a blank model
34-
llama = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", low_cpu_mem_usage=True)
35-
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
36-
37-
prompts = (
38-
"How do you", "I like to", "Can I help", "You have to",
39-
"The weather is", "I have a", "What is your", "You are a",
40-
) # bs = 8
41-
tokenizer.pad_token = tokenizer.eos_token
42-
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(args.device)
43-
44-
# Move model to `device` and set to evaluation
45-
llama.to(args.device)
46-
llama.eval()
47-
print(llama)
48-
49-
# Annotate split points
50-
add_split_points(llama, args.world_size)
51-
52-
# Create a pipeline stage from the model
53-
llama_pipe = Pipe.from_tracing(
54-
llama,
55-
num_chunks=args.world_size,
56-
example_args=(inputs['input_ids'],),
57-
)
58-
59-
assert len(list(llama_pipe.split_gm.children())) == args.world_size
60-
if args.rank == 0:
61-
for i, sm in enumerate(llama_pipe.split_gm.children()):
62-
print(f"Pipeline stage {i} {get_number_of_params(sm) // 10 ** 6}M params")
63-
64-
# Create schedule runtime
65-
stage = PipelineStage(
66-
llama_pipe,
67-
args.rank,
68-
device=args.device,
69-
)
70-
71-
# Run
72-
output = None
73-
if args.rank == 0:
74-
stage(inputs['input_ids'])
75-
elif args.rank == args.world_size - 1:
76-
output = stage()
77-
else:
78-
stage()
79-
80-
if output is not None:
81-
next_token_logits = output[0][:, -1, :]
82-
next_token = torch.argmax(next_token_logits, dim=-1)
83-
print(tokenizer.batch_decode(next_token))
84-
85-
86-
if __name__ == "__main__":
87-
parser = argparse.ArgumentParser()
88-
parser.add_argument('--world_size', type=int, default=int(os.getenv("WORLD_SIZE", 4)))
89-
parser.add_argument('--rank', type=int, default=int(os.getenv("RANK", -1)))
90-
parser.add_argument('--master_addr', type=str, default=os.getenv('MASTER_ADDR', 'localhost'))
91-
parser.add_argument('--master_port', type=str, default=os.getenv('MASTER_PORT', '29500'))
92-
parser.add_argument('--schedule', type=str, default="FillDrain")
93-
parser.add_argument('--cuda', type=int, default=int(torch.cuda.is_available()))
94-
95-
args = parser.parse_args()
96-
97-
if args.cuda:
98-
dev_id = args.rank % torch.cuda.device_count()
99-
args.device = torch.device(f"cuda:{dev_id}")
100-
else:
101-
args.device = torch.device("cpu")
102-
103-
# Init process group
104-
backend = "nccl" if args.cuda else "gloo"
105-
dist.init_process_group(
106-
backend=backend,
107-
rank=args.rank,
108-
world_size=args.world_size,
109-
)
110-
111-
run(args)
5+
from pippy import Pipe, PipeSplitWrapper, annotate_split_points, PipelineStage
6+
7+
# Grab the model
8+
llama = AutoModelForCausalLM.from_pretrained(
9+
"meta-llama/Llama-2-7b-chat-hf", low_cpu_mem_usage=True
10+
)
11+
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
12+
13+
prompts = (
14+
"How do you", "I like to", "Can I help", "You need to",
15+
"The weather is", "I found a", "What is your", "You are so",
16+
) # bs = 8
17+
tokenizer.pad_token = tokenizer.eos_token
18+
19+
rank = int(os.environ["RANK"])
20+
world_size = int(os.environ["WORLD_SIZE"])
21+
device = torch.device(f"cuda:{rank % torch.cuda.device_count()}")
22+
llama.to(device).eval()
23+
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(device)
24+
25+
# Cut model by equal number of layers per rank
26+
layers_per_rank = llama.config.num_hidden_layers // world_size
27+
for i in range(1, world_size):
28+
annotate_split_points(llama,
29+
{f"model.layers.{i * layers_per_rank}": PipeSplitWrapper.SplitPoint.BEGINNING})
30+
31+
# Create a pipeline representation from the model
32+
llama_pipe = Pipe.from_tracing(llama, world_size, example_args=(inputs["input_ids"],))
33+
34+
# Create pipeline stage for each rank
35+
torch.distributed.init_process_group(rank=rank, world_size=world_size)
36+
stage = PipelineStage(llama_pipe, rank, device=device)
37+
38+
# Run
39+
if rank == 0:
40+
args = inputs["input_ids"]
41+
else:
42+
args = None
43+
output = stage(args)
44+
45+
# Decode
46+
if output is not None:
47+
next_token_logits = output[0][:, -1, :]
48+
next_token = torch.argmax(next_token_logits, dim=-1)
49+
print(tokenizer.batch_decode(next_token))

‎pippy/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,15 @@
99
TrivialLossWrapper,
1010
)
1111
from pippy.ModelSplit import split_into_equal_size, split_on_size_threshold
12+
from pippy.PipelineStage import PipelineStage
1213

1314

1415
__all__ = [
1516
"PipeSequential",
1617
"LossWrapper",
1718
"TrivialLossWrapper",
1819
"Pipe",
20+
"PipelineStage",
1921
"pipe_split",
2022
"PipeSplitWrapper",
2123
"annotate_split_points",

0 commit comments

Comments
 (0)
Please sign in to comment.