55
66import math
77from dataclasses import dataclass , field
8- from typing import Callable , ClassVar , Literal , Optional , Tuple
8+ from typing import Callable , ClassVar , Literal , Optional , Tuple , Union
99
1010import torch
1111import torch .nn .functional as F
@@ -1534,27 +1534,9 @@ def __init__(self, model_args: DeepSeekV3ModelArgs):
15341534 self .model_args = model_args
15351535
15361536 def init_weights (self , buffer_device : torch .device | None = None ) -> None :
1537- buffer_device = buffer_device or self .freqs_cis .device # type: ignore[has-type]
1538- with torch .device (buffer_device ):
1539- self .freqs_cis = precompute_freqs_cis (self .model_args )
1540- if self .tok_embeddings is not None :
1541- nn .init .normal_ (self .tok_embeddings .weight )
1542- for layer in self .layers .values ():
1543- if layer is not None :
1544- assert isinstance (layer , TransformerBlock )
1545- layer .init_weights (buffer_device = buffer_device )
1546- if self .norm is not None :
1547- self .norm .reset_parameters ()
1548- final_out_std = self .model_args .dim ** - 0.5
1549- cutoff_factor = 3
1550- if self .output is not None :
1551- nn .init .trunc_normal_ (
1552- self .output .weight ,
1553- mean = 0.0 ,
1554- std = final_out_std ,
1555- a = - cutoff_factor * final_out_std ,
1556- b = cutoff_factor * final_out_std ,
1557- )
1537+ _init_weights_tok_embeddings (self )
1538+ _init_weights_layers (self , buffer_device )
1539+ _init_weights_norm_and_output (self )
15581540
15591541 def forward (
15601542 self ,
@@ -1593,12 +1575,13 @@ def forward(
15931575
15941576
15951577class DeepSeekV3StageI (nn .Module ):
1596- def __init__ (self , layers , config ):
1578+ def __init__ (self , layers , model_args ):
15971579 super ().__init__ ()
15981580 self .layers = layers
15991581 self .register_buffer (
1600- "freqs_cis" , precompute_freqs_cis (config ), persistent = False
1582+ "freqs_cis" , precompute_freqs_cis (model_args ), persistent = False
16011583 )
1584+ self .model_args = model_args
16021585
16031586 def forward (self , h ):
16041587 # intermediate stages only have layers
@@ -1607,14 +1590,12 @@ def forward(self, h):
16071590 return h
16081591
16091592 def init_weights (self , buffer_device : torch .device | None = None ) -> None :
1610- for layer in self .layers .values ():
1611- if layer is not None :
1612- layer .init_weights (buffer_device = buffer_device )
1593+ _init_weights_layers (self , buffer_device )
16131594
16141595
16151596class DeepSeekV3Stage0 (DeepSeekV3StageI ):
1616- def __init__ (self , embed , layers , config ):
1617- super ().__init__ (layers , config )
1597+ def __init__ (self , embed , layers , model_args ):
1598+ super ().__init__ (layers , model_args )
16181599 self .tok_embeddings = embed
16191600
16201601 def forward (self , tokens ):
@@ -1623,20 +1604,63 @@ def forward(self, tokens):
16231604 # torch.Size([1024, 1024, 2048])
16241605 return super ().forward (h )
16251606
1607+ def init_weights (self , buffer_device : torch .device | None = None ) -> None :
1608+ _init_weights_tok_embeddings (self )
1609+ super ().init_weights (buffer_device = buffer_device )
1610+
16261611
16271612class DeepSeekV3StageN (DeepSeekV3StageI ):
1628- def __init__ (self , layers , norm , output , config ):
1629- super ().__init__ (layers , config )
1613+ def __init__ (self , layers , norm , output , model_args ):
1614+ super ().__init__ (layers , model_args )
16301615 self .norm = norm
16311616 self .output = output
1617+ self .model_args = model_args
16321618
16331619 def forward (self , h ):
16341620 h = super ().forward (h )
16351621 h = self .norm (h ) if self .norm is not None else h
16361622 output = self .output (h ) if self .output is not None else h
16371623 return output
16381624
1625+ def init_weights (self , buffer_device : torch .device | None = None ) -> None :
1626+ super ().init_weights (buffer_device = buffer_device )
1627+ _init_weights_norm_and_output (self )
1628+
16391629
16401630######################
16411631# Pipeline stuff end #
16421632######################
1633+
1634+
1635+ def _init_weights_tok_embeddings (self : Union [DeepSeekV3Model , DeepSeekV3Stage0 ]):
1636+ if self .tok_embeddings is not None :
1637+ nn .init .normal_ (self .tok_embeddings .weight )
1638+
1639+
1640+ def _init_weights_layers (
1641+ self : Union [DeepSeekV3Model , DeepSeekV3StageI ],
1642+ buffer_device : torch .device | None ,
1643+ ):
1644+ if buffer_device is None :
1645+ buffer_device = self .freqs_cis .device # type: ignore[assignment]
1646+ with torch .device (buffer_device ): # type: ignore[arg-type]
1647+ self .freqs_cis = precompute_freqs_cis (self .model_args )
1648+ for layer in self .layers .values ():
1649+ if layer is not None :
1650+ assert isinstance (layer , TransformerBlock )
1651+ layer .init_weights (buffer_device = buffer_device ) # type: ignore[arg-type]
1652+
1653+
1654+ def _init_weights_norm_and_output (self : Union [DeepSeekV3Model , DeepSeekV3StageN ]):
1655+ if self .norm is not None :
1656+ self .norm .reset_parameters ()
1657+ if self .output is not None :
1658+ final_out_std = self .model_args .dim ** - 0.5
1659+ cutoff_factor = 3
1660+ nn .init .trunc_normal_ (
1661+ self .output .weight ,
1662+ mean = 0.0 ,
1663+ std = final_out_std ,
1664+ a = - cutoff_factor * final_out_std ,
1665+ b = cutoff_factor * final_out_std ,
1666+ )
0 commit comments