@@ -93,73 +93,64 @@ def log_softmax_backward_kernel(
93
93
tl .store (in_grad_ptrs , in_grad , mask = mask )
94
94
95
95
96
- class LogSoftmax (torch .autograd .Function ):
97
- @staticmethod
98
- def forward (ctx , x , dim , dtype ):
99
- logging .debug ("GEMS LOG_SOFTMAX" )
100
-
101
- assert dim >= - x .ndim and dim < x .ndim , "Invalid dim"
102
- dim = dim % x .ndim
103
- M = 1
104
- N = x .shape [dim ]
105
- for i in range (dim ):
106
- M *= x .shape [i ]
107
- inp = x .contiguous ()
108
- if dtype is None :
109
- dtype = x .dtype
110
- out = torch .empty_like (inp , dtype = dtype )
111
- K = inp .numel () // M // N
112
-
113
- grid = lambda meta : (
114
- triton .cdiv (M , meta ["BLOCK_M" ]),
96
+ def log_softmax (self , dim , half_to_float = False ):
97
+ logging .debug ("GEMS LOG_SOFTMAX" )
98
+
99
+ assert dim >= - self .ndim and dim < self .ndim , "Invalid dim"
100
+ dim = dim % self .ndim
101
+ M = 1
102
+ N = self .shape [dim ]
103
+ for i in range (dim ):
104
+ M *= self .shape [i ]
105
+ inp = self .contiguous ()
106
+ if half_to_float :
107
+ dtype = torch .float32
108
+ else :
109
+ dtype = self .dtype
110
+ out = torch .empty_like (inp , dtype = dtype )
111
+ K = inp .numel () // M // N
112
+
113
+ grid = lambda meta : (
114
+ triton .cdiv (M , meta ["BLOCK_M" ]),
115
+ K ,
116
+ )
117
+ with torch_device_fn .device (inp .device ):
118
+ log_softmax_kernel [grid ](
119
+ out ,
120
+ inp ,
121
+ M ,
122
+ N ,
115
123
K ,
124
+ num_warps = 8 ,
116
125
)
117
- with torch_device_fn .device (inp .device ):
118
- log_softmax_kernel [grid ](
119
- out ,
120
- inp ,
121
- M ,
122
- N ,
123
- K ,
124
- num_warps = 8 ,
125
- )
126
- ctx .save_for_backward (out )
127
- ctx .dim = dim
128
- return out
129
-
130
- @staticmethod
131
- def backward (ctx , out_grad ):
132
- logging .debug ("GEMS LOG_SOFTMAX VJP" )
133
-
134
- dim = ctx .dim
135
- (out ,) = ctx .saved_tensors
136
-
137
- assert dim >= - out .ndim and dim < out .ndim , "Invalid dim"
138
- dim = dim % out .ndim
139
- M = 1
140
- N = out .shape [dim ]
141
- for i in range (dim ):
142
- M *= out .shape [i ]
143
-
144
- out_grad = out_grad .contiguous ()
145
- in_grad = torch .empty_like (out )
146
- K = out .numel () // M // N
147
-
148
- grid = lambda meta : (
149
- triton .cdiv (M , meta ["BLOCK_M" ]),
126
+ return out
127
+
128
+
129
+ def log_softmax_backward (grad_output , output , dim , input_dtype ):
130
+ logging .debug ("GEMS LOG_SOFTMAX VJP" )
131
+
132
+ assert dim >= - output .ndim and dim < output .ndim , "Invalid dim"
133
+ dim = dim % output .ndim
134
+ M = 1
135
+ N = output .shape [dim ]
136
+ for i in range (dim ):
137
+ M *= output .shape [i ]
138
+
139
+ grad_output = grad_output .contiguous ()
140
+ in_grad = torch .empty_like (output , dtype = input_dtype )
141
+ K = output .numel () // M // N
142
+
143
+ grid = lambda meta : (
144
+ triton .cdiv (M , meta ["BLOCK_M" ]),
145
+ K ,
146
+ )
147
+ with torch_device_fn .device (in_grad .device ):
148
+ log_softmax_backward_kernel [grid ](
149
+ output ,
150
+ grad_output ,
151
+ in_grad ,
152
+ M ,
153
+ N ,
150
154
K ,
151
155
)
152
- with torch_device_fn .device (in_grad .device ):
153
- log_softmax_backward_kernel [grid ](
154
- out ,
155
- out_grad ,
156
- in_grad ,
157
- M ,
158
- N ,
159
- K ,
160
- )
161
- return in_grad , None , None
162
-
163
-
164
- def log_softmax (x , dim = - 1 , dtype = None ):
165
- return LogSoftmax .apply (x , dim , dtype )
156
+ return in_grad
0 commit comments