From 54f0f00f2fd297fee6a9e60ab2a9cfdc1801720b Mon Sep 17 00:00:00 2001 From: Wessel Bruinsma Date: Mon, 9 Sep 2024 17:55:33 +0200 Subject: [PATCH] Test decoder init --- tests/test_model.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/test_model.py b/tests/test_model.py index 4e6cdb9..c3df873 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -144,3 +144,13 @@ def assert_approx_equality(v_out: np.ndarray, v_ref: np.ndarray, tol: float) -> np.testing.assert_allclose(pred.metadata.lat, test_output["metadata"]["lat"]) assert pred.metadata.atmos_levels == tuple(test_output["metadata"]["atmos_levels"]) assert pred.metadata.time == tuple(test_output["metadata"]["time"]) + + +def test_aurora_small_decoder_init() -> None: + model = AuroraSmall(use_lora=True) + + # Check that the decoder heads are properly initialised. The biases should be zero, but the + # weights shouldn't. + for layer in [*model.decoder.surf_heads.values(), *model.decoder.atmos_heads.values()]: + assert not torch.all(layer.weight == 0) + assert torch.all(layer.bias == 0)