diff --git a/tests/test_model.py b/tests/test_model.py index be006e6..1ed9554 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1,5 +1,7 @@ """Copyright (c) Microsoft Corporation. Licensed under the MIT license.""" +from datetime import timedelta + import numpy as np import pytest import torch @@ -120,3 +122,48 @@ def test_aurora_small_lat_lon_matrices( pred_matrix.atmos_vars[k], rtol=1e-5, ) + + +def test_aurora_flags(test_input_output: tuple[Batch, SavedBatch]) -> None: + batch, test_output = test_input_output + + flag_collections: list[dict] = [ + {}, + {"stabilise_level_agg": True}, + {"timestep": timedelta(hours=12)}, + ] + + preds = [] + for flags in flag_collections: + model = AuroraSmall(use_lora=True, **flags) + model.load_checkpoint( + "microsoft/aurora", + "aurora-0.25-small-pretrained.ckpt", + strict=False, # LoRA parameters not available. + ) + model = model.double() + model.eval() + with torch.inference_mode(): + preds.append(model.forward(batch).normalise(model.surf_stats)) + + # Check that all predictions are different. + for i, pred1 in enumerate(preds): + for pred2 in preds[i + 1 :]: + for k in pred1.surf_vars: + assert not np.allclose( + pred1.surf_vars[k], + pred2.surf_vars[k], + rtol=5e-2, + ) + for k in pred1.static_vars: + np.testing.assert_allclose( + pred1.static_vars[k], + pred2.static_vars[k], + rtol=1e-5, + ) + for k in pred1.atmos_vars: + assert not np.allclose( + pred1.atmos_vars[k], + pred2.atmos_vars[k], + rtol=5e-2, + )