Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/maxdiffusion/configs/base_wan_14b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,9 @@ global_batch_size: 0
# For creating tfrecords from dataset
tfrecords_dir: ''
no_records_per_shard: 0
enable_eval_timesteps: False
considered_timesteps_list: [125, 250, 375, 500, 625, 750, 875]
num_eval_samples: 420

warmup_steps_fraction: 0.1
learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps.
Expand Down Expand Up @@ -315,3 +318,6 @@ quantization_calibration_method: "absmax"
eval_every: -1
eval_data_dir: ""
enable_generate_video_for_eval: False # This will increase the used TPU memory.
eval_max_number_of_samples_in_bucket: 60 # The number of samples per bucket for evaluation. This is calculated by num_eval_samples / len(considered_timesteps_list).

enable_ssim: True
27 changes: 25 additions & 2 deletions src/maxdiffusion/data_preprocessing/wan_pusav1_to_tfrecords.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,17 @@ def float_feature_list(value):
return tf.train.Feature(float_list=tf.train.FloatList(value=value))


def create_example(latent, hidden_states):
def create_example(latent, hidden_states, timestep=None):
latent = tf.io.serialize_tensor(latent)
hidden_states = tf.io.serialize_tensor(hidden_states)
feature = {
"latents": bytes_feature(latent),
"encoder_hidden_states": bytes_feature(hidden_states),
}
# Add timestep feature if it is provided
if timestep is not None:
feature["timesteps"] = int64_feature(timestep)

example = tf.train.Example(features=tf.train.Features(feature=feature))
return example.SerializeToString()

Expand All @@ -80,6 +84,12 @@ def generate_dataset(config):
)
shard_record_count = 0

# Define timesteps and bucket configuration
num_eval_samples = config.num_eval_samples
timesteps_list = config.timesteps_list
assert num_eval_samples % len(timesteps_list) == 0
bucket_size = num_eval_samples // len(timesteps_list)

# Load dataset
metadata_path = os.path.join(config.train_data_dir, "metadata.csv")
with open(metadata_path, "r", newline="") as file:
Expand All @@ -102,7 +112,20 @@ def generate_dataset(config):
# Save them as float32 because numpy cannot read bfloat16.
latent = jnp.array(latent.float().numpy(), dtype=jnp.float32)
prompt_embeds = jnp.array(prompt_embeds.float().numpy(), dtype=jnp.float32)
writer.write(create_example(latent, prompt_embeds))

current_timestep = None
# Determine the timestep for the first 420 samples
if config.enable_eval_timesteps:
if global_record_count < num_eval_samples:
print(f"global_record_count: {global_record_count}")
bucket_index = global_record_count // bucket_size
current_timestep = timesteps_list[bucket_index]
else:
print(f"value {global_record_count} is greater than or equal to {num_eval_samples}")
return

# Write the example, including the timestep if applicable
writer.write(create_example(latent, prompt_embeds, timestep=current_timestep))
shard_record_count += 1
global_record_count += 1

Expand Down
148 changes: 96 additions & 52 deletions src/maxdiffusion/trainers/wan_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from skimage.metrics import structural_similarity as ssim
from flax.training import train_state
from maxdiffusion.pipelines.wan.wan_pipeline import WanPipeline
from jax.experimental import multihost_utils


class TrainState(train_state.TrainState):
Expand Down Expand Up @@ -156,6 +157,11 @@ def get_data_shardings(self, mesh):
data_sharding = {"latents": data_sharding, "encoder_hidden_states": data_sharding}
return data_sharding

def get_eval_data_shardings(self, mesh):
data_sharding = jax.sharding.NamedSharding(mesh, P(*self.config.data_sharding))
data_sharding = {"latents": data_sharding, "encoder_hidden_states": data_sharding, "timesteps": data_sharding}
return data_sharding

def load_dataset(self, mesh, is_training=True):
# Stages of training as described in the Wan 2.1 paper - https://arxiv.org/pdf/2503.20314
# Image pre-training - txt2img 256px
Expand All @@ -170,34 +176,43 @@ def load_dataset(self, mesh, is_training=True):
raise ValueError(
"Wan 2.1 training only supports config.dataset_type set to tfrecords and config.cache_latents_text_encoder_outputs set to True"
)

feature_description = {
"latents": tf.io.FixedLenFeature([], tf.string),
"encoder_hidden_states": tf.io.FixedLenFeature([], tf.string),
}

def prepare_sample(features):
if not is_training:
feature_description["timesteps"] = tf.io.FixedLenFeature([], tf.int64)

def prepare_sample_train(features):
latents = tf.io.parse_tensor(features["latents"], out_type=tf.float32)
encoder_hidden_states = tf.io.parse_tensor(features["encoder_hidden_states"], out_type=tf.float32)
return {"latents": latents, "encoder_hidden_states": encoder_hidden_states}

def prepare_sample_eval(features):
latents = tf.io.parse_tensor(features["latents"], out_type=tf.float32)
encoder_hidden_states = tf.io.parse_tensor(features["encoder_hidden_states"], out_type=tf.float32)
timesteps = features["timesteps"]
return {"latents": latents, "encoder_hidden_states": encoder_hidden_states, "timesteps": timesteps}

data_iterator = make_data_iterator(
config,
jax.process_index(),
jax.process_count(),
mesh,
config.global_batch_size_to_load,
feature_description=feature_description,
prepare_sample_fn=prepare_sample,
prepare_sample_fn=prepare_sample_train if is_training else prepare_sample_eval,
is_training=is_training,
)
return data_iterator

def start_training(self):

pipeline = self.load_checkpoint()
# Generate a sample before training to compare against generated sample after training.
pretrained_video_path = generate_sample(self.config, pipeline, filename_prefix="pre-training-")
if self.config.enable_ssim:
# Generate a sample before training to compare against generated sample after training.
pretrained_video_path = generate_sample(self.config, pipeline, filename_prefix="pre-training-")

if self.config.eval_every == -1 or (not self.config.enable_generate_video_for_eval):
# save some memory.
Expand All @@ -215,8 +230,57 @@ def start_training(self):
# Returns pipeline with trained transformer state
pipeline = self.training_loop(pipeline, optimizer, learning_rate_scheduler, train_data_iterator)

posttrained_video_path = generate_sample(self.config, pipeline, filename_prefix="post-training-")
print_ssim(pretrained_video_path, posttrained_video_path)
if self.config.enable_ssim:
posttrained_video_path = generate_sample(self.config, pipeline, filename_prefix="post-training-")
print_ssim(pretrained_video_path, posttrained_video_path)

def eval(self, mesh, eval_rng_key, step, p_eval_step, state, scheduler_state, writer):
eval_data_iterator = self.load_dataset(mesh, is_training=False)
eval_rng = eval_rng_key
eval_losses_by_timestep = {}
# Loop indefinitely until the iterator is exhausted
while True:
try:
eval_start_time = datetime.datetime.now()
eval_batch = load_next_batch(eval_data_iterator, None, self.config)
with mesh, nn_partitioning.axis_rules(
self.config.logical_axis_rules
):
metrics, eval_rng = p_eval_step(state, eval_batch, eval_rng, scheduler_state)
metrics["scalar"]["learning/eval_loss"].block_until_ready()
losses = metrics["scalar"]["learning/eval_loss"]
timesteps = eval_batch["timesteps"]
gathered_losses = multihost_utils.process_allgather(losses)
gathered_losses = jax.device_get(gathered_losses)
gathered_timesteps = multihost_utils.process_allgather(timesteps)
gathered_timesteps = jax.device_get(gathered_timesteps)
if jax.process_index() == 0:
for t, l in zip(gathered_timesteps.flatten(), gathered_losses.flatten()):
timestep = int(t)
if timestep not in eval_losses_by_timestep:
eval_losses_by_timestep[timestep] = []
eval_losses_by_timestep[timestep].append(l)
eval_end_time = datetime.datetime.now()
eval_duration = eval_end_time - eval_start_time
max_logging.log(f"Eval time: {eval_duration.total_seconds():.2f} seconds.")
except StopIteration:
# This block is executed when the iterator has no more data
break
# Check if any evaluation was actually performed
if eval_losses_by_timestep and jax.process_index() == 0:
mean_per_timestep = []
if jax.process_index() == 0:
max_logging.log(f"Step {step}, calculating mean loss per timestep...")
for timestep, losses in sorted(eval_losses_by_timestep.items()):
losses = jnp.array(losses)
losses = losses[: min(self.config.eval_max_number_of_samples_in_bucket, len(losses))]
mean_loss = jnp.mean(losses)
max_logging.log(f" Mean eval loss for timestep {timestep}: {mean_loss:.4f}")
mean_per_timestep.append(mean_loss)
final_eval_loss = jnp.mean(jnp.array(mean_per_timestep))
max_logging.log(f"Step {step}, Final Average Eval loss: {final_eval_loss:.4f}")
if writer:
writer.add_scalar("learning/eval_loss", final_eval_loss, step)

def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data_iterator):
mesh = pipeline.mesh
Expand All @@ -231,6 +295,7 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data
state = jax.lax.with_sharding_constraint(state, state_spec)
state_shardings = nnx.get_named_sharding(state, mesh)
data_shardings = self.get_data_shardings(mesh)
eval_data_shardings = self.get_eval_data_shardings(mesh)

writer = max_utils.initialize_summary_writer(self.config)
writer_thread = threading.Thread(target=_tensorboard_writer_worker, args=(writer, self.config), daemon=True)
Expand All @@ -255,11 +320,12 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data
)
p_eval_step = jax.jit(
functools.partial(eval_step, scheduler=pipeline.scheduler, config=self.config),
in_shardings=(state_shardings, data_shardings, None, None),
in_shardings=(state_shardings, eval_data_shardings, None, None),
out_shardings=(None, None),
)

rng = jax.random.key(self.config.seed)
rng, eval_rng_key = jax.random.split(rng)
start_step = 0
last_step_completion = datetime.datetime.now()
local_metrics_file = open(self.config.metrics_file, "a", encoding="utf8") if self.config.metrics_file else None
Expand Down Expand Up @@ -304,27 +370,8 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data
inference_generate_video(self.config, pipeline, filename_prefix=f"{step+1}-train_steps-")
# Re-create the iterator each time you start evaluation to reset it
# This assumes your data loading logic can be called to get a fresh iterator.
eval_data_iterator = self.load_dataset(mesh, is_training=False)
eval_rng = jax.random.key(self.config.seed + step)
eval_metrics = []
# Loop indefinitely until the iterator is exhausted
while True:
try:
with mesh:
eval_batch = load_next_batch(eval_data_iterator, None, self.config)
metrics, eval_rng = p_eval_step(state, eval_batch, eval_rng, scheduler_state)
eval_metrics.append(metrics["scalar"]["learning/eval_loss"])
except StopIteration:
# This block is executed when the iterator has no more data
break
# Check if any evaluation was actually performed
if eval_metrics:
eval_loss = jnp.mean(jnp.array(eval_metrics))
max_logging.log(f"Step {step}, Eval loss: {eval_loss:.4f}")
if writer:
writer.add_scalar("learning/eval_loss", eval_loss, step)
else:
max_logging.log(f"Step {step}, evaluation dataset was empty.")
self.eval(mesh, eval_rng_key, step, p_eval_step, state, scheduler_state, writer)

example_batch = next_batch_future.result()
if step != 0 and self.config.checkpoint_every != -1 and step % self.config.checkpoint_every == 0:
max_logging.log(f"Saving checkpoint for step {step}")
Expand Down Expand Up @@ -394,57 +441,54 @@ def eval_step(state, data, rng, scheduler_state, scheduler, config):
"""
Computes the evaluation loss for a single batch without updating model weights.
"""
_, new_rng, timestep_rng = jax.random.split(rng, num=3)

# This ensures the batch size is consistent, though it might be redundant
# if the evaluation dataloader is already configured correctly.
for k, v in data.items():
data[k] = v[: config.global_batch_size_to_train_on, :]

# The loss function logic is identical to training. We are evaluating the model's
# ability to perform its core training objective (e.g., denoising).
def loss_fn(params):
@jax.jit
def loss_fn(params, latents, encoder_hidden_states, timesteps, rng):
# Reconstruct the model from its definition and parameters
model = nnx.merge(state.graphdef, params, state.rest_of_state)

# Prepare inputs
latents = data["latents"].astype(config.weights_dtype)
encoder_hidden_states = data["encoder_hidden_states"].astype(config.weights_dtype)
bsz = latents.shape[0]

# Sample random timesteps and noise, just as in a training step
timesteps = jax.random.randint(
timestep_rng,
(bsz,),
0,
scheduler.config.num_train_timesteps,
)
noise = jax.random.normal(key=new_rng, shape=latents.shape, dtype=latents.dtype)
noise = jax.random.normal(key=rng, shape=latents.shape, dtype=latents.dtype)
noisy_latents = scheduler.add_noise(scheduler_state, latents, noise, timesteps)

# Get the model's prediction
model_pred = model(
hidden_states=noisy_latents,
timestep=timesteps,
encoder_hidden_states=encoder_hidden_states,
deterministic=True,
)

# Calculate the loss against the target
training_target = scheduler.training_target(latents, noise, timesteps)
training_weight = jnp.expand_dims(scheduler.training_weight(scheduler_state, timesteps), axis=(1, 2, 3, 4))
loss = (training_target - model_pred) ** 2
loss = loss * training_weight
loss = jnp.mean(loss)
# Calculate the mean loss per sample across all non-batch dimensions.
loss = loss.reshape(loss.shape[0], -1).mean(axis=1)

return loss

# --- Key Difference from train_step ---
# Directly compute the loss without calculating gradients.
# The model's state.params are used but not updated.
loss = loss_fn(state.params)
# TODO(coolkp): Explore optimizing the creation of PRNGs in a vmap or statically outside of the loop
bs = len(data["latents"])
single_batch_size = config.global_batch_size_to_train_on
losses = jnp.zeros(bs)
for i in range(0, bs, single_batch_size):
start = i
end = min(i + single_batch_size, bs)
latents= data["latents"][start:end, :].astype(config.weights_dtype)
encoder_hidden_states = data["encoder_hidden_states"][start:end, :].astype(config.weights_dtype)
timesteps = data["timesteps"][start:end].astype("int64")
_, new_rng = jax.random.split(rng, num=2)
loss = loss_fn(state.params, latents, encoder_hidden_states, timesteps, new_rng)
losses = losses.at[start:end].set(loss)

# Structure the metrics for logging and aggregation
metrics = {"scalar": {"learning/eval_loss": loss}}
metrics = {"scalar": {"learning/eval_loss": losses}}

# Return the computed metrics and the new RNG key for the next eval step
return metrics, new_rng