Skip to content

Commit aae0b8a

Browse files
authored
Add materialisation of new modules before loading checkpoint (#1030)
* Add materialisation of new modules before loading checkpoint * Initialize new modules in load_model * Fix adding new embedding networks
1 parent 908dec4 commit aae0b8a

File tree

1 file changed

+43
-13
lines changed

1 file changed

+43
-13
lines changed

src/weathergen/train/trainer.py

Lines changed: 43 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -219,12 +219,14 @@ def run(self, cf, devices, run_id_contd=None, epoch_contd=None):
219219

220220
if cf.with_ddp and cf.with_fsdp:
221221
fsdp_kwargs = {
222-
"mp_policy": MixedPrecisionPolicy(
223-
param_dtype=self.mixed_precision_dtype,
224-
reduce_dtype=torch.float32,
225-
)
226-
if cf.with_mixed_precision
227-
else None,
222+
"mp_policy": (
223+
MixedPrecisionPolicy(
224+
param_dtype=self.mixed_precision_dtype,
225+
reduce_dtype=torch.float32,
226+
)
227+
if cf.with_mixed_precision
228+
else None
229+
),
228230
}
229231
modules_to_shard = (
230232
MLP,
@@ -255,12 +257,14 @@ def run(self, cf, devices, run_id_contd=None, epoch_contd=None):
255257
fully_shard(module, **fsdp_kwargs)
256258

257259
full_precision_fsdp_kwargs = {
258-
"mp_policy": MixedPrecisionPolicy(
259-
param_dtype=torch.float32,
260-
reduce_dtype=torch.float32,
261-
)
262-
if cf.with_mixed_precision
263-
else None,
260+
"mp_policy": (
261+
MixedPrecisionPolicy(
262+
param_dtype=torch.float32,
263+
reduce_dtype=torch.float32,
264+
)
265+
if cf.with_mixed_precision
266+
else None
267+
),
264268
}
265269
for module in self.model.pred_adapter_kv.modules():
266270
if isinstance(module, modules_to_shard):
@@ -277,7 +281,6 @@ def run(self, cf, devices, run_id_contd=None, epoch_contd=None):
277281
for tensor in itertools.chain(self.model.parameters(), self.model.buffers()):
278282
assert tensor.device == torch.device("meta")
279283

280-
# load model if specified
281284
if run_id_contd is None:
282285
self.model.to_empty(device="cuda")
283286
self.model.reset_parameters()
@@ -701,6 +704,14 @@ def load_model(self, run_id: str, epoch=-1):
701704
params = torch.load(
702705
path_run / filename, map_location=torch.device("cpu"), mmap=True, weights_only=True
703706
)
707+
708+
model_state_dict = self.model.state_dict()
709+
params = {
710+
k: v
711+
for k, v in params.items()
712+
if k in model_state_dict and v.shape == model_state_dict[k].shape
713+
}
714+
704715
is_model_sharded = self.cf.with_ddp and self.cf.with_fsdp
705716
if is_model_sharded:
706717
meta_sharded_sd = self.model.state_dict()
@@ -720,6 +731,25 @@ def load_model(self, run_id: str, epoch=-1):
720731
# choose `assign=True` for sharded model since we cannot call `copy_` on meta tensor
721732
mkeys, ukeys = self.model.load_state_dict(maybe_sharded_sd, strict=False, assign=True)
722733

734+
if mkeys:
735+
# Get the unique parent modules for the missing parameters
736+
new_modules_to_init = {key.rsplit(".", 1)[0] for key in mkeys}
737+
738+
# Find the highest-level "root" new modules to avoid redundant initializations
739+
root_new_modules = set()
740+
for path in sorted(list(new_modules_to_init)):
741+
if not any(path.startswith(root + ".") for root in root_new_modules):
742+
root_new_modules.add(path)
743+
744+
# Get all modules for quick lookup and initialize the new ones
745+
all_modules = dict(self.model.named_modules())
746+
for path in root_new_modules:
747+
if is_root():
748+
logger.info(f"Initializing new module not found in checkpoint: {path}")
749+
module_to_init = all_modules[path]
750+
module_to_init.to_empty(device="cuda")
751+
module_to_init.reset_parameters()
752+
723753
if not is_model_sharded:
724754
self.model = self.model.to(self.device)
725755

0 commit comments

Comments
 (0)