Skip to content

Commit

Permalink
Fix pre-commit issues
Browse files Browse the repository at this point in the history
  • Loading branch information
scottcha committed Sep 22, 2024
1 parent 70c5d2c commit 872541b
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 45 deletions.
68 changes: 39 additions & 29 deletions aurora/model/aurora.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,18 +273,18 @@ def load_checkpoint_local(self, path: str, strict: bool = True) -> None:
if "encoder.surf_token_embeds.weight" in d:
weight = d["encoder.surf_token_embeds.weight"]
del d["encoder.surf_token_embeds.weight"]

assert weight.shape[1] == 4 + 3
for i, name in enumerate(("2t", "10u", "10v", "msl", "lsm", "z", "slt")):
d[f"encoder.surf_token_embeds.weights.{name}"] = weight[:, [i]]

if "encoder.atmos_token_embeds.weight" in d:
weight = d["encoder.atmos_token_embeds.weight"]
del d["encoder.atmos_token_embeds.weight"]

assert weight.shape[1] == 5
for i, name in enumerate(("z", "u", "v", "t", "q")):
d[f"encoder.atmos_token_embeds.weights.{name}"] = weight[:, [i]]
d[f"encoder.atmos_token_embeds.weights.{name}"] = weight[:, [i]]

if "decoder.surf_head.weight" in d:
weight = d["decoder.surf_head.weight"]
Expand Down Expand Up @@ -316,45 +316,55 @@ def load_checkpoint_local(self, path: str, strict: bool = True) -> None:
d[f"decoder.atmos_heads.{name}.weight"] = weight[:, i]
d[f"decoder.atmos_heads.{name}.bias"] = bias[:, i]

#check if history size is compatible and adjust weights if necessary
# check if history size is compatible and adjust weights if necessary
if self.max_history_size > d["encoder.surf_token_embeds.weights.2t"].shape[2]:
d = self.adapt_checkpoint_max_history_size(d)
elif self.max_history_size < d["encoder.surf_token_embeds.weights.2t"].shape[2]:
assert False, f"Cannot load checkpoint with max_history_size {d['encoder.surf_token_embeds.weights.2t'].shape[2]} \
into model with max_history_size {self.max_history_size}"

raise AssertionError(f"Cannot load checkpoint with max_history_size \
{d['encoder.surf_token_embeds.weights.2t'].shape[2]} \
into model with max_history_size {self.max_history_size}")

self.load_state_dict(d, strict=strict)

def adapt_checkpoint_max_history_size(self, checkpoint) -> Any:
"""Adapt a checkpoint with smaller max_history_size to a model with a larger max_history_size
than the current model.
If a checkpoint was trained with a larger max_history_size than the current model, this function
will assert fail to prevent loading the checkpoint. This is to prevent loading a checkpoint
which will likely cause the checkpoint to degrade is performance.
This implementation copies weights from the checkpoint to the model and fills 0 for the new history
width dimension.
"""Adapt a checkpoint with smaller max_history_size to a model with a larger
max_history_size than the current model.
If a checkpoint was trained with a larger max_history_size than the current model,
this function will assert fail to prevent loading the checkpoint. This is to
prevent loading a checkpoint which will likely cause the checkpoint to degrade is
performance.
This implementation copies weights from the checkpoint to the model and fills 0
for the new history width dimension.
"""
# Find all weights with prefix "encoder.surf_token_embeds.weights."
# Find all weights with prefix "encoder.surf_token_embeds.weights."
for name, weight in list(checkpoint.items()):
if name.startswith("encoder.surf_token_embeds.weights.") or name.startswith("encoder.atmos_token_embeds.weights."):
# This shouldn't get called with current logic but leaving here for future proofing and in cases where its called
# outside current context
assert weight.shape[2] <= self.max_history_size, f"Cannot load checkpoint with max_history_size {weight.shape[2]} \
into model with max_history_size {self.max_history_size} for weight {name}"

if name.startswith("encoder.surf_token_embeds.weights.") or name.startswith(
"encoder.atmos_token_embeds.weights."
):
# This shouldn't get called with current logic but leaving here for future proofing
# and in cases where its called outside current context
assert (
weight.shape[2] <= self.max_history_size
), f"Cannot load checkpoint with max_history_size {weight.shape[2]} \
into model with max_history_size {self.max_history_size} for weight {name}"

# Initialize the new weight tensor
new_weight = torch.zeros((weight.shape[0], 1, self.max_history_size, weight.shape[3], weight.shape[4]),
device=weight.device, dtype=weight.dtype)

# Copy the existing weights to the new tensor by duplicating the histories provided into any new history dimensions
new_weight = torch.zeros(
(weight.shape[0], 1, self.max_history_size, weight.shape[3], weight.shape[4]),
device=weight.device,
dtype=weight.dtype,
)

# Copy the existing weights to the new tensor by duplicating the histories provided
# into any new history dimensions
for j in range(weight.shape[2]):
# only fill existing weights, others are zeros
new_weight[:, :, j, :, :] = weight[:, :, j, :, :]
checkpoint[name] = new_weight
return checkpoint

def configure_activation_checkpointing(self):
"""Configure activation checkpointing.
Expand Down
42 changes: 26 additions & 16 deletions tests/test_checkpoint_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,52 +2,62 @@

import pytest
import torch
from unittest.mock import patch

from aurora.model.aurora import AuroraSmall


@pytest.fixture
def model(request):
return AuroraSmall(max_history_size=request.param)


@pytest.fixture
def checkpoint():
return {
"encoder.surf_token_embeds.weights.0": torch.rand((2, 1, 2, 4, 4)),
"encoder.atmos_token_embeds.weights.0": torch.rand((2, 1, 2, 4, 4))
"encoder.atmos_token_embeds.weights.0": torch.rand((2, 1, 2, 4, 4)),
}

#check both history sizes which are divisible by 2 (original shape) and not
@pytest.mark.parametrize('model', [4, 5], indirect=True)

# check both history sizes which are divisible by 2 (original shape) and not
@pytest.mark.parametrize("model", [4, 5], indirect=True)
def test_adapt_checkpoint_max_history(model, checkpoint):
# checkpoint starts with history dim, shape[2], as size 2
assert checkpoint["encoder.surf_token_embeds.weights.0"].shape[2] == 2
assert checkpoint["encoder.surf_token_embeds.weights.0"].shape[2] == 2
adapted_checkpoint = model.adapt_checkpoint_max_history_size(checkpoint)

for name, weight in adapted_checkpoint.items():
assert weight.shape[2] == model.max_history_size
for j in range(weight.shape[2]):
if j >= checkpoint[name].shape[2]:
assert torch.equal(weight[:, :, j, :, :], torch.zeros_like(weight[:, :, j, :, :]))
else:
assert torch.equal(weight[:, :, j, :, :], checkpoint[name][:, :, j % checkpoint[name].shape[2], :, :])
assert torch.equal(
weight[:, :, j, :, :],
checkpoint[name][:, :, j % checkpoint[name].shape[2], :, :],
)

#check that assert is thrown when trying to load a larger checkpoint to a smaller history size
@pytest.mark.parametrize('model', [1], indirect=True)

# check that assert is thrown when trying to load a larger checkpoint to a smaller history size
@pytest.mark.parametrize("model", [1], indirect=True)
def test_adapt_checkpoint_max_history_fail(model, checkpoint):
with pytest.raises(AssertionError):
model.adapt_checkpoint_max_history_size(checkpoint)

#test adapting the checkpoint twice to ensure that the second time should not change the weights
@pytest.mark.parametrize('model', [4], indirect=True)


# test adapting the checkpoint twice to ensure that the second time should not change the weights
@pytest.mark.parametrize("model", [4], indirect=True)
def test_adapt_checkpoint_max_history_twice(model, checkpoint):
adapted_checkpoint = model.adapt_checkpoint_max_history_size(checkpoint)
adapted_checkpoint = model.adapt_checkpoint_max_history_size(checkpoint)
adapted_checkpoint = model.adapt_checkpoint_max_history_size(adapted_checkpoint)

for name, weight in adapted_checkpoint.items():
assert weight.shape[2] == model.max_history_size
for j in range(weight.shape[2]):
if j >= checkpoint[name].shape[2]:
assert torch.equal(weight[:, :, j, :, :], torch.zeros_like(weight[:, :, j, :, :]))
else:
assert torch.equal(weight[:, :, j, :, :], checkpoint[name][:, :, j % checkpoint[name].shape[2], :, :])

assert torch.equal(
weight[:, :, j, :, :],
checkpoint[name][:, :, j % checkpoint[name].shape[2], :, :],
)

0 comments on commit 872541b

Please sign in to comment.