1515@triton .jit
1616def gelu_none_and_mul_kernel (x , y ):
1717 x_fp32 = x .to (tl .float32 )
18- x_gelu = 0.5 * x_fp32 * (1 + erf (x_fp32 * 0.7071067811 ))
18+ RCP_SQRT_2 : tl .constexpr = 0.7071067811
19+ x_gelu = 0.5 * x_fp32 * (1 + erf (x_fp32 * RCP_SQRT_2 ))
1920 return x_gelu * y
2021
2122
23+ @pointwise_dynamic (
24+ promotion_methods = [(0 , 1 , 2 , "DEFAULT" ), (0 , 1 , 2 , "DEFAULT" )], num_outputs = 2
25+ )
26+ @triton .jit
27+ def gelu_none_and_mul_grad_kernel (x , y , dgrad ):
28+ RCP_SQRT_2 : tl .constexpr = 0.7071067811
29+ COEFF : tl .constexpr = 0.7978845608028654
30+
31+ x_fp32 = x .to (tl .float32 )
32+ x_gelu = 0.5 * x_fp32 * (1 + erf (x_fp32 * RCP_SQRT_2 ))
33+
34+ d_gelu = dgrad * y
35+ dx = (
36+ d_gelu
37+ * 0.5
38+ * (
39+ 1.0
40+ + erf (x_fp32 * RCP_SQRT_2 )
41+ + x_fp32 * COEFF * tl .exp (- 0.5 * x_fp32 * x_fp32 )
42+ )
43+ )
44+
45+ dy = dgrad * x_gelu
46+
47+ return dx , dy
48+
49+
2250@pointwise_dynamic (promotion_methods = [(0 , 1 , "DEFAULT" )])
2351@triton .jit
2452def gelu_tanh_and_mul_kernel (x , y ):
@@ -34,17 +62,57 @@ def gelu_tanh_and_mul_kernel(x, y):
3462 return x_gelu * y
3563
3664
65+ @pointwise_dynamic (
66+ promotion_methods = [(0 , 1 , 2 , "DEFAULT" ), (0 , 1 , 2 , "DEFAULT" )], num_outputs = 2
67+ )
68+ @triton .jit
69+ def gelu_tanh_and_mul_grad_kernel (x , y , dgrad ):
70+ x_fp32 = x .to (tl .float32 )
71+ y_fp32 = y .to (tl .float32 )
72+
73+ sqrt_2_over_pi = 0.7978845608028654 # sqrt(2 / pi)
74+ a_cubed = x_fp32 * x_fp32 * x_fp32
75+ tanh_arg = sqrt_2_over_pi * (x_fp32 + 0.044715 * a_cubed )
76+ tanh_result = tanh (tanh_arg )
77+ geglu_a = 0.5 * x_fp32 * (1 + tanh_result )
78+ dy = geglu_a * dgrad
79+
80+ term1 = 0.5 * (1 + tanh_result )
81+ tanh_sq = tanh_result * tanh_result
82+ term2 = (
83+ 0.5
84+ * x_fp32
85+ * (1 - tanh_sq )
86+ * (sqrt_2_over_pi * (1 + 3 * 0.044715 * x_fp32 * x_fp32 ))
87+ )
88+ dx = dgrad * y_fp32 * (term1 + term2 )
89+
90+ return dx , dy
91+
92+
3793class GeluAndMul (torch .autograd .Function ):
3894 @staticmethod
39- def forward (ctx , A , B , approximate = "none" ):
95+ def forward (ctx , x , y , approximate = "none" ):
4096 logging .debug ("GEMS GELU AND MUL FORWARD" )
97+ ctx .save_for_backward (x , y )
98+ ctx .approximate = approximate
4199 if approximate == "none" :
42- return gelu_none_and_mul_kernel (A , B )
100+ return gelu_none_and_mul_kernel (x , y )
43101 elif approximate == "tanh" :
44- return gelu_tanh_and_mul_kernel (A , B )
102+ return gelu_tanh_and_mul_kernel (x , y )
45103 else :
46104 raise ValueError (f"Invalid approximate value: { approximate } " )
47105
106+ @staticmethod
107+ def backward (ctx , dgrad ):
108+ logging .debug ("GEMS GELU AND MUL BACKWARD" )
109+ x , y = ctx .saved_tensors
110+ if ctx .approximate == "none" :
111+ dx , dy = gelu_none_and_mul_grad_kernel (x , y , dgrad )
112+ else :
113+ dx , dy = gelu_tanh_and_mul_grad_kernel (x , y , dgrad )
114+ return dx , dy , None
115+
48116
49- def gelu_and_mul (A , B , approximate = "none" ):
50- return GeluAndMul .apply (A , B , approximate )
117+ def gelu_and_mul (x , y , approximate = "none" ):
118+ return GeluAndMul .apply (x , y , approximate )
0 commit comments