Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add gelu and mul grad #501

Merged
merged 2 commits into from
Mar 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 74 additions & 6 deletions src/flag_gems/fused/gelu_and_mul.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,38 @@
@triton.jit
def gelu_none_and_mul_kernel(x, y):
x_fp32 = x.to(tl.float32)
x_gelu = 0.5 * x_fp32 * (1 + erf(x_fp32 * 0.7071067811))
RCP_SQRT_2: tl.constexpr = 0.7071067811
x_gelu = 0.5 * x_fp32 * (1 + erf(x_fp32 * RCP_SQRT_2))
return x_gelu * y


@pointwise_dynamic(
promotion_methods=[(0, 1, 2, "DEFAULT"), (0, 1, 2, "DEFAULT")], num_outputs=2
)
@triton.jit
def gelu_none_and_mul_grad_kernel(x, y, dgrad):
RCP_SQRT_2: tl.constexpr = 0.7071067811
COEFF: tl.constexpr = 0.7978845608028654

x_fp32 = x.to(tl.float32)
x_gelu = 0.5 * x_fp32 * (1 + erf(x_fp32 * RCP_SQRT_2))

d_gelu = dgrad * y
dx = (
d_gelu
* 0.5
* (
1.0
+ erf(x_fp32 * RCP_SQRT_2)
+ x_fp32 * COEFF * tl.exp(-0.5 * x_fp32 * x_fp32)
)
)

dy = dgrad * x_gelu

return dx, dy


@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
@triton.jit
def gelu_tanh_and_mul_kernel(x, y):
Expand All @@ -34,17 +62,57 @@ def gelu_tanh_and_mul_kernel(x, y):
return x_gelu * y


@pointwise_dynamic(
promotion_methods=[(0, 1, 2, "DEFAULT"), (0, 1, 2, "DEFAULT")], num_outputs=2
)
@triton.jit
def gelu_tanh_and_mul_grad_kernel(x, y, dgrad):
x_fp32 = x.to(tl.float32)
y_fp32 = y.to(tl.float32)

sqrt_2_over_pi = 0.7978845608028654 # sqrt(2 / pi)
a_cubed = x_fp32 * x_fp32 * x_fp32
tanh_arg = sqrt_2_over_pi * (x_fp32 + 0.044715 * a_cubed)
tanh_result = tanh(tanh_arg)
geglu_a = 0.5 * x_fp32 * (1 + tanh_result)
dy = geglu_a * dgrad

term1 = 0.5 * (1 + tanh_result)
tanh_sq = tanh_result * tanh_result
term2 = (
0.5
* x_fp32
* (1 - tanh_sq)
* (sqrt_2_over_pi * (1 + 3 * 0.044715 * x_fp32 * x_fp32))
)
dx = dgrad * y_fp32 * (term1 + term2)

return dx, dy


class GeluAndMul(torch.autograd.Function):
@staticmethod
def forward(ctx, A, B, approximate="none"):
def forward(ctx, x, y, approximate="none"):
logging.debug("GEMS GELU AND MUL FORWARD")
ctx.save_for_backward(x, y)
ctx.approximate = approximate
if approximate == "none":
return gelu_none_and_mul_kernel(A, B)
return gelu_none_and_mul_kernel(x, y)
elif approximate == "tanh":
return gelu_tanh_and_mul_kernel(A, B)
return gelu_tanh_and_mul_kernel(x, y)
else:
raise ValueError(f"Invalid approximate value: {approximate}")

@staticmethod
def backward(ctx, dgrad):
logging.debug("GEMS GELU AND MUL BACKWARD")
x, y = ctx.saved_tensors
if ctx.approximate == "none":
dx, dy = gelu_none_and_mul_grad_kernel(x, y, dgrad)
else:
dx, dy = gelu_tanh_and_mul_grad_kernel(x, y, dgrad)
return dx, dy, None


def gelu_and_mul(A, B, approximate="none"):
return GeluAndMul.apply(A, B, approximate)
def gelu_and_mul(x, y, approximate="none"):
return GeluAndMul.apply(x, y, approximate)
17 changes: 15 additions & 2 deletions tests/test_binary_pointwise_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -948,8 +948,8 @@ def test_accuracy_ge_scalar(shape, dtype):
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
@pytest.mark.parametrize("approximate", ["none", "tanh"])
def test_accuracy_gelu_and_mul(shape, approximate, dtype):
inp1 = torch.randn(shape, dtype=dtype, device=flag_gems.device)
inp2 = torch.randn(shape, dtype=dtype, device=flag_gems.device)
inp1 = torch.randn(shape, dtype=dtype, device=flag_gems.device, requires_grad=True)
inp2 = torch.randn(shape, dtype=dtype, device=flag_gems.device, requires_grad=True)
ref_inp1 = to_reference(inp1, True)
ref_inp2 = to_reference(inp2, True)

Expand All @@ -959,7 +959,20 @@ def test_accuracy_gelu_and_mul(shape, approximate, dtype):
with flag_gems.use_gems():
res_out = flag_gems.gelu_and_mul(inp1, inp2, approximate)

out_grad = torch.randn_like(res_out)
ref_grad = to_reference(out_grad, True)

(ref_inp1_grad, ref_inp2_grad) = torch.autograd.grad(
ref_out, (ref_inp1, ref_inp2), ref_grad
)

(res_inp1_grad, res_inp2_grad) = torch.autograd.grad(
res_out, (inp1, inp2), out_grad
)

gems_assert_close(res_out, ref_out, dtype)
gems_assert_close(res_inp1_grad, ref_inp1_grad, dtype)
gems_assert_close(res_inp2_grad, ref_inp2_grad, dtype)


@pytest.mark.gt
Expand Down
Loading