Skip to content

Commit 8609f78

Browse files
committed
lint.
1 parent 8c4f774 commit 8609f78

File tree

2 files changed

+16
-5
lines changed

2 files changed

+16
-5
lines changed

src/maxdiffusion/models/attention_flax.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@
5353
CROSS_ATTN_KV_LENGTH = common_types.CROSS_ATTN_KV_LENGTH
5454

5555

56-
5756
def _maybe_aqt_einsum(quant: Quant):
5857
return jnp.einsum if quant is None else quant.einsum()
5958

@@ -448,7 +447,16 @@ def _apply_attention(
448447
)
449448
elif attention_kernel == "flash":
450449
return _tpu_flash_attention(
451-
query, key * scale, value, heads, mesh, axis_names_q, axis_names_kv, flash_block_sizes, dtype, attention_kernel,
450+
query,
451+
key * scale,
452+
value,
453+
heads,
454+
mesh,
455+
axis_names_q,
456+
axis_names_kv,
457+
flash_block_sizes,
458+
dtype,
459+
attention_kernel,
452460
)
453461
elif attention_kernel == "ring":
454462
return _tpu_flash_attention(
@@ -733,7 +741,7 @@ def __init__(
733741
else:
734742
axis_names_q = (BATCH, CROSS_ATTN_HEAD, CROSS_ATTN_Q_LENGTH, D_KV)
735743
axis_names_kv = (BATCH, CROSS_ATTN_HEAD, CROSS_ATTN_KV_LENGTH, D_KV)
736-
744+
737745
self.attention_op = NNXAttentionOp(
738746
mesh=mesh,
739747
attention_kernel=attention_kernel,
@@ -1542,4 +1550,4 @@ def setup(self):
15421550
def __call__(self, hidden_states, deterministic=True):
15431551
hidden_states = self.proj(hidden_states)
15441552
hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=2)
1545-
return self.dropout_layer(hidden_linear * nn.gelu(hidden_gelu), deterministic=deterministic)
1553+
return self.dropout_layer(hidden_linear * nn.gelu(hidden_gelu), deterministic=deterministic)

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,10 @@ def __call__(
353353
# 2. Cross-attention
354354
norm_hidden_states = self.norm2(hidden_states)
355355
attn_output = self.attn2(
356-
hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states, deterministic=deterministic, rngs=rngs,
356+
hidden_states=norm_hidden_states,
357+
encoder_hidden_states=encoder_hidden_states,
358+
deterministic=deterministic,
359+
rngs=rngs,
357360
)
358361
hidden_states = hidden_states + attn_output
359362

0 commit comments

Comments
 (0)