Skip to content

Commit

Permalink
Test that changing the new flags changes the output
Browse files Browse the repository at this point in the history
  • Loading branch information
wesselb committed Dec 12, 2024
1 parent 1529cb7 commit ecc7f7f
Showing 1 changed file with 47 additions and 0 deletions.
47 changes: 47 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Copyright (c) Microsoft Corporation. Licensed under the MIT license."""

from datetime import timedelta

import numpy as np
import pytest
import torch
Expand Down Expand Up @@ -120,3 +122,48 @@ def test_aurora_small_lat_lon_matrices(
pred_matrix.atmos_vars[k],
rtol=1e-5,
)


def test_aurora_flags(test_input_output: tuple[Batch, SavedBatch]) -> None:
batch, test_output = test_input_output

flag_collections: list[dict] = [
{},
{"stabilise_level_agg": True},
{"timestep": timedelta(hours=12)},
]

preds = []
for flags in flag_collections:
model = AuroraSmall(use_lora=True, **flags)
model.load_checkpoint(
"microsoft/aurora",
"aurora-0.25-small-pretrained.ckpt",
strict=False, # LoRA parameters not available.
)
model = model.double()
model.eval()
with torch.inference_mode():
preds.append(model.forward(batch).normalise(model.surf_stats))

# Check that all predictions are different.
for i, pred1 in enumerate(preds):
for pred2 in preds[i + 1 :]:
for k in pred1.surf_vars:
assert not np.allclose(
pred1.surf_vars[k],
pred2.surf_vars[k],
rtol=5e-2,
)
for k in pred1.static_vars:
np.testing.assert_allclose(
pred1.static_vars[k],
pred2.static_vars[k],
rtol=1e-5,
)
for k in pred1.atmos_vars:
assert not np.allclose(
pred1.atmos_vars[k],
pred2.atmos_vars[k],
rtol=5e-2,
)

0 comments on commit ecc7f7f

Please sign in to comment.