Skip to content

Commit 0dc5cb3

Browse files
authored
Add gelu and mul grad (#501)
1 parent a2ef017 commit 0dc5cb3

File tree

2 files changed

+89
-8
lines changed

2 files changed

+89
-8
lines changed

src/flag_gems/fused/gelu_and_mul.py

Lines changed: 74 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,38 @@
1515
@triton.jit
1616
def gelu_none_and_mul_kernel(x, y):
1717
x_fp32 = x.to(tl.float32)
18-
x_gelu = 0.5 * x_fp32 * (1 + erf(x_fp32 * 0.7071067811))
18+
RCP_SQRT_2: tl.constexpr = 0.7071067811
19+
x_gelu = 0.5 * x_fp32 * (1 + erf(x_fp32 * RCP_SQRT_2))
1920
return x_gelu * y
2021

2122

23+
@pointwise_dynamic(
24+
promotion_methods=[(0, 1, 2, "DEFAULT"), (0, 1, 2, "DEFAULT")], num_outputs=2
25+
)
26+
@triton.jit
27+
def gelu_none_and_mul_grad_kernel(x, y, dgrad):
28+
RCP_SQRT_2: tl.constexpr = 0.7071067811
29+
COEFF: tl.constexpr = 0.7978845608028654
30+
31+
x_fp32 = x.to(tl.float32)
32+
x_gelu = 0.5 * x_fp32 * (1 + erf(x_fp32 * RCP_SQRT_2))
33+
34+
d_gelu = dgrad * y
35+
dx = (
36+
d_gelu
37+
* 0.5
38+
* (
39+
1.0
40+
+ erf(x_fp32 * RCP_SQRT_2)
41+
+ x_fp32 * COEFF * tl.exp(-0.5 * x_fp32 * x_fp32)
42+
)
43+
)
44+
45+
dy = dgrad * x_gelu
46+
47+
return dx, dy
48+
49+
2250
@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
2351
@triton.jit
2452
def gelu_tanh_and_mul_kernel(x, y):
@@ -34,17 +62,57 @@ def gelu_tanh_and_mul_kernel(x, y):
3462
return x_gelu * y
3563

3664

65+
@pointwise_dynamic(
66+
promotion_methods=[(0, 1, 2, "DEFAULT"), (0, 1, 2, "DEFAULT")], num_outputs=2
67+
)
68+
@triton.jit
69+
def gelu_tanh_and_mul_grad_kernel(x, y, dgrad):
70+
x_fp32 = x.to(tl.float32)
71+
y_fp32 = y.to(tl.float32)
72+
73+
sqrt_2_over_pi = 0.7978845608028654 # sqrt(2 / pi)
74+
a_cubed = x_fp32 * x_fp32 * x_fp32
75+
tanh_arg = sqrt_2_over_pi * (x_fp32 + 0.044715 * a_cubed)
76+
tanh_result = tanh(tanh_arg)
77+
geglu_a = 0.5 * x_fp32 * (1 + tanh_result)
78+
dy = geglu_a * dgrad
79+
80+
term1 = 0.5 * (1 + tanh_result)
81+
tanh_sq = tanh_result * tanh_result
82+
term2 = (
83+
0.5
84+
* x_fp32
85+
* (1 - tanh_sq)
86+
* (sqrt_2_over_pi * (1 + 3 * 0.044715 * x_fp32 * x_fp32))
87+
)
88+
dx = dgrad * y_fp32 * (term1 + term2)
89+
90+
return dx, dy
91+
92+
3793
class GeluAndMul(torch.autograd.Function):
3894
@staticmethod
39-
def forward(ctx, A, B, approximate="none"):
95+
def forward(ctx, x, y, approximate="none"):
4096
logging.debug("GEMS GELU AND MUL FORWARD")
97+
ctx.save_for_backward(x, y)
98+
ctx.approximate = approximate
4199
if approximate == "none":
42-
return gelu_none_and_mul_kernel(A, B)
100+
return gelu_none_and_mul_kernel(x, y)
43101
elif approximate == "tanh":
44-
return gelu_tanh_and_mul_kernel(A, B)
102+
return gelu_tanh_and_mul_kernel(x, y)
45103
else:
46104
raise ValueError(f"Invalid approximate value: {approximate}")
47105

106+
@staticmethod
107+
def backward(ctx, dgrad):
108+
logging.debug("GEMS GELU AND MUL BACKWARD")
109+
x, y = ctx.saved_tensors
110+
if ctx.approximate == "none":
111+
dx, dy = gelu_none_and_mul_grad_kernel(x, y, dgrad)
112+
else:
113+
dx, dy = gelu_tanh_and_mul_grad_kernel(x, y, dgrad)
114+
return dx, dy, None
115+
48116

49-
def gelu_and_mul(A, B, approximate="none"):
50-
return GeluAndMul.apply(A, B, approximate)
117+
def gelu_and_mul(x, y, approximate="none"):
118+
return GeluAndMul.apply(x, y, approximate)

tests/test_binary_pointwise_ops.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -948,8 +948,8 @@ def test_accuracy_ge_scalar(shape, dtype):
948948
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
949949
@pytest.mark.parametrize("approximate", ["none", "tanh"])
950950
def test_accuracy_gelu_and_mul(shape, approximate, dtype):
951-
inp1 = torch.randn(shape, dtype=dtype, device=flag_gems.device)
952-
inp2 = torch.randn(shape, dtype=dtype, device=flag_gems.device)
951+
inp1 = torch.randn(shape, dtype=dtype, device=flag_gems.device, requires_grad=True)
952+
inp2 = torch.randn(shape, dtype=dtype, device=flag_gems.device, requires_grad=True)
953953
ref_inp1 = to_reference(inp1, True)
954954
ref_inp2 = to_reference(inp2, True)
955955

@@ -959,7 +959,20 @@ def test_accuracy_gelu_and_mul(shape, approximate, dtype):
959959
with flag_gems.use_gems():
960960
res_out = flag_gems.gelu_and_mul(inp1, inp2, approximate)
961961

962+
out_grad = torch.randn_like(res_out)
963+
ref_grad = to_reference(out_grad, True)
964+
965+
(ref_inp1_grad, ref_inp2_grad) = torch.autograd.grad(
966+
ref_out, (ref_inp1, ref_inp2), ref_grad
967+
)
968+
969+
(res_inp1_grad, res_inp2_grad) = torch.autograd.grad(
970+
res_out, (inp1, inp2), out_grad
971+
)
972+
962973
gems_assert_close(res_out, ref_out, dtype)
974+
gems_assert_close(res_inp1_grad, ref_inp1_grad, dtype)
975+
gems_assert_close(res_inp2_grad, ref_inp2_grad, dtype)
963976

964977

965978
@pytest.mark.gt

0 commit comments

Comments
 (0)