@@ -1527,6 +1527,54 @@ def f(q, k1, k2, v1, v2):
15271527 tolerance = Tolerances (atol = 2e-1 , rtol = 2e-1 )
15281528 torch .testing .assert_close (out , out2 , atol = tolerance .atol , rtol = tolerance .rtol )
15291529
1530+ @supported_platform
1531+ def test_multiple_mask_calls (self ):
1532+ if TEST_WITH_ROCM :
1533+ self .skipTest (
1534+ "ROCM BUG SEE: https://github.com/pytorch/pytorch/issues/140855"
1535+ )
1536+ # Create inputs
1537+ query = torch .randn (
1538+ (1 , 4 , 512 , 64 ), dtype = torch .float32 , device = "cuda" , requires_grad = True
1539+ )
1540+ key = torch .randn (
1541+ (1 , 4 , 512 , 64 ), dtype = torch .float32 , device = "cuda" , requires_grad = True
1542+ )
1543+ value = torch .randn (
1544+ (1 , 4 , 512 , 64 ), dtype = torch .float32 , device = "cuda" , requires_grad = True
1545+ )
1546+
1547+ window_size = 32
1548+
1549+ def causal_mask (b , h , q_idx , kv_idx ):
1550+ return q_idx >= kv_idx
1551+
1552+ def causal_mask_slidewindow_mod (b , h , q_idx , kv_idx ):
1553+ return (q_idx >= kv_idx ) & (q_idx <= kv_idx + window_size )
1554+
1555+ mask1 = create_block_mask (causal_mask , 1 , None , 512 , 512 , _compile = False )
1556+ mask2 = create_block_mask (
1557+ causal_mask_slidewindow_mod , 1 , None , 512 , 512 , _compile = False
1558+ )
1559+
1560+ def f (q , k , v ):
1561+ out1 = flex_attention (q , k , v , block_mask = mask1 )
1562+ out2 = flex_attention (q , k , v , block_mask = mask2 )
1563+ return out1 + out2
1564+
1565+ f_compiled = torch .compile (f , fullgraph = True )
1566+
1567+ out = f (query , key , value )
1568+ out_compiled = f_compiled (query , key , value )
1569+
1570+ grads = torch .autograd .grad ((out ,), (query , key , value ), torch .ones_like (out ))
1571+ grads_compile = torch .autograd .grad (
1572+ (out_compiled ,), (query , key , value ), torch .ones_like (out_compiled )
1573+ )
1574+
1575+ for grad , grad_compiled in zip (grads , grads_compile ):
1576+ torch .testing .assert_close (grad , grad_compiled , atol = 3e-2 , rtol = 3e-2 )
1577+
15301578 @supported_platform
15311579 def test_multiple_score_mod_calls2 (self ):
15321580 query = torch .randn ((1 , 8 , 1024 , 64 ), dtype = torch .float32 , device = "cuda" )
@@ -3184,29 +3232,29 @@ def debug_compile_fx_inner(graph, example_inputs, *args, **kwargs):
31843232 class GraphModule(torch.nn.Module):
31853233 def forward(self, primals_1: "f64[2, 2, 128, 4]", primals_2: "f64[2, 2, 128, 4]", primals_3: "f64[2, 2, 128, 4]", full: "i32[1, 1, 1]", full_default: "i32[1, 1, 1, 1]", convert_element_type: "i32[1, 1, 1]", convert_element_type_1: "i32[1, 1, 1, 1]", getitem_2: "f64[2, 2, 128, 4]", getitem_3: "f32[2, 2, 128]", tangents_1: "f64[2, 2, 128, 4]"):
31863234 full_default_4: "f32[2, 2, 128]" = torch.ops.aten.full.default([2, 2, 128], 0, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
3187- fw_graph = self.fw_graph
3188- joint_graph = self.joint_graph
3189- mask_graph = self.mask_graph
3190- flex_attention_backward = torch.ops.higher_order.flex_attention_backward(primals_1, primals_2, primals_3, getitem_2, getitem_3, tangents_1, full_default_4, fw_graph, joint_graph , (full, full_default, None, None, convert_element_type, convert_element_type_1, None, None, 1073741824, 1073741824, mask_graph ), 0.5, {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'OUTPUT_LOGSUMEXP': True}, (), ()); primals_1 = primals_2 = primals_3 = getitem_2 = getitem_3 = tangents_1 = full_default_4 = fw_graph = joint_graph = full = full_default = convert_element_type = convert_element_type_1 = mask_graph = None
3235+ fw_graph0 = self.fw_graph0
3236+ joint_graph0 = self.joint_graph0
3237+ mask_graph0 = self.mask_graph0
3238+ flex_attention_backward = torch.ops.higher_order.flex_attention_backward(primals_1, primals_2, primals_3, getitem_2, getitem_3, tangents_1, full_default_4, fw_graph0, joint_graph0 , (full, full_default, None, None, convert_element_type, convert_element_type_1, None, None, 1073741824, 1073741824, mask_graph0 ), 0.5, {'PRESCALE_QK': False, 'ROWS_GUARANTEED_SAFE': False, 'BLOCKS_ARE_CONTIGUOUS': False, 'OUTPUT_LOGSUMEXP': True}, (), ()); primals_1 = primals_2 = primals_3 = getitem_2 = getitem_3 = tangents_1 = full_default_4 = fw_graph0 = joint_graph0 = full = full_default = convert_element_type = convert_element_type_1 = mask_graph0 = None
31913239 getitem_4: "f64[2, 2, 128, 4]" = flex_attention_backward[0]
31923240 getitem_5: "f64[2, 2, 128, 4]" = flex_attention_backward[1]
31933241 getitem_6: "f64[2, 2, 128, 4]" = flex_attention_backward[2]; flex_attention_backward = None
31943242 return (getitem_4, getitem_5, getitem_6)
31953243
3196- class fw_graph (torch.nn.Module):
3244+ class fw_graph0 (torch.nn.Module):
31973245 def forward(self, arg0_1: "f64[]", arg1_1: "i32[]", arg2_1: "i32[]", arg3_1: "i32[]", arg4_1: "i32[]"):
31983246 mul: "f64[]" = torch.ops.aten.mul.Tensor(arg0_1, arg0_1); arg0_1 = None
31993247 return mul
32003248
3201- class joint_graph (torch.nn.Module):
3249+ class joint_graph0 (torch.nn.Module):
32023250 def forward(self, arg0_1: "f64[]", arg1_1: "i32[]", arg2_1: "i32[]", arg3_1: "i32[]", arg4_1: "i32[]", arg5_1: "f64[]"):
32033251 mul: "f64[]" = torch.ops.aten.mul.Tensor(arg0_1, arg0_1); mul = None
32043252 mul_1: "f64[]" = torch.ops.aten.mul.Tensor(arg5_1, arg0_1)
32053253 mul_2: "f64[]" = torch.ops.aten.mul.Tensor(arg5_1, arg0_1); arg5_1 = arg0_1 = None
32063254 add: "f64[]" = torch.ops.aten.add.Tensor(mul_2, mul_1); mul_2 = mul_1 = None
32073255 return [add, None, None, None, None]
32083256
3209- class mask_graph (torch.nn.Module):
3257+ class mask_graph0 (torch.nn.Module):
32103258 def forward(self, arg0_1: "i32[]", arg1_1: "i32[]", arg2_1: "i32[]", arg3_1: "i32[]"):
32113259 full: "b8[]" = torch.ops.aten.full.default([], True, dtype = torch.bool, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False)
32123260 return full
0 commit comments