diff --git a/aurora/batch.py b/aurora/batch.py index eeda210..9f96356 100644 --- a/aurora/batch.py +++ b/aurora/batch.py @@ -42,15 +42,30 @@ class Metadata: rollout_step: int = 0 def __post_init__(self): - if not torch.all(self.lat[1:] - self.lat[:-1] < 0): - raise ValueError("Latitudes must be strictly decreasing.") if not (torch.all(self.lat <= 90) and torch.all(self.lat >= -90)): raise ValueError("Latitudes must be in the range [-90, 90].") - if not torch.all(self.lon[1:] - self.lon[:-1] > 0): - raise ValueError("Longitudes must be strictly increasing.") if not (torch.all(self.lon >= 0) and torch.all(self.lon < 360)): raise ValueError("Longitudes must be in the range [0, 360).") + # Validate vector-valued latitudes and longitudes: + if self.lat.dim() == self.lon.dim() == 1: + if not torch.all(self.lat[1:] - self.lat[:-1] < 0): + raise ValueError("Latitudes must be strictly decreasing.") + if not torch.all(self.lon[1:] - self.lon[:-1] > 0): + raise ValueError("Longitudes must be strictly increasing.") + + # Validate matrix-valued latitudes and longitudes: + elif self.lat.dim() == self.lon.dim() == 2: + if not torch.all(self.lat[1:, :] - self.lat[:-1, :]): + raise ValueError("Latitudes must be strictly decreasing along every column.") + if not torch.all(self.lon[:, 1:] - self.lon[:, :-1] > 0): + raise ValueError("Longitudes must be strictly increasing along every row.") + + else: + raise ValueError( + "The latitudes and longitudes must either both be vectors or both be matrices." + ) + @dataclasses.dataclass class Batch: diff --git a/docs/batch.md b/docs/batch.md index 22eadba..735edb5 100644 --- a/docs/batch.md +++ b/docs/batch.md @@ -101,9 +101,13 @@ The following atmospheric variables are allows: The latitudes must be _decreasing_. The latitudes can either include both endpoints, like `linspace(90, -90, 721)`, or not include the south pole, like `linspace(90, -90, 721)[:-1]`. + For curvilinear grids, this can also be a matrix, in which case the foregoing conditions + apply to every _column_. * `Metadata.lon` is the vector of longitudes. The longitudes must be _increasing_. The longitudes must be in the range `[0, 360)`, so they can include zero and cannot include 360. + For curvilinear grids, this can also be a matrix, in which case the foregoing conditions + apply to every _row_. * `Metadata.atmos_levels` is a `tuple` of the pressure levels of the atmospheric variables in hPa. Note that these levels must be in exactly correspond to the order of the atmospheric variables. Note also that `Metadata.atmos_levels` should be a `tuple`, not a `list`. diff --git a/tests/test_model.py b/tests/test_model.py index 7720d2d..be006e6 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1,19 +1,17 @@ """Copyright (c) Microsoft Corporation. Licensed under the MIT license.""" import numpy as np +import pytest import torch from tests.conftest import SavedBatch -from aurora import AuroraSmall, Batch +from aurora import Aurora, AuroraSmall, Batch -def test_aurora_small(test_input_output: tuple[Batch, SavedBatch]) -> None: - batch, test_output = test_input_output - +@pytest.fixture() +def aurora_small() -> Aurora: model = AuroraSmall(use_lora=True) - - # Load the checkpoint and run the model. model.load_checkpoint( "microsoft/aurora", "aurora-0.25-small-pretrained.ckpt", @@ -21,8 +19,14 @@ def test_aurora_small(test_input_output: tuple[Batch, SavedBatch]) -> None: ) model = model.double() model.eval() + return model + + +def test_aurora_small(aurora_small: Aurora, test_input_output: tuple[Batch, SavedBatch]) -> None: + batch, test_output = test_input_output + with torch.inference_mode(): - pred = model.forward(batch) + pred = aurora_small.forward(batch) def assert_approx_equality(v_out: np.ndarray, v_ref: np.ndarray, tol: float) -> None: err = np.abs(v_out - v_ref).mean() @@ -69,10 +73,50 @@ def assert_approx_equality(v_out: np.ndarray, v_ref: np.ndarray, tol: float) -> def test_aurora_small_decoder_init() -> None: - model = AuroraSmall(use_lora=True) + aurora_small = AuroraSmall(use_lora=True) # Check that the decoder heads are properly initialised. The biases should be zero, but the # weights shouldn't. - for layer in [*model.decoder.surf_heads.values(), *model.decoder.atmos_heads.values()]: + for layer in [ + *aurora_small.decoder.surf_heads.values(), + *aurora_small.decoder.atmos_heads.values(), + ]: assert not torch.all(layer.weight == 0) assert torch.all(layer.bias == 0) + + +def test_aurora_small_lat_lon_matrices( + aurora_small: Aurora, test_input_output: tuple[Batch, SavedBatch] +) -> None: + batch, test_output = test_input_output + + with torch.inference_mode(): + pred = aurora_small.forward(batch) + + # Modify the batch to have a latitude and longitude matrices. + n_lat = len(batch.metadata.lat) + n_lon = len(batch.metadata.lon) + batch.metadata.lat = batch.metadata.lat[:, None].expand(n_lat, n_lon) + batch.metadata.lon = batch.metadata.lon[None, :].expand(n_lat, n_lon) + + pred_matrix = aurora_small.forward(batch) + + # Check the outputs. + for k in pred.surf_vars: + np.testing.assert_allclose( + pred.surf_vars[k], + pred_matrix.surf_vars[k], + rtol=1e-5, + ) + for k in pred.static_vars: + np.testing.assert_allclose( + pred.static_vars[k], + pred_matrix.static_vars[k], + rtol=1e-5, + ) + for k in pred.atmos_vars: + np.testing.assert_allclose( + pred.atmos_vars[k], + pred_matrix.atmos_vars[k], + rtol=1e-5, + )