diff --git a/src/flag_gems/fused/gelu_and_mul.py b/src/flag_gems/fused/gelu_and_mul.py index 52a50e057..7803520b9 100644 --- a/src/flag_gems/fused/gelu_and_mul.py +++ b/src/flag_gems/fused/gelu_and_mul.py @@ -15,10 +15,38 @@ @triton.jit def gelu_none_and_mul_kernel(x, y): x_fp32 = x.to(tl.float32) - x_gelu = 0.5 * x_fp32 * (1 + erf(x_fp32 * 0.7071067811)) + RCP_SQRT_2: tl.constexpr = 0.7071067811 + x_gelu = 0.5 * x_fp32 * (1 + erf(x_fp32 * RCP_SQRT_2)) return x_gelu * y +@pointwise_dynamic( + promotion_methods=[(0, 1, 2, "DEFAULT"), (0, 1, 2, "DEFAULT")], num_outputs=2 +) +@triton.jit +def gelu_none_and_mul_grad_kernel(x, y, dgrad): + RCP_SQRT_2: tl.constexpr = 0.7071067811 + COEFF: tl.constexpr = 0.7978845608028654 + + x_fp32 = x.to(tl.float32) + x_gelu = 0.5 * x_fp32 * (1 + erf(x_fp32 * RCP_SQRT_2)) + + d_gelu = dgrad * y + dx = ( + d_gelu + * 0.5 + * ( + 1.0 + + erf(x_fp32 * RCP_SQRT_2) + + x_fp32 * COEFF * tl.exp(-0.5 * x_fp32 * x_fp32) + ) + ) + + dy = dgrad * x_gelu + + return dx, dy + + @pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) @triton.jit def gelu_tanh_and_mul_kernel(x, y): @@ -34,17 +62,57 @@ def gelu_tanh_and_mul_kernel(x, y): return x_gelu * y +@pointwise_dynamic( + promotion_methods=[(0, 1, 2, "DEFAULT"), (0, 1, 2, "DEFAULT")], num_outputs=2 +) +@triton.jit +def gelu_tanh_and_mul_grad_kernel(x, y, dgrad): + x_fp32 = x.to(tl.float32) + y_fp32 = y.to(tl.float32) + + sqrt_2_over_pi = 0.7978845608028654 # sqrt(2 / pi) + a_cubed = x_fp32 * x_fp32 * x_fp32 + tanh_arg = sqrt_2_over_pi * (x_fp32 + 0.044715 * a_cubed) + tanh_result = tanh(tanh_arg) + geglu_a = 0.5 * x_fp32 * (1 + tanh_result) + dy = geglu_a * dgrad + + term1 = 0.5 * (1 + tanh_result) + tanh_sq = tanh_result * tanh_result + term2 = ( + 0.5 + * x_fp32 + * (1 - tanh_sq) + * (sqrt_2_over_pi * (1 + 3 * 0.044715 * x_fp32 * x_fp32)) + ) + dx = dgrad * y_fp32 * (term1 + term2) + + return dx, dy + + class GeluAndMul(torch.autograd.Function): @staticmethod - def forward(ctx, A, B, approximate="none"): + def forward(ctx, x, y, approximate="none"): logging.debug("GEMS GELU AND MUL FORWARD") + ctx.save_for_backward(x, y) + ctx.approximate = approximate if approximate == "none": - return gelu_none_and_mul_kernel(A, B) + return gelu_none_and_mul_kernel(x, y) elif approximate == "tanh": - return gelu_tanh_and_mul_kernel(A, B) + return gelu_tanh_and_mul_kernel(x, y) else: raise ValueError(f"Invalid approximate value: {approximate}") + @staticmethod + def backward(ctx, dgrad): + logging.debug("GEMS GELU AND MUL BACKWARD") + x, y = ctx.saved_tensors + if ctx.approximate == "none": + dx, dy = gelu_none_and_mul_grad_kernel(x, y, dgrad) + else: + dx, dy = gelu_tanh_and_mul_grad_kernel(x, y, dgrad) + return dx, dy, None + -def gelu_and_mul(A, B, approximate="none"): - return GeluAndMul.apply(A, B, approximate) +def gelu_and_mul(x, y, approximate="none"): + return GeluAndMul.apply(x, y, approximate) diff --git a/tests/test_binary_pointwise_ops.py b/tests/test_binary_pointwise_ops.py index 563e75233..5fa8f2517 100755 --- a/tests/test_binary_pointwise_ops.py +++ b/tests/test_binary_pointwise_ops.py @@ -948,8 +948,8 @@ def test_accuracy_ge_scalar(shape, dtype): @pytest.mark.parametrize("dtype", FLOAT_DTYPES) @pytest.mark.parametrize("approximate", ["none", "tanh"]) def test_accuracy_gelu_and_mul(shape, approximate, dtype): - inp1 = torch.randn(shape, dtype=dtype, device=flag_gems.device) - inp2 = torch.randn(shape, dtype=dtype, device=flag_gems.device) + inp1 = torch.randn(shape, dtype=dtype, device=flag_gems.device, requires_grad=True) + inp2 = torch.randn(shape, dtype=dtype, device=flag_gems.device, requires_grad=True) ref_inp1 = to_reference(inp1, True) ref_inp2 = to_reference(inp2, True) @@ -959,7 +959,20 @@ def test_accuracy_gelu_and_mul(shape, approximate, dtype): with flag_gems.use_gems(): res_out = flag_gems.gelu_and_mul(inp1, inp2, approximate) + out_grad = torch.randn_like(res_out) + ref_grad = to_reference(out_grad, True) + + (ref_inp1_grad, ref_inp2_grad) = torch.autograd.grad( + ref_out, (ref_inp1, ref_inp2), ref_grad + ) + + (res_inp1_grad, res_inp2_grad) = torch.autograd.grad( + res_out, (inp1, inp2), out_grad + ) + gems_assert_close(res_out, ref_out, dtype) + gems_assert_close(res_inp1_grad, ref_inp1_grad, dtype) + gems_assert_close(res_inp2_grad, ref_inp2_grad, dtype) @pytest.mark.gt