diff --git a/src/flag_gems/fused/silu_and_mul.py b/src/flag_gems/fused/silu_and_mul.py index 70b1b3261..dceb8a809 100644 --- a/src/flag_gems/fused/silu_and_mul.py +++ b/src/flag_gems/fused/silu_and_mul.py @@ -15,12 +15,32 @@ def silu_and_mul_kernel(x, y): return x_silu * y +@pointwise_dynamic( + promotion_methods=[(0, 1, 2, "DEFAULT"), (0, 1, 2, "DEFAULT")], num_outputs=2 +) +@triton.jit +def silu_and_mul_grad_kernel(x, y, dgrad): + x_fp32 = x.to(tl.float32) + sig = 1 / (1 + tl.exp(-x_fp32)) + x_silu = x_fp32 * sig + d_x_silu = sig * (1 + x_fp32 * (1 - sig)) + dx = d_x_silu * dgrad * y + dy = dgrad * x_silu + return dx, dy + + class SiluAndMul(torch.autograd.Function): @staticmethod def forward(ctx, A, B): + ctx.save_for_backward(A, B) logging.debug("GEMS SILU AND MUL FORWARD") return silu_and_mul_kernel(A, B) + def backward(ctx, grad_output): + A, B = ctx.saved_tensors + grad_A, grad_B = silu_and_mul_grad_kernel(A, B, grad_output) + return grad_A, grad_B + def silu_and_mul(A, B): return SiluAndMul.apply(A, B) diff --git a/tests/test_binary_pointwise_ops.py b/tests/test_binary_pointwise_ops.py index 563e75233..37ba8acfa 100755 --- a/tests/test_binary_pointwise_ops.py +++ b/tests/test_binary_pointwise_ops.py @@ -1353,8 +1353,8 @@ def test_accuracy_rsub(shape, alpha, dtype): @pytest.mark.parametrize("shape", POINTWISE_SHAPES) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) def test_accuracy_silu_and_mul(shape, dtype): - inp1 = torch.randn(shape, dtype=dtype, device=flag_gems.device) - inp2 = torch.randn(shape, dtype=dtype, device=flag_gems.device) + inp1 = torch.randn(shape, dtype=dtype, device=flag_gems.device, requires_grad=True) + inp2 = torch.randn(shape, dtype=dtype, device=flag_gems.device, requires_grad=True) ref_inp1 = to_reference(inp1, True) ref_inp2 = to_reference(inp2, True) @@ -1362,7 +1362,20 @@ def test_accuracy_silu_and_mul(shape, dtype): with flag_gems.use_gems(): res_out = flag_gems.silu_and_mul(inp1, inp2) + out_grad = torch.randn_like(res_out) + ref_grad = to_reference(out_grad, True) + + (ref_inp1_grad, ref_inp2_grad) = torch.autograd.grad( + ref_out, (ref_inp1, ref_inp2), ref_grad + ) + + (res_inp1_grad, res_inp2_grad) = torch.autograd.grad( + res_out, (inp1, inp2), out_grad + ) + gems_assert_close(res_out, ref_out, dtype) + gems_assert_close(res_inp1_grad, ref_inp1_grad, dtype) + gems_assert_close(res_inp2_grad, ref_inp2_grad, dtype) @pytest.mark.sub