Skip to content

Commit cc270b6

Browse files
authored
New Eval Pipeline as the MLPerf (#254)
1 parent 3a9d12b commit cc270b6

File tree

3 files changed

+127
-54
lines changed

3 files changed

+127
-54
lines changed

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,9 @@ global_batch_size: 0
234234
# For creating tfrecords from dataset
235235
tfrecords_dir: ''
236236
no_records_per_shard: 0
237+
enable_eval_timesteps: False
238+
considered_timesteps_list: [125, 250, 375, 500, 625, 750, 875]
239+
num_eval_samples: 420
237240

238241
warmup_steps_fraction: 0.1
239242
learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps.
@@ -315,3 +318,6 @@ quantization_calibration_method: "absmax"
315318
eval_every: -1
316319
eval_data_dir: ""
317320
enable_generate_video_for_eval: False # This will increase the used TPU memory.
321+
eval_max_number_of_samples_in_bucket: 60 # The number of samples per bucket for evaluation. This is calculated by num_eval_samples / len(considered_timesteps_list).
322+
323+
enable_ssim: True

src/maxdiffusion/data_preprocessing/wan_pusav1_to_tfrecords.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,17 @@ def float_feature_list(value):
5555
return tf.train.Feature(float_list=tf.train.FloatList(value=value))
5656

5757

58-
def create_example(latent, hidden_states):
58+
def create_example(latent, hidden_states, timestep=None):
5959
latent = tf.io.serialize_tensor(latent)
6060
hidden_states = tf.io.serialize_tensor(hidden_states)
6161
feature = {
6262
"latents": bytes_feature(latent),
6363
"encoder_hidden_states": bytes_feature(hidden_states),
6464
}
65+
# Add timestep feature if it is provided
66+
if timestep is not None:
67+
feature["timesteps"] = int64_feature(timestep)
68+
6569
example = tf.train.Example(features=tf.train.Features(feature=feature))
6670
return example.SerializeToString()
6771

@@ -80,6 +84,12 @@ def generate_dataset(config):
8084
)
8185
shard_record_count = 0
8286

87+
# Define timesteps and bucket configuration
88+
num_eval_samples = config.num_eval_samples
89+
timesteps_list = config.timesteps_list
90+
assert num_eval_samples % len(timesteps_list) == 0
91+
bucket_size = num_eval_samples // len(timesteps_list)
92+
8393
# Load dataset
8494
metadata_path = os.path.join(config.train_data_dir, "metadata.csv")
8595
with open(metadata_path, "r", newline="") as file:
@@ -102,7 +112,20 @@ def generate_dataset(config):
102112
# Save them as float32 because numpy cannot read bfloat16.
103113
latent = jnp.array(latent.float().numpy(), dtype=jnp.float32)
104114
prompt_embeds = jnp.array(prompt_embeds.float().numpy(), dtype=jnp.float32)
105-
writer.write(create_example(latent, prompt_embeds))
115+
116+
current_timestep = None
117+
# Determine the timestep for the first 420 samples
118+
if config.enable_eval_timesteps:
119+
if global_record_count < num_eval_samples:
120+
print(f"global_record_count: {global_record_count}")
121+
bucket_index = global_record_count // bucket_size
122+
current_timestep = timesteps_list[bucket_index]
123+
else:
124+
print(f"value {global_record_count} is greater than or equal to {num_eval_samples}")
125+
return
126+
127+
# Write the example, including the timestep if applicable
128+
writer.write(create_example(latent, prompt_embeds, timestep=current_timestep))
106129
shard_record_count += 1
107130
global_record_count += 1
108131

src/maxdiffusion/trainers/wan_trainer.py

Lines changed: 96 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from skimage.metrics import structural_similarity as ssim
3939
from flax.training import train_state
4040
from maxdiffusion.pipelines.wan.wan_pipeline import WanPipeline
41+
from jax.experimental import multihost_utils
4142

4243

4344
class TrainState(train_state.TrainState):
@@ -156,6 +157,11 @@ def get_data_shardings(self, mesh):
156157
data_sharding = {"latents": data_sharding, "encoder_hidden_states": data_sharding}
157158
return data_sharding
158159

160+
def get_eval_data_shardings(self, mesh):
161+
data_sharding = jax.sharding.NamedSharding(mesh, P(*self.config.data_sharding))
162+
data_sharding = {"latents": data_sharding, "encoder_hidden_states": data_sharding, "timesteps": data_sharding}
163+
return data_sharding
164+
159165
def load_dataset(self, mesh, is_training=True):
160166
# Stages of training as described in the Wan 2.1 paper - https://arxiv.org/pdf/2503.20314
161167
# Image pre-training - txt2img 256px
@@ -170,34 +176,43 @@ def load_dataset(self, mesh, is_training=True):
170176
raise ValueError(
171177
"Wan 2.1 training only supports config.dataset_type set to tfrecords and config.cache_latents_text_encoder_outputs set to True"
172178
)
173-
174179
feature_description = {
175180
"latents": tf.io.FixedLenFeature([], tf.string),
176181
"encoder_hidden_states": tf.io.FixedLenFeature([], tf.string),
177182
}
178183

179-
def prepare_sample(features):
184+
if not is_training:
185+
feature_description["timesteps"] = tf.io.FixedLenFeature([], tf.int64)
186+
187+
def prepare_sample_train(features):
180188
latents = tf.io.parse_tensor(features["latents"], out_type=tf.float32)
181189
encoder_hidden_states = tf.io.parse_tensor(features["encoder_hidden_states"], out_type=tf.float32)
182190
return {"latents": latents, "encoder_hidden_states": encoder_hidden_states}
183191

192+
def prepare_sample_eval(features):
193+
latents = tf.io.parse_tensor(features["latents"], out_type=tf.float32)
194+
encoder_hidden_states = tf.io.parse_tensor(features["encoder_hidden_states"], out_type=tf.float32)
195+
timesteps = features["timesteps"]
196+
return {"latents": latents, "encoder_hidden_states": encoder_hidden_states, "timesteps": timesteps}
197+
184198
data_iterator = make_data_iterator(
185199
config,
186200
jax.process_index(),
187201
jax.process_count(),
188202
mesh,
189203
config.global_batch_size_to_load,
190204
feature_description=feature_description,
191-
prepare_sample_fn=prepare_sample,
205+
prepare_sample_fn=prepare_sample_train if is_training else prepare_sample_eval,
192206
is_training=is_training,
193207
)
194208
return data_iterator
195209

196210
def start_training(self):
197211

198212
pipeline = self.load_checkpoint()
199-
# Generate a sample before training to compare against generated sample after training.
200-
pretrained_video_path = generate_sample(self.config, pipeline, filename_prefix="pre-training-")
213+
if self.config.enable_ssim:
214+
# Generate a sample before training to compare against generated sample after training.
215+
pretrained_video_path = generate_sample(self.config, pipeline, filename_prefix="pre-training-")
201216

202217
if self.config.eval_every == -1 or (not self.config.enable_generate_video_for_eval):
203218
# save some memory.
@@ -215,8 +230,57 @@ def start_training(self):
215230
# Returns pipeline with trained transformer state
216231
pipeline = self.training_loop(pipeline, optimizer, learning_rate_scheduler, train_data_iterator)
217232

218-
posttrained_video_path = generate_sample(self.config, pipeline, filename_prefix="post-training-")
219-
print_ssim(pretrained_video_path, posttrained_video_path)
233+
if self.config.enable_ssim:
234+
posttrained_video_path = generate_sample(self.config, pipeline, filename_prefix="post-training-")
235+
print_ssim(pretrained_video_path, posttrained_video_path)
236+
237+
def eval(self, mesh, eval_rng_key, step, p_eval_step, state, scheduler_state, writer):
238+
eval_data_iterator = self.load_dataset(mesh, is_training=False)
239+
eval_rng = eval_rng_key
240+
eval_losses_by_timestep = {}
241+
# Loop indefinitely until the iterator is exhausted
242+
while True:
243+
try:
244+
eval_start_time = datetime.datetime.now()
245+
eval_batch = load_next_batch(eval_data_iterator, None, self.config)
246+
with mesh, nn_partitioning.axis_rules(
247+
self.config.logical_axis_rules
248+
):
249+
metrics, eval_rng = p_eval_step(state, eval_batch, eval_rng, scheduler_state)
250+
metrics["scalar"]["learning/eval_loss"].block_until_ready()
251+
losses = metrics["scalar"]["learning/eval_loss"]
252+
timesteps = eval_batch["timesteps"]
253+
gathered_losses = multihost_utils.process_allgather(losses)
254+
gathered_losses = jax.device_get(gathered_losses)
255+
gathered_timesteps = multihost_utils.process_allgather(timesteps)
256+
gathered_timesteps = jax.device_get(gathered_timesteps)
257+
if jax.process_index() == 0:
258+
for t, l in zip(gathered_timesteps.flatten(), gathered_losses.flatten()):
259+
timestep = int(t)
260+
if timestep not in eval_losses_by_timestep:
261+
eval_losses_by_timestep[timestep] = []
262+
eval_losses_by_timestep[timestep].append(l)
263+
eval_end_time = datetime.datetime.now()
264+
eval_duration = eval_end_time - eval_start_time
265+
max_logging.log(f"Eval time: {eval_duration.total_seconds():.2f} seconds.")
266+
except StopIteration:
267+
# This block is executed when the iterator has no more data
268+
break
269+
# Check if any evaluation was actually performed
270+
if eval_losses_by_timestep and jax.process_index() == 0:
271+
mean_per_timestep = []
272+
if jax.process_index() == 0:
273+
max_logging.log(f"Step {step}, calculating mean loss per timestep...")
274+
for timestep, losses in sorted(eval_losses_by_timestep.items()):
275+
losses = jnp.array(losses)
276+
losses = losses[: min(self.config.eval_max_number_of_samples_in_bucket, len(losses))]
277+
mean_loss = jnp.mean(losses)
278+
max_logging.log(f" Mean eval loss for timestep {timestep}: {mean_loss:.4f}")
279+
mean_per_timestep.append(mean_loss)
280+
final_eval_loss = jnp.mean(jnp.array(mean_per_timestep))
281+
max_logging.log(f"Step {step}, Final Average Eval loss: {final_eval_loss:.4f}")
282+
if writer:
283+
writer.add_scalar("learning/eval_loss", final_eval_loss, step)
220284

221285
def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data_iterator):
222286
mesh = pipeline.mesh
@@ -231,6 +295,7 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data
231295
state = jax.lax.with_sharding_constraint(state, state_spec)
232296
state_shardings = nnx.get_named_sharding(state, mesh)
233297
data_shardings = self.get_data_shardings(mesh)
298+
eval_data_shardings = self.get_eval_data_shardings(mesh)
234299

235300
writer = max_utils.initialize_summary_writer(self.config)
236301
writer_thread = threading.Thread(target=_tensorboard_writer_worker, args=(writer, self.config), daemon=True)
@@ -255,11 +320,12 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data
255320
)
256321
p_eval_step = jax.jit(
257322
functools.partial(eval_step, scheduler=pipeline.scheduler, config=self.config),
258-
in_shardings=(state_shardings, data_shardings, None, None),
323+
in_shardings=(state_shardings, eval_data_shardings, None, None),
259324
out_shardings=(None, None),
260325
)
261326

262327
rng = jax.random.key(self.config.seed)
328+
rng, eval_rng_key = jax.random.split(rng)
263329
start_step = 0
264330
last_step_completion = datetime.datetime.now()
265331
local_metrics_file = open(self.config.metrics_file, "a", encoding="utf8") if self.config.metrics_file else None
@@ -304,27 +370,8 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data
304370
inference_generate_video(self.config, pipeline, filename_prefix=f"{step+1}-train_steps-")
305371
# Re-create the iterator each time you start evaluation to reset it
306372
# This assumes your data loading logic can be called to get a fresh iterator.
307-
eval_data_iterator = self.load_dataset(mesh, is_training=False)
308-
eval_rng = jax.random.key(self.config.seed + step)
309-
eval_metrics = []
310-
# Loop indefinitely until the iterator is exhausted
311-
while True:
312-
try:
313-
with mesh:
314-
eval_batch = load_next_batch(eval_data_iterator, None, self.config)
315-
metrics, eval_rng = p_eval_step(state, eval_batch, eval_rng, scheduler_state)
316-
eval_metrics.append(metrics["scalar"]["learning/eval_loss"])
317-
except StopIteration:
318-
# This block is executed when the iterator has no more data
319-
break
320-
# Check if any evaluation was actually performed
321-
if eval_metrics:
322-
eval_loss = jnp.mean(jnp.array(eval_metrics))
323-
max_logging.log(f"Step {step}, Eval loss: {eval_loss:.4f}")
324-
if writer:
325-
writer.add_scalar("learning/eval_loss", eval_loss, step)
326-
else:
327-
max_logging.log(f"Step {step}, evaluation dataset was empty.")
373+
self.eval(mesh, eval_rng_key, step, p_eval_step, state, scheduler_state, writer)
374+
328375
example_batch = next_batch_future.result()
329376
if step != 0 and self.config.checkpoint_every != -1 and step % self.config.checkpoint_every == 0:
330377
max_logging.log(f"Saving checkpoint for step {step}")
@@ -394,57 +441,54 @@ def eval_step(state, data, rng, scheduler_state, scheduler, config):
394441
"""
395442
Computes the evaluation loss for a single batch without updating model weights.
396443
"""
397-
_, new_rng, timestep_rng = jax.random.split(rng, num=3)
398-
399-
# This ensures the batch size is consistent, though it might be redundant
400-
# if the evaluation dataloader is already configured correctly.
401-
for k, v in data.items():
402-
data[k] = v[: config.global_batch_size_to_train_on, :]
403444

404445
# The loss function logic is identical to training. We are evaluating the model's
405446
# ability to perform its core training objective (e.g., denoising).
406-
def loss_fn(params):
447+
@jax.jit
448+
def loss_fn(params, latents, encoder_hidden_states, timesteps, rng):
407449
# Reconstruct the model from its definition and parameters
408450
model = nnx.merge(state.graphdef, params, state.rest_of_state)
409451

410-
# Prepare inputs
411-
latents = data["latents"].astype(config.weights_dtype)
412-
encoder_hidden_states = data["encoder_hidden_states"].astype(config.weights_dtype)
413-
bsz = latents.shape[0]
414-
415-
# Sample random timesteps and noise, just as in a training step
416-
timesteps = jax.random.randint(
417-
timestep_rng,
418-
(bsz,),
419-
0,
420-
scheduler.config.num_train_timesteps,
421-
)
422-
noise = jax.random.normal(key=new_rng, shape=latents.shape, dtype=latents.dtype)
452+
noise = jax.random.normal(key=rng, shape=latents.shape, dtype=latents.dtype)
423453
noisy_latents = scheduler.add_noise(scheduler_state, latents, noise, timesteps)
424454

425455
# Get the model's prediction
426456
model_pred = model(
427457
hidden_states=noisy_latents,
428458
timestep=timesteps,
429459
encoder_hidden_states=encoder_hidden_states,
460+
deterministic=True,
430461
)
431462

432463
# Calculate the loss against the target
433464
training_target = scheduler.training_target(latents, noise, timesteps)
434465
training_weight = jnp.expand_dims(scheduler.training_weight(scheduler_state, timesteps), axis=(1, 2, 3, 4))
435466
loss = (training_target - model_pred) ** 2
436467
loss = loss * training_weight
437-
loss = jnp.mean(loss)
468+
# Calculate the mean loss per sample across all non-batch dimensions.
469+
loss = loss.reshape(loss.shape[0], -1).mean(axis=1)
438470

439471
return loss
440472

441473
# --- Key Difference from train_step ---
442474
# Directly compute the loss without calculating gradients.
443475
# The model's state.params are used but not updated.
444-
loss = loss_fn(state.params)
476+
# TODO(coolkp): Explore optimizing the creation of PRNGs in a vmap or statically outside of the loop
477+
bs = len(data["latents"])
478+
single_batch_size = config.global_batch_size_to_train_on
479+
losses = jnp.zeros(bs)
480+
for i in range(0, bs, single_batch_size):
481+
start = i
482+
end = min(i + single_batch_size, bs)
483+
latents= data["latents"][start:end, :].astype(config.weights_dtype)
484+
encoder_hidden_states = data["encoder_hidden_states"][start:end, :].astype(config.weights_dtype)
485+
timesteps = data["timesteps"][start:end].astype("int64")
486+
_, new_rng = jax.random.split(rng, num=2)
487+
loss = loss_fn(state.params, latents, encoder_hidden_states, timesteps, new_rng)
488+
losses = losses.at[start:end].set(loss)
445489

446490
# Structure the metrics for logging and aggregation
447-
metrics = {"scalar": {"learning/eval_loss": loss}}
491+
metrics = {"scalar": {"learning/eval_loss": losses}}
448492

449493
# Return the computed metrics and the new RNG key for the next eval step
450494
return metrics, new_rng

0 commit comments

Comments
 (0)