Skip to content

Conversation

hyeygit
Copy link

@hyeygit hyeygit commented Sep 18, 2025

Benchmark results on v6e-8:

480p video generation (81 frames)

  • Baseline (main 3cd6a44): 78s
  • This PR: 66s (15% improvement)

720p video generation (81 frames)

  • Baseline (main 3cd6a44): 215s
  • This PR: 200s (7% improvement)

Benchmark command used:
480p:

HF_HUB_DISABLE_XET=True HF_HUB_CACHE=/mnt/disks/external_disk/maxdiffusion_hf_cache/ HF_TOKEN=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 LIBTPU_INIT_ARGS="--xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_reduce=true" HF_HUB_ENABLE_HF_TRANSFER=1 python src/maxdiffusion/generate_wan.py src/maxdiffusion/configs/base_wan_14b.yml attention="splash_wan" num_inference_steps=30 num_frames=81 width=832 height=480 jax_cache_dir=/tmp/jax_cache/ per_device_batch_size=.125 ici_data_parallelism=2 ici_fsdp_parallelism=4 enable_profiler=False run_name=wan-inference-testing output_dir=/tmp/ fps=16 flash_block_sizes='{"block_q" : 3024, "block_kv_compute" : 1024, "block_kv" : 2048, "block_q_dkv": 3024, "block_kv_dkv" : 2048, "block_kv_dkv_compute" : 2048, "block_q_dq" : 3024, "block_kv_dq" : 2048 }'

720p:

HF_HUB_DISABLE_XET=True \
HF_HUB_CACHE=/mnt/disks/external_disk/maxdiffusion_hf_cache/ \
HF_TOKEN=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 \
LIBTPU_INIT_ARGS="--xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_reduce=true" \
HF_HUB_ENABLE_HF_TRANSFER=1 \
python src/maxdiffusion/generate_wan.py \
  src/maxdiffusion/configs/base_wan_14b.yml \
  attention="flash" \
  num_inference_steps=30 \
  num_frames=81 \
  width=1280 \
  height=720 \
  jax_cache_dir=/tmp/jax_cache/ \
  per_device_batch_size=.125 \
  ici_data_parallelism=2 \
  ici_fsdp_parallelism=4 \
  flow_shift=5.0 \
  enable_profiler=False \
  run_name=wan-inference-testing-720p \
  output_dir=/tmp/ \
  fps=16 \
  flash_block_sizes='{"block_q" : 3024, "block_kv_compute" : 1024, "block_kv" : 2048, "block_q_dkv": 3024, "block_kv_dkv" : 2048, "block_kv_dkv_compute" : 2048, "block_q_dq" : 3024, "block_kv_dq" : 2048 }'

Copy link

google-cla bot commented Sep 18, 2025

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@entrpn
Copy link
Collaborator

entrpn commented Sep 18, 2025

I noticed that the running command doesn't set flash_min_seq_length=0, and by default this value is 4096, so I don't think you're running cross attention with splash attention, so the perf improvement is mostly coming from sharding over heads. Do you know if there is a difference in perf if we just use tensor: 4 instead of fsdp: 4 which would accomplish something similar (sharding across heads)?

@hyeygit
Copy link
Author

hyeygit commented Sep 19, 2025

I noticed that the running command doesn't set flash_min_seq_length=0, and by default this value is 4096, so I don't think you're running cross attention with splash attention, so the perf improvement is mostly coming from sharding over heads.

Good call out. I re-ran the benchmark with flash_min_seq_length=0 and got a slight performance improvement. So I guess sequence sharding for cross attention helps a bit.

720p: 197s (prev 200s, baseline 215s)

HF_HUB_DISABLE_XET=True \
HF_HUB_CACHE=/mnt/disks/external_disk/maxdiffusion_hf_cache/ \
HF_TOKEN=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 \
LIBTPU_INIT_ARGS="--xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_reduce=true" \
HF_HUB_ENABLE_HF_TRANSFER=1 \
python src/maxdiffusion/generate_wan.py \
  src/maxdiffusion/configs/base_wan_14b.yml \
  attention="flash" \
  num_inference_steps=30 \
  num_frames=81 \
  width=1280 \
  height=720 \
  jax_cache_dir=/tmp/jax_cache/ \
  per_device_batch_size=.125 \
  ici_data_parallelism=2 \
  ici_fsdp_parallelism=4 \
  flow_shift=5.0 \
  enable_profiler=False \
  run_name=wan-inference-testing-720p \
  output_dir=/tmp/ \
  fps=16 \
  flash_min_seq_length=0 \
  flash_block_sizes='{"block_q" : 3024, "block_kv_compute" : 1024, "block_kv" : 2048, "block_q_dkv": 3024, "block_kv_dkv" : 2048, "block_kv_dkv_compute" : 2048, "block_q_dq" : 3024, "block_kv_dq" : 2048 }' seed=118445

480p: 62s (prev 66s, baseline 78s)

HF_HUB_DISABLE_XET=True HF_HUB_CACHE=/mnt/disks/external_disk/maxdiffusion_hf_cache/ HF_TOKEN=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 LIBTPU_INIT_ARGS="--xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_reduce=true" HF_HUB_ENABLE_HF_TRANSFER=1 python src/maxdiffusion/generate_wan.py src/maxdiffusion/configs/base_wan_14b.yml attention="flash" num_inference_steps=30 num_frames=81 width=832 height=480 jax_cache_dir=/tmp/jax_cache/ per_device_batch_size=.125 ici_data_parallelism=2 ici_fsdp_parallelism=4 enable_profiler=False run_name=wan-inference-testing output_dir=/tmp/ fps=16 flash_block_sizes='{"block_q" : 3024, "block_kv_compute" : 1024, "block_kv" : 2048, "block_q_dkv": 3024, "block_kv_dkv" : 2048, "block_kv_dkv_compute" : 2048, "block_q_dq" : 3024, "block_kv_dq" : 2048 }' seed=118445 flash_min_seq_length=0

@hyeygit
Copy link
Author

hyeygit commented Sep 19, 2025

Do you know if there is a difference in perf if we just use tensor: 4 instead of fsdp: 4 which would accomplish something similar (sharding across heads)?

I tried running with ici_tensor_parallelism=4 at HEAD and the performance is worse. Specifically,

HF_HUB_DISABLE_XET=True \
HF_HUB_CACHE=/mnt/disks/external_disk/maxdiffusion_hf_cache/ \
HF_TOKEN=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 \
LIBTPU_INIT_ARGS="--xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_reduce=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_reduce=true" \
HF_HUB_ENABLE_HF_TRANSFER=1 \
python src/maxdiffusion/generate_wan.py \
  src/maxdiffusion/configs/base_wan_14b.yml \
  attention="flash" \
  num_inference_steps=30 \
  num_frames=81 \
  width=1280 \
  height=720 \
  jax_cache_dir=/tmp/jax_cache/ \
  per_device_batch_size=.125 \
  ici_data_parallelism=2 \
  ici_fsdp_parallelism=1 \
  ici_tensor_parallelism=4 \
  flow_shift=5.0 \
  enable_profiler=False \
  run_name=wan-inference-testing-720p \
  output_dir=/tmp/ \
  fps=16 \
  flash_min_seq_length=0 \
  flash_block_sizes='{"block_q" : 3024, "block_kv_compute" : 1024, "block_kv" : 2048, "block_q_dkv": 3024, "block_kv_dkv" : 2048, "block_kv_dkv_compute" : 2048, "block_q_dq" : 3024, "block_kv_dq" : 2048 }' seed=118445

yielded a generation time of 240s, worse than the 215s baseline (where data=2, fsdp=4, tensor=1).

@entrpn
Copy link
Collaborator

entrpn commented Sep 19, 2025

@hyeygit I verified the change on my side with 197 step time. I made a PR with these changes that better aligns the codeflow. PTLA #251

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants