diff --git a/neural_lam/utils.py b/neural_lam/utils.py index 742ef982..ceb5259d 100644 --- a/neural_lam/utils.py +++ b/neural_lam/utils.py @@ -620,3 +620,53 @@ def get_integer_time(tdelta) -> tuple[int, str]: return int(total_seconds / unit_in_seconds), unit return 1, "unknown" + +def expand_ensemble_batch(tensor: torch.Tensor, n_members: int, has_ensemble_dim: bool = False) -> torch.Tensor: + """ + Allows for concurrent ensemble processing by expanding a deterministic batch tensor. + + The input expands (B,...) -> (B * n_members,...) if it lacks an ensemble dimension + (such as beginning states or circumstances shared by all members). + It flattens the input to (B * S,...) if it already has an ensemble + dimension (for example, perturbed LBCs with shape (B, S,...)). + + Args: + tensor: Tensor of shape (B, ...) or (B, S, ...) + n_members: The ensemble size (S) + has_ensemble_dim: Whether the tensor already contains the ensemble dimension as its second dimension. + + Returns: + Tensor of shape (B * n_members, ...) + """ + B = tensor.shape[0] + + if has_ensemble_dim: + if tensor.dim() < 2 or tensor.shape[1] != n_members: + raise ValueError(f"Expected ensemble dimension of size {n_members} at index 1, but got shape {tensor.shape}") + # Already has ensemble dimension, flatten B and S + return tensor.view(B * n_members, *tensor.shape[2:]) + + # No ensemble dimension, repeat interleave so elements stay grouped by batch + # e.g., [b1, b2] -> [b1, b1, b2, b2] + return tensor.repeat_interleave(n_members, dim=0) + +def fold_ensemble_batch(tensor: torch.Tensor, n_members: int) -> torch.Tensor: + """ + Folds a batched ensemble tensor back into separated batch and ensemble dimensions. + + (B * n_members, ...) -> (B, n_members, ...) + + Args: + tensor: Tensor of shape (B * n_members, ...) + n_members: The ensemble size (S) + + Returns: + Tensor of shape (B, n_members, ...) + """ + B_times_S = tensor.shape[0] + + if B_times_S % n_members != 0: + raise ValueError(f"Batch dimension {B_times_S} is not divisible by ensemble size {n_members}") + + B = B_times_S // n_members + return tensor.view(B, n_members, *tensor.shape[1:]) \ No newline at end of file diff --git a/tests/test_ensemble_utils.py b/tests/test_ensemble_utils.py new file mode 100644 index 00000000..a442b1a3 --- /dev/null +++ b/tests/test_ensemble_utils.py @@ -0,0 +1,49 @@ +import torch +from neural_lam.utils import expand_ensemble_batch, fold_ensemble_batch + +def test_batched_ensemble_expansion_and_folding(): + B = 2 # Batch size + S = 50 # Ensemble size + T, N, F = 3, 10, 4 # Time, Nodes, Features + + # 1. Test deterministic state expansion (B, T, N, F) -> (B*S, T, N, F) + init_state = torch.rand(B, T, N, F) + expanded_state = expand_ensemble_batch(init_state, n_members=S) + + assert expanded_state.shape == (B * S, T, N, F) + # Ensure grouping is correct: first S elements should all equal init_state[0] + assert torch.allclose(expanded_state[0], init_state[0]) + assert torch.allclose(expanded_state[S - 1], init_state[0]) + assert torch.allclose(expanded_state[S], init_state[1]) + + # 2. Test folding back (B*S, T, N, F) -> (B, S, T, N, F) + folded_state = fold_ensemble_batch(expanded_state, n_members=S) + + assert folded_state.shape == (B, S, T, N, F) + assert torch.allclose(folded_state[0, 0], init_state[0]) + assert torch.allclose(folded_state[1, 49], init_state[1]) + + # 3. Test with Probabilistic Lateral Boundary Conditions (B, S, T, N, F) + prob_lbc = torch.rand(B, S, T, N, F) + expanded_lbc = expand_ensemble_batch(prob_lbc, n_members=S, has_ensemble_dim=True) + + assert expanded_lbc.shape == (B * S, T, N, F) + # The first element of B*S should be the first member of the first batch + assert torch.allclose(expanded_lbc[0], prob_lbc[0, 0]) + # The (S)th element should be the first member of the second batch + assert torch.allclose(expanded_lbc[S], prob_lbc[1, 0]) + + # 4. Test ambiguous case where S == T (e.g., T=3 timesteps, S=3 members) + S_ambig = 3 + init_state_ambig = torch.rand(B, S_ambig, N, F) # Shape (B, T, N, F) where T == S + # If has_ensemble_dim=False, it should repeat B -> B*S rather than flattening B*T + expanded_ambig = expand_ensemble_batch(init_state_ambig, n_members=S_ambig, has_ensemble_dim=False) + assert expanded_ambig.shape == (B * S_ambig, S_ambig, N, F) + assert torch.allclose(expanded_ambig[0], init_state_ambig[0]) + assert torch.allclose(expanded_ambig[1], init_state_ambig[0]) + assert torch.allclose(expanded_ambig[S_ambig], init_state_ambig[1]) + +if __name__ == "__main__": + test_batched_ensemble_expansion_and_folding() + print("Test passed successfully!") +