@@ -45,6 +45,7 @@ def liger_cross_entropy_kernel(
4545 BLOCK_SIZE : tl .constexpr ,
4646 HAS_WEIGHT : tl .constexpr ,
4747 HAS_SOFTCAPPING : tl .constexpr ,
48+ HAS_GRADIENTS : tl .constexpr ,
4849):
4950 """
5051 This kernel computes both cross entropy loss and the gradient of the input.
@@ -72,6 +73,7 @@ def liger_cross_entropy_kernel(
7273 BLOCK_SIZE (int): The block size for Triton operations.
7374 HAS_WEIGHT (bool): The boolean value to determine whether assigning weight to each of the classes.
7475 HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not.
76+ HAS_GRADIENTS (bool): The boolean value to determine whether calculating gradients in forward pass.
7577 """
7678
7779 # https://github.com/triton-lang/triton/issues/1058
@@ -155,58 +157,58 @@ def liger_cross_entropy_kernel(
155157 # For 'sum' reduction, no normalization is applied:
156158 # dx_y = softmax(x_y) - 1
157159 # dx_i = softmax(x_i), for i ≠ y
158-
159- for i in range (0 , n_cols , BLOCK_SIZE ):
160- X_offsets = i + tl .arange (0 , BLOCK_SIZE )
161- X_block = tl .load (
162- X_ptr + X_offsets ,
163- mask = X_offsets < n_cols ,
164- other = float ("-inf" ),
165- # Ensure float32 precision for softmax calculation
166- ).cast (tl .float32 )
167- if HAS_SOFTCAPPING :
168- intermediate = tanh (X_block / softcap )
169- X_block = softcap * intermediate
170-
171- if not HAS_WEIGHT :
172- # softmax(x_i)
173- X_block = tl .exp (X_block - m ) / d
174- # derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i)
175- X_block += 2 * lse_square_scale * lse * X_block
176- # smoothing term
177- X_block += - eps
178- # special handle dx_y
179- X_block = tl .where (X_offsets != y , X_block , X_block - (1 - label_smoothing ))
180- # reduction scale
181- if reduction == "mean" :
182- X_block = X_block / n_non_ignore
183- else :
184- weight_block = tl .load (weight_ptr + X_offsets , mask = X_offsets < n_cols )
185- softmax_X = tl .exp (X_block - m ) / d
186- # derivative of original_loss
187- dloss_ori = (1 - label_smoothing ) * softmax_X
188- # specially handle dx_y
189- dloss_ori = tl .where (X_offsets != y , dloss_ori , dloss_ori - (1 - label_smoothing ))
190- dloss_ori = dloss_ori * weight_y
191- # derivative of smooth_loss
192- dloss_smooth = eps * (- weight_block + softmax_X * weight_sum )
193- # derivative of z-loss
194- dz_loss = 2 * lse_square_scale * lse * softmax_X
195- # reduction scale
196- if reduction == "mean" :
197- dloss_ori = dloss_ori / sum_non_ignore_weight
198- dloss_smooth = dloss_smooth / sum_non_ignore_weight
199- # TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight.
200- dz_loss = dz_loss / n_non_ignore
201- # derivative of total_loss
202- X_block = dloss_ori + dloss_smooth + dz_loss
203-
204- # chain rule softcapping
205- # d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap))
206- if HAS_SOFTCAPPING :
207- X_block = X_block * (1 - intermediate * intermediate )
208-
209- tl .store (X_ptr + X_offsets , X_block , mask = X_offsets < n_cols )
160+ if HAS_GRADIENTS :
161+ for i in range (0 , n_cols , BLOCK_SIZE ):
162+ X_offsets = i + tl .arange (0 , BLOCK_SIZE )
163+ X_block = tl .load (
164+ X_ptr + X_offsets ,
165+ mask = X_offsets < n_cols ,
166+ other = float ("-inf" ),
167+ # Ensure float32 precision for softmax calculation
168+ ).cast (tl .float32 )
169+ if HAS_SOFTCAPPING :
170+ intermediate = tanh (X_block / softcap )
171+ X_block = softcap * intermediate
172+
173+ if not HAS_WEIGHT :
174+ # softmax(x_i)
175+ X_block = tl .exp (X_block - m ) / d
176+ # derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i)
177+ X_block += 2 * lse_square_scale * lse * X_block
178+ # smoothing term
179+ X_block += - eps
180+ # special handle dx_y
181+ X_block = tl .where (X_offsets != y , X_block , X_block - (1 - label_smoothing ))
182+ # reduction scale
183+ if reduction == "mean" :
184+ X_block = X_block / n_non_ignore
185+ else :
186+ weight_block = tl .load (weight_ptr + X_offsets , mask = X_offsets < n_cols )
187+ softmax_X = tl .exp (X_block - m ) / d
188+ # derivative of original_loss
189+ dloss_ori = (1 - label_smoothing ) * softmax_X
190+ # specially handle dx_y
191+ dloss_ori = tl .where (X_offsets != y , dloss_ori , dloss_ori - (1 - label_smoothing ))
192+ dloss_ori = dloss_ori * weight_y
193+ # derivative of smooth_loss
194+ dloss_smooth = eps * (- weight_block + softmax_X * weight_sum )
195+ # derivative of z-loss
196+ dz_loss = 2 * lse_square_scale * lse * softmax_X
197+ # reduction scale
198+ if reduction == "mean" :
199+ dloss_ori = dloss_ori / sum_non_ignore_weight
200+ dloss_smooth = dloss_smooth / sum_non_ignore_weight
201+ # TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight.
202+ dz_loss = dz_loss / n_non_ignore
203+ # derivative of total_loss
204+ X_block = dloss_ori + dloss_smooth + dz_loss
205+
206+ # chain rule softcapping
207+ # d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap))
208+ if HAS_SOFTCAPPING :
209+ X_block = X_block * (1 - intermediate * intermediate )
210+
211+ tl .store (X_ptr + X_offsets , X_block , mask = X_offsets < n_cols )
210212
211213 # We need tl.debug_barrier() to ensure the new result of X_ptr is written as mentioned in
212214 # https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/ops/cross_entropy.py#L34
@@ -332,6 +334,7 @@ def cross_entropy_forward(
332334 BLOCK_SIZE = BLOCK_SIZE ,
333335 HAS_WEIGHT = True if weight is not None else False ,
334336 HAS_SOFTCAPPING = True if softcap is not None else False ,
337+ HAS_GRADIENTS = _input .requires_grad ,
335338 # TODO: 32 seems to give the best performance
336339 # Performance is quite sensitive to num_warps
337340 num_warps = 32 if not is_hip () else 16 ,
0 commit comments