Skip to content

Commit aacf017

Browse files
zhiyuan1iyzhangcs
andcommitted
[RWKV7] Add more fused kernel to accerlate both training and decoding (#379)
* Add `fused_k_rwkv7` kernel implementation * Fuse WKV operations to accelerate decoding * Rewrite critical paths to reduce CPU overhead * Update test cases in test_rwkv7.py * Add backward pass documentation in README --------- Co-authored-by: Yu Zhang <[email protected]>
1 parent 4a7a9f4 commit aacf017

File tree

7 files changed

+1263
-43
lines changed

7 files changed

+1263
-43
lines changed

fla/layers/rwkv7.py

+42-30
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@
1414
from fla.modules import GroupNorm
1515
from fla.modules.l2norm import l2_norm
1616
from fla.modules.token_shift import token_shift
17-
from fla.ops.rwkv7 import chunk_rwkv7, fused_recurrent_rwkv7
17+
from fla.ops.rwkv7 import chunk_rwkv7, fused_mul_recurrent_rwkv7
1818
from fla.ops.rwkv7.fused_addcmul import fused_addcmul_rwkv7
19+
from fla.ops.rwkv7.fused_k_update import fused_k_rwkv7
1920

2021
if TYPE_CHECKING:
2122
from fla.models.utils import Cache
@@ -181,30 +182,26 @@ def forward(
181182

182183
batch_size, seq_len, _ = hidden_states.shape
183184

184-
if self.training:
185-
# if training, use chunk mode no matter how short the sequence is
186-
mode = 'chunk'
187-
else:
188-
# launching the triton kernel for just one token will actually be slower
189-
mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode
190-
191185
last_state = None
192186
if past_key_values is not None and len(past_key_values) > self.layer_idx:
193187
last_state = past_key_values[self.layer_idx]
194188

195189
if attention_mask is not None:
196-
hidden_states = hidden_states.mul(attention_mask[:, -hidden_states.shape[-2]:, None])
190+
hidden_states = hidden_states.mul(attention_mask[:, -seq_len:, None])
197191
cu_seqlens = kwargs.get('cu_seqlens', None)
198-
# [batch_size, seq_len, hidden_size]
199-
if hidden_states.shape[1] == 1 and last_state is not None:
192+
# delta [batch_size, seq_len, hidden_size]
193+
if last_state is None:
194+
delta = token_shift(hidden_states, cu_seqlens)
195+
recurrent_state = None
196+
elif hidden_states.shape[1] == 1:
200197
shifted = last_state['conv_state'].unsqueeze(1)
201198
delta = shifted - hidden_states
202-
elif last_state is None:
203-
delta = token_shift(hidden_states, cu_seqlens)
199+
recurrent_state = last_state['recurrent_state']
204200
else:
205201
shifted = self.time_shift(hidden_states)
206202
shifted[:, 0] = last_state['conv_state']
207203
delta = shifted - hidden_states
204+
recurrent_state = last_state['recurrent_state']
208205

209206
xr, xw, xk, xv, xa, xg = fused_addcmul_rwkv7(hidden_states, delta, self.x_r, self.x_w,
210207
self.x_k, self.x_v, self.x_a, self.x_g)
@@ -242,29 +239,44 @@ def forward(
242239
# 2. Mathematically equivalent to k*(1 + (a-1)*self.k_a)
243240
# but with better precision preservation
244241
# 3. Particularly crucial for bf16 where intermediate values easily lose precision
245-
k = k.addcmul(k * (a - 1), self.k_a)
242+
# 4. Pytorch method: k = k.addcmul(k * (a - 1), self.k_a)
243+
k = fused_k_rwkv7(k, a, self.k_a)
246244

247245
# dealing with left-padding
248246
if attention_mask is not None:
249-
v = v * attention_mask[:, -v.shape[-2]:, None]
247+
v = v * attention_mask[:, -seq_len:, None]
248+
250249
r, w, k, a = map(lambda x: rearrange(x, 'b t (h d) -> b t h d', d=self.head_dim), (r, w, k, a))
251250
v = rearrange(v, 'b t (h d) -> b t h d', d=self.head_v_dim)
252251

253-
recurrent_state = last_state['recurrent_state'] if last_state is not None else None
254-
255-
rwkv7_fn = chunk_rwkv7 if mode == 'chunk' else fused_recurrent_rwkv7
256-
o, recurrent_state = rwkv7_fn(
257-
r=r,
258-
w=w,
259-
k=k,
260-
v=v,
261-
a=-kk,
262-
b=kk * a,
263-
scale=1.,
264-
initial_state=recurrent_state,
265-
output_final_state=use_cache,
266-
cu_seqlens=cu_seqlens,
267-
)
252+
if self.training or seq_len >= 64:
253+
# if training, use chunk mode no matter how short the sequence is
254+
# launching the triton kernel for just one token will actually be slower
255+
o, recurrent_state = chunk_rwkv7(
256+
r=r,
257+
w=w,
258+
k=k,
259+
v=v,
260+
a=-kk,
261+
b=kk * a,
262+
scale=1.,
263+
initial_state=recurrent_state,
264+
output_final_state=use_cache,
265+
cu_seqlens=cu_seqlens,
266+
)
267+
else:
268+
o, recurrent_state = fused_mul_recurrent_rwkv7(
269+
r=r,
270+
w=w,
271+
k=k,
272+
v=v,
273+
kk=kk,
274+
a=a,
275+
scale=1.,
276+
initial_state=recurrent_state,
277+
output_final_state=use_cache,
278+
cu_seqlens=cu_seqlens,
279+
)
268280

269281
if past_key_values is not None:
270282
past_key_values.update(

fla/modules/token_shift.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import triton
77
import triton.language as tl
88

9-
from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard
9+
from fla.utils import input_guard
1010

1111

1212
def token_shift_ref(
@@ -212,14 +212,12 @@ class TokenShift(torch.autograd.Function):
212212

213213
@staticmethod
214214
@input_guard
215-
@autocast_custom_fwd
216215
def forward(ctx, x: torch.Tensor, cu_seqlens: Optional[torch.Tensor] = None):
217216
ctx.cu_seqlens = cu_seqlens
218217
return token_shift_fwd(x, cu_seqlens)
219218

220219
@staticmethod
221220
@input_guard
222-
@autocast_custom_bwd
223221
def backward(ctx, dy: torch.Tensor):
224222
cu_seqlens = ctx.cu_seqlens
225223
dx = token_shift_bwd(dy, cu_seqlens)
@@ -232,11 +230,9 @@ def token_shift(
232230
):
233231
"""
234232
Implementation of token shift using Triton kernels
235-
236233
Args:
237234
x: Input tensor of shape [B, T, D]
238235
cu_seqlens: Cumulative sequence lengths (optional)
239-
240236
Returns:
241237
Tensor of same shape as input with token shift applied
242238
"""

0 commit comments

Comments
 (0)