Skip to content

Commit 9b847f4

Browse files
drisspgpobin6
authored andcommitted
[FlexAttention] Fix multiple calls to flex bug (pytorch#140761)
# Summary Fixes long-standing bug we've had in the backward pass for flex attention. See pytorch#135161 for details Pull Request resolved: pytorch#140761 Approved by: https://github.com/Chillee, https://github.com/zou3519
1 parent 1bd70cf commit 9b847f4

File tree

2 files changed

+63
-10
lines changed

2 files changed

+63
-10
lines changed

test/inductor/test_flex_attention.py

Lines changed: 55 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1527,6 +1527,54 @@ def f(q, k1, k2, v1, v2):
15271527
tolerance = Tolerances(atol=2e-1, rtol=2e-1)
15281528
torch.testing.assert_close(out, out2, atol=tolerance.atol, rtol=tolerance.rtol)
15291529

1530+
@supported_platform
1531+
def test_multiple_mask_calls(self):
1532+
if TEST_WITH_ROCM:
1533+
self.skipTest(
1534+
"ROCM BUG SEE: https://github.com/pytorch/pytorch/issues/140855"
1535+
)
1536+
# Create inputs
1537+
query = torch.randn(
1538+
(1, 4, 512, 64), dtype=torch.float32, device="cuda", requires_grad=True
1539+
)
1540+
key = torch.randn(
1541+
(1, 4, 512, 64), dtype=torch.float32, device="cuda", requires_grad=True
1542+
)
1543+
value = torch.randn(
1544+
(1, 4, 512, 64), dtype=torch.float32, device="cuda", requires_grad=True
1545+
)
1546+
1547+
window_size = 32
1548+
1549+
def causal_mask(b, h, q_idx, kv_idx):
1550+
return q_idx >= kv_idx
1551+
1552+
def causal_mask_slidewindow_mod(b, h, q_idx, kv_idx):
1553+
return (q_idx >= kv_idx) & (q_idx <= kv_idx + window_size)
1554+
1555+
mask1 = create_block_mask(causal_mask, 1, None, 512, 512, _compile=False)
1556+
mask2 = create_block_mask(
1557+
causal_mask_slidewindow_mod, 1, None, 512, 512, _compile=False
1558+
)
1559+
1560+
def f(q, k, v):
1561+
out1 = flex_attention(q, k, v, block_mask=mask1)
1562+
out2 = flex_attention(q, k, v, block_mask=mask2)
1563+
return out1 + out2
1564+
1565+
f_compiled = torch.compile(f, fullgraph=True)
1566+
1567+
out = f(query, key, value)
1568+
out_compiled = f_compiled(query, key, value)
1569+
1570+
grads = torch.autograd.grad((out,), (query, key, value), torch.ones_like(out))
1571+
grads_compile = torch.autograd.grad(
1572+
(out_compiled,), (query, key, value), torch.ones_like(out_compiled)
1573+
)
1574+
1575+
for grad, grad_compiled in zip(grads, grads_compile):
1576+
torch.testing.assert_close(grad, grad_compiled, atol=3e-2, rtol=3e-2)
1577+
15301578
@supported_platform
15311579
def test_multiple_score_mod_calls2(self):
15321580
query = torch.randn((1, 8, 1024, 64), dtype=torch.float32, device="cuda")
@@ -3184,29 +3232,29 @@ def debug_compile_fx_inner(graph, example_inputs, *args, **kwargs):
31843232
class GraphModule(torch.nn.Module):
31853233
def forward(self, primals_1: "f64[2, 2, 128, 4]", primals_2: "f64[2, 2, 128, 4]", primals_3: "f64[2, 2, 128, 4]", full: "i32[1, 1, 1]", full_default: "i32[1, 1, 1, 1]", convert_element_type: "i32[1, 1, 1]", convert_element_type_1: "i32[1, 1, 1, 1]", getitem_2: "f64[2, 2, 128, 4]", getitem_3: "f32[2, 2, 128]", tangents_1: "f64[2, 2, 128, 4]"):
31863234
full_default_4: "f32[2, 2, 128]" = torch.ops.aten.full.default([2, 2, 128], 0, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
3187-
fw_graph = self.fw_graph
3188-
joint_graph = self.joint_graph
3189-
mask_graph = self.mask_graph
3190-
flex_attention_backward = torch.ops.higher_order.flex_attention_backward(primals_1, primals_2, primals_3, getitem_2, getitem_3, tangents_1, full_default_4, fw_graph, joint_graph, (full, full_default, None, None, convert_element_type, convert_element_type_1, None, None, 1073741824, 1073741824, mask_graph), 0.5, {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'OUTPUT_LOGSUMEXP': True}, (), ()); primals_1 = primals_2 = primals_3 = getitem_2 = getitem_3 = tangents_1 = full_default_4 = fw_graph = joint_graph = full = full_default = convert_element_type = convert_element_type_1 = mask_graph = None
3235+
fw_graph0 = self.fw_graph0
3236+
joint_graph0 = self.joint_graph0
3237+
mask_graph0 = self.mask_graph0
3238+
flex_attention_backward = torch.ops.higher_order.flex_attention_backward(primals_1, primals_2, primals_3, getitem_2, getitem_3, tangents_1, full_default_4, fw_graph0, joint_graph0, (full, full_default, None, None, convert_element_type, convert_element_type_1, None, None, 1073741824, 1073741824, mask_graph0), 0.5, {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'OUTPUT_LOGSUMEXP': True}, (), ()); primals_1 = primals_2 = primals_3 = getitem_2 = getitem_3 = tangents_1 = full_default_4 = fw_graph0 = joint_graph0 = full = full_default = convert_element_type = convert_element_type_1 = mask_graph0 = None
31913239
getitem_4: "f64[2, 2, 128, 4]" = flex_attention_backward[0]
31923240
getitem_5: "f64[2, 2, 128, 4]" = flex_attention_backward[1]
31933241
getitem_6: "f64[2, 2, 128, 4]" = flex_attention_backward[2]; flex_attention_backward = None
31943242
return (getitem_4, getitem_5, getitem_6)
31953243
3196-
class fw_graph(torch.nn.Module):
3244+
class fw_graph0(torch.nn.Module):
31973245
def forward(self, arg0_1: "f64[]", arg1_1: "i32[]", arg2_1: "i32[]", arg3_1: "i32[]", arg4_1: "i32[]"):
31983246
mul: "f64[]" = torch.ops.aten.mul.Tensor(arg0_1, arg0_1); arg0_1 = None
31993247
return mul
32003248
3201-
class joint_graph(torch.nn.Module):
3249+
class joint_graph0(torch.nn.Module):
32023250
def forward(self, arg0_1: "f64[]", arg1_1: "i32[]", arg2_1: "i32[]", arg3_1: "i32[]", arg4_1: "i32[]", arg5_1: "f64[]"):
32033251
mul: "f64[]" = torch.ops.aten.mul.Tensor(arg0_1, arg0_1); mul = None
32043252
mul_1: "f64[]" = torch.ops.aten.mul.Tensor(arg5_1, arg0_1)
32053253
mul_2: "f64[]" = torch.ops.aten.mul.Tensor(arg5_1, arg0_1); arg5_1 = arg0_1 = None
32063254
add: "f64[]" = torch.ops.aten.add.Tensor(mul_2, mul_1); mul_2 = mul_1 = None
32073255
return [add, None, None, None, None]
32083256
3209-
class mask_graph(torch.nn.Module):
3257+
class mask_graph0(torch.nn.Module):
32103258
def forward(self, arg0_1: "i32[]", arg1_1: "i32[]", arg2_1: "i32[]", arg3_1: "i32[]"):
32113259
full: "b8[]" = torch.ops.aten.full.default([], True, dtype = torch.bool, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
32123260
return full

torch/_higher_order_ops/flex_attention.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -946,9 +946,14 @@ def trace_flex_attention_backward(
946946
)
947947
assert isinstance(proxy_mode.tracer, torch.fx.Tracer)
948948
block_mask = block_mask[:-1] + (mask_graph,)
949-
proxy_mode.tracer.root.register_module("fw_graph", fw_graph) # type: ignore[arg-type]
950-
proxy_mode.tracer.root.register_module("joint_graph", joint_graph)
951-
proxy_mode.tracer.root.register_module("mask_graph", mask_graph)
949+
950+
qualname = proxy_mode.tracer.get_fresh_qualname("fw_graph")
951+
proxy_mode.tracer.root.register_module(qualname, fw_graph) # type: ignore[arg-type]
952+
qualname = proxy_mode.tracer.get_fresh_qualname("joint_graph")
953+
proxy_mode.tracer.root.register_module(qualname, joint_graph)
954+
qualname = proxy_mode.tracer.get_fresh_qualname("mask_graph")
955+
proxy_mode.tracer.root.register_module(qualname, mask_graph)
956+
952957
node_args = (
953958
query,
954959
key,

0 commit comments

Comments
 (0)