From c22a2b85fac19838d0a4ab452c7b4b57cb83d43f Mon Sep 17 00:00:00 2001 From: DebadriDas <17debadridaskecs@gmail.com> Date: Sun, 22 Mar 2026 01:01:47 +0530 Subject: [PATCH 1/2] feat: Add batched tensor expansion utilities for O(S) ensemble unrolling optimization --- neural_lam/utils.py | 49 ++++++++++++++++++++++++++++++++++++ tests/test_ensemble_utils.py | 38 ++++++++++++++++++++++++++++ 2 files changed, 87 insertions(+) create mode 100644 tests/test_ensemble_utils.py diff --git a/neural_lam/utils.py b/neural_lam/utils.py index 742ef9823..3ebc000e2 100644 --- a/neural_lam/utils.py +++ b/neural_lam/utils.py @@ -620,3 +620,52 @@ def get_integer_time(tdelta) -> tuple[int, str]: return int(total_seconds / unit_in_seconds), unit return 1, "unknown" + +def expand_ensemble_batch(tensor: torch.Tensor, n_members: int) -> torch.Tensor: + """ + Allows for concurrent ensemble processing by expanding a deterministic batch tensor. + + The input expands (B,...) -> (B * n_members,...) if it lacks an ensemble dimension + (such as beginning states or circumstances shared by all members). + It flattens the input to (B * S,...) if it already has an ensemble + dimension (for example, perturbed LBCs with shape (B, S,...)). + + Args: + tensor: Tensor of shape (B, ...) or (B, S, ...) + n_members: The ensemble size (S) + + Returns: + Tensor of shape (B * n_members, ...) + """ + B = tensor.shape[0] + + # Check if tensor already has the ensemble dimension S by some heuristic + # or by strictly passing shapes. Assuming standard (B, S, ...) input for LBCs: + if tensor.dim() >= 2 and tensor.shape[1] == n_members: + # Already has ensemble dimension, flatten B and S + return tensor.view(B * n_members, *tensor.shape[2:]) + + # No ensemble dimension, repeat interleave so elements stay grouped by batch + # e.g., [b1, b2] -> [b1, b1, b2, b2] + return tensor.repeat_interleave(n_members, dim=0) + +def fold_ensemble_batch(tensor: torch.Tensor, n_members: int) -> torch.Tensor: + """ + Folds a batched ensemble tensor back into separated batch and ensemble dimensions. + + (B * n_members, ...) -> (B, n_members, ...) + + Args: + tensor: Tensor of shape (B * n_members, ...) + n_members: The ensemble size (S) + + Returns: + Tensor of shape (B, n_members, ...) + """ + B_times_S = tensor.shape[0] + + if B_times_S % n_members != 0: + raise ValueError(f"Batch dimension {B_times_S} is not divisible by ensemble size {n_members}") + + B = B_times_S // n_members + return tensor.view(B, n_members, *tensor.shape[1:]) \ No newline at end of file diff --git a/tests/test_ensemble_utils.py b/tests/test_ensemble_utils.py new file mode 100644 index 000000000..d754ab53c --- /dev/null +++ b/tests/test_ensemble_utils.py @@ -0,0 +1,38 @@ +import torch +from neural_lam.utils import expand_ensemble_batch, fold_ensemble_batch + +def test_batched_ensemble_expansion_and_folding(): + B = 2 # Batch size + S = 50 # Ensemble size + T, N, F = 3, 10, 4 # Time, Nodes, Features + + # 1. Test deterministic state expansion (B, T, N, F) -> (B*S, T, N, F) + init_state = torch.rand(B, T, N, F) + expanded_state = expand_ensemble_batch(init_state, n_members=S) + + assert expanded_state.shape == (B * S, T, N, F) + # Ensure grouping is correct: first S elements should all equal init_state[0] + assert torch.allclose(expanded_state[0], init_state[0]) + assert torch.allclose(expanded_state[S - 1], init_state[0]) + assert torch.allclose(expanded_state[S], init_state[1]) + + # 2. Test folding back (B*S, T, N, F) -> (B, S, T, N, F) + folded_state = fold_ensemble_batch(expanded_state, n_members=S) + + assert folded_state.shape == (B, S, T, N, F) + assert torch.allclose(folded_state[0, 0], init_state[0]) + assert torch.allclose(folded_state[1, 49], init_state[1]) + + # 3. Test with Probabilistic Lateral Boundary Conditions (B, S, T, N, F) + prob_lbc = torch.rand(B, S, T, N, F) + expanded_lbc = expand_ensemble_batch(prob_lbc, n_members=S) + + assert expanded_lbc.shape == (B * S, T, N, F) + # The first element of B*S should be the first member of the first batch + assert torch.allclose(expanded_lbc[0], prob_lbc[0, 0]) + # The (S)th element should be the first member of the second batch + assert torch.allclose(expanded_lbc[S], prob_lbc[1, 0]) +if __name__ == "__main__": + test_batched_ensemble_expansion_and_folding() + print("Test passed successfully!") + From b60a0e8e20820720345978b5a53abc7637694598 Mon Sep 17 00:00:00 2001 From: DebadriDas <17debadridaskecs@gmail.com> Date: Mon, 23 Mar 2026 10:55:48 +0530 Subject: [PATCH 2/2] fix(utils): add explicit has_ensemble_dim flag to prevent S==T ambiguity --- neural_lam/utils.py | 9 +++++---- tests/test_ensemble_utils.py | 13 ++++++++++++- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/neural_lam/utils.py b/neural_lam/utils.py index 3ebc000e2..ceb5259d1 100644 --- a/neural_lam/utils.py +++ b/neural_lam/utils.py @@ -621,7 +621,7 @@ def get_integer_time(tdelta) -> tuple[int, str]: return 1, "unknown" -def expand_ensemble_batch(tensor: torch.Tensor, n_members: int) -> torch.Tensor: +def expand_ensemble_batch(tensor: torch.Tensor, n_members: int, has_ensemble_dim: bool = False) -> torch.Tensor: """ Allows for concurrent ensemble processing by expanding a deterministic batch tensor. @@ -633,15 +633,16 @@ def expand_ensemble_batch(tensor: torch.Tensor, n_members: int) -> torch.Tensor: Args: tensor: Tensor of shape (B, ...) or (B, S, ...) n_members: The ensemble size (S) + has_ensemble_dim: Whether the tensor already contains the ensemble dimension as its second dimension. Returns: Tensor of shape (B * n_members, ...) """ B = tensor.shape[0] - # Check if tensor already has the ensemble dimension S by some heuristic - # or by strictly passing shapes. Assuming standard (B, S, ...) input for LBCs: - if tensor.dim() >= 2 and tensor.shape[1] == n_members: + if has_ensemble_dim: + if tensor.dim() < 2 or tensor.shape[1] != n_members: + raise ValueError(f"Expected ensemble dimension of size {n_members} at index 1, but got shape {tensor.shape}") # Already has ensemble dimension, flatten B and S return tensor.view(B * n_members, *tensor.shape[2:]) diff --git a/tests/test_ensemble_utils.py b/tests/test_ensemble_utils.py index d754ab53c..a442b1a3d 100644 --- a/tests/test_ensemble_utils.py +++ b/tests/test_ensemble_utils.py @@ -25,13 +25,24 @@ def test_batched_ensemble_expansion_and_folding(): # 3. Test with Probabilistic Lateral Boundary Conditions (B, S, T, N, F) prob_lbc = torch.rand(B, S, T, N, F) - expanded_lbc = expand_ensemble_batch(prob_lbc, n_members=S) + expanded_lbc = expand_ensemble_batch(prob_lbc, n_members=S, has_ensemble_dim=True) assert expanded_lbc.shape == (B * S, T, N, F) # The first element of B*S should be the first member of the first batch assert torch.allclose(expanded_lbc[0], prob_lbc[0, 0]) # The (S)th element should be the first member of the second batch assert torch.allclose(expanded_lbc[S], prob_lbc[1, 0]) + + # 4. Test ambiguous case where S == T (e.g., T=3 timesteps, S=3 members) + S_ambig = 3 + init_state_ambig = torch.rand(B, S_ambig, N, F) # Shape (B, T, N, F) where T == S + # If has_ensemble_dim=False, it should repeat B -> B*S rather than flattening B*T + expanded_ambig = expand_ensemble_batch(init_state_ambig, n_members=S_ambig, has_ensemble_dim=False) + assert expanded_ambig.shape == (B * S_ambig, S_ambig, N, F) + assert torch.allclose(expanded_ambig[0], init_state_ambig[0]) + assert torch.allclose(expanded_ambig[1], init_state_ambig[0]) + assert torch.allclose(expanded_ambig[S_ambig], init_state_ambig[1]) + if __name__ == "__main__": test_batched_ensemble_expansion_and_folding() print("Test passed successfully!")