|
14 | 14 | from fla.modules import GroupNorm
|
15 | 15 | from fla.modules.l2norm import l2_norm
|
16 | 16 | from fla.modules.token_shift import token_shift
|
17 |
| -from fla.ops.rwkv7 import chunk_rwkv7, fused_recurrent_rwkv7 |
| 17 | +from fla.ops.rwkv7 import chunk_rwkv7, fused_mul_recurrent_rwkv7 |
18 | 18 | from fla.ops.rwkv7.fused_addcmul import fused_addcmul_rwkv7
|
| 19 | +from fla.ops.rwkv7.fused_k_update import fused_k_rwkv7 |
19 | 20 |
|
20 | 21 | if TYPE_CHECKING:
|
21 | 22 | from fla.models.utils import Cache
|
@@ -181,30 +182,26 @@ def forward(
|
181 | 182 |
|
182 | 183 | batch_size, seq_len, _ = hidden_states.shape
|
183 | 184 |
|
184 |
| - if self.training: |
185 |
| - # if training, use chunk mode no matter how short the sequence is |
186 |
| - mode = 'chunk' |
187 |
| - else: |
188 |
| - # launching the triton kernel for just one token will actually be slower |
189 |
| - mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode |
190 |
| - |
191 | 185 | last_state = None
|
192 | 186 | if past_key_values is not None and len(past_key_values) > self.layer_idx:
|
193 | 187 | last_state = past_key_values[self.layer_idx]
|
194 | 188 |
|
195 | 189 | if attention_mask is not None:
|
196 |
| - hidden_states = hidden_states.mul(attention_mask[:, -hidden_states.shape[-2]:, None]) |
| 190 | + hidden_states = hidden_states.mul(attention_mask[:, -seq_len:, None]) |
197 | 191 | cu_seqlens = kwargs.get('cu_seqlens', None)
|
198 |
| - # [batch_size, seq_len, hidden_size] |
199 |
| - if hidden_states.shape[1] == 1 and last_state is not None: |
| 192 | + # delta [batch_size, seq_len, hidden_size] |
| 193 | + if last_state is None: |
| 194 | + delta = token_shift(hidden_states, cu_seqlens) |
| 195 | + recurrent_state = None |
| 196 | + elif hidden_states.shape[1] == 1: |
200 | 197 | shifted = last_state['conv_state'].unsqueeze(1)
|
201 | 198 | delta = shifted - hidden_states
|
202 |
| - elif last_state is None: |
203 |
| - delta = token_shift(hidden_states, cu_seqlens) |
| 199 | + recurrent_state = last_state['recurrent_state'] |
204 | 200 | else:
|
205 | 201 | shifted = self.time_shift(hidden_states)
|
206 | 202 | shifted[:, 0] = last_state['conv_state']
|
207 | 203 | delta = shifted - hidden_states
|
| 204 | + recurrent_state = last_state['recurrent_state'] |
208 | 205 |
|
209 | 206 | xr, xw, xk, xv, xa, xg = fused_addcmul_rwkv7(hidden_states, delta, self.x_r, self.x_w,
|
210 | 207 | self.x_k, self.x_v, self.x_a, self.x_g)
|
@@ -242,29 +239,44 @@ def forward(
|
242 | 239 | # 2. Mathematically equivalent to k*(1 + (a-1)*self.k_a)
|
243 | 240 | # but with better precision preservation
|
244 | 241 | # 3. Particularly crucial for bf16 where intermediate values easily lose precision
|
245 |
| - k = k.addcmul(k * (a - 1), self.k_a) |
| 242 | + # 4. Pytorch method: k = k.addcmul(k * (a - 1), self.k_a) |
| 243 | + k = fused_k_rwkv7(k, a, self.k_a) |
246 | 244 |
|
247 | 245 | # dealing with left-padding
|
248 | 246 | if attention_mask is not None:
|
249 |
| - v = v * attention_mask[:, -v.shape[-2]:, None] |
| 247 | + v = v * attention_mask[:, -seq_len:, None] |
| 248 | + |
250 | 249 | r, w, k, a = map(lambda x: rearrange(x, 'b t (h d) -> b t h d', d=self.head_dim), (r, w, k, a))
|
251 | 250 | v = rearrange(v, 'b t (h d) -> b t h d', d=self.head_v_dim)
|
252 | 251 |
|
253 |
| - recurrent_state = last_state['recurrent_state'] if last_state is not None else None |
254 |
| - |
255 |
| - rwkv7_fn = chunk_rwkv7 if mode == 'chunk' else fused_recurrent_rwkv7 |
256 |
| - o, recurrent_state = rwkv7_fn( |
257 |
| - r=r, |
258 |
| - w=w, |
259 |
| - k=k, |
260 |
| - v=v, |
261 |
| - a=-kk, |
262 |
| - b=kk * a, |
263 |
| - scale=1., |
264 |
| - initial_state=recurrent_state, |
265 |
| - output_final_state=use_cache, |
266 |
| - cu_seqlens=cu_seqlens, |
267 |
| - ) |
| 252 | + if self.training or seq_len >= 64: |
| 253 | + # if training, use chunk mode no matter how short the sequence is |
| 254 | + # launching the triton kernel for just one token will actually be slower |
| 255 | + o, recurrent_state = chunk_rwkv7( |
| 256 | + r=r, |
| 257 | + w=w, |
| 258 | + k=k, |
| 259 | + v=v, |
| 260 | + a=-kk, |
| 261 | + b=kk * a, |
| 262 | + scale=1., |
| 263 | + initial_state=recurrent_state, |
| 264 | + output_final_state=use_cache, |
| 265 | + cu_seqlens=cu_seqlens, |
| 266 | + ) |
| 267 | + else: |
| 268 | + o, recurrent_state = fused_mul_recurrent_rwkv7( |
| 269 | + r=r, |
| 270 | + w=w, |
| 271 | + k=k, |
| 272 | + v=v, |
| 273 | + kk=kk, |
| 274 | + a=a, |
| 275 | + scale=1., |
| 276 | + initial_state=recurrent_state, |
| 277 | + output_final_state=use_cache, |
| 278 | + cu_seqlens=cu_seqlens, |
| 279 | + ) |
268 | 280 |
|
269 | 281 | if past_key_values is not None:
|
270 | 282 | past_key_values.update(
|
|
0 commit comments