-
Notifications
You must be signed in to change notification settings - Fork 199
Description
The equations in the paper and the code don't match for the last equation.
The figure shows the last output equation as

But based on the current code. It looks like this is the execution
$o_t^{'} = RMSNORM(g_t) * \sigma(h_t)$
instead of
$o_t^{'} = RMSNORM(h_t) * \sigma(g_t)$
Seems like this is fixed in recent commit HGRN - flash-linear-attention repository
last_state = (recurrent_state,)
past_key_values.update(last_state, self.layer_idx, i.shape[2])
- o = self.g_norm(self.g_proj(hidden_states), rearrange(o, 'b h l d -> b l (h d)'))
+ o = self.g_norm(rearrange(o, 'b h l d -> b l (h d)'), self.g_proj(hidden_states))
o = self.o_proj(o)
return o, None, past_key_valuesExisting code path of current repository :
-
(g_norm is called with
$g_t$ and$h_t$ ) : g_norm(self.g_proj(hidden_states), rearrange(o, 'b h l d -> b l (h d)'))matmulfreellm/mmfreelm/layers/hgrn_bit.py
Line 139 in ec1c298
o = self.g_norm(self.g_proj(hidden_states), rearrange(o, 'b h l d -> b l (h d)')) -
(X is
$g_t$ and O is$h_t$ ) in FusedRMSNormSwishGate.forward(self, x, o, ...)matmulfreellm/mmfreelm/layers/hgrn_bit.py
Line 139 in ec1c298
o = self.g_norm(self.g_proj(hidden_states), rearrange(o, 'b h l d -> b l (h d)')) -
(sigmoid is called on O which is
$h_t$ instead of$g_t$ ) in _layer_norm_fwd_1pass_kernel y = y * o * tl.sigmoid(o)y = y * o * tl.sigmoid(o)
Are the results with the inverted equation or with the fixed equation ?