From 36d34fdd5838afb13d589108d2c77c8846c61a31 Mon Sep 17 00:00:00 2001 From: Max Podkorytov <4273004+tenpercent@users.noreply.github.com> Date: Mon, 13 Jan 2025 22:34:24 +0000 Subject: [PATCH 1/2] fix enable memory efficient attention on ROCm while calling CK implementation --- src/diffusers/models/attention_processor.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 4d7ae6bef26e..2a3885bca173 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -399,11 +399,12 @@ def set_use_memory_efficient_attention_xformers( else: try: # Make sure we can run the memory efficient attention - _ = xformers.ops.memory_efficient_attention( - torch.randn((1, 2, 40), device="cuda"), - torch.randn((1, 2, 40), device="cuda"), - torch.randn((1, 2, 40), device="cuda"), - ) + dtype = None + if attention_op is not None: + op_fw, op_bw = attention_op + dtype = list(op_fw.SUPPORTED_DTYPES)[0] + q = torch.randn((1, 2, 40), device="cuda", dtype=dtype) + _ = xformers.ops.memory_efficient_attention(q, q, q) except Exception as e: raise e From 2d2411e9315beb5c0b6fcf285d80d2a8a56679a4 Mon Sep 17 00:00:00 2001 From: Max Podkorytov <4273004+tenpercent@users.noreply.github.com> Date: Wed, 15 Jan 2025 12:43:01 -0800 Subject: [PATCH 2/2] Update attention_processor.py refactor of picking a set element --- src/diffusers/models/attention_processor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 2a3885bca173..1bc7913722e2 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -402,7 +402,7 @@ def set_use_memory_efficient_attention_xformers( dtype = None if attention_op is not None: op_fw, op_bw = attention_op - dtype = list(op_fw.SUPPORTED_DTYPES)[0] + dtype, *_ = op_fw.SUPPORTED_DTYPES q = torch.randn((1, 2, 40), device="cuda", dtype=dtype) _ = xformers.ops.memory_efficient_attention(q, q, q) except Exception as e: