Skip to content

Commit d1f3dcd

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent b6933b6 commit d1f3dcd

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

transformer_engine/jax/flax/transformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -633,7 +633,7 @@ def __post_init__(self):
633633
def _assert_dtypes(self, query: Array, key: Array, value: Array, qkv_layout: QKVLayout):
634634
"""Asserts that the dtypes of query, key, and value dtypes are consistent."""
635635
if qkv_layout.is_qkvpacked():
636-
pass # No need to check dtypes for key and value since it is packed
636+
pass # No need to check dtypes for key and value since it is packed
637637
elif qkv_layout.is_kvpacked():
638638
assert (
639639
key.dtype == query.dtype

0 commit comments

Comments
 (0)