diff --git a/src/weather_model_graphs/save.py b/src/weather_model_graphs/save.py index 74ff5ef..92d9bb7 100644 --- a/src/weather_model_graphs/save.py +++ b/src/weather_model_graphs/save.py @@ -1,6 +1,6 @@ import pickle from pathlib import Path -from typing import List +from typing import Dict, List import networkx from loguru import logger @@ -158,6 +158,306 @@ def _concat_pyg_features( logger.info(f"Saved node features {node_features} to {fp_node_features}.") +def _graph_to_edge_tensors(graph, edge_features=None): + """Convert a single networkx DiGraph to edge_index and edge_features tensors. + + Parameters + ---------- + graph : networkx.DiGraph + Graph to convert, must have integer node labels. + edge_features : list of str, optional + Edge attribute names to include. Default: ["len", "vdiff"]. + + Returns + ------- + edge_index : torch.Tensor + Shape (2, num_edges). + features : torch.Tensor + Shape (num_edges, num_feature_cols). With default features + this is (num_edges, 3) for [len, vdiff_x, vdiff_y]. + """ + if not HAS_PYG: + raise RuntimeError( + "install weather-model-graphs[pytorch] to enable writing to torch files" + ) + + if edge_features is None: + edge_features = ["len", "vdiff"] + + # Strip node attributes to only "pos" and edge attributes to only the + # requested features so that from_networkx does not fail on heterogeneous + # attribute sets (e.g. g2m graphs with grid + mesh nodes). + clean = networkx.DiGraph() + for node, data in sorted(graph.nodes(data=True)): + clean.add_node(node, pos=data["pos"]) + for u, v, data in graph.edges(data=True): + edge_data = {k: data[k] for k in edge_features if k in data} + clean.add_edge(u, v, **edge_data) + + sorted_graph = sort_nodes_in_graph(clean) + pyg_graph = pyg_convert.from_networkx(sorted_graph) + + edge_index = pyg_graph.edge_index + + v_concat = [] + for f in edge_features: + v = pyg_graph[f] + if v.ndim == 1: + v = v.unsqueeze(1) + v_concat.append(v) + features = torch.cat(v_concat, dim=1).to(torch.float32) + + return edge_index, features + + +def _graph_to_node_features(graph, node_features=None): + """Extract node feature tensor from a networkx DiGraph. + + Parameters + ---------- + graph : networkx.DiGraph + Graph with integer node labels and node attributes. + node_features : list of str, optional + Node attribute names to include. Default: ["pos"]. + + Returns + ------- + torch.Tensor + Shape (num_nodes, num_feature_cols). With default features + this is (num_nodes, 2) for [pos_x, pos_y]. + """ + if not HAS_PYG: + raise RuntimeError( + "install weather-model-graphs[pytorch] to enable writing to torch files" + ) + + if node_features is None: + node_features = ["pos"] + + # Strip to only requested node attributes for clean PyG conversion + clean = networkx.DiGraph() + for node, data in sorted(graph.nodes(data=True)): + keep = {k: data[k] for k in node_features if k in data} + clean.add_node(node, **keep) + clean.add_edges_from(graph.edges()) + + sorted_graph = sort_nodes_in_graph(clean) + pyg_graph = pyg_convert.from_networkx(sorted_graph) + + v_concat = [] + for f in node_features: + v = pyg_graph[f] + if v.ndim == 1: + v = v.unsqueeze(1) + v_concat.append(v) + + return torch.cat(v_concat, dim=1).to(torch.float32) + + +def to_neural_lam( + graph_components: Dict[str, networkx.DiGraph], + output_directory: str, + hierarchical: bool = False, +): + """ + Save graph components to the neural-lam tensor-on-disk format. + + Takes graph components as returned by + ``wmg.create.archetype.*(..., return_components=True)`` and writes + ``.pt`` files matching the format expected by + ``neural_lam.utils.load_graph()``. + + Edge features are written **raw** (unnormalized) — neural-lam normalizes + at load time. Mesh node features (positions) are normalized by + ``max(abs(pos))`` before saving, matching the existing neural-lam convention. + + Parameters + ---------- + graph_components : dict of networkx.DiGraph + Dictionary with keys ``"g2m"``, ``"m2m"``, and ``"m2g"``, each mapping + to a directed graph. This is the output of + ``wmg.create.archetype.*(..., return_components=True)``. + output_directory : str + Directory where the ``.pt`` files will be saved. + hierarchical : bool, optional + If True, the m2m graph is expected to contain hierarchical edges + with ``"direction"`` attribute (``"same"``, ``"up"``, ``"down"``). + Additional mesh_up/mesh_down files are written. Default: False. + + Returns + ------- + None + + Notes + ----- + **Output files** (always produced): + + - ``g2m_edge_index.pt`` — ``torch.Tensor`` of shape ``(2, M_g2m)`` + - ``g2m_features.pt`` — ``torch.Tensor`` of shape ``(M_g2m, 3)`` + - ``m2g_edge_index.pt`` — ``torch.Tensor`` of shape ``(2, M_m2g)`` + - ``m2g_features.pt`` — ``torch.Tensor`` of shape ``(M_m2g, 3)`` + - ``m2m_edge_index.pt`` — ``List[torch.Tensor]``, each ``(2, M_l)`` + - ``m2m_features.pt`` — ``List[torch.Tensor]``, each ``(M_l, 3)`` + - ``mesh_features.pt`` — ``List[torch.Tensor]``, each ``(N_l, 2)`` + + **Additional files** (hierarchical only): + + - ``mesh_up_edge_index.pt`` — ``List[torch.Tensor]``, each ``(2, M_up_l)`` + - ``mesh_up_features.pt`` — ``List[torch.Tensor]``, each ``(M_up_l, 3)`` + - ``mesh_down_edge_index.pt`` — ``List[torch.Tensor]``, each ``(2, M_down_l)`` + - ``mesh_down_features.pt`` — ``List[torch.Tensor]``, each ``(M_down_l, 3)`` + + Edge features have 3 columns: ``[len, vdiff_x, vdiff_y]``. + Mesh node features have 2 columns: ``[pos_x, pos_y]`` (normalized). + """ + if not HAS_PYG: + raise RuntimeError( + "install weather-model-graphs[pytorch] to enable writing to torch files" + ) + + required_keys = {"g2m", "m2m", "m2g"} + missing = required_keys - set(graph_components.keys()) + if missing: + raise ValueError( + f"graph_components is missing required keys: {sorted(missing)}. " + f"Expected keys: {sorted(required_keys)}" + ) + + output_dir = Path(output_directory) + output_dir.mkdir(exist_ok=True, parents=True) + + # --- g2m (grid-to-mesh): single tensor --- + g2m_graph = graph_components["g2m"] + g2m_edge_index, g2m_features = _graph_to_edge_tensors(g2m_graph) + torch.save(g2m_edge_index, output_dir / "g2m_edge_index.pt") + torch.save(g2m_features, output_dir / "g2m_features.pt") + logger.info(f"Saved g2m edges: {g2m_edge_index.shape[1]} edges") + + # --- m2g (mesh-to-grid): single tensor --- + m2g_graph = graph_components["m2g"] + m2g_edge_index, m2g_features = _graph_to_edge_tensors(m2g_graph) + torch.save(m2g_edge_index, output_dir / "m2g_edge_index.pt") + torch.save(m2g_features, output_dir / "m2g_features.pt") + logger.info(f"Saved m2g edges: {m2g_edge_index.shape[1]} edges") + + # --- m2m (mesh-to-mesh): list of tensors per level --- + m2m_graph = graph_components["m2m"] + + if hierarchical: + # Split by direction: "same", "up", "down" + direction_subgraphs = split_graph_by_edge_attribute( + m2m_graph, attr="direction" + ) + + # --- Intra-level (same-level) m2m edges --- + same_graph = direction_subgraphs["same"] + try: + level_subgraphs = split_graph_by_edge_attribute( + same_graph, attr="level" + ) + except MissingEdgeAttributeError: + level_subgraphs = {0: same_graph} + sorted_levels = sorted(level_subgraphs.keys()) + + m2m_edge_indices = [] + m2m_features_list = [] + mesh_node_features_list = [] + for level_key in sorted_levels: + sub = level_subgraphs[level_key] + ei, ef = _graph_to_edge_tensors(sub) + nf = _graph_to_node_features(sub) + m2m_edge_indices.append(ei) + m2m_features_list.append(ef) + mesh_node_features_list.append(nf) + + # --- Inter-level up edges --- + up_graph = direction_subgraphs["up"] + try: + up_subgraphs = split_graph_by_edge_attribute(up_graph, attr="levels") + except MissingEdgeAttributeError: + up_subgraphs = {"0": up_graph} + sorted_up_keys = sorted(up_subgraphs.keys()) + + mesh_up_edge_indices = [] + mesh_up_features_list = [] + for key in sorted_up_keys: + ei, ef = _graph_to_edge_tensors(up_subgraphs[key]) + mesh_up_edge_indices.append(ei) + mesh_up_features_list.append(ef) + + # --- Inter-level down edges --- + down_graph = direction_subgraphs["down"] + try: + down_subgraphs = split_graph_by_edge_attribute( + down_graph, attr="levels" + ) + except MissingEdgeAttributeError: + down_subgraphs = {"0": down_graph} + sorted_down_keys = sorted(down_subgraphs.keys()) + + mesh_down_edge_indices = [] + mesh_down_features_list = [] + for key in sorted_down_keys: + ei, ef = _graph_to_edge_tensors(down_subgraphs[key]) + mesh_down_edge_indices.append(ei) + mesh_down_features_list.append(ef) + + # Save hierarchical-only files + torch.save( + mesh_up_edge_indices, output_dir / "mesh_up_edge_index.pt" + ) + torch.save( + mesh_up_features_list, output_dir / "mesh_up_features.pt" + ) + torch.save( + mesh_down_edge_indices, output_dir / "mesh_down_edge_index.pt" + ) + torch.save( + mesh_down_features_list, output_dir / "mesh_down_features.pt" + ) + logger.info( + f"Saved hierarchical mesh_up ({len(mesh_up_edge_indices)} levels) " + f"and mesh_down ({len(mesh_down_edge_indices)} levels)" + ) + + else: + # Non-hierarchical: split by "level" if available, otherwise single list + try: + level_subgraphs = split_graph_by_edge_attribute( + m2m_graph, attr="level" + ) + except MissingEdgeAttributeError: + level_subgraphs = {0: m2m_graph} + sorted_levels = sorted(level_subgraphs.keys()) + + m2m_edge_indices = [] + m2m_features_list = [] + mesh_node_features_list = [] + for level_key in sorted_levels: + sub = level_subgraphs[level_key] + ei, ef = _graph_to_edge_tensors(sub) + nf = _graph_to_node_features(sub) + m2m_edge_indices.append(ei) + m2m_features_list.append(ef) + mesh_node_features_list.append(nf) + + # Save m2m edge tensors (always as lists) + torch.save(m2m_edge_indices, output_dir / "m2m_edge_index.pt") + torch.save(m2m_features_list, output_dir / "m2m_features.pt") + logger.info(f"Saved m2m edges: {len(m2m_edge_indices)} level(s)") + + # --- mesh_features.pt: normalized mesh node positions --- + pos_max = max( + torch.max(torch.abs(nf)) for nf in mesh_node_features_list + ) + mesh_features_normalized = [nf / pos_max for nf in mesh_node_features_list] + torch.save(mesh_features_normalized, output_dir / "mesh_features.pt") + logger.info( + f"Saved mesh_features: {len(mesh_features_normalized)} level(s), " + f"normalized by pos_max={pos_max:.4f}" + ) + + def to_pickle(graph: networkx.DiGraph, output_directory: str, name: str): """ Save the networkx graph to a pickle file. diff --git a/tests/test_save.py b/tests/test_save.py index f889ab9..276ca8f 100644 --- a/tests/test_save.py +++ b/tests/test_save.py @@ -1,6 +1,9 @@ import tempfile +from pathlib import Path +import numpy as np import pytest +import torch from loguru import logger import tests.utils as test_utils @@ -40,3 +43,404 @@ def test_save_to_pyg(list_from_attribute): name=name, list_from_attribute=list_from_attribute, ) + + +# ─── to_neural_lam tests ─── + +# Files expected for all graph types +CORE_FILES = [ + "g2m_edge_index.pt", + "g2m_features.pt", + "m2g_edge_index.pt", + "m2g_features.pt", + "m2m_edge_index.pt", + "m2m_features.pt", + "mesh_features.pt", +] + +# Additional files for hierarchical graphs +HIERARCHICAL_FILES = [ + "mesh_up_edge_index.pt", + "mesh_up_features.pt", + "mesh_down_edge_index.pt", + "mesh_down_features.pt", +] + + +def _skip_if_no_pyg(): + if not HAS_PYG: + pytest.skip("weather-model-graphs[pytorch] not installed") + + +def _create_and_save(archetype, hierarchical, N=64): + """Helper: create graph components and save to temp dir, return path.""" + xy = test_utils.create_fake_xy(N=N) + + if archetype == "keisler": + components = wmg.create.archetype.create_keisler_graph( + coords=xy, return_components=True + ) + elif archetype == "graphcast": + components = wmg.create.archetype.create_graphcast_graph( + coords=xy, return_components=True + ) + elif archetype == "hierarchical": + components = wmg.create.archetype.create_oskarsson_hierarchical_graph( + coords=xy, return_components=True + ) + else: + raise ValueError(f"Unknown archetype: {archetype}") + + tmpdir = tempfile.mkdtemp() + wmg.save.to_neural_lam( + graph_components=components, + output_directory=tmpdir, + hierarchical=hierarchical, + ) + return tmpdir, components + + +class TestToNeuralLamKeisler: + """Tests for to_neural_lam with keisler (flat single-scale) archetype.""" + + def test_core_files_created(self): + _skip_if_no_pyg() + tmpdir, _ = _create_and_save("keisler", hierarchical=False) + for fname in CORE_FILES: + assert (Path(tmpdir) / fname).exists(), f"Missing: {fname}" + # Hierarchical files should NOT exist + for fname in HIERARCHICAL_FILES: + assert not (Path(tmpdir) / fname).exists(), f"Unexpected: {fname}" + + def test_g2m_m2g_are_single_tensors(self): + _skip_if_no_pyg() + tmpdir, _ = _create_and_save("keisler", hierarchical=False) + g2m_ei = torch.load(Path(tmpdir) / "g2m_edge_index.pt", weights_only=True) + g2m_f = torch.load(Path(tmpdir) / "g2m_features.pt", weights_only=True) + m2g_ei = torch.load(Path(tmpdir) / "m2g_edge_index.pt", weights_only=True) + m2g_f = torch.load(Path(tmpdir) / "m2g_features.pt", weights_only=True) + + # Must be plain tensors, not lists + assert isinstance(g2m_ei, torch.Tensor) + assert isinstance(g2m_f, torch.Tensor) + assert isinstance(m2g_ei, torch.Tensor) + assert isinstance(m2g_f, torch.Tensor) + + # Shape checks + assert g2m_ei.shape[0] == 2 + assert g2m_f.shape[1] == 3 + assert m2g_ei.shape[0] == 2 + assert m2g_f.shape[1] == 3 + + # Edge count consistency + assert g2m_ei.shape[1] == g2m_f.shape[0] + assert m2g_ei.shape[1] == m2g_f.shape[0] + + def test_m2m_is_list_of_one(self): + _skip_if_no_pyg() + tmpdir, _ = _create_and_save("keisler", hierarchical=False) + m2m_ei = torch.load(Path(tmpdir) / "m2m_edge_index.pt", weights_only=True) + m2m_f = torch.load(Path(tmpdir) / "m2m_features.pt", weights_only=True) + mesh_f = torch.load(Path(tmpdir) / "mesh_features.pt", weights_only=True) + + # Keisler is single-level, so lists of length 1 + assert isinstance(m2m_ei, list) + assert len(m2m_ei) == 1 + assert isinstance(m2m_f, list) + assert len(m2m_f) == 1 + assert isinstance(mesh_f, list) + assert len(mesh_f) == 1 + + # Shape checks for tensors inside lists + assert m2m_ei[0].shape[0] == 2 + assert m2m_f[0].shape[1] == 3 + assert mesh_f[0].shape[1] == 2 + + # Edge count consistency + assert m2m_ei[0].shape[1] == m2m_f[0].shape[0] + + def test_mesh_features_normalized(self): + _skip_if_no_pyg() + tmpdir, _ = _create_and_save("keisler", hierarchical=False) + mesh_f = torch.load(Path(tmpdir) / "mesh_features.pt", weights_only=True) + # All values should be in [-1, 1] after normalization + for level_f in mesh_f: + assert torch.max(torch.abs(level_f)) <= 1.0 + 1e-6 + + def test_edge_features_are_raw(self): + """Edge features should NOT be normalized (neural-lam normalizes at load time).""" + _skip_if_no_pyg() + tmpdir, _ = _create_and_save("keisler", hierarchical=False) + m2m_f = torch.load(Path(tmpdir) / "m2m_features.pt", weights_only=True) + # Column 0 is edge length — should be > 0 for all edges + for level_f in m2m_f: + assert torch.all(level_f[:, 0] > 0) + + def test_has_nonzero_edges(self): + _skip_if_no_pyg() + tmpdir, _ = _create_and_save("keisler", hierarchical=False) + g2m_ei = torch.load(Path(tmpdir) / "g2m_edge_index.pt", weights_only=True) + m2g_ei = torch.load(Path(tmpdir) / "m2g_edge_index.pt", weights_only=True) + m2m_ei = torch.load(Path(tmpdir) / "m2m_edge_index.pt", weights_only=True) + assert g2m_ei.shape[1] > 0 + assert m2g_ei.shape[1] > 0 + assert m2m_ei[0].shape[1] > 0 + + +class TestToNeuralLamGraphcast: + """Tests for to_neural_lam with graphcast (flat multiscale) archetype.""" + + def test_core_files_created(self): + _skip_if_no_pyg() + tmpdir, _ = _create_and_save("graphcast", hierarchical=False) + for fname in CORE_FILES: + assert (Path(tmpdir) / fname).exists(), f"Missing: {fname}" + for fname in HIERARCHICAL_FILES: + assert not (Path(tmpdir) / fname).exists(), f"Unexpected: {fname}" + + def test_m2m_has_multiple_levels(self): + _skip_if_no_pyg() + tmpdir, _ = _create_and_save("graphcast", hierarchical=False, N=64) + m2m_ei = torch.load(Path(tmpdir) / "m2m_edge_index.pt", weights_only=True) + m2m_f = torch.load(Path(tmpdir) / "m2m_features.pt", weights_only=True) + mesh_f = torch.load(Path(tmpdir) / "mesh_features.pt", weights_only=True) + + assert isinstance(m2m_ei, list) + assert isinstance(m2m_f, list) + assert isinstance(mesh_f, list) + + # All list lengths must match + assert len(m2m_ei) == len(m2m_f) == len(mesh_f) + + # Each level should have correct shapes + for i in range(len(m2m_ei)): + assert m2m_ei[i].shape[0] == 2 + assert m2m_f[i].shape[1] == 3 + assert mesh_f[i].shape[1] == 2 + assert m2m_ei[i].shape[1] == m2m_f[i].shape[0] + + def test_g2m_m2g_single_tensors(self): + _skip_if_no_pyg() + tmpdir, _ = _create_and_save("graphcast", hierarchical=False) + g2m_ei = torch.load(Path(tmpdir) / "g2m_edge_index.pt", weights_only=True) + m2g_ei = torch.load(Path(tmpdir) / "m2g_edge_index.pt", weights_only=True) + assert isinstance(g2m_ei, torch.Tensor) + assert isinstance(m2g_ei, torch.Tensor) + assert g2m_ei.shape[0] == 2 + assert m2g_ei.shape[0] == 2 + + +class TestToNeuralLamHierarchical: + """Tests for to_neural_lam with oskarsson hierarchical archetype.""" + + def test_all_files_created(self): + _skip_if_no_pyg() + tmpdir, _ = _create_and_save("hierarchical", hierarchical=True) + for fname in CORE_FILES + HIERARCHICAL_FILES: + assert (Path(tmpdir) / fname).exists(), f"Missing: {fname}" + + def test_m2m_shapes(self): + _skip_if_no_pyg() + tmpdir, _ = _create_and_save("hierarchical", hierarchical=True) + m2m_ei = torch.load(Path(tmpdir) / "m2m_edge_index.pt", weights_only=True) + m2m_f = torch.load(Path(tmpdir) / "m2m_features.pt", weights_only=True) + mesh_f = torch.load(Path(tmpdir) / "mesh_features.pt", weights_only=True) + + n_levels = len(m2m_ei) + assert n_levels > 1, "Hierarchical graph should have multiple levels" + assert len(m2m_f) == n_levels + assert len(mesh_f) == n_levels + + for i in range(n_levels): + assert m2m_ei[i].shape[0] == 2 + assert m2m_f[i].shape[1] == 3 + assert mesh_f[i].shape[1] == 2 + assert m2m_ei[i].shape[1] == m2m_f[i].shape[0] + + def test_up_down_shapes(self): + _skip_if_no_pyg() + tmpdir, _ = _create_and_save("hierarchical", hierarchical=True) + m2m_ei = torch.load(Path(tmpdir) / "m2m_edge_index.pt", weights_only=True) + up_ei = torch.load(Path(tmpdir) / "mesh_up_edge_index.pt", weights_only=True) + up_f = torch.load(Path(tmpdir) / "mesh_up_features.pt", weights_only=True) + down_ei = torch.load( + Path(tmpdir) / "mesh_down_edge_index.pt", weights_only=True + ) + down_f = torch.load( + Path(tmpdir) / "mesh_down_features.pt", weights_only=True + ) + + n_levels = len(m2m_ei) + # Up/down should have n_levels - 1 entries + assert len(up_ei) == n_levels - 1 + assert len(up_f) == n_levels - 1 + assert len(down_ei) == n_levels - 1 + assert len(down_f) == n_levels - 1 + + for i in range(n_levels - 1): + assert up_ei[i].shape[0] == 2 + assert up_f[i].shape[1] == 3 + assert down_ei[i].shape[0] == 2 + assert down_f[i].shape[1] == 3 + assert up_ei[i].shape[1] == up_f[i].shape[0] + assert down_ei[i].shape[1] == down_f[i].shape[0] + + def test_up_down_have_nonzero_edges(self): + _skip_if_no_pyg() + tmpdir, _ = _create_and_save("hierarchical", hierarchical=True) + up_ei = torch.load(Path(tmpdir) / "mesh_up_edge_index.pt", weights_only=True) + down_ei = torch.load( + Path(tmpdir) / "mesh_down_edge_index.pt", weights_only=True + ) + for i, (u, d) in enumerate(zip(up_ei, down_ei)): + assert u.shape[1] > 0, f"up level {i} has no edges" + assert d.shape[1] > 0, f"down level {i} has no edges" + + def test_mesh_features_normalized(self): + _skip_if_no_pyg() + tmpdir, _ = _create_and_save("hierarchical", hierarchical=True) + mesh_f = torch.load(Path(tmpdir) / "mesh_features.pt", weights_only=True) + for level_f in mesh_f: + assert torch.max(torch.abs(level_f)) <= 1.0 + 1e-6 + + +class TestToNeuralLamEdgeCases: + """Edge case and validation tests.""" + + def test_missing_component_raises(self): + _skip_if_no_pyg() + with tempfile.TemporaryDirectory() as tmpdir: + with pytest.raises(ValueError, match="missing required keys"): + wmg.save.to_neural_lam( + graph_components={"g2m": None, "m2m": None}, + output_directory=tmpdir, + ) + + def test_empty_components_dict_raises(self): + _skip_if_no_pyg() + with tempfile.TemporaryDirectory() as tmpdir: + with pytest.raises(ValueError, match="missing required keys"): + wmg.save.to_neural_lam( + graph_components={}, + output_directory=tmpdir, + ) + + def test_output_dir_created_if_missing(self): + _skip_if_no_pyg() + xy = test_utils.create_fake_xy(N=64) + components = wmg.create.archetype.create_keisler_graph( + coords=xy, return_components=True + ) + with tempfile.TemporaryDirectory() as tmpdir: + nested_dir = Path(tmpdir) / "deeply" / "nested" / "dir" + assert not nested_dir.exists() + wmg.save.to_neural_lam( + graph_components=components, + output_directory=str(nested_dir), + hierarchical=False, + ) + assert nested_dir.exists() + for fname in CORE_FILES: + assert (nested_dir / fname).exists() + + def test_edge_features_have_positive_lengths(self): + """Column 0 of edge features should be edge length > 0 for + intra-component edges. Inter-level (up/down) edges may have + zero-length when coarser nodes coincide with finer nodes.""" + _skip_if_no_pyg() + tmpdir, _ = _create_and_save("hierarchical", hierarchical=True) + + # Intra-component edges: strictly positive lengths + for fname in [ + "g2m_features.pt", + "m2g_features.pt", + "m2m_features.pt", + ]: + data = torch.load(Path(tmpdir) / fname, weights_only=True) + if isinstance(data, list): + for level_f in data: + assert torch.all( + level_f[:, 0] > 0 + ), f"{fname} has non-positive edge lengths" + else: + assert torch.all( + data[:, 0] > 0 + ), f"{fname} has non-positive edge lengths" + + # Inter-level edges: allow zero-length (coincident nodes) + for fname in [ + "mesh_up_features.pt", + "mesh_down_features.pt", + ]: + data = torch.load(Path(tmpdir) / fname, weights_only=True) + for level_f in data: + assert torch.all( + level_f[:, 0] >= 0 + ), f"{fname} has negative edge lengths" + + def test_all_tensors_are_float32(self): + """All feature tensors should be float32.""" + _skip_if_no_pyg() + tmpdir, _ = _create_and_save("keisler", hierarchical=False) + for fname in [ + "g2m_features.pt", + "m2g_features.pt", + "m2m_features.pt", + "mesh_features.pt", + ]: + data = torch.load(Path(tmpdir) / fname, weights_only=True) + if isinstance(data, list): + for t in data: + assert t.dtype == torch.float32, f"{fname} not float32" + else: + assert data.dtype == torch.float32, f"{fname} not float32" + + def test_edge_index_dtype_is_int64(self): + """Edge index tensors should be int64 (standard PyG format).""" + _skip_if_no_pyg() + tmpdir, _ = _create_and_save("keisler", hierarchical=False) + for fname in ["g2m_edge_index.pt", "m2g_edge_index.pt", "m2m_edge_index.pt"]: + data = torch.load(Path(tmpdir) / fname, weights_only=True) + if isinstance(data, list): + for t in data: + assert t.dtype == torch.int64, f"{fname} not int64" + else: + assert data.dtype == torch.int64, f"{fname} not int64" + + def test_rectangular_grid(self): + """Test with non-square grid coordinates.""" + _skip_if_no_pyg() + xy = test_utils.create_rectangular_fake_xy(Nx=40, Ny=80) + components = wmg.create.archetype.create_keisler_graph( + coords=xy, return_components=True + ) + with tempfile.TemporaryDirectory() as tmpdir: + wmg.save.to_neural_lam( + graph_components=components, + output_directory=tmpdir, + hierarchical=False, + ) + for fname in CORE_FILES: + assert (Path(tmpdir) / fname).exists(), f"Missing: {fname}" + + def test_overwrite_existing_files(self): + """Saving twice to same dir should overwrite without error.""" + _skip_if_no_pyg() + xy = test_utils.create_fake_xy(N=64) + components = wmg.create.archetype.create_keisler_graph( + coords=xy, return_components=True + ) + with tempfile.TemporaryDirectory() as tmpdir: + wmg.save.to_neural_lam( + graph_components=components, + output_directory=tmpdir, + hierarchical=False, + ) + # Save again — should not raise + wmg.save.to_neural_lam( + graph_components=components, + output_directory=tmpdir, + hierarchical=False, + ) + for fname in CORE_FILES: + assert (Path(tmpdir) / fname).exists()