23
23
from absl .testing import absltest
24
24
from flax import nnx
25
25
from jax .sharding import Mesh
26
-
26
+ from flax . linen import partitioning as nn_partitioning
27
27
from .. import pyconfig
28
28
from ..max_utils import (create_device_mesh , get_flash_block_sizes )
29
29
from ..models .wan .transformers .transformer_wan import (
@@ -48,6 +48,18 @@ class WanTransformerTest(unittest.TestCase):
48
48
49
49
def setUp (self ):
50
50
WanTransformerTest .dummy_data = {}
51
+ pyconfig .initialize (
52
+ [
53
+ None ,
54
+ os .path .join (THIS_DIR , ".." , "configs" , "base_wan_14b.yml" ),
55
+ ],
56
+ unittest = True ,
57
+ )
58
+ config = pyconfig .config
59
+ self .config = config
60
+ devices_array = create_device_mesh (config )
61
+ self .mesh = Mesh (devices_array , config .mesh_axes )
62
+
51
63
52
64
def test_rotary_pos_embed (self ):
53
65
batch_size = 1
@@ -65,28 +77,31 @@ def test_nnx_pixart_alpha_text_projection(self):
65
77
key = jax .random .key (0 )
66
78
rngs = nnx .Rngs (key )
67
79
dummy_caption = jnp .ones ((1 , 512 , 4096 ))
68
- layer = NNXPixArtAlphaTextProjection (rngs = rngs , in_features = 4096 , hidden_size = 5120 )
69
- dummy_output = layer (dummy_caption )
70
- dummy_output .shape == (1 , 512 , 5120 )
80
+ with self .mesh , nn_partitioning .axis_rules (self .config .logical_axis_rules ):
81
+ layer = NNXPixArtAlphaTextProjection (rngs = rngs , in_features = 4096 , hidden_size = 5120 )
82
+ dummy_output = layer (dummy_caption )
83
+ dummy_output .shape == (1 , 512 , 5120 )
71
84
72
85
def test_nnx_timestep_embedding (self ):
73
86
key = jax .random .key (0 )
74
87
rngs = nnx .Rngs (key )
75
88
76
89
dummy_sample = jnp .ones ((1 , 256 ))
77
- layer = NNXTimestepEmbedding (rngs = rngs , in_channels = 256 , time_embed_dim = 5120 )
78
- dummy_output = layer (dummy_sample )
79
- assert dummy_output .shape == (1 , 5120 )
90
+ with self .mesh , nn_partitioning .axis_rules (self .config .logical_axis_rules ):
91
+ layer = NNXTimestepEmbedding (rngs = rngs , in_channels = 256 , time_embed_dim = 5120 )
92
+ dummy_output = layer (dummy_sample )
93
+ assert dummy_output .shape == (1 , 5120 )
80
94
81
95
def test_fp32_layer_norm (self ):
82
96
key = jax .random .key (0 )
83
97
rngs = nnx .Rngs (key )
84
98
batch_size = 1
85
99
dummy_hidden_states = jnp .ones ((batch_size , 75600 , 5120 ))
86
100
# expected same output shape with same dtype
87
- layer = FP32LayerNorm (rngs = rngs , dim = 5120 , eps = 1e-6 , elementwise_affine = False )
88
- dummy_output = layer (dummy_hidden_states )
89
- assert dummy_output .shape == dummy_hidden_states .shape
101
+ with self .mesh , nn_partitioning .axis_rules (self .config .logical_axis_rules ):
102
+ layer = FP32LayerNorm (rngs = rngs , dim = 5120 , eps = 1e-6 , elementwise_affine = False )
103
+ dummy_output = layer (dummy_hidden_states )
104
+ assert dummy_output .shape == dummy_hidden_states .shape
90
105
91
106
@pytest .mark .skipif (IN_GITHUB_ACTIONS , reason = "Don't run smoke tests on Github Actions" )
92
107
def test_wan_time_text_embedding (self ):
@@ -97,20 +112,21 @@ def test_wan_time_text_embedding(self):
97
112
time_freq_dim = 256
98
113
time_proj_dim = 30720
99
114
text_embed_dim = 4096
100
- layer = WanTimeTextImageEmbedding (
101
- rngs = rngs , dim = dim , time_freq_dim = time_freq_dim , time_proj_dim = time_proj_dim , text_embed_dim = text_embed_dim
102
- )
115
+ with self .mesh , nn_partitioning .axis_rules (self .config .logical_axis_rules ):
116
+ layer = WanTimeTextImageEmbedding (
117
+ rngs = rngs , dim = dim , time_freq_dim = time_freq_dim , time_proj_dim = time_proj_dim , text_embed_dim = text_embed_dim
118
+ )
103
119
104
- dummy_timestep = jnp .ones (batch_size )
120
+ dummy_timestep = jnp .ones (batch_size )
105
121
106
- encoder_hidden_states_shape = (batch_size , time_freq_dim * 2 , text_embed_dim )
107
- dummy_encoder_hidden_states = jnp .ones (encoder_hidden_states_shape )
108
- temb , timestep_proj , encoder_hidden_states , encoder_hidden_states_image = layer (
109
- dummy_timestep , dummy_encoder_hidden_states
110
- )
111
- assert temb .shape == (batch_size , dim )
112
- assert timestep_proj .shape == (batch_size , time_proj_dim )
113
- assert encoder_hidden_states .shape == (batch_size , time_freq_dim * 2 , dim )
122
+ encoder_hidden_states_shape = (batch_size , time_freq_dim * 2 , text_embed_dim )
123
+ dummy_encoder_hidden_states = jnp .ones (encoder_hidden_states_shape )
124
+ temb , timestep_proj , encoder_hidden_states , encoder_hidden_states_image = layer (
125
+ dummy_timestep , dummy_encoder_hidden_states
126
+ )
127
+ assert temb .shape == (batch_size , dim )
128
+ assert timestep_proj .shape == (batch_size , time_proj_dim )
129
+ assert encoder_hidden_states .shape == (batch_size , time_freq_dim * 2 , dim )
114
130
115
131
def test_wan_block (self ):
116
132
key = jax .random .key (0 )
@@ -158,20 +174,19 @@ def test_wan_block(self):
158
174
dummy_encoder_hidden_states = jnp .ones ((batch_size , 512 , dim ))
159
175
160
176
dummy_temb = jnp .ones ((batch_size , 6 , dim ))
161
-
162
- wan_block = WanTransformerBlock (
163
- rngs = rngs ,
164
- dim = dim ,
165
- ffn_dim = ffn_dim ,
166
- num_heads = num_heads ,
167
- qk_norm = qk_norm ,
168
- cross_attn_norm = cross_attn_norm ,
169
- eps = eps ,
170
- attention = "flash" ,
171
- mesh = mesh ,
172
- flash_block_sizes = flash_block_sizes ,
173
- )
174
- with mesh :
177
+ with mesh , nn_partitioning .axis_rules (self .config .logical_axis_rules ):
178
+ wan_block = WanTransformerBlock (
179
+ rngs = rngs ,
180
+ dim = dim ,
181
+ ffn_dim = ffn_dim ,
182
+ num_heads = num_heads ,
183
+ qk_norm = qk_norm ,
184
+ cross_attn_norm = cross_attn_norm ,
185
+ eps = eps ,
186
+ attention = "flash" ,
187
+ mesh = mesh ,
188
+ flash_block_sizes = flash_block_sizes ,
189
+ )
175
190
dummy_output = wan_block (dummy_hidden_states , dummy_encoder_hidden_states , dummy_temb , dummy_rotary_emb )
176
191
assert dummy_output .shape == dummy_hidden_states .shape
177
192
@@ -204,40 +219,39 @@ def test_wan_attention(self):
204
219
mesh = Mesh (devices_array , config .mesh_axes )
205
220
batch_size = 1
206
221
query_dim = 5120
207
- attention = FlaxWanAttention (
208
- rngs = rngs ,
209
- query_dim = query_dim ,
210
- heads = 40 ,
211
- dim_head = 128 ,
212
- attention_kernel = "flash" ,
213
- mesh = mesh ,
214
- flash_block_sizes = flash_block_sizes ,
215
- )
216
-
217
- dummy_hidden_states_shape = (batch_size , 75600 , query_dim )
218
-
219
- dummy_hidden_states = jnp .ones (dummy_hidden_states_shape )
220
- dummy_encoder_hidden_states = jnp .ones (dummy_hidden_states_shape )
221
- with mesh :
222
- dummy_output = attention (
223
- hidden_states = dummy_hidden_states , encoder_hidden_states = dummy_encoder_hidden_states , rotary_emb = dummy_rotary_emb
224
- )
225
- assert dummy_output .shape == dummy_hidden_states_shape
226
-
227
- # dot product
228
- try :
222
+ with mesh , nn_partitioning .axis_rules (self .config .logical_axis_rules ):
229
223
attention = FlaxWanAttention (
230
224
rngs = rngs ,
231
225
query_dim = query_dim ,
232
226
heads = 40 ,
233
227
dim_head = 128 ,
234
- attention_kernel = "dot_product" ,
235
- split_head_dim = True ,
228
+ attention_kernel = "flash" ,
236
229
mesh = mesh ,
237
230
flash_block_sizes = flash_block_sizes ,
238
231
)
239
- except NotImplementedError :
240
- pass
232
+ dummy_hidden_states_shape = (batch_size , 75600 , query_dim )
233
+
234
+ dummy_hidden_states = jnp .ones (dummy_hidden_states_shape )
235
+ dummy_encoder_hidden_states = jnp .ones (dummy_hidden_states_shape )
236
+ dummy_output = attention (
237
+ hidden_states = dummy_hidden_states , encoder_hidden_states = dummy_encoder_hidden_states , rotary_emb = dummy_rotary_emb
238
+ )
239
+ assert dummy_output .shape == dummy_hidden_states_shape
240
+
241
+ # dot product
242
+ try :
243
+ attention = FlaxWanAttention (
244
+ rngs = rngs ,
245
+ query_dim = query_dim ,
246
+ heads = 40 ,
247
+ dim_head = 128 ,
248
+ attention_kernel = "dot_product" ,
249
+ split_head_dim = True ,
250
+ mesh = mesh ,
251
+ flash_block_sizes = flash_block_sizes ,
252
+ )
253
+ except NotImplementedError :
254
+ pass
241
255
242
256
@pytest .mark .skipif (IN_GITHUB_ACTIONS , reason = "Don't run smoke tests on Github Actions" )
243
257
def test_wan_model (self ):
@@ -267,7 +281,8 @@ def test_wan_model(self):
267
281
mesh = Mesh (devices_array , config .mesh_axes )
268
282
batch_size = 1
269
283
num_layers = 1
270
- wan_model = WanModel (rngs = rngs , attention = "flash" , mesh = mesh , flash_block_sizes = flash_block_sizes , num_layers = num_layers )
284
+ with nn_partitioning .axis_rules (config .logical_axis_rules ):
285
+ wan_model = WanModel (rngs = rngs , attention = "flash" , mesh = mesh , flash_block_sizes = flash_block_sizes , num_layers = num_layers )
271
286
272
287
dummy_timestep = jnp .ones ((batch_size ))
273
288
dummy_encoder_hidden_states = jnp .ones ((batch_size , 512 , 4096 ))
0 commit comments