Skip to content

Commit 6b673e6

Browse files
committed
[RWKV7] use fast fused_addcmul_rwkv7 op
1 parent cdc805e commit 6b673e6

File tree

4 files changed

+65
-28
lines changed

4 files changed

+65
-28
lines changed

fla/layers/rwkv7.py

+10-13
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from fla.modules import GroupNorm
1515
from fla.modules.l2norm import l2_norm
1616
from fla.ops.rwkv7 import chunk_rwkv7, fused_recurrent_rwkv7
17+
from fla.ops.rwkv7.fused_addcmul import fused_addcmul_rwkv7
1718

1819
if TYPE_CHECKING:
1920
from fla.models.utils import Cache
@@ -36,7 +37,6 @@ def __init__(
3637
layer_idx: int = None,
3738
fuse_norm: bool = False,
3839
value_dim: int = None,
39-
wkv_precision: str = 'bfloat16',
4040
**kwargs
4141
) -> RWKV7Attention:
4242
super().__init__()
@@ -65,8 +65,12 @@ def __init__(
6565
self.fuse_norm = fuse_norm
6666

6767
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
68-
69-
self.x_x = nn.Parameter(torch.zeros(6, hidden_size))
68+
self.x_r = nn.Parameter(torch.zeros(1, 1, hidden_size))
69+
self.x_w = nn.Parameter(torch.zeros(1, 1, hidden_size))
70+
self.x_k = nn.Parameter(torch.zeros(1, 1, hidden_size))
71+
self.x_v = nn.Parameter(torch.zeros(1, 1, hidden_size))
72+
self.x_a = nn.Parameter(torch.zeros(1, 1, hidden_size))
73+
self.x_g = nn.Parameter(torch.zeros(1, 1, hidden_size))
7074

7175
self.k_k = nn.Parameter(torch.zeros(self.key_dim))
7276
self.k_a = nn.Parameter(torch.zeros(self.key_dim))
@@ -99,15 +103,6 @@ def __init__(
99103
affine=elementwise_affine
100104
)
101105

102-
if wkv_precision == 'bfloat16':
103-
self.precision = torch.bfloat16
104-
elif wkv_precision == 'float16':
105-
self.precision = torch.float16
106-
elif wkv_precision == 'float32':
107-
self.precision = torch.float32
108-
else:
109-
raise ValueError(f"""Unsupported wkv_precision `{wkv_precision}`.
110-
Supported values are `bfloat16`, `float16`, and `float32`.""")
111106
self.apply(self._initialize_weights)
112107

113108
def _initialize_weights(self, module: nn.Module):
@@ -162,7 +157,9 @@ def forward(
162157

163158
# [batch_size, seq_len, hidden_size]
164159
delta = shifted - hidden_states
165-
xr, xw, xk, xv, xa, xg = hidden_states.addcmul(delta, self.x_x.view(6, 1, 1, -1)).unbind(0)
160+
161+
xr, xw, xk, xv, xa, xg = fused_addcmul_rwkv7(hidden_states, delta, self.x_r, self.x_w,
162+
self.x_k, self.x_v, self.x_a, self.x_g)
166163

167164
r = self.r_proj(xr)
168165
# w (-0.6065, 0)

fla/models/rwkv7/configuration_rwkv7.py

-2
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ def __init__(
3939
fuse_cross_entropy: bool = True,
4040
vocab_size: int = 32000,
4141
value_dim: Optional[Union[int, List[int]]] = None,
42-
wkv_precision: Optional[str] = "bfloat16",
4342
**kwargs
4443
):
4544
self.attn_mode = attn_mode
@@ -84,7 +83,6 @@ def __init__(
8483
self.fuse_norm = fuse_norm
8584
self.fuse_cross_entropy = fuse_cross_entropy
8685
self.vocab_size = vocab_size
87-
self.wkv_precision = wkv_precision
8886

8987
if attn is not None:
9088
if not isinstance(attn, Dict):

fla/models/rwkv7/modeling_rwkv7.py

+51-1
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,56 @@ def get_input_embeddings(self):
263263
def set_input_embeddings(self, value):
264264
self.embeddings = value
265265

266+
def load_state_dict(self, state_dict, strict=True, assign=False):
267+
"""
268+
Override the load_state_dict method to handle migration from version 1 to version 2.
269+
Handles hierarchical keys like 'model.layers.0.attn.x_x'.
270+
"""
271+
# Collect all layer indices from the state_dict keys
272+
layer_indices = set()
273+
for key in state_dict.keys():
274+
if key.startswith("model.layers."):
275+
# Extract the layer index from the key
276+
try:
277+
layer_idx = int(key.split(".")[2]) # Extract the number after 'model.layers.'
278+
layer_indices.add(layer_idx)
279+
except ValueError:
280+
# Skip keys that don't match the expected format
281+
continue
282+
283+
# Sort the layer indices to process them in order
284+
sorted_layer_indices = sorted(layer_indices)
285+
286+
# Migration logic for each layer
287+
for layer_idx in sorted_layer_indices:
288+
layer_prefix = f"model.layers.{layer_idx}"
289+
attn_prefix = f"{layer_prefix}.attn"
290+
291+
# Check if the layer contains the old 'x_x' parameter
292+
if f"{attn_prefix}.x_x" in state_dict:
293+
logger.info(f"Migrating weights for layer {layer_idx} from RWKV7Attention version 1 to version 2...")
294+
# Extract the x_x parameter
295+
x_x = state_dict[f"{attn_prefix}.x_x"]
296+
with torch.no_grad():
297+
# Create new parameters for version 2
298+
state_dict[f"{attn_prefix}.x_r"] = x_x[0].unsqueeze(0).unsqueeze(0)
299+
state_dict[f"{attn_prefix}.x_w"] = x_x[1].unsqueeze(0).unsqueeze(0)
300+
state_dict[f"{attn_prefix}.x_k"] = x_x[2].unsqueeze(0).unsqueeze(0)
301+
state_dict[f"{attn_prefix}.x_v"] = x_x[3].unsqueeze(0).unsqueeze(0)
302+
state_dict[f"{attn_prefix}.x_a"] = x_x[4].unsqueeze(0).unsqueeze(0)
303+
state_dict[f"{attn_prefix}.x_g"] = x_x[5].unsqueeze(0).unsqueeze(0)
304+
305+
# Call the parent method to load the modified state_dict
306+
try:
307+
super().load_state_dict(state_dict, strict=strict, assign=assign)
308+
except TypeError:
309+
# If the parent method does not support `assign`, fall back to strict loading
310+
logger.warning(
311+
"`assign` parameter is not supported by the parent `load_state_dict` method. "
312+
"Falling back to default behavior."
313+
)
314+
super().load_state_dict(state_dict, strict=strict)
315+
266316
def forward(
267317
self,
268318
input_ids: Optional[torch.LongTensor] = None,
@@ -349,7 +399,7 @@ def forward(
349399
)
350400

351401

352-
class RWKV7ForCausalLM(RWKV7PreTrainedModel, GenerationMixin):
402+
class RWKV7ForCausalLM(RWKV7Model, GenerationMixin):
353403

354404
_tied_weights_keys = ["lm_head.weight"]
355405

utils/convert_from_rwkv7.py

+4-12
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,7 @@ def translate_into_fla(name):
9090
'ln1': 'attn_norm',
9191
'ln2': 'ffn_norm'
9292
}[name_compo[2]]
93-
if name_compo[2] == 'attn' and re.match("x_[rwkvag]", name_compo[3]):
94-
name_compo[3] = 'x_x'
95-
elif re.match("[wvag][012]", name_compo[3]):
93+
if re.match("[wvag][012]", name_compo[3]):
9694
typ, num = name_compo[3]
9795
name_compo[3] = f'{typ}_lora.lora.' + {
9896
'0': '2.bias',
@@ -121,15 +119,9 @@ def translate_into_fla(name):
121119
if shape1 == [1, 1, config.hidden_size]:
122120
weight.squeeze_()
123121

124-
# fix: fusing x_[rwkvag] to x_x
125-
if fla_name.endswith('attn.x_x'):
126-
model_dict[fla_name].data['rwkvag'.find(name[-1])].copy_(weight)
127-
if fla_name in model_names:
128-
model_names.remove(fla_name)
129-
else:
130-
assert model_dict[fla_name].shape == weight.shape
131-
model_dict[fla_name].data.copy_(weight)
132-
model_names.remove(fla_name)
122+
assert model_dict[fla_name].shape == weight.shape
123+
model_dict[fla_name].data.copy_(weight)
124+
model_names.remove(fla_name)
133125

134126
print("uninitialized parameters: ", model_names)
135127
for n in model_names:

0 commit comments

Comments
 (0)