Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 43 additions & 13 deletions src/weathergen/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,12 +216,14 @@ def run(self, cf, devices, run_id_contd=None, epoch_contd=None):

if cf.with_ddp and cf.with_fsdp:
fsdp_kwargs = {
"mp_policy": MixedPrecisionPolicy(
param_dtype=self.mixed_precision_dtype,
reduce_dtype=torch.float32,
)
if cf.with_mixed_precision
else None,
"mp_policy": (
MixedPrecisionPolicy(
param_dtype=self.mixed_precision_dtype,
reduce_dtype=torch.float32,
)
if cf.with_mixed_precision
else None
),
}
modules_to_shard = (
MLP,
Expand Down Expand Up @@ -252,12 +254,14 @@ def run(self, cf, devices, run_id_contd=None, epoch_contd=None):
fully_shard(module, **fsdp_kwargs)

full_precision_fsdp_kwargs = {
"mp_policy": MixedPrecisionPolicy(
param_dtype=torch.float32,
reduce_dtype=torch.float32,
)
if cf.with_mixed_precision
else None,
"mp_policy": (
MixedPrecisionPolicy(
param_dtype=torch.float32,
reduce_dtype=torch.float32,
)
if cf.with_mixed_precision
else None
),
}
for module in self.model.pred_adapter_kv.modules():
if isinstance(module, modules_to_shard):
Expand All @@ -274,7 +278,6 @@ def run(self, cf, devices, run_id_contd=None, epoch_contd=None):
for tensor in itertools.chain(self.model.parameters(), self.model.buffers()):
assert tensor.device == torch.device("meta")

# load model if specified
if run_id_contd is None:
self.model.to_empty(device="cuda")
self.model.reset_parameters()
Expand Down Expand Up @@ -695,6 +698,14 @@ def load_model(self, run_id: str, epoch=-1):
params = torch.load(
path_run / filename, map_location=torch.device("cpu"), mmap=True, weights_only=True
)

model_state_dict = self.model.state_dict()
params = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a comment explaining this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanna test this with a couple configs and then I will approve it

k: v
for k, v in params.items()
if k in model_state_dict and v.shape == model_state_dict[k].shape
}

is_model_sharded = self.cf.with_ddp and self.cf.with_fsdp
if is_model_sharded:
meta_sharded_sd = self.model.state_dict()
Expand All @@ -714,6 +725,25 @@ def load_model(self, run_id: str, epoch=-1):
# choose `assign=True` for sharded model since we cannot call `copy_` on meta tensor
mkeys, ukeys = self.model.load_state_dict(maybe_sharded_sd, strict=False, assign=True)

if mkeys:
# Get the unique parent modules for the missing parameters
new_modules_to_init = {key.rsplit(".", 1)[0] for key in mkeys}

# Find the highest-level "root" new modules to avoid redundant initializations
root_new_modules = set()
for path in sorted(list(new_modules_to_init)):
if not any(path.startswith(root + ".") for root in root_new_modules):
root_new_modules.add(path)

# Get all modules for quick lookup and initialize the new ones
all_modules = dict(self.model.named_modules())
for path in root_new_modules:
if is_root():
logger.info(f"Initializing new module not found in checkpoint: {path}")
module_to_init = all_modules[path]
module_to_init.to_empty(device="cuda")
module_to_init.reset_parameters()

if not is_model_sharded:
self.model = self.model.to(self.device)

Expand Down
Loading