@@ -219,12 +219,14 @@ def run(self, cf, devices, run_id_contd=None, epoch_contd=None):
219219
220220 if cf .with_ddp and cf .with_fsdp :
221221 fsdp_kwargs = {
222- "mp_policy" : MixedPrecisionPolicy (
223- param_dtype = self .mixed_precision_dtype ,
224- reduce_dtype = torch .float32 ,
225- )
226- if cf .with_mixed_precision
227- else None ,
222+ "mp_policy" : (
223+ MixedPrecisionPolicy (
224+ param_dtype = self .mixed_precision_dtype ,
225+ reduce_dtype = torch .float32 ,
226+ )
227+ if cf .with_mixed_precision
228+ else None
229+ ),
228230 }
229231 modules_to_shard = (
230232 MLP ,
@@ -255,12 +257,14 @@ def run(self, cf, devices, run_id_contd=None, epoch_contd=None):
255257 fully_shard (module , ** fsdp_kwargs )
256258
257259 full_precision_fsdp_kwargs = {
258- "mp_policy" : MixedPrecisionPolicy (
259- param_dtype = torch .float32 ,
260- reduce_dtype = torch .float32 ,
261- )
262- if cf .with_mixed_precision
263- else None ,
260+ "mp_policy" : (
261+ MixedPrecisionPolicy (
262+ param_dtype = torch .float32 ,
263+ reduce_dtype = torch .float32 ,
264+ )
265+ if cf .with_mixed_precision
266+ else None
267+ ),
264268 }
265269 for module in self .model .pred_adapter_kv .modules ():
266270 if isinstance (module , modules_to_shard ):
@@ -277,7 +281,6 @@ def run(self, cf, devices, run_id_contd=None, epoch_contd=None):
277281 for tensor in itertools .chain (self .model .parameters (), self .model .buffers ()):
278282 assert tensor .device == torch .device ("meta" )
279283
280- # load model if specified
281284 if run_id_contd is None :
282285 self .model .to_empty (device = "cuda" )
283286 self .model .reset_parameters ()
@@ -701,6 +704,14 @@ def load_model(self, run_id: str, epoch=-1):
701704 params = torch .load (
702705 path_run / filename , map_location = torch .device ("cpu" ), mmap = True , weights_only = True
703706 )
707+
708+ model_state_dict = self .model .state_dict ()
709+ params = {
710+ k : v
711+ for k , v in params .items ()
712+ if k in model_state_dict and v .shape == model_state_dict [k ].shape
713+ }
714+
704715 is_model_sharded = self .cf .with_ddp and self .cf .with_fsdp
705716 if is_model_sharded :
706717 meta_sharded_sd = self .model .state_dict ()
@@ -720,6 +731,25 @@ def load_model(self, run_id: str, epoch=-1):
720731 # choose `assign=True` for sharded model since we cannot call `copy_` on meta tensor
721732 mkeys , ukeys = self .model .load_state_dict (maybe_sharded_sd , strict = False , assign = True )
722733
734+ if mkeys :
735+ # Get the unique parent modules for the missing parameters
736+ new_modules_to_init = {key .rsplit ("." , 1 )[0 ] for key in mkeys }
737+
738+ # Find the highest-level "root" new modules to avoid redundant initializations
739+ root_new_modules = set ()
740+ for path in sorted (list (new_modules_to_init )):
741+ if not any (path .startswith (root + "." ) for root in root_new_modules ):
742+ root_new_modules .add (path )
743+
744+ # Get all modules for quick lookup and initialize the new ones
745+ all_modules = dict (self .model .named_modules ())
746+ for path in root_new_modules :
747+ if is_root ():
748+ logger .info (f"Initializing new module not found in checkpoint: { path } " )
749+ module_to_init = all_modules [path ]
750+ module_to_init .to_empty (device = "cuda" )
751+ module_to_init .reset_parameters ()
752+
723753 if not is_model_sharded :
724754 self .model = self .model .to (self .device )
725755
0 commit comments