Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions neural_lam/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,3 +620,52 @@ def get_integer_time(tdelta) -> tuple[int, str]:
return int(total_seconds / unit_in_seconds), unit

return 1, "unknown"

def expand_ensemble_batch(tensor: torch.Tensor, n_members: int) -> torch.Tensor:
"""
Allows for concurrent ensemble processing by expanding a deterministic batch tensor.

The input expands (B,...) -> (B * n_members,...) if it lacks an ensemble dimension
(such as beginning states or circumstances shared by all members).
It flattens the input to (B * S,...) if it already has an ensemble
dimension (for example, perturbed LBCs with shape (B, S,...)).

Args:
tensor: Tensor of shape (B, ...) or (B, S, ...)
n_members: The ensemble size (S)

Returns:
Tensor of shape (B * n_members, ...)
"""
B = tensor.shape[0]

# Check if tensor already has the ensemble dimension S by some heuristic
# or by strictly passing shapes. Assuming standard (B, S, ...) input for LBCs:
if tensor.dim() >= 2 and tensor.shape[1] == n_members:
# Already has ensemble dimension, flatten B and S
return tensor.view(B * n_members, *tensor.shape[2:])

# No ensemble dimension, repeat interleave so elements stay grouped by batch
# e.g., [b1, b2] -> [b1, b1, b2, b2]
return tensor.repeat_interleave(n_members, dim=0)

def fold_ensemble_batch(tensor: torch.Tensor, n_members: int) -> torch.Tensor:
"""
Folds a batched ensemble tensor back into separated batch and ensemble dimensions.

(B * n_members, ...) -> (B, n_members, ...)

Args:
tensor: Tensor of shape (B * n_members, ...)
n_members: The ensemble size (S)

Returns:
Tensor of shape (B, n_members, ...)
"""
B_times_S = tensor.shape[0]

if B_times_S % n_members != 0:
raise ValueError(f"Batch dimension {B_times_S} is not divisible by ensemble size {n_members}")

B = B_times_S // n_members
return tensor.view(B, n_members, *tensor.shape[1:])
38 changes: 38 additions & 0 deletions tests/test_ensemble_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import torch
from neural_lam.utils import expand_ensemble_batch, fold_ensemble_batch

def test_batched_ensemble_expansion_and_folding():
B = 2 # Batch size
S = 50 # Ensemble size
T, N, F = 3, 10, 4 # Time, Nodes, Features

# 1. Test deterministic state expansion (B, T, N, F) -> (B*S, T, N, F)
init_state = torch.rand(B, T, N, F)
expanded_state = expand_ensemble_batch(init_state, n_members=S)

assert expanded_state.shape == (B * S, T, N, F)
# Ensure grouping is correct: first S elements should all equal init_state[0]
assert torch.allclose(expanded_state[0], init_state[0])
assert torch.allclose(expanded_state[S - 1], init_state[0])
assert torch.allclose(expanded_state[S], init_state[1])

# 2. Test folding back (B*S, T, N, F) -> (B, S, T, N, F)
folded_state = fold_ensemble_batch(expanded_state, n_members=S)

assert folded_state.shape == (B, S, T, N, F)
assert torch.allclose(folded_state[0, 0], init_state[0])
assert torch.allclose(folded_state[1, 49], init_state[1])

# 3. Test with Probabilistic Lateral Boundary Conditions (B, S, T, N, F)
prob_lbc = torch.rand(B, S, T, N, F)
expanded_lbc = expand_ensemble_batch(prob_lbc, n_members=S)

assert expanded_lbc.shape == (B * S, T, N, F)
# The first element of B*S should be the first member of the first batch
assert torch.allclose(expanded_lbc[0], prob_lbc[0, 0])
# The (S)th element should be the first member of the second batch
assert torch.allclose(expanded_lbc[S], prob_lbc[1, 0])
if __name__ == "__main__":
test_batched_ensemble_expansion_and_folding()
print("Test passed successfully!")