104104 + xla_flags_library .DATA_PARALLEL_OVERLAP
105105 ),
106106 ),
107- )
107+ )
108+
109+ llama4_scout_dropless_v5p_256 = _add_to_model_dictionary (
110+ v5p_model_dict ,
111+ MaxTextModel (
112+ model_name = "llama4_scout_dropless_v5p_256" ,
113+ model_type = "llama4-17b-16e" ,
114+ tuning_params = {
115+ "per_device_batch_size" : 8 ,
116+ "max_target_length" : 8192 ,
117+ "ici_fsdp_parallelism" : - 1 ,
118+ "enable_checkpointing" : False ,
119+ "dtype" : "bfloat16" ,
120+ "weight_dtype" : "float32" ,
121+ "megablox" : True ,
122+ "sparse_matmul" : True ,
123+ "dataset_type" : "synthetic" ,
124+ "opt_type" : "adamw" ,
125+ "skip_first_n_steps_for_profiler" : 5 ,
126+ "profiler_steps" : 3 ,
127+ "profiler" : "xplane" ,
128+ "remat_policy" : "custom" ,
129+ "decoder_layer_input" : "offload" ,
130+ "reuse_example_batch" : 1 ,
131+ "sa_block_q" : 2048 ,
132+ "sa_block_kv" : 2048 ,
133+ "sa_block_kv_compute" : 2048 ,
134+ "sa_block_q_dkv" : 2048 ,
135+ "sa_block_kv_dkv" : 2048 ,
136+ "sa_block_kv_dkv_compute" : 2048 ,
137+ "sa_block_q_dq" : 2048 ,
138+ "sa_block_kv_dq" : 2048 ,
139+ "tokenizer_path" : "meta-llama/Llama-4-Scout-17B-16E" ,
140+ },
141+ xla_flags = (
142+ xla_flags_library .MOE_VMEM_LIMIT_FLAG
143+ + xla_flags_library .CF_FOR_ALL_GATHER
144+ + xla_flags_library .DATA_PARALLEL_OVERLAP
145+ + xla_flags_library .LAYOUT_FOR_ALL_REDUCE_SCATTER
146+ + xla_flags_library .HOST_OFFLOAD_FLAGS
147+ ),
148+ ),
149+ )
150+
151+ llama4_maverick_dropless_v5p_256 = _add_to_model_dictionary (
152+ v5p_model_dict ,
153+ MaxTextModel (
154+ model_name = "llama4_maverick_dropless_v5p_256" ,
155+ model_type = "llama4-17b-128e" ,
156+ tuning_params = {
157+ "per_device_batch_size" : 4 ,
158+ "max_target_length" : 8192 ,
159+ "ici_fsdp_parallelism" : 32 ,
160+ "ici_tensor_parallelism" : 4 ,
161+ "enable_checkpointing" : False ,
162+ "dtype" : "bfloat16" ,
163+ "weight_dtype" : "float32" ,
164+ "megablox" : True ,
165+ "sparse_matmul" : True ,
166+ "dataset_type" : "synthetic" ,
167+ "opt_type" : "adamw" ,
168+ "skip_first_n_steps_for_profiler" : 5 ,
169+ "profiler_steps" : 3 ,
170+ "profiler" : "xplane" ,
171+ "remat_policy" : "custom" ,
172+ "decoder_layer_input" : "offload" ,
173+ "out_proj" : "offload" ,
174+ "query_proj" : "offload" ,
175+ "key_proj" : "offload" ,
176+ "value_proj" : "offload" ,
177+ "reuse_example_batch" : 1 ,
178+ "sa_block_q" : 2048 ,
179+ "sa_block_kv" : 2048 ,
180+ "sa_block_kv_compute" : 2048 ,
181+ "sa_block_q_dkv" : 2048 ,
182+ "sa_block_kv_dkv" : 2048 ,
183+ "sa_block_kv_dkv_compute" : 2048 ,
184+ "sa_block_q_dq" : 2048 ,
185+ "sa_block_kv_dq" : 2048 ,
186+ "tokenizer_path" : "meta-llama/Llama-4-Maverick-17B-128E" ,
187+ },
188+ xla_flags = (
189+ xla_flags_library .MOE_VMEM_LIMIT_FLAG
190+ + xla_flags_library .CF_FOR_ALL_GATHER
191+ + xla_flags_library .DATA_PARALLEL_OVERLAP
192+ + xla_flags_library .LAYOUT_FOR_ALL_REDUCE_SCATTER
193+ + xla_flags_library .HOST_OFFLOAD_FLAGS
194+ ),
195+ ),
196+ )
0 commit comments