Skip to content

Commit 90cd287

Browse files
authored
Fix DSv3 model weight initialization for PP (#234)
stack-info: PR: #234, branch: xmfan/stack/15
1 parent 68c8a07 commit 90cd287

File tree

1 file changed

+55
-31
lines changed
  • autoparallel/_testing/models

1 file changed

+55
-31
lines changed

autoparallel/_testing/models/dsv3.py

Lines changed: 55 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import math
77
from dataclasses import dataclass, field
8-
from typing import Callable, ClassVar, Literal, Optional, Tuple
8+
from typing import Callable, ClassVar, Literal, Optional, Tuple, Union
99

1010
import torch
1111
import torch.nn.functional as F
@@ -1534,27 +1534,9 @@ def __init__(self, model_args: DeepSeekV3ModelArgs):
15341534
self.model_args = model_args
15351535

15361536
def init_weights(self, buffer_device: torch.device | None = None) -> None:
1537-
buffer_device = buffer_device or self.freqs_cis.device # type: ignore[has-type]
1538-
with torch.device(buffer_device):
1539-
self.freqs_cis = precompute_freqs_cis(self.model_args)
1540-
if self.tok_embeddings is not None:
1541-
nn.init.normal_(self.tok_embeddings.weight)
1542-
for layer in self.layers.values():
1543-
if layer is not None:
1544-
assert isinstance(layer, TransformerBlock)
1545-
layer.init_weights(buffer_device=buffer_device)
1546-
if self.norm is not None:
1547-
self.norm.reset_parameters()
1548-
final_out_std = self.model_args.dim**-0.5
1549-
cutoff_factor = 3
1550-
if self.output is not None:
1551-
nn.init.trunc_normal_(
1552-
self.output.weight,
1553-
mean=0.0,
1554-
std=final_out_std,
1555-
a=-cutoff_factor * final_out_std,
1556-
b=cutoff_factor * final_out_std,
1557-
)
1537+
_init_weights_tok_embeddings(self)
1538+
_init_weights_layers(self, buffer_device)
1539+
_init_weights_norm_and_output(self)
15581540

15591541
def forward(
15601542
self,
@@ -1593,12 +1575,13 @@ def forward(
15931575

15941576

15951577
class DeepSeekV3StageI(nn.Module):
1596-
def __init__(self, layers, config):
1578+
def __init__(self, layers, model_args):
15971579
super().__init__()
15981580
self.layers = layers
15991581
self.register_buffer(
1600-
"freqs_cis", precompute_freqs_cis(config), persistent=False
1582+
"freqs_cis", precompute_freqs_cis(model_args), persistent=False
16011583
)
1584+
self.model_args = model_args
16021585

16031586
def forward(self, h):
16041587
# intermediate stages only have layers
@@ -1607,14 +1590,12 @@ def forward(self, h):
16071590
return h
16081591

16091592
def init_weights(self, buffer_device: torch.device | None = None) -> None:
1610-
for layer in self.layers.values():
1611-
if layer is not None:
1612-
layer.init_weights(buffer_device=buffer_device)
1593+
_init_weights_layers(self, buffer_device)
16131594

16141595

16151596
class DeepSeekV3Stage0(DeepSeekV3StageI):
1616-
def __init__(self, embed, layers, config):
1617-
super().__init__(layers, config)
1597+
def __init__(self, embed, layers, model_args):
1598+
super().__init__(layers, model_args)
16181599
self.tok_embeddings = embed
16191600

16201601
def forward(self, tokens):
@@ -1623,20 +1604,63 @@ def forward(self, tokens):
16231604
# torch.Size([1024, 1024, 2048])
16241605
return super().forward(h)
16251606

1607+
def init_weights(self, buffer_device: torch.device | None = None) -> None:
1608+
_init_weights_tok_embeddings(self)
1609+
super().init_weights(buffer_device=buffer_device)
1610+
16261611

16271612
class DeepSeekV3StageN(DeepSeekV3StageI):
1628-
def __init__(self, layers, norm, output, config):
1629-
super().__init__(layers, config)
1613+
def __init__(self, layers, norm, output, model_args):
1614+
super().__init__(layers, model_args)
16301615
self.norm = norm
16311616
self.output = output
1617+
self.model_args = model_args
16321618

16331619
def forward(self, h):
16341620
h = super().forward(h)
16351621
h = self.norm(h) if self.norm is not None else h
16361622
output = self.output(h) if self.output is not None else h
16371623
return output
16381624

1625+
def init_weights(self, buffer_device: torch.device | None = None) -> None:
1626+
super().init_weights(buffer_device=buffer_device)
1627+
_init_weights_norm_and_output(self)
1628+
16391629

16401630
######################
16411631
# Pipeline stuff end #
16421632
######################
1633+
1634+
1635+
def _init_weights_tok_embeddings(self: Union[DeepSeekV3Model, DeepSeekV3Stage0]):
1636+
if self.tok_embeddings is not None:
1637+
nn.init.normal_(self.tok_embeddings.weight)
1638+
1639+
1640+
def _init_weights_layers(
1641+
self: Union[DeepSeekV3Model, DeepSeekV3StageI],
1642+
buffer_device: torch.device | None,
1643+
):
1644+
if buffer_device is None:
1645+
buffer_device = self.freqs_cis.device # type: ignore[assignment]
1646+
with torch.device(buffer_device): # type: ignore[arg-type]
1647+
self.freqs_cis = precompute_freqs_cis(self.model_args)
1648+
for layer in self.layers.values():
1649+
if layer is not None:
1650+
assert isinstance(layer, TransformerBlock)
1651+
layer.init_weights(buffer_device=buffer_device) # type: ignore[arg-type]
1652+
1653+
1654+
def _init_weights_norm_and_output(self: Union[DeepSeekV3Model, DeepSeekV3StageN]):
1655+
if self.norm is not None:
1656+
self.norm.reset_parameters()
1657+
if self.output is not None:
1658+
final_out_std = self.model_args.dim**-0.5
1659+
cutoff_factor = 3
1660+
nn.init.trunc_normal_(
1661+
self.output.weight,
1662+
mean=0.0,
1663+
std=final_out_std,
1664+
a=-cutoff_factor * final_out_std,
1665+
b=cutoff_factor * final_out_std,
1666+
)

0 commit comments

Comments
 (0)