From 872541b8c756cb9d6380893e441041f27d26d581 Mon Sep 17 00:00:00 2001 From: scottcha Date: Sun, 22 Sep 2024 08:22:42 -0600 Subject: [PATCH] Fix pre-commit issues --- aurora/model/aurora.py | 68 +++++++++++++++++------------ tests/test_checkpoint_adaptation.py | 42 +++++++++++------- 2 files changed, 65 insertions(+), 45 deletions(-) diff --git a/aurora/model/aurora.py b/aurora/model/aurora.py index efe5541..1fe4220 100644 --- a/aurora/model/aurora.py +++ b/aurora/model/aurora.py @@ -273,18 +273,18 @@ def load_checkpoint_local(self, path: str, strict: bool = True) -> None: if "encoder.surf_token_embeds.weight" in d: weight = d["encoder.surf_token_embeds.weight"] del d["encoder.surf_token_embeds.weight"] - + assert weight.shape[1] == 4 + 3 for i, name in enumerate(("2t", "10u", "10v", "msl", "lsm", "z", "slt")): d[f"encoder.surf_token_embeds.weights.{name}"] = weight[:, [i]] - + if "encoder.atmos_token_embeds.weight" in d: weight = d["encoder.atmos_token_embeds.weight"] del d["encoder.atmos_token_embeds.weight"] - + assert weight.shape[1] == 5 for i, name in enumerate(("z", "u", "v", "t", "q")): - d[f"encoder.atmos_token_embeds.weights.{name}"] = weight[:, [i]] + d[f"encoder.atmos_token_embeds.weights.{name}"] = weight[:, [i]] if "decoder.surf_head.weight" in d: weight = d["decoder.surf_head.weight"] @@ -316,45 +316,55 @@ def load_checkpoint_local(self, path: str, strict: bool = True) -> None: d[f"decoder.atmos_heads.{name}.weight"] = weight[:, i] d[f"decoder.atmos_heads.{name}.bias"] = bias[:, i] - #check if history size is compatible and adjust weights if necessary + # check if history size is compatible and adjust weights if necessary if self.max_history_size > d["encoder.surf_token_embeds.weights.2t"].shape[2]: d = self.adapt_checkpoint_max_history_size(d) elif self.max_history_size < d["encoder.surf_token_embeds.weights.2t"].shape[2]: - assert False, f"Cannot load checkpoint with max_history_size {d['encoder.surf_token_embeds.weights.2t'].shape[2]} \ - into model with max_history_size {self.max_history_size}" - + raise AssertionError(f"Cannot load checkpoint with max_history_size \ + {d['encoder.surf_token_embeds.weights.2t'].shape[2]} \ + into model with max_history_size {self.max_history_size}") + self.load_state_dict(d, strict=strict) def adapt_checkpoint_max_history_size(self, checkpoint) -> Any: - """Adapt a checkpoint with smaller max_history_size to a model with a larger max_history_size - than the current model. - - If a checkpoint was trained with a larger max_history_size than the current model, this function - will assert fail to prevent loading the checkpoint. This is to prevent loading a checkpoint - which will likely cause the checkpoint to degrade is performance. - - This implementation copies weights from the checkpoint to the model and fills 0 for the new history - width dimension. + """Adapt a checkpoint with smaller max_history_size to a model with a larger + max_history_size than the current model. + + If a checkpoint was trained with a larger max_history_size than the current model, + this function will assert fail to prevent loading the checkpoint. This is to + prevent loading a checkpoint which will likely cause the checkpoint to degrade is + performance. + + This implementation copies weights from the checkpoint to the model and fills 0 + for the new history width dimension. """ - # Find all weights with prefix "encoder.surf_token_embeds.weights." + # Find all weights with prefix "encoder.surf_token_embeds.weights." for name, weight in list(checkpoint.items()): - if name.startswith("encoder.surf_token_embeds.weights.") or name.startswith("encoder.atmos_token_embeds.weights."): - # This shouldn't get called with current logic but leaving here for future proofing and in cases where its called - # outside current context - assert weight.shape[2] <= self.max_history_size, f"Cannot load checkpoint with max_history_size {weight.shape[2]} \ - into model with max_history_size {self.max_history_size} for weight {name}" - + if name.startswith("encoder.surf_token_embeds.weights.") or name.startswith( + "encoder.atmos_token_embeds.weights." + ): + # This shouldn't get called with current logic but leaving here for future proofing + # and in cases where its called outside current context + assert ( + weight.shape[2] <= self.max_history_size + ), f"Cannot load checkpoint with max_history_size {weight.shape[2]} \ + into model with max_history_size {self.max_history_size} for weight {name}" + # Initialize the new weight tensor - new_weight = torch.zeros((weight.shape[0], 1, self.max_history_size, weight.shape[3], weight.shape[4]), - device=weight.device, dtype=weight.dtype) - - # Copy the existing weights to the new tensor by duplicating the histories provided into any new history dimensions + new_weight = torch.zeros( + (weight.shape[0], 1, self.max_history_size, weight.shape[3], weight.shape[4]), + device=weight.device, + dtype=weight.dtype, + ) + + # Copy the existing weights to the new tensor by duplicating the histories provided + # into any new history dimensions for j in range(weight.shape[2]): # only fill existing weights, others are zeros new_weight[:, :, j, :, :] = weight[:, :, j, :, :] checkpoint[name] = new_weight return checkpoint - + def configure_activation_checkpointing(self): """Configure activation checkpointing. diff --git a/tests/test_checkpoint_adaptation.py b/tests/test_checkpoint_adaptation.py index e43c48f..83793cc 100644 --- a/tests/test_checkpoint_adaptation.py +++ b/tests/test_checkpoint_adaptation.py @@ -2,52 +2,62 @@ import pytest import torch -from unittest.mock import patch + from aurora.model.aurora import AuroraSmall + @pytest.fixture def model(request): return AuroraSmall(max_history_size=request.param) + @pytest.fixture def checkpoint(): return { "encoder.surf_token_embeds.weights.0": torch.rand((2, 1, 2, 4, 4)), - "encoder.atmos_token_embeds.weights.0": torch.rand((2, 1, 2, 4, 4)) + "encoder.atmos_token_embeds.weights.0": torch.rand((2, 1, 2, 4, 4)), } -#check both history sizes which are divisible by 2 (original shape) and not -@pytest.mark.parametrize('model', [4, 5], indirect=True) + +# check both history sizes which are divisible by 2 (original shape) and not +@pytest.mark.parametrize("model", [4, 5], indirect=True) def test_adapt_checkpoint_max_history(model, checkpoint): # checkpoint starts with history dim, shape[2], as size 2 - assert checkpoint["encoder.surf_token_embeds.weights.0"].shape[2] == 2 + assert checkpoint["encoder.surf_token_embeds.weights.0"].shape[2] == 2 adapted_checkpoint = model.adapt_checkpoint_max_history_size(checkpoint) - + for name, weight in adapted_checkpoint.items(): assert weight.shape[2] == model.max_history_size for j in range(weight.shape[2]): if j >= checkpoint[name].shape[2]: assert torch.equal(weight[:, :, j, :, :], torch.zeros_like(weight[:, :, j, :, :])) else: - assert torch.equal(weight[:, :, j, :, :], checkpoint[name][:, :, j % checkpoint[name].shape[2], :, :]) + assert torch.equal( + weight[:, :, j, :, :], + checkpoint[name][:, :, j % checkpoint[name].shape[2], :, :], + ) -#check that assert is thrown when trying to load a larger checkpoint to a smaller history size -@pytest.mark.parametrize('model', [1], indirect=True) + +# check that assert is thrown when trying to load a larger checkpoint to a smaller history size +@pytest.mark.parametrize("model", [1], indirect=True) def test_adapt_checkpoint_max_history_fail(model, checkpoint): with pytest.raises(AssertionError): model.adapt_checkpoint_max_history_size(checkpoint) - -#test adapting the checkpoint twice to ensure that the second time should not change the weights -@pytest.mark.parametrize('model', [4], indirect=True) + + +# test adapting the checkpoint twice to ensure that the second time should not change the weights +@pytest.mark.parametrize("model", [4], indirect=True) def test_adapt_checkpoint_max_history_twice(model, checkpoint): - adapted_checkpoint = model.adapt_checkpoint_max_history_size(checkpoint) + adapted_checkpoint = model.adapt_checkpoint_max_history_size(checkpoint) adapted_checkpoint = model.adapt_checkpoint_max_history_size(adapted_checkpoint) - + for name, weight in adapted_checkpoint.items(): assert weight.shape[2] == model.max_history_size for j in range(weight.shape[2]): if j >= checkpoint[name].shape[2]: assert torch.equal(weight[:, :, j, :, :], torch.zeros_like(weight[:, :, j, :, :])) else: - assert torch.equal(weight[:, :, j, :, :], checkpoint[name][:, :, j % checkpoint[name].shape[2], :, :]) - + assert torch.equal( + weight[:, :, j, :, :], + checkpoint[name][:, :, j % checkpoint[name].shape[2], :, :], + )