Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions neural_lam/models/base_hi_graph_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,8 @@ def process_step(self, mesh_rep):
mesh_rep_levels = [mesh_rep] + [
self.expand_to_batch(emb(node_static_features), batch_size)
for emb, node_static_features in zip(
list(self.mesh_embedders)[1:],
list(self.mesh_static_features)[1:],
self.mesh_embedders[1:],
self.mesh_static_features[1:],
)
]

Expand Down
4 changes: 4 additions & 0 deletions neural_lam/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ def __init__(self, buffer_tensors, persistent=True):
self.register_buffer(f"b{buffer_i}", tensor, persistent=persistent)

def __getitem__(self, key):
if isinstance(key, slice):
return [self[i] for i in range(*key.indices(len(self)))]
if key < 0:
key += len(self)
return getattr(self, f"b{key}")

def __len__(self):
Expand Down
58 changes: 58 additions & 0 deletions tests/test_graph_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from neural_lam.create_graph import create_graph_from_datastore
from neural_lam.datastore import DATASTORES
from neural_lam.datastore.base import BaseRegularGridDatastore
from neural_lam.utils import BufferList
from tests.conftest import init_datastore_example


Expand Down Expand Up @@ -117,3 +118,60 @@ def test_graph_creation(datastore_name, graph_name):
assert r.shape[0] == 2 # adjacency matrix uses two rows
elif file_id.endswith("_features"):
assert r.shape[1] == d_features


class TestBufferList:
"""Tests for BufferList slice and negative index support."""

@pytest.fixture
def buffer_list(self):
tensors = [torch.tensor([float(i)]) for i in range(5)]
return BufferList(tensors)

def test_integer_index(self, buffer_list):
"""Positive integer indexing returns the correct buffer."""
assert torch.equal(buffer_list[0], torch.tensor([0.0]))
assert torch.equal(buffer_list[4], torch.tensor([4.0]))

def test_negative_index(self, buffer_list):
"""Negative indexing follows Python sequence convention."""
assert torch.equal(buffer_list[-1], torch.tensor([4.0]))
assert torch.equal(buffer_list[-3], torch.tensor([2.0]))

def test_slice_full(self, buffer_list):
"""Full slice returns all buffers as a list."""
result = buffer_list[:]
assert len(result) == 5
for i, tensor in enumerate(result):
assert torch.equal(tensor, torch.tensor([float(i)]))

def test_slice_from_index(self, buffer_list):
"""Slice from index returns the correct subset."""
result = buffer_list[2:]
assert len(result) == 3
for i, tensor in enumerate(result):
assert torch.equal(tensor, torch.tensor([float(i + 2)]))

def test_slice_with_step(self, buffer_list):
"""Slice with step skips elements correctly."""
result = buffer_list[::2]
assert len(result) == 3
expected_vals = [0.0, 2.0, 4.0]
for tensor, expected in zip(result, expected_vals):
assert torch.equal(tensor, torch.tensor([expected]))

def test_slice_negative_bounds(self, buffer_list):
"""Slice with negative bounds follows Python convention."""
result = buffer_list[-2:]
assert len(result) == 2
assert torch.equal(result[0], torch.tensor([3.0]))
assert torch.equal(result[1], torch.tensor([4.0]))

def test_len(self, buffer_list):
"""Length reflects number of registered buffers."""
assert len(buffer_list) == 5

def test_iter(self, buffer_list):
"""Iteration yields all buffers in order."""
values = [t.item() for t in buffer_list]
assert values == [0.0, 1.0, 2.0, 3.0, 4.0]