Skip to content

Commit f4db8e4

Browse files
author
maxtext authors
committed
Merge pull request #1798 from AI-Hypercomputer:moe_benchmark
PiperOrigin-RevId: 766679174
2 parents b3f4113 + 9ca35d7 commit f4db8e4

File tree

2 files changed

+91
-2
lines changed

2 files changed

+91
-2
lines changed

benchmarks/benchmark_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ def add_on_device_runner_arguments(custom_parser: argparse.ArgumentParser):
201201
custom_parser.add_argument(
202202
'--model_name',
203203
type=str,
204-
choices=list(trillium_model_dict.keys()) + list(v5e_model_dict.keys()),
204+
choices=list(trillium_model_dict.keys()) + list(v5p_model_dict.keys()) + list(v5e_model_dict.keys()),
205205
default=list(trillium_model_dict.keys())[0],
206206
help=(
207207
'model to be benchmarked, supported models are the command choices.'

benchmarks/maxtext_v5p_model_configs.py

Lines changed: 90 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,4 +104,93 @@
104104
+ xla_flags_library.DATA_PARALLEL_OVERLAP
105105
),
106106
),
107-
)
107+
)
108+
109+
llama4_scout_dropless_v5p_256 = _add_to_model_dictionary(
110+
v5p_model_dict,
111+
MaxTextModel(
112+
model_name="llama4_scout_dropless_v5p_256",
113+
model_type="llama4-17b-16e",
114+
tuning_params={
115+
"per_device_batch_size": 8,
116+
"max_target_length": 8192,
117+
"ici_fsdp_parallelism": -1,
118+
"enable_checkpointing": False,
119+
"dtype": "bfloat16",
120+
"weight_dtype": "float32",
121+
"megablox": True,
122+
"sparse_matmul": True,
123+
"dataset_type": "synthetic",
124+
"opt_type": "adamw",
125+
"skip_first_n_steps_for_profiler": 5,
126+
"profiler_steps": 3,
127+
"profiler": "xplane",
128+
"remat_policy": "custom",
129+
"decoder_layer_input": "offload",
130+
"reuse_example_batch": 1,
131+
"sa_block_q": 2048,
132+
"sa_block_kv": 2048,
133+
"sa_block_kv_compute": 2048,
134+
"sa_block_q_dkv": 2048,
135+
"sa_block_kv_dkv": 2048,
136+
"sa_block_kv_dkv_compute": 2048,
137+
"sa_block_q_dq": 2048,
138+
"sa_block_kv_dq": 2048,
139+
"tokenizer_path": "meta-llama/Llama-4-Scout-17B-16E",
140+
},
141+
xla_flags=(
142+
xla_flags_library.MOE_VMEM_LIMIT_FLAG
143+
+ xla_flags_library.CF_FOR_ALL_GATHER
144+
+ xla_flags_library.DATA_PARALLEL_OVERLAP
145+
+ xla_flags_library.LAYOUT_FOR_ALL_REDUCE_SCATTER
146+
+ xla_flags_library.HOST_OFFLOAD_FLAGS
147+
),
148+
),
149+
)
150+
151+
llama4_maverick_dropless_v5p_256 = _add_to_model_dictionary(
152+
v5p_model_dict,
153+
MaxTextModel(
154+
model_name="llama4_maverick_dropless_v5p_256",
155+
model_type="llama4-17b-128e",
156+
tuning_params={
157+
"per_device_batch_size": 4,
158+
"max_target_length": 8192,
159+
"ici_fsdp_parallelism": 32,
160+
"ici_tensor_parallelism": 4,
161+
"enable_checkpointing": False,
162+
"dtype": "bfloat16",
163+
"weight_dtype": "float32",
164+
"megablox": True,
165+
"sparse_matmul": True,
166+
"dataset_type": "synthetic",
167+
"opt_type": "adamw",
168+
"skip_first_n_steps_for_profiler": 5,
169+
"profiler_steps": 3,
170+
"profiler": "xplane",
171+
"remat_policy": "custom",
172+
"decoder_layer_input": "offload",
173+
"out_proj": "offload",
174+
"query_proj": "offload",
175+
"key_proj": "offload",
176+
"value_proj": "offload",
177+
"reuse_example_batch": 1,
178+
"sa_block_q": 2048,
179+
"sa_block_kv": 2048,
180+
"sa_block_kv_compute": 2048,
181+
"sa_block_q_dkv": 2048,
182+
"sa_block_kv_dkv": 2048,
183+
"sa_block_kv_dkv_compute": 2048,
184+
"sa_block_q_dq": 2048,
185+
"sa_block_kv_dq": 2048,
186+
"tokenizer_path": "meta-llama/Llama-4-Maverick-17B-128E",
187+
},
188+
xla_flags=(
189+
xla_flags_library.MOE_VMEM_LIMIT_FLAG
190+
+ xla_flags_library.CF_FOR_ALL_GATHER
191+
+ xla_flags_library.DATA_PARALLEL_OVERLAP
192+
+ xla_flags_library.LAYOUT_FOR_ALL_REDUCE_SCATTER
193+
+ xla_flags_library.HOST_OFFLOAD_FLAGS
194+
),
195+
),
196+
)

0 commit comments

Comments
 (0)