Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions neural_lam/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,3 +620,53 @@ def get_integer_time(tdelta) -> tuple[int, str]:
return int(total_seconds / unit_in_seconds), unit

return 1, "unknown"

def expand_ensemble_batch(tensor: torch.Tensor, n_members: int, has_ensemble_dim: bool = False) -> torch.Tensor:
"""
Allows for concurrent ensemble processing by expanding a deterministic batch tensor.

The input expands (B,...) -> (B * n_members,...) if it lacks an ensemble dimension
(such as beginning states or circumstances shared by all members).
It flattens the input to (B * S,...) if it already has an ensemble
dimension (for example, perturbed LBCs with shape (B, S,...)).

Args:
tensor: Tensor of shape (B, ...) or (B, S, ...)
n_members: The ensemble size (S)
has_ensemble_dim: Whether the tensor already contains the ensemble dimension as its second dimension.

Returns:
Tensor of shape (B * n_members, ...)
"""
B = tensor.shape[0]

if has_ensemble_dim:
if tensor.dim() < 2 or tensor.shape[1] != n_members:
raise ValueError(f"Expected ensemble dimension of size {n_members} at index 1, but got shape {tensor.shape}")
# Already has ensemble dimension, flatten B and S
return tensor.view(B * n_members, *tensor.shape[2:])

# No ensemble dimension, repeat interleave so elements stay grouped by batch
# e.g., [b1, b2] -> [b1, b1, b2, b2]
return tensor.repeat_interleave(n_members, dim=0)

def fold_ensemble_batch(tensor: torch.Tensor, n_members: int) -> torch.Tensor:
"""
Folds a batched ensemble tensor back into separated batch and ensemble dimensions.

(B * n_members, ...) -> (B, n_members, ...)

Args:
tensor: Tensor of shape (B * n_members, ...)
n_members: The ensemble size (S)

Returns:
Tensor of shape (B, n_members, ...)
"""
B_times_S = tensor.shape[0]

if B_times_S % n_members != 0:
raise ValueError(f"Batch dimension {B_times_S} is not divisible by ensemble size {n_members}")

B = B_times_S // n_members
return tensor.view(B, n_members, *tensor.shape[1:])
49 changes: 49 additions & 0 deletions tests/test_ensemble_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import torch
from neural_lam.utils import expand_ensemble_batch, fold_ensemble_batch

def test_batched_ensemble_expansion_and_folding():
B = 2 # Batch size
S = 50 # Ensemble size
T, N, F = 3, 10, 4 # Time, Nodes, Features

# 1. Test deterministic state expansion (B, T, N, F) -> (B*S, T, N, F)
init_state = torch.rand(B, T, N, F)
expanded_state = expand_ensemble_batch(init_state, n_members=S)

assert expanded_state.shape == (B * S, T, N, F)
# Ensure grouping is correct: first S elements should all equal init_state[0]
assert torch.allclose(expanded_state[0], init_state[0])
assert torch.allclose(expanded_state[S - 1], init_state[0])
assert torch.allclose(expanded_state[S], init_state[1])

# 2. Test folding back (B*S, T, N, F) -> (B, S, T, N, F)
folded_state = fold_ensemble_batch(expanded_state, n_members=S)

assert folded_state.shape == (B, S, T, N, F)
assert torch.allclose(folded_state[0, 0], init_state[0])
assert torch.allclose(folded_state[1, 49], init_state[1])

# 3. Test with Probabilistic Lateral Boundary Conditions (B, S, T, N, F)
prob_lbc = torch.rand(B, S, T, N, F)
expanded_lbc = expand_ensemble_batch(prob_lbc, n_members=S, has_ensemble_dim=True)

assert expanded_lbc.shape == (B * S, T, N, F)
# The first element of B*S should be the first member of the first batch
assert torch.allclose(expanded_lbc[0], prob_lbc[0, 0])
# The (S)th element should be the first member of the second batch
assert torch.allclose(expanded_lbc[S], prob_lbc[1, 0])

# 4. Test ambiguous case where S == T (e.g., T=3 timesteps, S=3 members)
S_ambig = 3
init_state_ambig = torch.rand(B, S_ambig, N, F) # Shape (B, T, N, F) where T == S
# If has_ensemble_dim=False, it should repeat B -> B*S rather than flattening B*T
expanded_ambig = expand_ensemble_batch(init_state_ambig, n_members=S_ambig, has_ensemble_dim=False)
assert expanded_ambig.shape == (B * S_ambig, S_ambig, N, F)
assert torch.allclose(expanded_ambig[0], init_state_ambig[0])
assert torch.allclose(expanded_ambig[1], init_state_ambig[0])
assert torch.allclose(expanded_ambig[S_ambig], init_state_ambig[1])

if __name__ == "__main__":
test_batched_ensemble_expansion_and_folding()
print("Test passed successfully!")