|
1 |
| -# Minimum effort to run this example: |
2 |
| -# $ pip install transformers |
3 |
| -# $ torchrun --nproc-per-node 2 pippy_llama.py |
4 |
| - |
5 |
| -import argparse |
| 1 | +# $ torchrun --nproc-per-node 4 pippy_llama.py |
6 | 2 | import os
|
7 |
| - |
8 | 3 | import torch
|
9 |
| -import torch.distributed as dist |
10 |
| - |
11 | 4 | from transformers import AutoModelForCausalLM, AutoTokenizer
|
12 |
| - |
13 |
| -from pippy.IR import Pipe, PipeSplitWrapper, annotate_split_points |
14 |
| -from pippy.PipelineStage import PipelineStage |
15 |
| - |
16 |
| - |
17 |
| -def add_split_points(llama, nranks): |
18 |
| - # Cut model by equal number of layers per rank |
19 |
| - layers_per_rank = (llama.config.num_hidden_layers + nranks - 1) // nranks |
20 |
| - print(f"layers_per_rank = {layers_per_rank}") |
21 |
| - for i in range(1, nranks): |
22 |
| - annotate_split_points( |
23 |
| - llama, |
24 |
| - {f'model.layers.{i * layers_per_rank}': PipeSplitWrapper.SplitPoint.BEGINNING}, |
25 |
| - ) |
26 |
| - |
27 |
| - |
28 |
| -def get_number_of_params(model): |
29 |
| - return sum(p.numel() for p in model.parameters() if p.requires_grad) |
30 |
| - |
31 |
| - |
32 |
| -def run(args): |
33 |
| - # Create a blank model |
34 |
| - llama = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", low_cpu_mem_usage=True) |
35 |
| - tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") |
36 |
| - |
37 |
| - prompts = ( |
38 |
| - "How do you", "I like to", "Can I help", "You have to", |
39 |
| - "The weather is", "I have a", "What is your", "You are a", |
40 |
| - ) # bs = 8 |
41 |
| - tokenizer.pad_token = tokenizer.eos_token |
42 |
| - inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(args.device) |
43 |
| - |
44 |
| - # Move model to `device` and set to evaluation |
45 |
| - llama.to(args.device) |
46 |
| - llama.eval() |
47 |
| - print(llama) |
48 |
| - |
49 |
| - # Annotate split points |
50 |
| - add_split_points(llama, args.world_size) |
51 |
| - |
52 |
| - # Create a pipeline stage from the model |
53 |
| - llama_pipe = Pipe.from_tracing( |
54 |
| - llama, |
55 |
| - num_chunks=args.world_size, |
56 |
| - example_args=(inputs['input_ids'],), |
57 |
| - ) |
58 |
| - |
59 |
| - assert len(list(llama_pipe.split_gm.children())) == args.world_size |
60 |
| - if args.rank == 0: |
61 |
| - for i, sm in enumerate(llama_pipe.split_gm.children()): |
62 |
| - print(f"Pipeline stage {i} {get_number_of_params(sm) // 10 ** 6}M params") |
63 |
| - |
64 |
| - # Create schedule runtime |
65 |
| - stage = PipelineStage( |
66 |
| - llama_pipe, |
67 |
| - args.rank, |
68 |
| - device=args.device, |
69 |
| - ) |
70 |
| - |
71 |
| - # Run |
72 |
| - output = None |
73 |
| - if args.rank == 0: |
74 |
| - stage(inputs['input_ids']) |
75 |
| - elif args.rank == args.world_size - 1: |
76 |
| - output = stage() |
77 |
| - else: |
78 |
| - stage() |
79 |
| - |
80 |
| - if output is not None: |
81 |
| - next_token_logits = output[0][:, -1, :] |
82 |
| - next_token = torch.argmax(next_token_logits, dim=-1) |
83 |
| - print(tokenizer.batch_decode(next_token)) |
84 |
| - |
85 |
| - |
86 |
| -if __name__ == "__main__": |
87 |
| - parser = argparse.ArgumentParser() |
88 |
| - parser.add_argument('--world_size', type=int, default=int(os.getenv("WORLD_SIZE", 4))) |
89 |
| - parser.add_argument('--rank', type=int, default=int(os.getenv("RANK", -1))) |
90 |
| - parser.add_argument('--master_addr', type=str, default=os.getenv('MASTER_ADDR', 'localhost')) |
91 |
| - parser.add_argument('--master_port', type=str, default=os.getenv('MASTER_PORT', '29500')) |
92 |
| - parser.add_argument('--schedule', type=str, default="FillDrain") |
93 |
| - parser.add_argument('--cuda', type=int, default=int(torch.cuda.is_available())) |
94 |
| - |
95 |
| - args = parser.parse_args() |
96 |
| - |
97 |
| - if args.cuda: |
98 |
| - dev_id = args.rank % torch.cuda.device_count() |
99 |
| - args.device = torch.device(f"cuda:{dev_id}") |
100 |
| - else: |
101 |
| - args.device = torch.device("cpu") |
102 |
| - |
103 |
| - # Init process group |
104 |
| - backend = "nccl" if args.cuda else "gloo" |
105 |
| - dist.init_process_group( |
106 |
| - backend=backend, |
107 |
| - rank=args.rank, |
108 |
| - world_size=args.world_size, |
109 |
| - ) |
110 |
| - |
111 |
| - run(args) |
| 5 | +from pippy import Pipe, PipeSplitWrapper, annotate_split_points, PipelineStage |
| 6 | + |
| 7 | +# Grab the model |
| 8 | +llama = AutoModelForCausalLM.from_pretrained( |
| 9 | + "meta-llama/Llama-2-7b-chat-hf", low_cpu_mem_usage=True |
| 10 | +) |
| 11 | +tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") |
| 12 | + |
| 13 | +prompts = ( |
| 14 | + "How do you", "I like to", "Can I help", "You need to", |
| 15 | + "The weather is", "I found a", "What is your", "You are so", |
| 16 | +) # bs = 8 |
| 17 | +tokenizer.pad_token = tokenizer.eos_token |
| 18 | + |
| 19 | +rank = int(os.environ["RANK"]) |
| 20 | +world_size = int(os.environ["WORLD_SIZE"]) |
| 21 | +device = torch.device(f"cuda:{rank % torch.cuda.device_count()}") |
| 22 | +llama.to(device).eval() |
| 23 | +inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(device) |
| 24 | + |
| 25 | +# Cut model by equal number of layers per rank |
| 26 | +layers_per_rank = llama.config.num_hidden_layers // world_size |
| 27 | +for i in range(1, world_size): |
| 28 | + annotate_split_points(llama, |
| 29 | + {f"model.layers.{i * layers_per_rank}": PipeSplitWrapper.SplitPoint.BEGINNING}) |
| 30 | + |
| 31 | +# Create a pipeline representation from the model |
| 32 | +llama_pipe = Pipe.from_tracing(llama, world_size, example_args=(inputs["input_ids"],)) |
| 33 | + |
| 34 | +# Create pipeline stage for each rank |
| 35 | +torch.distributed.init_process_group(rank=rank, world_size=world_size) |
| 36 | +stage = PipelineStage(llama_pipe, rank, device=device) |
| 37 | + |
| 38 | +# Run |
| 39 | +if rank == 0: |
| 40 | + args = inputs["input_ids"] |
| 41 | +else: |
| 42 | + args = None |
| 43 | +output = stage(args) |
| 44 | + |
| 45 | +# Decode |
| 46 | +if output is not None: |
| 47 | + next_token_logits = output[0][:, -1, :] |
| 48 | + next_token = torch.argmax(next_token_logits, dim=-1) |
| 49 | + print(tokenizer.batch_decode(next_token)) |
0 commit comments