Skip to content

Commit fc17634

Browse files
committed
test fix
1 parent 7861e25 commit fc17634

File tree

3 files changed

+198
-165
lines changed

3 files changed

+198
-165
lines changed

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ logical_axis_rules: [
153153
['norm', 'tensor'],
154154
['conv_batch', ['data','fsdp']],
155155
['out_channels', 'tensor'],
156+
['conv_in', 'fsdp'],
156157
['conv_out', 'fsdp'],
157158
]
158159
data_sharding: [['data', 'fsdp', 'tensor']]

src/maxdiffusion/tests/wan_transformer_test.py

Lines changed: 78 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from absl.testing import absltest
2424
from flax import nnx
2525
from jax.sharding import Mesh
26-
26+
from flax.linen import partitioning as nn_partitioning
2727
from .. import pyconfig
2828
from ..max_utils import (create_device_mesh, get_flash_block_sizes)
2929
from ..models.wan.transformers.transformer_wan import (
@@ -48,6 +48,18 @@ class WanTransformerTest(unittest.TestCase):
4848

4949
def setUp(self):
5050
WanTransformerTest.dummy_data = {}
51+
pyconfig.initialize(
52+
[
53+
None,
54+
os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"),
55+
],
56+
unittest=True,
57+
)
58+
config = pyconfig.config
59+
self.config = config
60+
devices_array = create_device_mesh(config)
61+
self.mesh = Mesh(devices_array, config.mesh_axes)
62+
5163

5264
def test_rotary_pos_embed(self):
5365
batch_size = 1
@@ -65,28 +77,31 @@ def test_nnx_pixart_alpha_text_projection(self):
6577
key = jax.random.key(0)
6678
rngs = nnx.Rngs(key)
6779
dummy_caption = jnp.ones((1, 512, 4096))
68-
layer = NNXPixArtAlphaTextProjection(rngs=rngs, in_features=4096, hidden_size=5120)
69-
dummy_output = layer(dummy_caption)
70-
dummy_output.shape == (1, 512, 5120)
80+
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
81+
layer = NNXPixArtAlphaTextProjection(rngs=rngs, in_features=4096, hidden_size=5120)
82+
dummy_output = layer(dummy_caption)
83+
dummy_output.shape == (1, 512, 5120)
7184

7285
def test_nnx_timestep_embedding(self):
7386
key = jax.random.key(0)
7487
rngs = nnx.Rngs(key)
7588

7689
dummy_sample = jnp.ones((1, 256))
77-
layer = NNXTimestepEmbedding(rngs=rngs, in_channels=256, time_embed_dim=5120)
78-
dummy_output = layer(dummy_sample)
79-
assert dummy_output.shape == (1, 5120)
90+
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
91+
layer = NNXTimestepEmbedding(rngs=rngs, in_channels=256, time_embed_dim=5120)
92+
dummy_output = layer(dummy_sample)
93+
assert dummy_output.shape == (1, 5120)
8094

8195
def test_fp32_layer_norm(self):
8296
key = jax.random.key(0)
8397
rngs = nnx.Rngs(key)
8498
batch_size = 1
8599
dummy_hidden_states = jnp.ones((batch_size, 75600, 5120))
86100
# expected same output shape with same dtype
87-
layer = FP32LayerNorm(rngs=rngs, dim=5120, eps=1e-6, elementwise_affine=False)
88-
dummy_output = layer(dummy_hidden_states)
89-
assert dummy_output.shape == dummy_hidden_states.shape
101+
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
102+
layer = FP32LayerNorm(rngs=rngs, dim=5120, eps=1e-6, elementwise_affine=False)
103+
dummy_output = layer(dummy_hidden_states)
104+
assert dummy_output.shape == dummy_hidden_states.shape
90105

91106
@pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Don't run smoke tests on Github Actions")
92107
def test_wan_time_text_embedding(self):
@@ -97,20 +112,21 @@ def test_wan_time_text_embedding(self):
97112
time_freq_dim = 256
98113
time_proj_dim = 30720
99114
text_embed_dim = 4096
100-
layer = WanTimeTextImageEmbedding(
101-
rngs=rngs, dim=dim, time_freq_dim=time_freq_dim, time_proj_dim=time_proj_dim, text_embed_dim=text_embed_dim
102-
)
115+
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
116+
layer = WanTimeTextImageEmbedding(
117+
rngs=rngs, dim=dim, time_freq_dim=time_freq_dim, time_proj_dim=time_proj_dim, text_embed_dim=text_embed_dim
118+
)
103119

104-
dummy_timestep = jnp.ones(batch_size)
120+
dummy_timestep = jnp.ones(batch_size)
105121

106-
encoder_hidden_states_shape = (batch_size, time_freq_dim * 2, text_embed_dim)
107-
dummy_encoder_hidden_states = jnp.ones(encoder_hidden_states_shape)
108-
temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = layer(
109-
dummy_timestep, dummy_encoder_hidden_states
110-
)
111-
assert temb.shape == (batch_size, dim)
112-
assert timestep_proj.shape == (batch_size, time_proj_dim)
113-
assert encoder_hidden_states.shape == (batch_size, time_freq_dim * 2, dim)
122+
encoder_hidden_states_shape = (batch_size, time_freq_dim * 2, text_embed_dim)
123+
dummy_encoder_hidden_states = jnp.ones(encoder_hidden_states_shape)
124+
temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = layer(
125+
dummy_timestep, dummy_encoder_hidden_states
126+
)
127+
assert temb.shape == (batch_size, dim)
128+
assert timestep_proj.shape == (batch_size, time_proj_dim)
129+
assert encoder_hidden_states.shape == (batch_size, time_freq_dim * 2, dim)
114130

115131
def test_wan_block(self):
116132
key = jax.random.key(0)
@@ -158,20 +174,19 @@ def test_wan_block(self):
158174
dummy_encoder_hidden_states = jnp.ones((batch_size, 512, dim))
159175

160176
dummy_temb = jnp.ones((batch_size, 6, dim))
161-
162-
wan_block = WanTransformerBlock(
163-
rngs=rngs,
164-
dim=dim,
165-
ffn_dim=ffn_dim,
166-
num_heads=num_heads,
167-
qk_norm=qk_norm,
168-
cross_attn_norm=cross_attn_norm,
169-
eps=eps,
170-
attention="flash",
171-
mesh=mesh,
172-
flash_block_sizes=flash_block_sizes,
173-
)
174-
with mesh:
177+
with mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
178+
wan_block = WanTransformerBlock(
179+
rngs=rngs,
180+
dim=dim,
181+
ffn_dim=ffn_dim,
182+
num_heads=num_heads,
183+
qk_norm=qk_norm,
184+
cross_attn_norm=cross_attn_norm,
185+
eps=eps,
186+
attention="flash",
187+
mesh=mesh,
188+
flash_block_sizes=flash_block_sizes,
189+
)
175190
dummy_output = wan_block(dummy_hidden_states, dummy_encoder_hidden_states, dummy_temb, dummy_rotary_emb)
176191
assert dummy_output.shape == dummy_hidden_states.shape
177192

@@ -204,40 +219,39 @@ def test_wan_attention(self):
204219
mesh = Mesh(devices_array, config.mesh_axes)
205220
batch_size = 1
206221
query_dim = 5120
207-
attention = FlaxWanAttention(
208-
rngs=rngs,
209-
query_dim=query_dim,
210-
heads=40,
211-
dim_head=128,
212-
attention_kernel="flash",
213-
mesh=mesh,
214-
flash_block_sizes=flash_block_sizes,
215-
)
216-
217-
dummy_hidden_states_shape = (batch_size, 75600, query_dim)
218-
219-
dummy_hidden_states = jnp.ones(dummy_hidden_states_shape)
220-
dummy_encoder_hidden_states = jnp.ones(dummy_hidden_states_shape)
221-
with mesh:
222-
dummy_output = attention(
223-
hidden_states=dummy_hidden_states, encoder_hidden_states=dummy_encoder_hidden_states, rotary_emb=dummy_rotary_emb
224-
)
225-
assert dummy_output.shape == dummy_hidden_states_shape
226-
227-
# dot product
228-
try:
222+
with mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
229223
attention = FlaxWanAttention(
230224
rngs=rngs,
231225
query_dim=query_dim,
232226
heads=40,
233227
dim_head=128,
234-
attention_kernel="dot_product",
235-
split_head_dim=True,
228+
attention_kernel="flash",
236229
mesh=mesh,
237230
flash_block_sizes=flash_block_sizes,
238231
)
239-
except NotImplementedError:
240-
pass
232+
dummy_hidden_states_shape = (batch_size, 75600, query_dim)
233+
234+
dummy_hidden_states = jnp.ones(dummy_hidden_states_shape)
235+
dummy_encoder_hidden_states = jnp.ones(dummy_hidden_states_shape)
236+
dummy_output = attention(
237+
hidden_states=dummy_hidden_states, encoder_hidden_states=dummy_encoder_hidden_states, rotary_emb=dummy_rotary_emb
238+
)
239+
assert dummy_output.shape == dummy_hidden_states_shape
240+
241+
# dot product
242+
try:
243+
attention = FlaxWanAttention(
244+
rngs=rngs,
245+
query_dim=query_dim,
246+
heads=40,
247+
dim_head=128,
248+
attention_kernel="dot_product",
249+
split_head_dim=True,
250+
mesh=mesh,
251+
flash_block_sizes=flash_block_sizes,
252+
)
253+
except NotImplementedError:
254+
pass
241255

242256
@pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Don't run smoke tests on Github Actions")
243257
def test_wan_model(self):
@@ -267,7 +281,8 @@ def test_wan_model(self):
267281
mesh = Mesh(devices_array, config.mesh_axes)
268282
batch_size = 1
269283
num_layers = 1
270-
wan_model = WanModel(rngs=rngs, attention="flash", mesh=mesh, flash_block_sizes=flash_block_sizes, num_layers=num_layers)
284+
with nn_partitioning.axis_rules(config.logical_axis_rules):
285+
wan_model = WanModel(rngs=rngs, attention="flash", mesh=mesh, flash_block_sizes=flash_block_sizes, num_layers=num_layers)
271286

272287
dummy_timestep = jnp.ones((batch_size))
273288
dummy_encoder_hidden_states = jnp.ones((batch_size, 512, 4096))

0 commit comments

Comments
 (0)