38
38
from skimage .metrics import structural_similarity as ssim
39
39
from flax .training import train_state
40
40
from maxdiffusion .pipelines .wan .wan_pipeline import WanPipeline
41
+ from jax .experimental import multihost_utils
41
42
42
43
43
44
class TrainState (train_state .TrainState ):
@@ -156,6 +157,11 @@ def get_data_shardings(self, mesh):
156
157
data_sharding = {"latents" : data_sharding , "encoder_hidden_states" : data_sharding }
157
158
return data_sharding
158
159
160
+ def get_eval_data_shardings (self , mesh ):
161
+ data_sharding = jax .sharding .NamedSharding (mesh , P (* self .config .data_sharding ))
162
+ data_sharding = {"latents" : data_sharding , "encoder_hidden_states" : data_sharding , "timesteps" : data_sharding }
163
+ return data_sharding
164
+
159
165
def load_dataset (self , mesh , is_training = True ):
160
166
# Stages of training as described in the Wan 2.1 paper - https://arxiv.org/pdf/2503.20314
161
167
# Image pre-training - txt2img 256px
@@ -170,34 +176,43 @@ def load_dataset(self, mesh, is_training=True):
170
176
raise ValueError (
171
177
"Wan 2.1 training only supports config.dataset_type set to tfrecords and config.cache_latents_text_encoder_outputs set to True"
172
178
)
173
-
174
179
feature_description = {
175
180
"latents" : tf .io .FixedLenFeature ([], tf .string ),
176
181
"encoder_hidden_states" : tf .io .FixedLenFeature ([], tf .string ),
177
182
}
178
183
179
- def prepare_sample (features ):
184
+ if not is_training :
185
+ feature_description ["timesteps" ] = tf .io .FixedLenFeature ([], tf .int64 )
186
+
187
+ def prepare_sample_train (features ):
180
188
latents = tf .io .parse_tensor (features ["latents" ], out_type = tf .float32 )
181
189
encoder_hidden_states = tf .io .parse_tensor (features ["encoder_hidden_states" ], out_type = tf .float32 )
182
190
return {"latents" : latents , "encoder_hidden_states" : encoder_hidden_states }
183
191
192
+ def prepare_sample_eval (features ):
193
+ latents = tf .io .parse_tensor (features ["latents" ], out_type = tf .float32 )
194
+ encoder_hidden_states = tf .io .parse_tensor (features ["encoder_hidden_states" ], out_type = tf .float32 )
195
+ timesteps = features ["timesteps" ]
196
+ return {"latents" : latents , "encoder_hidden_states" : encoder_hidden_states , "timesteps" : timesteps }
197
+
184
198
data_iterator = make_data_iterator (
185
199
config ,
186
200
jax .process_index (),
187
201
jax .process_count (),
188
202
mesh ,
189
203
config .global_batch_size_to_load ,
190
204
feature_description = feature_description ,
191
- prepare_sample_fn = prepare_sample ,
205
+ prepare_sample_fn = prepare_sample_train if is_training else prepare_sample_eval ,
192
206
is_training = is_training ,
193
207
)
194
208
return data_iterator
195
209
196
210
def start_training (self ):
197
211
198
212
pipeline = self .load_checkpoint ()
199
- # Generate a sample before training to compare against generated sample after training.
200
- pretrained_video_path = generate_sample (self .config , pipeline , filename_prefix = "pre-training-" )
213
+ if self .config .enable_ssim :
214
+ # Generate a sample before training to compare against generated sample after training.
215
+ pretrained_video_path = generate_sample (self .config , pipeline , filename_prefix = "pre-training-" )
201
216
202
217
if self .config .eval_every == - 1 or (not self .config .enable_generate_video_for_eval ):
203
218
# save some memory.
@@ -215,8 +230,57 @@ def start_training(self):
215
230
# Returns pipeline with trained transformer state
216
231
pipeline = self .training_loop (pipeline , optimizer , learning_rate_scheduler , train_data_iterator )
217
232
218
- posttrained_video_path = generate_sample (self .config , pipeline , filename_prefix = "post-training-" )
219
- print_ssim (pretrained_video_path , posttrained_video_path )
233
+ if self .config .enable_ssim :
234
+ posttrained_video_path = generate_sample (self .config , pipeline , filename_prefix = "post-training-" )
235
+ print_ssim (pretrained_video_path , posttrained_video_path )
236
+
237
+ def eval (self , mesh , eval_rng_key , step , p_eval_step , state , scheduler_state , writer ):
238
+ eval_data_iterator = self .load_dataset (mesh , is_training = False )
239
+ eval_rng = eval_rng_key
240
+ eval_losses_by_timestep = {}
241
+ # Loop indefinitely until the iterator is exhausted
242
+ while True :
243
+ try :
244
+ eval_start_time = datetime .datetime .now ()
245
+ eval_batch = load_next_batch (eval_data_iterator , None , self .config )
246
+ with mesh , nn_partitioning .axis_rules (
247
+ self .config .logical_axis_rules
248
+ ):
249
+ metrics , eval_rng = p_eval_step (state , eval_batch , eval_rng , scheduler_state )
250
+ metrics ["scalar" ]["learning/eval_loss" ].block_until_ready ()
251
+ losses = metrics ["scalar" ]["learning/eval_loss" ]
252
+ timesteps = eval_batch ["timesteps" ]
253
+ gathered_losses = multihost_utils .process_allgather (losses )
254
+ gathered_losses = jax .device_get (gathered_losses )
255
+ gathered_timesteps = multihost_utils .process_allgather (timesteps )
256
+ gathered_timesteps = jax .device_get (gathered_timesteps )
257
+ if jax .process_index () == 0 :
258
+ for t , l in zip (gathered_timesteps .flatten (), gathered_losses .flatten ()):
259
+ timestep = int (t )
260
+ if timestep not in eval_losses_by_timestep :
261
+ eval_losses_by_timestep [timestep ] = []
262
+ eval_losses_by_timestep [timestep ].append (l )
263
+ eval_end_time = datetime .datetime .now ()
264
+ eval_duration = eval_end_time - eval_start_time
265
+ max_logging .log (f"Eval time: { eval_duration .total_seconds ():.2f} seconds." )
266
+ except StopIteration :
267
+ # This block is executed when the iterator has no more data
268
+ break
269
+ # Check if any evaluation was actually performed
270
+ if eval_losses_by_timestep and jax .process_index () == 0 :
271
+ mean_per_timestep = []
272
+ if jax .process_index () == 0 :
273
+ max_logging .log (f"Step { step } , calculating mean loss per timestep..." )
274
+ for timestep , losses in sorted (eval_losses_by_timestep .items ()):
275
+ losses = jnp .array (losses )
276
+ losses = losses [: min (self .config .eval_max_number_of_samples_in_bucket , len (losses ))]
277
+ mean_loss = jnp .mean (losses )
278
+ max_logging .log (f" Mean eval loss for timestep { timestep } : { mean_loss :.4f} " )
279
+ mean_per_timestep .append (mean_loss )
280
+ final_eval_loss = jnp .mean (jnp .array (mean_per_timestep ))
281
+ max_logging .log (f"Step { step } , Final Average Eval loss: { final_eval_loss :.4f} " )
282
+ if writer :
283
+ writer .add_scalar ("learning/eval_loss" , final_eval_loss , step )
220
284
221
285
def training_loop (self , pipeline , optimizer , learning_rate_scheduler , train_data_iterator ):
222
286
mesh = pipeline .mesh
@@ -231,6 +295,7 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data
231
295
state = jax .lax .with_sharding_constraint (state , state_spec )
232
296
state_shardings = nnx .get_named_sharding (state , mesh )
233
297
data_shardings = self .get_data_shardings (mesh )
298
+ eval_data_shardings = self .get_eval_data_shardings (mesh )
234
299
235
300
writer = max_utils .initialize_summary_writer (self .config )
236
301
writer_thread = threading .Thread (target = _tensorboard_writer_worker , args = (writer , self .config ), daemon = True )
@@ -255,11 +320,12 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data
255
320
)
256
321
p_eval_step = jax .jit (
257
322
functools .partial (eval_step , scheduler = pipeline .scheduler , config = self .config ),
258
- in_shardings = (state_shardings , data_shardings , None , None ),
323
+ in_shardings = (state_shardings , eval_data_shardings , None , None ),
259
324
out_shardings = (None , None ),
260
325
)
261
326
262
327
rng = jax .random .key (self .config .seed )
328
+ rng , eval_rng_key = jax .random .split (rng )
263
329
start_step = 0
264
330
last_step_completion = datetime .datetime .now ()
265
331
local_metrics_file = open (self .config .metrics_file , "a" , encoding = "utf8" ) if self .config .metrics_file else None
@@ -304,27 +370,8 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data
304
370
inference_generate_video (self .config , pipeline , filename_prefix = f"{ step + 1 } -train_steps-" )
305
371
# Re-create the iterator each time you start evaluation to reset it
306
372
# This assumes your data loading logic can be called to get a fresh iterator.
307
- eval_data_iterator = self .load_dataset (mesh , is_training = False )
308
- eval_rng = jax .random .key (self .config .seed + step )
309
- eval_metrics = []
310
- # Loop indefinitely until the iterator is exhausted
311
- while True :
312
- try :
313
- with mesh :
314
- eval_batch = load_next_batch (eval_data_iterator , None , self .config )
315
- metrics , eval_rng = p_eval_step (state , eval_batch , eval_rng , scheduler_state )
316
- eval_metrics .append (metrics ["scalar" ]["learning/eval_loss" ])
317
- except StopIteration :
318
- # This block is executed when the iterator has no more data
319
- break
320
- # Check if any evaluation was actually performed
321
- if eval_metrics :
322
- eval_loss = jnp .mean (jnp .array (eval_metrics ))
323
- max_logging .log (f"Step { step } , Eval loss: { eval_loss :.4f} " )
324
- if writer :
325
- writer .add_scalar ("learning/eval_loss" , eval_loss , step )
326
- else :
327
- max_logging .log (f"Step { step } , evaluation dataset was empty." )
373
+ self .eval (mesh , eval_rng_key , step , p_eval_step , state , scheduler_state , writer )
374
+
328
375
example_batch = next_batch_future .result ()
329
376
if step != 0 and self .config .checkpoint_every != - 1 and step % self .config .checkpoint_every == 0 :
330
377
max_logging .log (f"Saving checkpoint for step { step } " )
@@ -394,57 +441,54 @@ def eval_step(state, data, rng, scheduler_state, scheduler, config):
394
441
"""
395
442
Computes the evaluation loss for a single batch without updating model weights.
396
443
"""
397
- _ , new_rng , timestep_rng = jax .random .split (rng , num = 3 )
398
-
399
- # This ensures the batch size is consistent, though it might be redundant
400
- # if the evaluation dataloader is already configured correctly.
401
- for k , v in data .items ():
402
- data [k ] = v [: config .global_batch_size_to_train_on , :]
403
444
404
445
# The loss function logic is identical to training. We are evaluating the model's
405
446
# ability to perform its core training objective (e.g., denoising).
406
- def loss_fn (params ):
447
+ @jax .jit
448
+ def loss_fn (params , latents , encoder_hidden_states , timesteps , rng ):
407
449
# Reconstruct the model from its definition and parameters
408
450
model = nnx .merge (state .graphdef , params , state .rest_of_state )
409
451
410
- # Prepare inputs
411
- latents = data ["latents" ].astype (config .weights_dtype )
412
- encoder_hidden_states = data ["encoder_hidden_states" ].astype (config .weights_dtype )
413
- bsz = latents .shape [0 ]
414
-
415
- # Sample random timesteps and noise, just as in a training step
416
- timesteps = jax .random .randint (
417
- timestep_rng ,
418
- (bsz ,),
419
- 0 ,
420
- scheduler .config .num_train_timesteps ,
421
- )
422
- noise = jax .random .normal (key = new_rng , shape = latents .shape , dtype = latents .dtype )
452
+ noise = jax .random .normal (key = rng , shape = latents .shape , dtype = latents .dtype )
423
453
noisy_latents = scheduler .add_noise (scheduler_state , latents , noise , timesteps )
424
454
425
455
# Get the model's prediction
426
456
model_pred = model (
427
457
hidden_states = noisy_latents ,
428
458
timestep = timesteps ,
429
459
encoder_hidden_states = encoder_hidden_states ,
460
+ deterministic = True ,
430
461
)
431
462
432
463
# Calculate the loss against the target
433
464
training_target = scheduler .training_target (latents , noise , timesteps )
434
465
training_weight = jnp .expand_dims (scheduler .training_weight (scheduler_state , timesteps ), axis = (1 , 2 , 3 , 4 ))
435
466
loss = (training_target - model_pred ) ** 2
436
467
loss = loss * training_weight
437
- loss = jnp .mean (loss )
468
+ # Calculate the mean loss per sample across all non-batch dimensions.
469
+ loss = loss .reshape (loss .shape [0 ], - 1 ).mean (axis = 1 )
438
470
439
471
return loss
440
472
441
473
# --- Key Difference from train_step ---
442
474
# Directly compute the loss without calculating gradients.
443
475
# The model's state.params are used but not updated.
444
- loss = loss_fn (state .params )
476
+ # TODO(coolkp): Explore optimizing the creation of PRNGs in a vmap or statically outside of the loop
477
+ bs = len (data ["latents" ])
478
+ single_batch_size = config .global_batch_size_to_train_on
479
+ losses = jnp .zeros (bs )
480
+ for i in range (0 , bs , single_batch_size ):
481
+ start = i
482
+ end = min (i + single_batch_size , bs )
483
+ latents = data ["latents" ][start :end , :].astype (config .weights_dtype )
484
+ encoder_hidden_states = data ["encoder_hidden_states" ][start :end , :].astype (config .weights_dtype )
485
+ timesteps = data ["timesteps" ][start :end ].astype ("int64" )
486
+ _ , new_rng = jax .random .split (rng , num = 2 )
487
+ loss = loss_fn (state .params , latents , encoder_hidden_states , timesteps , new_rng )
488
+ losses = losses .at [start :end ].set (loss )
445
489
446
490
# Structure the metrics for logging and aggregation
447
- metrics = {"scalar" : {"learning/eval_loss" : loss }}
491
+ metrics = {"scalar" : {"learning/eval_loss" : losses }}
448
492
449
493
# Return the computed metrics and the new RNG key for the next eval step
450
494
return metrics , new_rng
0 commit comments