Skip to content

Commit 36d34fd

Browse files
committed
fix enable memory efficient attention on ROCm
while calling CK implementation
1 parent b0c8973 commit 36d34fd

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

src/diffusers/models/attention_processor.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -399,11 +399,12 @@ def set_use_memory_efficient_attention_xformers(
399399
else:
400400
try:
401401
# Make sure we can run the memory efficient attention
402-
_ = xformers.ops.memory_efficient_attention(
403-
torch.randn((1, 2, 40), device="cuda"),
404-
torch.randn((1, 2, 40), device="cuda"),
405-
torch.randn((1, 2, 40), device="cuda"),
406-
)
402+
dtype = None
403+
if attention_op is not None:
404+
op_fw, op_bw = attention_op
405+
dtype = list(op_fw.SUPPORTED_DTYPES)[0]
406+
q = torch.randn((1, 2, 40), device="cuda", dtype=dtype)
407+
_ = xformers.ops.memory_efficient_attention(q, q, q)
407408
except Exception as e:
408409
raise e
409410

0 commit comments

Comments
 (0)