diff --git a/src/maxdiffusion/common_types.py b/src/maxdiffusion/common_types.py index f03864da..724e2313 100644 --- a/src/maxdiffusion/common_types.py +++ b/src/maxdiffusion/common_types.py @@ -33,7 +33,11 @@ BlockSizes = splash_attention_kernel.BlockSizes AxisNames = tuple[str, ...] - +# Physical axis names for device meshes. +DATA = "data" +FSDP = "fsdp" +TENSOR = "tensor" +# Logical axis names for model parameters and activations. BATCH = "activation_batch" LENGTH = "activation_length" KV_LENGTH = "activation_kv_length" @@ -44,4 +48,32 @@ KEEP_2 = "activation_keep_2" CONV_OUT = "activation_conv_out_channels" +# For setting self/cross attention independently in splash kernel +SELF_ATTN_HEAD = "activation_self_attn_heads" +SELF_ATTN_Q_LENGTH = "activation_self_attn_q_length" +SELF_ATTN_KV_LENGTH = "activation_self_attn_kv_length" +CROSS_ATTN_HEAD = "activation_cross_attn_heads" +CROSS_ATTN_Q_LENGTH = "activation_cross_attn_q_length" +CROSS_ATTN_KV_LENGTH = "activation_cross_attn_kv_length" + + WAN_MODEL = "Wan2.1" + +### Common axis rules for ring attention ### +RING_ATTENTION_AXIS_RULES = [ + [SELF_ATTN_HEAD, None], + [SELF_ATTN_Q_LENGTH, FSDP], + [SELF_ATTN_KV_LENGTH, FSDP], + [CROSS_ATTN_HEAD, None], + [CROSS_ATTN_Q_LENGTH, FSDP], + [CROSS_ATTN_KV_LENGTH, FSDP], +] + +SEQUENCE_PARALLEL_AXIS_RULES = [ + [SELF_ATTN_HEAD, None], + [SELF_ATTN_Q_LENGTH, FSDP], + [SELF_ATTN_KV_LENGTH, None], + [CROSS_ATTN_HEAD, None], + [CROSS_ATTN_Q_LENGTH, FSDP], + [CROSS_ATTN_KV_LENGTH, None], +] diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index 56fa47ca..c4dad191 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -68,10 +68,21 @@ flash_block_sizes: {} # "block_kv" : 2048, # "block_q_dkv" : 3024, # "block_kv_dkv" : 2048, -# "block_kv_dkv_compute" : 2048, +# "block_kv_dkv_compute" : 1024, # "block_q_dq" : 3024, # "block_kv_dq" : 2048 # } +# Use on v5p +flash_block_sizes: { + "block_q" : 1024, + "block_kv_compute" : 256, + "block_kv" : 3072, + "block_q_dkv" : 1024, + "block_kv_dkv" : 3072, + "block_kv_dkv_compute" : 256, + "block_q_dq" : 1024, + "block_kv_dq" : 3072 +} # GroupNorm groups norm_num_groups: 32 @@ -132,8 +143,9 @@ mesh_axes: ['data', 'fsdp', 'tensor'] logical_axis_rules: [ ['batch', 'data'], ['activation_batch', 'data'], + ['activation_self_attn_heads', ['fsdp', 'tensor']], + ['activation_cross_attn_q_length', ['fsdp', 'tensor']], ['activation_length', 'fsdp'], - ['activation_heads', 'tensor'], ['mlp','tensor'], ['embed','fsdp'], @@ -141,6 +153,7 @@ logical_axis_rules: [ ['norm', 'tensor'], ['conv_batch', ['data','fsdp']], ['out_channels', 'tensor'], + ['conv_in', 'fsdp'], ['conv_out', 'fsdp'], ] data_sharding: [['data', 'fsdp', 'tensor']] diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index 6638e0f8..47da450e 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -495,14 +495,14 @@ def get_flash_block_sizes(config): flash_block_sizes = None if len(config.flash_block_sizes.keys()) > 0: flash_block_sizes = splash_attention_kernel.BlockSizes( - block_q=config.flash_block_sizes["block_q"], - block_kv_compute=config.flash_block_sizes["block_kv_compute"], - block_kv=config.flash_block_sizes["block_kv"], - block_q_dkv=config.flash_block_sizes["block_q_dkv"], - block_kv_dkv=config.flash_block_sizes["block_kv_dkv"], - block_kv_dkv_compute=config.flash_block_sizes["block_kv_dkv_compute"], - block_q_dq=config.flash_block_sizes["block_q_dq"], - block_kv_dq=config.flash_block_sizes["block_kv_dq"], + block_q=int(config.flash_block_sizes["block_q"]), + block_kv_compute=int(config.flash_block_sizes["block_kv_compute"]), + block_kv=int(config.flash_block_sizes["block_kv"]), + block_q_dkv=int(config.flash_block_sizes["block_q_dkv"]), + block_kv_dkv=int(config.flash_block_sizes["block_kv_dkv"]), + block_kv_dkv_compute=int(config.flash_block_sizes["block_kv_dkv_compute"]), + block_q_dq=int(config.flash_block_sizes["block_q_dq"]), + block_kv_dq=int(config.flash_block_sizes["block_kv_dq"]), ) return flash_block_sizes diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index 5df5f334..b4bb5ed5 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -45,6 +45,13 @@ EMBED = common_types.EMBED Quant = quantizations.AqtQuantization +SELF_ATTN_HEAD = common_types.SELF_ATTN_HEAD +SELF_ATTN_Q_LENGTH = common_types.SELF_ATTN_Q_LENGTH +SELF_ATTN_KV_LENGTH = common_types.SELF_ATTN_KV_LENGTH +CROSS_ATTN_HEAD = common_types.CROSS_ATTN_HEAD +CROSS_ATTN_Q_LENGTH = common_types.CROSS_ATTN_Q_LENGTH +CROSS_ATTN_KV_LENGTH = common_types.CROSS_ATTN_KV_LENGTH + def _maybe_aqt_einsum(quant: Quant): return jnp.einsum if quant is None else quant.einsum() @@ -184,7 +191,8 @@ def _tpu_flash_attention( kv_max_block_size = key.shape[1] else: kv_max_block_size = q_max_block_size - if flash_block_sizes: + # ensure that for cross attention we override the block sizes. + if flash_block_sizes and key.shape[1] == query.shape[1]: block_sizes = flash_block_sizes else: block_sizes = splash_attention_kernel.BlockSizes( @@ -439,7 +447,16 @@ def _apply_attention( ) elif attention_kernel == "flash": return _tpu_flash_attention( - query, key * scale, value, heads, mesh, axis_names_q, axis_names_kv, flash_block_sizes, dtype + query, + key * scale, + value, + heads, + mesh, + axis_names_q, + axis_names_kv, + flash_block_sizes, + dtype, + attention_kernel, ) elif attention_kernel == "ring": return _tpu_flash_attention( @@ -701,6 +718,7 @@ def __init__( precision: jax.lax.Precision = None, qkv_bias: bool = False, quant: Quant = None, + is_self_attention: bool = True, ): if attention_kernel == "cudnn_flash_te": raise NotImplementedError(f"Wan 2.1 has not been tested with {attention_kernel}") @@ -717,6 +735,13 @@ def __init__( self.value_axis_names = value_axis_names self.out_axis_names = out_axis_names + if is_self_attention: + axis_names_q = (BATCH, SELF_ATTN_HEAD, SELF_ATTN_Q_LENGTH, D_KV) + axis_names_kv = (BATCH, SELF_ATTN_HEAD, SELF_ATTN_KV_LENGTH, D_KV) + else: + axis_names_q = (BATCH, CROSS_ATTN_HEAD, CROSS_ATTN_Q_LENGTH, D_KV) + axis_names_kv = (BATCH, CROSS_ATTN_HEAD, CROSS_ATTN_KV_LENGTH, D_KV) + self.attention_op = NNXAttentionOp( mesh=mesh, attention_kernel=attention_kernel, @@ -726,6 +751,8 @@ def __init__( use_memory_efficient_attention=use_memory_efficient_attention, split_head_dim=split_head_dim, float32_qk_product=False, + axis_names_q=axis_names_q, + axis_names_kv=axis_names_kv, flash_min_seq_length=flash_min_seq_length, flash_block_sizes=flash_block_sizes, dtype=dtype, diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan.py b/src/maxdiffusion/models/wan/transformers/transformer_wan.py index 48ed7b8e..f891bc83 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan.py @@ -282,6 +282,7 @@ def __init__( precision=precision, attention_kernel=attention, dropout=dropout, + is_self_attention=True, ) # 1. Cross-attention @@ -300,6 +301,7 @@ def __init__( precision=precision, attention_kernel=attention, dropout=dropout, + is_self_attention=False, ) assert cross_attn_norm is True self.norm2 = FP32LayerNorm(rngs=rngs, dim=dim, eps=eps, elementwise_affine=True) @@ -351,7 +353,10 @@ def __call__( # 2. Cross-attention norm_hidden_states = self.norm2(hidden_states) attn_output = self.attn2( - hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states, deterministic=deterministic, rngs=rngs + hidden_states=norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + deterministic=deterministic, + rngs=rngs, ) hidden_states = hidden_states + attn_output diff --git a/src/maxdiffusion/pyconfig.py b/src/maxdiffusion/pyconfig.py index 3bb5bd13..14e7fcb3 100644 --- a/src/maxdiffusion/pyconfig.py +++ b/src/maxdiffusion/pyconfig.py @@ -27,7 +27,7 @@ from . import max_logging from . import max_utils from .models.wan.wan_utils import CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH, WAN_21_FUSION_X_MODEL_NAME_OR_PATH -from maxdiffusion.common_types import LENGTH, KV_LENGTH +from maxdiffusion.common_types import LENGTH, KV_LENGTH, RING_ATTENTION_AXIS_RULES def string_to_bool(s: str) -> bool: @@ -180,14 +180,22 @@ def user_init(raw_keys): raw_keys["logical_axis_rules"] = _lists_to_tuples(raw_keys["logical_axis_rules"]) # Verify qkv is sharded across sequence. if raw_keys["attention"] == "ring": + max_logging.log("Using ring attention, adding sequence sharding to q and kv if not already present.") logical_axis_rules = list(raw_keys["logical_axis_rules"]) + max_logging.log(f"Initial logical axis rules: {logical_axis_rules}") + new_rules = [] q_seq_sharding = (LENGTH, "fsdp") kv_seq_sharding = (KV_LENGTH, "fsdp") if q_seq_sharding not in logical_axis_rules: logical_axis_rules.append(q_seq_sharding) if kv_seq_sharding not in logical_axis_rules: logical_axis_rules.append(kv_seq_sharding) - raw_keys["logical_axis_rules"] = tuple(logical_axis_rules) + for ring_attention_axis_rule in RING_ATTENTION_AXIS_RULES: + if ring_attention_axis_rule not in logical_axis_rules: + max_logging.log(f"Adding ring attention axis rule {ring_attention_axis_rule}") + new_rules.append(ring_attention_axis_rule) + raw_keys["logical_axis_rules"] = tuple(new_rules) + tuple(logical_axis_rules) + max_logging.log(f"Final logical axis rules: {raw_keys['logical_axis_rules']}") raw_keys["data_sharding"] = _lists_to_tuples(raw_keys["data_sharding"]) diff --git a/src/maxdiffusion/tests/wan_transformer_test.py b/src/maxdiffusion/tests/wan_transformer_test.py index 26ea0f02..bca9b747 100644 --- a/src/maxdiffusion/tests/wan_transformer_test.py +++ b/src/maxdiffusion/tests/wan_transformer_test.py @@ -23,7 +23,7 @@ from absl.testing import absltest from flax import nnx from jax.sharding import Mesh - +from flax.linen import partitioning as nn_partitioning from .. import pyconfig from ..max_utils import (create_device_mesh, get_flash_block_sizes) from ..models.wan.transformers.transformer_wan import ( @@ -48,6 +48,18 @@ class WanTransformerTest(unittest.TestCase): def setUp(self): WanTransformerTest.dummy_data = {} + pyconfig.initialize( + [ + None, + os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"), + ], + unittest=True, + ) + config = pyconfig.config + self.config = config + devices_array = create_device_mesh(config) + self.mesh = Mesh(devices_array, config.mesh_axes) + def test_rotary_pos_embed(self): batch_size = 1 @@ -65,18 +77,20 @@ def test_nnx_pixart_alpha_text_projection(self): key = jax.random.key(0) rngs = nnx.Rngs(key) dummy_caption = jnp.ones((1, 512, 4096)) - layer = NNXPixArtAlphaTextProjection(rngs=rngs, in_features=4096, hidden_size=5120) - dummy_output = layer(dummy_caption) - dummy_output.shape == (1, 512, 5120) + with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + layer = NNXPixArtAlphaTextProjection(rngs=rngs, in_features=4096, hidden_size=5120) + dummy_output = layer(dummy_caption) + dummy_output.shape == (1, 512, 5120) def test_nnx_timestep_embedding(self): key = jax.random.key(0) rngs = nnx.Rngs(key) dummy_sample = jnp.ones((1, 256)) - layer = NNXTimestepEmbedding(rngs=rngs, in_channels=256, time_embed_dim=5120) - dummy_output = layer(dummy_sample) - assert dummy_output.shape == (1, 5120) + with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + layer = NNXTimestepEmbedding(rngs=rngs, in_channels=256, time_embed_dim=5120) + dummy_output = layer(dummy_sample) + assert dummy_output.shape == (1, 5120) def test_fp32_layer_norm(self): key = jax.random.key(0) @@ -84,9 +98,10 @@ def test_fp32_layer_norm(self): batch_size = 1 dummy_hidden_states = jnp.ones((batch_size, 75600, 5120)) # expected same output shape with same dtype - layer = FP32LayerNorm(rngs=rngs, dim=5120, eps=1e-6, elementwise_affine=False) - dummy_output = layer(dummy_hidden_states) - assert dummy_output.shape == dummy_hidden_states.shape + with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + layer = FP32LayerNorm(rngs=rngs, dim=5120, eps=1e-6, elementwise_affine=False) + dummy_output = layer(dummy_hidden_states) + assert dummy_output.shape == dummy_hidden_states.shape @pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Don't run smoke tests on Github Actions") def test_wan_time_text_embedding(self): @@ -97,20 +112,21 @@ def test_wan_time_text_embedding(self): time_freq_dim = 256 time_proj_dim = 30720 text_embed_dim = 4096 - layer = WanTimeTextImageEmbedding( - rngs=rngs, dim=dim, time_freq_dim=time_freq_dim, time_proj_dim=time_proj_dim, text_embed_dim=text_embed_dim - ) + with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + layer = WanTimeTextImageEmbedding( + rngs=rngs, dim=dim, time_freq_dim=time_freq_dim, time_proj_dim=time_proj_dim, text_embed_dim=text_embed_dim + ) - dummy_timestep = jnp.ones(batch_size) + dummy_timestep = jnp.ones(batch_size) - encoder_hidden_states_shape = (batch_size, time_freq_dim * 2, text_embed_dim) - dummy_encoder_hidden_states = jnp.ones(encoder_hidden_states_shape) - temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = layer( - dummy_timestep, dummy_encoder_hidden_states - ) - assert temb.shape == (batch_size, dim) - assert timestep_proj.shape == (batch_size, time_proj_dim) - assert encoder_hidden_states.shape == (batch_size, time_freq_dim * 2, dim) + encoder_hidden_states_shape = (batch_size, time_freq_dim * 2, text_embed_dim) + dummy_encoder_hidden_states = jnp.ones(encoder_hidden_states_shape) + temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = layer( + dummy_timestep, dummy_encoder_hidden_states + ) + assert temb.shape == (batch_size, dim) + assert timestep_proj.shape == (batch_size, time_proj_dim) + assert encoder_hidden_states.shape == (batch_size, time_freq_dim * 2, dim) def test_wan_block(self): key = jax.random.key(0) @@ -158,20 +174,19 @@ def test_wan_block(self): dummy_encoder_hidden_states = jnp.ones((batch_size, 512, dim)) dummy_temb = jnp.ones((batch_size, 6, dim)) - - wan_block = WanTransformerBlock( - rngs=rngs, - dim=dim, - ffn_dim=ffn_dim, - num_heads=num_heads, - qk_norm=qk_norm, - cross_attn_norm=cross_attn_norm, - eps=eps, - attention="flash", - mesh=mesh, - flash_block_sizes=flash_block_sizes, - ) - with mesh: + with mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + wan_block = WanTransformerBlock( + rngs=rngs, + dim=dim, + ffn_dim=ffn_dim, + num_heads=num_heads, + qk_norm=qk_norm, + cross_attn_norm=cross_attn_norm, + eps=eps, + attention="flash", + mesh=mesh, + flash_block_sizes=flash_block_sizes, + ) dummy_output = wan_block(dummy_hidden_states, dummy_encoder_hidden_states, dummy_temb, dummy_rotary_emb) assert dummy_output.shape == dummy_hidden_states.shape @@ -204,40 +219,39 @@ def test_wan_attention(self): mesh = Mesh(devices_array, config.mesh_axes) batch_size = 1 query_dim = 5120 - attention = FlaxWanAttention( - rngs=rngs, - query_dim=query_dim, - heads=40, - dim_head=128, - attention_kernel="flash", - mesh=mesh, - flash_block_sizes=flash_block_sizes, - ) - - dummy_hidden_states_shape = (batch_size, 75600, query_dim) - - dummy_hidden_states = jnp.ones(dummy_hidden_states_shape) - dummy_encoder_hidden_states = jnp.ones(dummy_hidden_states_shape) - with mesh: - dummy_output = attention( - hidden_states=dummy_hidden_states, encoder_hidden_states=dummy_encoder_hidden_states, rotary_emb=dummy_rotary_emb - ) - assert dummy_output.shape == dummy_hidden_states_shape - - # dot product - try: + with mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): attention = FlaxWanAttention( rngs=rngs, query_dim=query_dim, heads=40, dim_head=128, - attention_kernel="dot_product", - split_head_dim=True, + attention_kernel="flash", mesh=mesh, flash_block_sizes=flash_block_sizes, ) - except NotImplementedError: - pass + dummy_hidden_states_shape = (batch_size, 75600, query_dim) + + dummy_hidden_states = jnp.ones(dummy_hidden_states_shape) + dummy_encoder_hidden_states = jnp.ones(dummy_hidden_states_shape) + dummy_output = attention( + hidden_states=dummy_hidden_states, encoder_hidden_states=dummy_encoder_hidden_states, rotary_emb=dummy_rotary_emb + ) + assert dummy_output.shape == dummy_hidden_states_shape + + # dot product + try: + attention = FlaxWanAttention( + rngs=rngs, + query_dim=query_dim, + heads=40, + dim_head=128, + attention_kernel="dot_product", + split_head_dim=True, + mesh=mesh, + flash_block_sizes=flash_block_sizes, + ) + except NotImplementedError: + pass @pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Don't run smoke tests on Github Actions") def test_wan_model(self): @@ -267,7 +281,8 @@ def test_wan_model(self): mesh = Mesh(devices_array, config.mesh_axes) batch_size = 1 num_layers = 1 - wan_model = WanModel(rngs=rngs, attention="flash", mesh=mesh, flash_block_sizes=flash_block_sizes, num_layers=num_layers) + with nn_partitioning.axis_rules(config.logical_axis_rules): + wan_model = WanModel(rngs=rngs, attention="flash", mesh=mesh, flash_block_sizes=flash_block_sizes, num_layers=num_layers) dummy_timestep = jnp.ones((batch_size)) dummy_encoder_hidden_states = jnp.ones((batch_size, 512, 4096)) diff --git a/src/maxdiffusion/tests/wan_vae_test.py b/src/maxdiffusion/tests/wan_vae_test.py index 7b131e7f..659c3992 100644 --- a/src/maxdiffusion/tests/wan_vae_test.py +++ b/src/maxdiffusion/tests/wan_vae_test.py @@ -22,6 +22,7 @@ import jax import jax.numpy as jnp from flax import nnx +from flax.linen import partitioning as nn_partitioning from jax.sharding import Mesh from .. import pyconfig from ..max_utils import ( @@ -160,6 +161,17 @@ class WanVaeTest(unittest.TestCase): def setUp(self): WanVaeTest.dummy_data = {} + pyconfig.initialize( + [ + None, + os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"), + ], + unittest=True, + ) + config = pyconfig.config + self.config = config + devices_array = create_device_mesh(config) + self.mesh = Mesh(devices_array, config.mesh_axes) def test_wanrms_norm(self): """Test against the Pytorch implementation""" @@ -209,12 +221,13 @@ def test_zero_padded_conv(self): output_torch = resample(input) assert output_torch.shape == (1, 96, 240, 360) - model = ZeroPaddedConv2D(dim=dim, rngs=rngs, kernel_size=(1, 3, 3), stride=(1, 2, 2)) - dummy_input = jnp.ones(input_shape) - dummy_input = jnp.transpose(dummy_input, (0, 2, 3, 1)) - output = model(dummy_input) - output = jnp.transpose(output, (0, 3, 1, 2)) - assert output.shape == (1, 96, 240, 360) + with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + model = ZeroPaddedConv2D(dim=dim, rngs=rngs, kernel_size=(1, 3, 3), stride=(1, 2, 2)) + dummy_input = jnp.ones(input_shape) + dummy_input = jnp.transpose(dummy_input, (0, 2, 3, 1)) + output = model(dummy_input) + output = jnp.transpose(output, (0, 3, 1, 2)) + assert output.shape == (1, 96, 240, 360) def test_wan_upsample(self): batch_size = 1 @@ -246,13 +259,13 @@ def test_wan_resample(self): torch_wan_resample = TorchWanResample(dim=dim, mode=mode) torch_output = torch_wan_resample(dummy_input) assert torch_output.shape == (batch, dim, t, h // 2, w // 2) - - wan_resample = WanResample(dim, mode=mode, rngs=rngs) - # channels is always last here - input_shape = (batch, t, h, w, dim) - dummy_input = jnp.ones(input_shape) - output = wan_resample(dummy_input) - assert output.shape == (batch, t, h // 2, w // 2, dim) + with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + wan_resample = WanResample(dim, mode=mode, rngs=rngs) + # channels is always last here + input_shape = (batch, t, h, w, dim) + dummy_input = jnp.ones(input_shape) + output = wan_resample(dummy_input) + assert output.shape == (batch, t, h // 2, w // 2, dim) def test_3d_conv(self): key = jax.random.key(0) @@ -283,28 +296,29 @@ def test_3d_conv(self): dummy_cache = jnp.zeros((batch_size, cache_depth, in_height, in_width, in_channels)) # Instantiate the module - causal_conv_layer = WanCausalConv3d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=(kernel_d, kernel_h, kernel_w), - padding=(padding_d, padding_h, padding_w), - rngs=rngs, # Pass rngs for initialization, - mesh=mesh, - ) + with self.mesh, nn_partitioning.axis_rules(config.logical_axis_rules): + causal_conv_layer = WanCausalConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(kernel_d, kernel_h, kernel_w), + padding=(padding_d, padding_h, padding_w), + rngs=rngs, # Pass rngs for initialization, + mesh=mesh, + ) - # --- Test Case 1: No Cache --- - output_no_cache = causal_conv_layer(dummy_input) - assert output_no_cache.shape == (1, 10, 32, 32, 16) + # --- Test Case 1: No Cache --- + output_no_cache = causal_conv_layer(dummy_input) + assert output_no_cache.shape == (1, 10, 32, 32, 16) - # --- Test Case 2: With Cache --- - output_with_cache = causal_conv_layer(dummy_input, cache_x=dummy_cache) - assert output_with_cache.shape == (1, 10, 32, 32, 16) + # --- Test Case 2: With Cache --- + output_with_cache = causal_conv_layer(dummy_input, cache_x=dummy_cache) + assert output_with_cache.shape == (1, 10, 32, 32, 16) - # --- Test Case 3: With Cache larger than padding --- - larger_cache_depth = 4 # Larger than needed padding (2*padding_d = 2) - dummy_larger_cache = jnp.zeros((batch_size, larger_cache_depth, in_height, in_width, in_channels)) - output_with_larger_cache = causal_conv_layer(dummy_input, cache_x=dummy_larger_cache) - assert output_with_larger_cache.shape == (1, 10, 32, 32, 16) + # --- Test Case 3: With Cache larger than padding --- + larger_cache_depth = 4 # Larger than needed padding (2*padding_d = 2) + dummy_larger_cache = jnp.zeros((batch_size, larger_cache_depth, in_height, in_width, in_channels)) + output_with_larger_cache = causal_conv_layer(dummy_input, cache_x=dummy_larger_cache) + assert output_with_larger_cache.shape == (1, 10, 32, 32, 16) def test_wan_residual(self): key = jax.random.key(0) @@ -328,21 +342,20 @@ def test_wan_residual(self): dim = 96 input_shape = (batch, t, height, width, dim) expected_output_shape = (batch, t, height, width, dim) - - wan_residual_block = WanResidualBlock(in_dim=in_dim, out_dim=out_dim, rngs=rngs, mesh=mesh) - dummy_input = jnp.ones(input_shape) - dummy_output = wan_residual_block(dummy_input) - assert dummy_output.shape == expected_output_shape - - # --- Test Case 1: different in/out dim --- - in_dim = 96 - out_dim = 196 - expected_output_shape = (batch, t, height, width, out_dim) - - wan_residual_block = WanResidualBlock(in_dim=in_dim, out_dim=out_dim, rngs=rngs, mesh=mesh) - dummy_input = jnp.ones(input_shape) - dummy_output = wan_residual_block(dummy_input) - assert dummy_output.shape == expected_output_shape + with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): + wan_residual_block = WanResidualBlock(in_dim=in_dim, out_dim=out_dim, rngs=rngs, mesh=mesh) + dummy_input = jnp.ones(input_shape) + dummy_output = wan_residual_block(dummy_input) + assert dummy_output.shape == expected_output_shape + # --- Test Case 1: different in/out dim --- + in_dim = 96 + out_dim = 196 + expected_output_shape = (batch, t, height, width, out_dim) + + wan_residual_block = WanResidualBlock(in_dim=in_dim, out_dim=out_dim, rngs=rngs, mesh=mesh) + dummy_input = jnp.ones(input_shape) + dummy_output = wan_residual_block(dummy_input) + assert dummy_output.shape == expected_output_shape def test_wan_attention(self): key = jax.random.key(0) @@ -353,10 +366,11 @@ def test_wan_attention(self): height = 60 width = 90 input_shape = (batch, t, height, width, dim) - wan_attention = WanAttentionBlock(dim=dim, rngs=rngs) - dummy_input = jnp.ones(input_shape) - output = wan_attention(dummy_input) - assert output.shape == input_shape + with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + wan_attention = WanAttentionBlock(dim=dim, rngs=rngs) + dummy_input = jnp.ones(input_shape) + output = wan_attention(dummy_input) + assert output.shape == input_shape def test_wan_midblock(self): key = jax.random.key(0) @@ -377,10 +391,11 @@ def test_wan_midblock(self): height = 60 width = 90 input_shape = (batch, t, height, width, dim) - wan_midblock = WanMidBlock(dim=dim, rngs=rngs, mesh=mesh) - dummy_input = jnp.ones(input_shape) - output = wan_midblock(dummy_input) - assert output.shape == input_shape + with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): + wan_midblock = WanMidBlock(dim=dim, rngs=rngs, mesh=mesh) + dummy_input = jnp.ones(input_shape) + output = wan_midblock(dummy_input) + assert output.shape == input_shape def test_wan_decode(self): key = jax.random.key(0) @@ -401,30 +416,31 @@ def test_wan_decode(self): num_res_blocks = 2 attn_scales = [] temperal_downsample = [False, True, True] - wan_vae = AutoencoderKLWan( - rngs=rngs, - base_dim=dim, - z_dim=z_dim, - dim_mult=dim_mult, - num_res_blocks=num_res_blocks, - attn_scales=attn_scales, - temperal_downsample=temperal_downsample, - mesh=mesh, - ) - vae_cache = AutoencoderKLWanCache(wan_vae) - batch = 1 - t = 13 - channels = 16 - height = 60 - width = 90 - input_shape = (batch, t, height, width, channels) - input = jnp.ones(input_shape) - - latents_mean = jnp.array(wan_vae.latents_mean).reshape(1, 1, 1, 1, wan_vae.z_dim) - latents_std = 1.0 / jnp.array(wan_vae.latents_std).reshape(1, 1, 1, 1, wan_vae.z_dim) - input = input / latents_std + latents_mean - dummy_output = wan_vae.decode(input, feat_cache=vae_cache) - assert dummy_output.sample.shape == (batch, 49, 480, 720, 3) + with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): + wan_vae = AutoencoderKLWan( + rngs=rngs, + base_dim=dim, + z_dim=z_dim, + dim_mult=dim_mult, + num_res_blocks=num_res_blocks, + attn_scales=attn_scales, + temperal_downsample=temperal_downsample, + mesh=mesh, + ) + vae_cache = AutoencoderKLWanCache(wan_vae) + batch = 1 + t = 13 + channels = 16 + height = 60 + width = 90 + input_shape = (batch, t, height, width, channels) + input = jnp.ones(input_shape) + + latents_mean = jnp.array(wan_vae.latents_mean).reshape(1, 1, 1, 1, wan_vae.z_dim) + latents_std = 1.0 / jnp.array(wan_vae.latents_std).reshape(1, 1, 1, 1, wan_vae.z_dim) + input = input / latents_std + latents_mean + dummy_output = wan_vae.decode(input, feat_cache=vae_cache) + assert dummy_output.sample.shape == (batch, 49, 480, 720, 3) def test_wan_encode(self): key = jax.random.key(0) @@ -445,26 +461,27 @@ def test_wan_encode(self): num_res_blocks = 2 attn_scales = [] temperal_downsample = [False, True, True] - wan_vae = AutoencoderKLWan( - rngs=rngs, - base_dim=dim, - z_dim=z_dim, - dim_mult=dim_mult, - num_res_blocks=num_res_blocks, - attn_scales=attn_scales, - temperal_downsample=temperal_downsample, - mesh=mesh, - ) - vae_cache = AutoencoderKLWanCache(wan_vae) - batch = 1 - channels = 3 - t = 49 - height = 480 - width = 720 - input_shape = (batch, channels, t, height, width) - input = jnp.ones(input_shape) - output = wan_vae.encode(input, feat_cache=vae_cache) - assert output.latent_dist.sample(key).shape == (1, 13, 60, 90, 16) + with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): + wan_vae = AutoencoderKLWan( + rngs=rngs, + base_dim=dim, + z_dim=z_dim, + dim_mult=dim_mult, + num_res_blocks=num_res_blocks, + attn_scales=attn_scales, + temperal_downsample=temperal_downsample, + mesh=mesh, + ) + vae_cache = AutoencoderKLWanCache(wan_vae) + batch = 1 + channels = 3 + t = 49 + height = 480 + width = 720 + input_shape = (batch, channels, t, height, width) + input = jnp.ones(input_shape) + output = wan_vae.encode(input, feat_cache=vae_cache) + assert output.latent_dist.sample(key).shape == (1, 13, 60, 90, 16) def test_load_checkpoint(self): def vae_encode(video, wan_vae, vae_cache, key): @@ -484,9 +501,9 @@ def vae_encode(video, wan_vae, vae_cache, key): config = pyconfig.config devices_array = create_device_mesh(config) mesh = Mesh(devices_array, config.mesh_axes) - - wan_vae = AutoencoderKLWan.from_config(config.pretrained_model_name_or_path, subfolder="vae", rngs=rngs, mesh=mesh) - vae_cache = AutoencoderKLWanCache(wan_vae) + with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + wan_vae = AutoencoderKLWan.from_config(config.pretrained_model_name_or_path, subfolder="vae", rngs=rngs, mesh=mesh) + vae_cache = AutoencoderKLWanCache(wan_vae) video_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/hiker.mp4" video = load_video(video_path)