Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/maxdiffusion/configs/base_wan_14b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ global_batch_size: 0
# For creating tfrecords from dataset
tfrecords_dir: ''
no_records_per_shard: 0
enable_eval_timesteps: False

warmup_steps_fraction: 0.1
learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps.
Expand Down
26 changes: 24 additions & 2 deletions src/maxdiffusion/data_preprocessing/wan_pusav1_to_tfrecords.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,17 @@ def float_feature_list(value):
return tf.train.Feature(float_list=tf.train.FloatList(value=value))


def create_example(latent, hidden_states):
def create_example(latent, hidden_states, timestep=None):
latent = tf.io.serialize_tensor(latent)
hidden_states = tf.io.serialize_tensor(hidden_states)
feature = {
"latents": bytes_feature(latent),
"encoder_hidden_states": bytes_feature(hidden_states),
}
# Add timestep feature if it is provided
if timestep is not None:
feature["timesteps"] = int64_feature(timestep)

example = tf.train.Example(features=tf.train.Features(feature=feature))
return example.SerializeToString()

Expand All @@ -80,6 +84,11 @@ def generate_dataset(config):
)
shard_record_count = 0

# Define timesteps and bucket configuration
timesteps_list = [125, 250, 375, 500, 625, 750, 875]
bucket_size = 60
num_samples_to_process = 420

# Load dataset
metadata_path = os.path.join(config.train_data_dir, "metadata.csv")
with open(metadata_path, "r", newline="") as file:
Expand All @@ -102,7 +111,20 @@ def generate_dataset(config):
# Save them as float32 because numpy cannot read bfloat16.
latent = jnp.array(latent.float().numpy(), dtype=jnp.float32)
prompt_embeds = jnp.array(prompt_embeds.float().numpy(), dtype=jnp.float32)
writer.write(create_example(latent, prompt_embeds))

current_timestep = None
# Determine the timestep for the first 420 samples
if config.enable_eval_timesteps:
if global_record_count < num_samples_to_process:
print(f"global_record_count: {global_record_count}")
bucket_index = global_record_count // bucket_size
current_timestep = timesteps_list[bucket_index]
else:
print(f"value {global_record_count} is greater than or equal to {num_samples_to_process}")
return

# Write the example, including the timestep if applicable
writer.write(create_example(latent, prompt_embeds, timestep=current_timestep))
shard_record_count += 1
global_record_count += 1

Expand Down
81 changes: 57 additions & 24 deletions src/maxdiffusion/trainers/wan_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,12 @@ def get_data_shardings(self, mesh):
data_sharding = {"latents": data_sharding, "encoder_hidden_states": data_sharding}
return data_sharding

def get_eval_data_shardings(self, mesh):
data_sharding = jax.sharding.NamedSharding(mesh, P(*self.config.data_sharding))
timesteps_sharding = jax.sharding.NamedSharding(mesh, P('data'))
data_sharding = {"latents": data_sharding, "encoder_hidden_states": data_sharding, "timesteps": timesteps_sharding}
return data_sharding

def load_dataset(self, mesh, is_training=True):
# Stages of training as described in the Wan 2.1 paper - https://arxiv.org/pdf/2503.20314
# Image pre-training - txt2img 256px
Expand All @@ -170,25 +176,35 @@ def load_dataset(self, mesh, is_training=True):
raise ValueError(
"Wan 2.1 training only supports config.dataset_type set to tfrecords and config.cache_latents_text_encoder_outputs set to True"
)

feature_description = {
"latents": tf.io.FixedLenFeature([], tf.string),
"encoder_hidden_states": tf.io.FixedLenFeature([], tf.string),
}

def prepare_sample(features):
if not is_training:
feature_description["timesteps"] = tf.io.FixedLenFeature([], tf.int64)

def prepare_sample_train(features):
latents = tf.io.parse_tensor(features["latents"], out_type=tf.float32)
encoder_hidden_states = tf.io.parse_tensor(features["encoder_hidden_states"], out_type=tf.float32)
return {"latents": latents, "encoder_hidden_states": encoder_hidden_states}

def prepare_sample_eval(features):
latents = tf.io.parse_tensor(features["latents"], out_type=tf.float32)
encoder_hidden_states = tf.io.parse_tensor(features["encoder_hidden_states"], out_type=tf.float32)
timesteps = features["timesteps"]
tf.print("timesteps in prepare_sample_eval:", timesteps)
return {"latents": latents, "encoder_hidden_states": encoder_hidden_states, "timesteps": timesteps}

data_iterator = make_data_iterator(
config,
jax.process_index(),
jax.process_count(),
mesh,
config.global_batch_size_to_load,
feature_description=feature_description,
prepare_sample_fn=prepare_sample,
prepare_sample_fn=prepare_sample_train if is_training else prepare_sample_eval,
is_training=is_training,
)
return data_iterator
Expand All @@ -197,7 +213,7 @@ def start_training(self):

pipeline = self.load_checkpoint()
# Generate a sample before training to compare against generated sample after training.
pretrained_video_path = generate_sample(self.config, pipeline, filename_prefix="pre-training-")
# pretrained_video_path = generate_sample(self.config, pipeline, filename_prefix="pre-training-")

if self.config.eval_every == -1 or (not self.config.enable_generate_video_for_eval):
# save some memory.
Expand All @@ -215,8 +231,8 @@ def start_training(self):
# Returns pipeline with trained transformer state
pipeline = self.training_loop(pipeline, optimizer, learning_rate_scheduler, train_data_iterator)

posttrained_video_path = generate_sample(self.config, pipeline, filename_prefix="post-training-")
print_ssim(pretrained_video_path, posttrained_video_path)
# posttrained_video_path = generate_sample(self.config, pipeline, filename_prefix="post-training-")
# print_ssim(pretrained_video_path, posttrained_video_path)

def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data_iterator):
mesh = pipeline.mesh
Expand All @@ -231,6 +247,7 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data
state = jax.lax.with_sharding_constraint(state, state_spec)
state_shardings = nnx.get_named_sharding(state, mesh)
data_shardings = self.get_data_shardings(mesh)
eval_data_shardings = self.get_eval_data_shardings(mesh)

writer = max_utils.initialize_summary_writer(self.config)
writer_thread = threading.Thread(target=_tensorboard_writer_worker, args=(writer, self.config), daemon=True)
Expand All @@ -255,11 +272,12 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data
)
p_eval_step = jax.jit(
functools.partial(eval_step, scheduler=pipeline.scheduler, config=self.config),
in_shardings=(state_shardings, data_shardings, None, None),
in_shardings=(state_shardings, eval_data_shardings, None, None),
out_shardings=(None, None),
)

rng = jax.random.key(self.config.seed)
rng, eval_rng_key = jax.random.split(rng)
start_step = 0
last_step_completion = datetime.datetime.now()
local_metrics_file = open(self.config.metrics_file, "a", encoding="utf8") if self.config.metrics_file else None
Expand Down Expand Up @@ -305,24 +323,40 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data
# Re-create the iterator each time you start evaluation to reset it
# This assumes your data loading logic can be called to get a fresh iterator.
eval_data_iterator = self.load_dataset(mesh, is_training=False)
eval_rng = jax.random.key(self.config.seed + step)
eval_metrics = []
eval_rng = eval_rng_key
eval_losses_by_timestep = {}
# Loop indefinitely until the iterator is exhausted
while True:
try:
with mesh:
eval_batch = load_next_batch(eval_data_iterator, None, self.config)
eval_batch["timesteps"] = jax.device_put(
eval_batch["timesteps"], eval_data_shardings["timesteps"]
)
metrics, eval_rng = p_eval_step(state, eval_batch, eval_rng, scheduler_state)
eval_metrics.append(metrics["scalar"]["learning/eval_loss"])
loss = metrics["scalar"]["learning/eval_loss"]
timestep = int(eval_batch["timesteps"][0])
jax.debug.print("timesteps in eval_step: {x}", x=timestep)
if timestep not in eval_losses_by_timestep:
eval_losses_by_timestep[timestep] = []
eval_losses_by_timestep[timestep].append(loss)
except StopIteration:
# This block is executed when the iterator has no more data
break
# Check if any evaluation was actually performed
if eval_metrics:
eval_loss = jnp.mean(jnp.array(eval_metrics))
max_logging.log(f"Step {step}, Eval loss: {eval_loss:.4f}")
if eval_losses_by_timestep:
mean_per_timestep = []
max_logging.log(f"Step {step}, calculating mean loss per timestep...")
for timestep, losses in sorted(eval_losses_by_timestep.items()):
losses = jnp.array(losses)
losses = losses[: min(60, len(losses))]
mean_loss = jnp.mean(losses)
max_logging.log(f" Mean eval loss for timestep {timestep}: {mean_loss:.4f}, num of losses: {len(losses)}")
mean_per_timestep.append(mean_loss)
final_eval_loss = jnp.mean(jnp.array(mean_per_timestep))
max_logging.log(f"Step {step}, Final Average Eval loss: {final_eval_loss:.4f}")
if writer:
writer.add_scalar("learning/eval_loss", eval_loss, step)
writer.add_scalar("learning/eval_loss", final_eval_loss, step)
else:
max_logging.log(f"Step {step}, evaluation dataset was empty.")
example_batch = next_batch_future.result()
Expand Down Expand Up @@ -394,12 +428,17 @@ def eval_step(state, data, rng, scheduler_state, scheduler, config):
"""
Computes the evaluation loss for a single batch without updating model weights.
"""
_, new_rng, timestep_rng = jax.random.split(rng, num=3)
_, new_rng = jax.random.split(rng, num=2)

# This ensures the batch size is consistent, though it might be redundant
# if the evaluation dataloader is already configured correctly.
jax.debug.print("timesteps before clip: {x}", x=data["timesteps"])
for k, v in data.items():
data[k] = v[: config.global_batch_size_to_train_on, :]
if k != "timesteps":
data[k] = v[: config.global_batch_size_to_train_on, :]
else:
data[k] = v[: config.global_batch_size_to_train_on]
jax.debug.print("timesteps after clip: {x}", x=data["timesteps"])

# The loss function logic is identical to training. We are evaluating the model's
# ability to perform its core training objective (e.g., denoising).
Expand All @@ -410,15 +449,8 @@ def loss_fn(params):
# Prepare inputs
latents = data["latents"].astype(config.weights_dtype)
encoder_hidden_states = data["encoder_hidden_states"].astype(config.weights_dtype)
bsz = latents.shape[0]
timesteps = data["timesteps"].astype("int64")

# Sample random timesteps and noise, just as in a training step
timesteps = jax.random.randint(
timestep_rng,
(bsz,),
0,
scheduler.config.num_train_timesteps,
)
noise = jax.random.normal(key=new_rng, shape=latents.shape, dtype=latents.dtype)
noisy_latents = scheduler.add_noise(scheduler_state, latents, noise, timesteps)

Expand All @@ -427,6 +459,7 @@ def loss_fn(params):
hidden_states=noisy_latents,
timestep=timesteps,
encoder_hidden_states=encoder_hidden_states,
deterministic=True,
)

# Calculate the loss against the target
Expand Down
Loading