diff --git a/CHANGELOG.md b/CHANGELOG.md index 1b1d8e6..1d93fa2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,15 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [unreleased](https://github.com/mllam/weather-model-graphs/compare/v0.3.0...HEAD) + +### Added + +- Add `mesh_layout` argument to mesh graph creation functions, with `rectilinear` + as the first supported layout. Uses a two-step architecture separating coordinate + creation from connectivity creation, enabling future alternative layouts (e.g. triangular). + [\#78](https://github.com/mllam/weather-model-graphs/issues/78), @prajwal-tech07 + ## [v0.3.0](https://github.com/mllam/weather-model-graphs/releases/tag/v0.3.0) ### Added diff --git a/src/weather_model_graphs/create/archetype.py b/src/weather_model_graphs/create/archetype.py index b4d717e..d8ac806 100644 --- a/src/weather_model_graphs/create/archetype.py +++ b/src/weather_model_graphs/create/archetype.py @@ -57,18 +57,20 @@ def create_keisler_graph( """ return create_all_graph_components( coords=coords, - m2m_connectivity="flat", - m2m_connectivity_kwargs=dict(mesh_node_distance=mesh_node_distance), + coords_crs=coords_crs, + graph_crs=graph_crs, + mesh_layout="rectilinear", + mesh_layout_kwargs=dict(mesh_node_spacing=mesh_node_distance), g2m_connectivity="within_radius", - m2g_connectivity="nearest_neighbours", g2m_connectivity_kwargs=dict( rel_max_dist=0.51, ), + m2m_connectivity="flat", + m2m_connectivity_kwargs=dict(pattern="8-star"), + m2g_connectivity="nearest_neighbours", m2g_connectivity_kwargs=dict( max_num_neighbours=4, ), - coords_crs=coords_crs, - graph_crs=graph_crs, decode_mask=decode_mask, return_components=return_components, ) @@ -132,22 +134,26 @@ def create_graphcast_graph( """ return create_all_graph_components( coords=coords, - m2m_connectivity="flat_multiscale", - m2m_connectivity_kwargs=dict( - mesh_node_distance=mesh_node_distance, - level_refinement_factor=level_refinement_factor, - max_num_levels=max_num_levels, + coords_crs=coords_crs, + graph_crs=graph_crs, + mesh_layout="rectilinear", + mesh_layout_kwargs=dict( + mesh_node_spacing=mesh_node_distance, + refinement_factor=level_refinement_factor, + max_num_refinement_levels=max_num_levels, ), g2m_connectivity="within_radius", - m2g_connectivity="nearest_neighbours", g2m_connectivity_kwargs=dict( rel_max_dist=0.51, ), + m2m_connectivity="flat_multiscale", + m2m_connectivity_kwargs=dict( + pattern="8-star", + ), + m2g_connectivity="nearest_neighbours", m2g_connectivity_kwargs=dict( max_num_neighbours=4, ), - coords_crs=coords_crs, - graph_crs=graph_crs, decode_mask=decode_mask, return_components=return_components, ) @@ -216,22 +222,27 @@ def create_oskarsson_hierarchical_graph( """ return create_all_graph_components( coords=coords, - m2m_connectivity="hierarchical", - m2m_connectivity_kwargs=dict( - mesh_node_distance=mesh_node_distance, - level_refinement_factor=level_refinement_factor, - max_num_levels=max_num_levels, + coords_crs=coords_crs, + graph_crs=graph_crs, + mesh_layout="rectilinear", + mesh_layout_kwargs=dict( + mesh_node_spacing=mesh_node_distance, + refinement_factor=level_refinement_factor, + max_num_refinement_levels=max_num_levels, ), g2m_connectivity="within_radius", - m2g_connectivity="nearest_neighbours", g2m_connectivity_kwargs=dict( rel_max_dist=0.51, ), + m2m_connectivity="hierarchical", + m2m_connectivity_kwargs=dict( + intra_level=dict(pattern="8-star"), + inter_level=dict(pattern="nearest", k=1), + ), + m2g_connectivity="nearest_neighbours", m2g_connectivity_kwargs=dict( max_num_neighbours=4, ), - coords_crs=coords_crs, - graph_crs=graph_crs, decode_mask=decode_mask, return_components=return_components, ) diff --git a/src/weather_model_graphs/create/base.py b/src/weather_model_graphs/create/base.py index f922b2d..891e137 100644 --- a/src/weather_model_graphs/create/base.py +++ b/src/weather_model_graphs/create/base.py @@ -8,7 +8,7 @@ function uses `connect_nodes_across_graphs` to connect nodes across the component graphs. """ - +import warnings from typing import Iterable import networkx @@ -25,10 +25,78 @@ ) from .grid import create_grid_graph_nodes from .mesh.kinds.flat import ( - create_flat_multiscale_mesh_graph, - create_flat_singlescale_mesh_graph, + create_flat_multiscale_from_coordinates, + create_flat_singlescale_from_coordinates, +) +from .mesh.kinds.hierarchical import ( + create_hierarchical_from_coordinates, +) +from .mesh.coords import ( + create_multirange_2d_mesh_primitives, + create_single_level_2d_mesh_primitive, ) -from .mesh.kinds.hierarchical import create_hierarchical_multiscale_mesh_graph +from .mesh.kinds.triangular import ( + create_flat_multiscale_from_triangular_coordinates, + create_multirange_2d_triangular_mesh_primitives, + create_single_level_2d_triangular_mesh_primitive, +) + + +def _migrate_deprecated_kwargs(mesh_layout_kwargs, m2m_connectivity_kwargs): + """Migrate old-style kwargs to the new mesh_layout_kwargs structure. + + In the old API, ``mesh_node_distance``, ``level_refinement_factor``, and + ``max_num_levels`` were passed via ``m2m_connectivity_kwargs``. In the new + design these belong in ``mesh_layout_kwargs`` (as ``mesh_node_spacing``, + ``refinement_factor``, and ``max_num_refinement_levels`` respectively). + + This helper emits ``DeprecationWarning`` for each migrated key and moves + the value into *mesh_layout_kwargs*. It is intended to be removed once the + old API is no longer supported. + + Parameters + ---------- + mesh_layout_kwargs : dict + Mutable dict of mesh layout keyword arguments. + m2m_connectivity_kwargs : dict + Mutable dict of m2m connectivity keyword arguments. + + Returns + ------- + tuple[dict, dict] + Updated (mesh_layout_kwargs, m2m_connectivity_kwargs). + """ + if "mesh_node_distance" in m2m_connectivity_kwargs and "mesh_node_spacing" not in mesh_layout_kwargs: + warnings.warn( + "Passing 'mesh_node_distance' in m2m_connectivity_kwargs is deprecated. " + "Use mesh_layout_kwargs=dict(mesh_node_spacing=...) instead.", + DeprecationWarning, + stacklevel=3, + ) + mesh_layout_kwargs["mesh_node_spacing"] = m2m_connectivity_kwargs.pop( + "mesh_node_distance" + ) + if "level_refinement_factor" in m2m_connectivity_kwargs and "refinement_factor" not in mesh_layout_kwargs: + warnings.warn( + "Passing 'level_refinement_factor' in m2m_connectivity_kwargs is deprecated. " + "Use mesh_layout_kwargs=dict(refinement_factor=...) instead.", + DeprecationWarning, + stacklevel=3, + ) + mesh_layout_kwargs["refinement_factor"] = ( + m2m_connectivity_kwargs.pop("level_refinement_factor") + ) + if "max_num_levels" in m2m_connectivity_kwargs and "max_num_refinement_levels" not in mesh_layout_kwargs: + warnings.warn( + "Passing 'max_num_levels' in m2m_connectivity_kwargs is deprecated. " + "Use mesh_layout_kwargs=dict(max_num_refinement_levels=...) instead.", + DeprecationWarning, + stacklevel=3, + ) + mesh_layout_kwargs["max_num_refinement_levels"] = m2m_connectivity_kwargs.pop( + "max_num_levels" + ) + return mesh_layout_kwargs, m2m_connectivity_kwargs def create_all_graph_components( @@ -36,9 +104,11 @@ def create_all_graph_components( m2m_connectivity: str, m2g_connectivity: str, g2m_connectivity: str, - m2m_connectivity_kwargs={}, - m2g_connectivity_kwargs={}, - g2m_connectivity_kwargs={}, + mesh_layout: str, + mesh_layout_kwargs: dict = None, + m2m_connectivity_kwargs: dict = None, + m2g_connectivity_kwargs: dict = None, + g2m_connectivity_kwargs: dict = None, coords_crs: pyproj.crs.CRS | None = None, graph_crs: pyproj.crs.CRS | None = None, decode_mask: Iterable[bool] | None = None, @@ -49,6 +119,14 @@ def create_all_graph_components( grid-to-mesh (g2m), mesh-to-mesh (m2m) and mesh-to-grid (m2g), representing the encode-process-decode respectively. + The mesh graph creation follows a two-step process: + 1. **Coordinate creation** (controlled by `mesh_layout` + `mesh_layout_kwargs`): + Creates an undirected graph (nx.Graph) with node positions and spatial + adjacency edges annotated with adjacency types. + 2. **Connectivity creation** (controlled by `m2m_connectivity` + `m2m_connectivity_kwargs`): + Converts the coordinate graph to directed connectivity (nx.DiGraph) + based on the specified pattern and connectivity method. + For each graph component, the method for connecting nodes across graphs should be specified (with the `*_connectivity` arguments, e.g. `m2g_connectivity`). And the method-specific arguments should be passed as keyword arguments using @@ -62,15 +140,34 @@ def create_all_graph_components( - "within_radius": Find all neighbours in grid within an absolute distance of `max_dist` or relative distance of `rel_max_dist` from each node in mesh + mesh_layout: + - "rectilinear": Uniform regular grid with ``mesh_node_spacing`` resolution. + Produces an undirected mesh primitive with 4-star (cardinal) and + 8-star (cardinal + diagonal) spatial adjacency edges. + - "triangular": Regular triangular lattice with ``mesh_node_spacing`` + resolution. Uses ``networkx.triangular_lattice_graph`` to produce + equilateral triangles with 6-connectivity. A CRS warning is emitted + if ``graph_crs`` is geographic (lat/lon). + + mesh_layout_kwargs (for mesh_layout="rectilinear" or "triangular"): + - mesh_node_spacing: float, distance between mesh nodes in coordinate units. + - refinement_factor: int, refinement factor between levels + (for multi-level and hierarchical mesh graphs, default: 3) + - max_num_refinement_levels: int, maximum number of mesh levels + (for multi-level and hierarchical mesh graphs) + + Wherever the ``pattern`` argument appears below it defines the spatial + neighbourhood connectivity: + - ``"4-star"``: only cardinal directions (horizontal and vertical neighbours) + - ``"8-star"``: cardinal plus diagonal neighbours (all 8 surrounding nodes) + m2m_connectivity: - - "flat": Create a single-level 2D mesh graph with `mesh_node_distance`, - similar to Keisler et al. (2022) - - "flat_multiscale": Create a flat multiscale mesh graph with `max_num_levels`, - `mesh_node_distance` and `level_refinement_factor`, - similar to GraphCast, Lam et al. (2023) - - "hierarchical": Create a hierarchical mesh graph with `max_num_levels`, - `mesh_node_distance` and `level_refinement_factor`, - similar to Oskarsson et al. (2023) + - "flat": Create a single-level directed mesh graph. + m2m_connectivity_kwargs: pattern (default: "8-star") + - "flat_multiscale": Create a flat multiscale mesh graph. + m2m_connectivity_kwargs: pattern (default: "8-star") + - "hierarchical": Create a hierarchical mesh graph with up/down connections. + m2m_connectivity_kwargs: intra_level=dict(pattern=...), inter_level=dict(pattern=..., k=...) m2g_connectivity: - "nearest_neighbour": Find the nearest neighbour in mesh for each node in grid @@ -97,6 +194,29 @@ def create_all_graph_components( """ graph_components: dict[networkx.DiGraph] = {} + # Initialize mutable default arguments (and copy to avoid mutating caller's dicts) + if mesh_layout_kwargs is None: + mesh_layout_kwargs = {} + else: + mesh_layout_kwargs = dict(mesh_layout_kwargs) + if m2m_connectivity_kwargs is None: + m2m_connectivity_kwargs = {} + else: + m2m_connectivity_kwargs = dict(m2m_connectivity_kwargs) + if m2g_connectivity_kwargs is None: + m2g_connectivity_kwargs = {} + else: + m2g_connectivity_kwargs = dict(m2g_connectivity_kwargs) + if g2m_connectivity_kwargs is None: + g2m_connectivity_kwargs = {} + else: + g2m_connectivity_kwargs = dict(g2m_connectivity_kwargs) + + # Migrate deprecated kwargs (to be removed in a future version) + mesh_layout_kwargs, m2m_connectivity_kwargs = _migrate_deprecated_kwargs( + mesh_layout_kwargs, m2m_connectivity_kwargs + ) + assert ( len(coords.shape) == 2 and coords.shape[1] == 2 ), "Grid node coordinates should be given as an array of shape [num_grid_nodes, 2]." @@ -125,28 +245,222 @@ def create_all_graph_components( xy_tuple = coord_transformer.transform(xx=coords[:, 0], yy=coords[:, 1]) xy = np.stack(xy_tuple, axis=1) - if m2m_connectivity == "flat": - graph_components["m2m"] = create_flat_singlescale_mesh_graph( - xy, - **m2m_connectivity_kwargs, + # CRS warning for triangular layout in geographic coordinates + if ( + mesh_layout == "triangular" + and graph_crs is not None + and graph_crs.is_geographic + ): + warnings.warn( + "mesh_layout='triangular' produces non-uniform physical spacing in " + "geographic coordinates. Consider mesh_layout='icosahedral' for " + "uniform coverage on a sphere.", + UserWarning, + stacklevel=2, ) - grid_connect_graph = graph_components["m2m"] + + if m2m_connectivity == "flat": + # --- Step 1: Coordinate creation based on mesh_layout --- + if mesh_layout == "rectilinear": + mesh_node_spacing = mesh_layout_kwargs.get("mesh_node_spacing") + # Backward compat: also check for old name "grid_spacing" + if mesh_node_spacing is None: + mesh_node_spacing = mesh_layout_kwargs.get("grid_spacing") + if mesh_node_spacing is None: + raise ValueError( + "mesh_layout='rectilinear' requires 'mesh_node_spacing' in " + "mesh_layout_kwargs (or 'mesh_node_distance' in " + "m2m_connectivity_kwargs for backward compatibility)." + ) + # Compute number of mesh nodes from mesh_node_spacing + range_x, range_y = np.ptp(xy, axis=0) + nx_mesh = int(range_x / mesh_node_spacing) + ny_mesh = int(range_y / mesh_node_spacing) + if nx_mesh == 0 or ny_mesh == 0: + raise ValueError( + "The given `mesh_node_spacing` is too large for the provided " + f"coordinates. Got mesh_node_spacing={mesh_node_spacing}, but the " + f"x-range is {range_x} and y-range is {range_y}. Maybe you " + "want to decrease the `mesh_node_spacing` so that the mesh nodes " + "are spaced closer together?" + ) + G_mesh_coords = create_single_level_2d_mesh_primitive( + xy, nx_mesh, ny_mesh + ) + + # --- Step 2: Connectivity creation --- + pattern = m2m_connectivity_kwargs.get("pattern", "8-star") + graph_components["m2m"] = create_flat_singlescale_from_coordinates( + G_mesh_coords, pattern=pattern + ) + grid_connect_graph = graph_components["m2m"] + + elif mesh_layout == "triangular": + # --- Step 1: Coordinate creation (triangular lattice) --- + mesh_node_spacing = mesh_layout_kwargs.get("mesh_node_spacing") + if mesh_node_spacing is None: + mesh_node_spacing = mesh_layout_kwargs.get("grid_spacing") + if mesh_node_spacing is None: + raise ValueError( + "mesh_layout='triangular' requires 'mesh_node_spacing' in " + "mesh_layout_kwargs (or 'mesh_node_distance' in " + "m2m_connectivity_kwargs for backward compatibility)." + ) + range_x, range_y = np.ptp(xy, axis=0) + nx_mesh = int(range_x / mesh_node_spacing) + ny_mesh = int(range_y / (mesh_node_spacing * np.sqrt(3) / 2)) + if nx_mesh == 0 or ny_mesh == 0: + raise ValueError( + "The given `mesh_node_spacing` is too large for the provided " + f"coordinates. Got mesh_node_spacing={mesh_node_spacing}, but the " + f"x-range is {range_x} and y-range is {range_y}. Maybe you " + "want to decrease the `mesh_node_spacing` so that the mesh nodes " + "are spaced closer together?" + ) + G_mesh_coords = create_single_level_2d_triangular_mesh_primitive( + xy, nx_mesh, ny_mesh + ) + + # --- Step 2: Connectivity creation --- + pattern = m2m_connectivity_kwargs.get("pattern", "4-star") + graph_components["m2m"] = create_flat_singlescale_from_coordinates( + G_mesh_coords, pattern=pattern + ) + grid_connect_graph = graph_components["m2m"] + + else: + raise NotImplementedError( + f"mesh_layout='{mesh_layout}' is not yet supported. " + "Currently supported: 'rectilinear', 'triangular'." + ) + elif m2m_connectivity == "hierarchical": + # --- Step 1: Coordinate creation based on mesh_layout --- + if mesh_layout == "rectilinear": + mesh_node_spacing = mesh_layout_kwargs.get("mesh_node_spacing") + if mesh_node_spacing is None: + mesh_node_spacing = mesh_layout_kwargs.get("grid_spacing") + refinement_factor = mesh_layout_kwargs.get("refinement_factor", 3) + max_num_refinement_levels = mesh_layout_kwargs.get( + "max_num_refinement_levels" + ) + if mesh_node_spacing is None: + raise ValueError( + "mesh_layout='rectilinear' with m2m_connectivity='hierarchical' " + "requires 'mesh_node_spacing' in mesh_layout_kwargs." + ) + G_coords_list = create_multirange_2d_mesh_primitives( + max_num_levels=max_num_refinement_levels, + xy=xy, + mesh_node_spacing=mesh_node_spacing, + interlevel_refinement_factor=refinement_factor, + ) + + elif mesh_layout == "triangular": + mesh_node_spacing = mesh_layout_kwargs.get("mesh_node_spacing") + if mesh_node_spacing is None: + mesh_node_spacing = mesh_layout_kwargs.get("grid_spacing") + refinement_factor = mesh_layout_kwargs.get("refinement_factor", 3) + max_num_refinement_levels = mesh_layout_kwargs.get( + "max_num_refinement_levels" + ) + if mesh_node_spacing is None: + raise ValueError( + "mesh_layout='triangular' with m2m_connectivity='hierarchical' " + "requires 'mesh_node_spacing' in mesh_layout_kwargs." + ) + G_coords_list = create_multirange_2d_triangular_mesh_primitives( + max_num_levels=max_num_refinement_levels, + xy=xy, + mesh_node_spacing=mesh_node_spacing, + interlevel_refinement_factor=refinement_factor, + ) + + else: + raise NotImplementedError( + f"mesh_layout='{mesh_layout}' is not yet supported. " + "Currently supported: 'rectilinear', 'triangular'." + ) + + # --- Step 2: Connectivity creation --- # hierarchical mesh graph have three sub-graphs: - # `m2m` (mesh-to-mesh), `mesh_up` (up edge connections) and `mesh_down` (down edge connections) - graph_components["m2m"] = create_hierarchical_multiscale_mesh_graph( - xy=xy, - **m2m_connectivity_kwargs, + # `m2m` (mesh-to-mesh), `mesh_up` (up edge connections) and + # `mesh_down` (down edge connections) + intra_level = m2m_connectivity_kwargs.get("intra_level") + if intra_level is None and mesh_layout == "triangular": + intra_level = {"pattern": "4-star"} + graph_components["m2m"] = create_hierarchical_from_coordinates( + G_coords_list, + intra_level=intra_level, + inter_level=m2m_connectivity_kwargs.get("inter_level"), ) # Only connect grid to bottom level of hierarchy grid_connect_graph = split_graph_by_edge_attribute( graph_components["m2m"], "level" )[0] + elif m2m_connectivity == "flat_multiscale": - graph_components["m2m"] = create_flat_multiscale_mesh_graph( - xy=xy, - **m2m_connectivity_kwargs, - ) + # --- Step 1: Coordinate creation based on mesh_layout --- + if mesh_layout == "rectilinear": + mesh_node_spacing = mesh_layout_kwargs.get("mesh_node_spacing") + if mesh_node_spacing is None: + mesh_node_spacing = mesh_layout_kwargs.get("grid_spacing") + refinement_factor = mesh_layout_kwargs.get("refinement_factor", 3) + max_num_refinement_levels = mesh_layout_kwargs.get( + "max_num_refinement_levels" + ) + if mesh_node_spacing is None: + raise ValueError( + "mesh_layout='rectilinear' with m2m_connectivity='flat_multiscale' " + "requires 'mesh_node_spacing' in mesh_layout_kwargs." + ) + G_coords_list = create_multirange_2d_mesh_primitives( + max_num_levels=max_num_refinement_levels, + xy=xy, + mesh_node_spacing=mesh_node_spacing, + interlevel_refinement_factor=refinement_factor, + ) + + elif mesh_layout == "triangular": + mesh_node_spacing = mesh_layout_kwargs.get("mesh_node_spacing") + if mesh_node_spacing is None: + mesh_node_spacing = mesh_layout_kwargs.get("grid_spacing") + refinement_factor = mesh_layout_kwargs.get("refinement_factor", 3) + max_num_refinement_levels = mesh_layout_kwargs.get( + "max_num_refinement_levels" + ) + if mesh_node_spacing is None: + raise ValueError( + "mesh_layout='triangular' with m2m_connectivity='flat_multiscale' " + "requires 'mesh_node_spacing' in mesh_layout_kwargs." + ) + G_coords_list = create_multirange_2d_triangular_mesh_primitives( + max_num_levels=max_num_refinement_levels, + xy=xy, + mesh_node_spacing=mesh_node_spacing, + interlevel_refinement_factor=refinement_factor, + ) + + else: + raise NotImplementedError( + f"mesh_layout='{mesh_layout}' is not yet supported. " + "Currently supported: 'rectilinear', 'triangular'." + ) + + # --- Step 2: Connectivity creation --- + pattern = m2m_connectivity_kwargs.get("pattern") + if pattern is None: + pattern = "4-star" if mesh_layout == "triangular" else "8-star" + if mesh_layout == "triangular": + graph_components["m2m"] = create_flat_multiscale_from_triangular_coordinates( + G_coords_list, + pattern=pattern, + ) + else: + graph_components["m2m"] = create_flat_multiscale_from_coordinates( + G_coords_list, + pattern=pattern, + ) grid_connect_graph = graph_components["m2m"] else: raise NotImplementedError(f"Kind {m2m_connectivity} not implemented") diff --git a/src/weather_model_graphs/create/mesh/__init__.py b/src/weather_model_graphs/create/mesh/__init__.py index 573d9f9..c2d4946 100644 --- a/src/weather_model_graphs/create/mesh/__init__.py +++ b/src/weather_model_graphs/create/mesh/__init__.py @@ -1 +1,16 @@ -from .mesh import create_single_level_2d_mesh_graph +from .coords import ( + create_directed_mesh_graph, + create_multirange_2d_mesh_primitives, + create_single_level_2d_mesh_graph, + create_single_level_2d_mesh_primitive, +) + +from .kinds.triangular import ( + create_flat_multiscale_from_triangular_coordinates, + create_flat_multiscale_triangular_mesh_graph, + create_flat_singlescale_triangular_mesh_graph, + create_hierarchical_triangular_mesh_graph, + create_multirange_2d_triangular_mesh_primitives, + create_single_level_2d_triangular_mesh_graph, + create_single_level_2d_triangular_mesh_primitive, +) diff --git a/src/weather_model_graphs/create/mesh/coords.py b/src/weather_model_graphs/create/mesh/coords.py new file mode 100644 index 0000000..47dd30c --- /dev/null +++ b/src/weather_model_graphs/create/mesh/coords.py @@ -0,0 +1,332 @@ +import networkx +import numpy as np +from loguru import logger + + +def create_single_level_2d_mesh_primitive(xy: np.ndarray, nx: int, ny: int): + """ + Create an undirected mesh primitive graph (nx.Graph) with node positions + and spatial adjacency edges, representing the coordinate creation step. + + A mesh primitive is an undirected graph that encodes all potential + neighbourhood connectivity edges. It serves as a blueprint from which + directed connectivity graphs can later be built by selecting a subset + of edges (e.g. 4-star or 8-star pattern). + + This produces a graph where: + - Nodes have a ``"pos"`` attribute (np.ndarray of shape [2,] with x and y + coordinates) and a ``"type"`` attribute (str, always ``"mesh"``). + - Edges have an ``"adjacency_type"`` attribute (str): ``"cardinal"`` for + horizontal/vertical neighbours (4-star connectivity) or ``"diagonal"`` + for diagonal neighbours (additional edges in 8-star connectivity). + + This is the first step in the two-step mesh creation process: + 1. Coordinate creation (this function) -> nx.Graph with spatial adjacency + 2. Connectivity creation (create_directed_mesh_graph) -> nx.DiGraph + + Parameters + ---------- + xy : np.ndarray + Grid point coordinates, shaped [N_grid_points, 2], with first column + representing x coordinates and second column y coordinates. + nx : int + Number of nodes in x direction + ny : int + Number of nodes in y direction + + Returns + ------- + networkx.Graph + Undirected mesh primitive graph with node positions and annotated + spatial adjacency edges. + """ + xm, xM = np.amin(xy[:, 0]), np.amax(xy[:, 0]) + ym, yM = np.amin(xy[:, 1]), np.amax(xy[:, 1]) + + # avoid nodes on border + dx = (xM - xm) / nx + dy = (yM - ym) / ny + lx = np.linspace(xm + dx / 2, xM - dx / 2, nx) + ly = np.linspace(ym + dy / 2, yM - dy / 2, ny) + + mg = np.meshgrid(lx, ly) + g = networkx.grid_2d_graph(len(lx), len(ly)) + + # Node name and `pos` attribute takes form (x, y) + for node in g.nodes: + node_xi, node_yi = node # Extract x and y index from node to index mx + g.nodes[node]["pos"] = np.array( + [mg[0][node_yi, node_xi], mg[1][node_yi, node_xi]] + ) + g.nodes[node]["type"] = "mesh" + + # Mark existing grid_2d_graph edges as cardinal (4-star adjacency) + for u, v in g.edges(): + g.edges[u, v]["adjacency_type"] = "cardinal" + + # Add diagonal edges (8-star adjacency) + diagonal_edges = [ + ((x, y), (x + 1, y + 1)) for x in range(nx - 1) for y in range(ny - 1) + ] + [((x + 1, y), (x, y + 1)) for x in range(nx - 1) for y in range(ny - 1)] + g.add_edges_from(diagonal_edges) + for u, v in diagonal_edges: + g.edges[u, v]["adjacency_type"] = "diagonal" + + g.graph["dx"] = dx + g.graph["dy"] = dy + + return g + + +def create_directed_mesh_graph( + G_undirected: networkx.Graph, pattern: str = "8-star" +): + """ + Convert an undirected mesh primitive graph with spatial adjacency edges to a + directed mesh graph (nx.DiGraph) based on the specified connectivity pattern. + + This is the second step in the two-step mesh creation process: + 1. Coordinate creation (create_single_level_2d_mesh_primitive) -> nx.Graph + 2. Connectivity creation (this function) -> nx.DiGraph + + The ``pattern`` argument defines the spatial neighbourhood connectivity: + - ``"4-star"``: only cardinal directions (horizontal and vertical neighbours) + - ``"8-star"``: cardinal directions plus diagonals (all 8 surrounding neighbours) + + Parameters + ---------- + G_undirected : networkx.Graph + Undirected mesh primitive graph. Expected node attributes: + - ``"pos"``: np.ndarray of shape [2,], spatial coordinates. + Expected edge attributes: + - ``"adjacency_type"``: str, either ``"cardinal"`` or ``"diagonal"``. + Additional edge attributes (e.g. ``"level"``) are preserved in the + output directed graph. + pattern : str + Connectivity pattern. Options: + - ``"4-star"``: only cardinal edges (horizontal/vertical neighbours) + - ``"8-star"``: all edges (cardinal + diagonal neighbours) + + Returns + ------- + networkx.DiGraph + Directed graph with bidirectional edges, each having ``"len"`` and + ``"vdiff"`` attributes. All original edge attributes from the + primitive graph are preserved. + """ + if pattern == "4-star": + # Filter to only cardinal edges, preserving edge data + edges_to_use = [ + (u, v, d) + for u, v, d in G_undirected.edges(data=True) + if d.get("adjacency_type") == "cardinal" + ] + elif pattern == "8-star": + # Use all edges with their data + edges_to_use = list(G_undirected.edges(data=True)) + else: + raise ValueError( + f"Unknown connectivity pattern: '{pattern}'. " + "Choose '4-star' or '8-star'." + ) + + # Create filtered undirected graph with only selected edges (preserving attrs) + g_filtered = networkx.Graph() + g_filtered.add_nodes_from(G_undirected.nodes(data=True)) + g_filtered.add_edges_from(edges_to_use) + + # Convert to directed graph (creates edges in both directions) + dg = networkx.DiGraph(g_filtered) + for u, v in g_filtered.edges(): + d = np.sqrt( + np.sum( + (G_undirected.nodes[u]["pos"] - G_undirected.nodes[v]["pos"]) ** 2 + ) + ) + dg.edges[u, v]["len"] = d + dg.edges[u, v]["vdiff"] = ( + G_undirected.nodes[u]["pos"] - G_undirected.nodes[v]["pos"] + ) + # Ensure reverse edge exists and has attributes + dg.edges[v, u]["len"] = d + dg.edges[v, u]["vdiff"] = ( + G_undirected.nodes[v]["pos"] - G_undirected.nodes[u]["pos"] + ) + + # Preserve graph-level attributes (dx, dy, level, etc.) + dg.graph.update(G_undirected.graph) + + return dg + + +def create_single_level_2d_mesh_graph(xy, nx, ny): + """ + Create directed graph with nx * ny nodes representing a 2D grid with + positions spanning the range of xy coordinate values (first dimension + is assumed to be x and y coordinate values respectively). Each nodes is + connected to its eight nearest neighbours, both horizontally, vertically + and diagonally as directed edges (which means that the graph contains two + edges between each pair of connected nodes). + + The nodes contain a "pos" attribute with the x and y + coordinates of the node, and an "type" attribute with the + type of the node (i.e. "mesh" for mesh nodes). + + The edges contain a "len" attribute with the length of the edge + and a "vdiff" attribute with the vector difference between the + nodes. + + Internally, this uses the two-step process: + 1. create_single_level_2d_mesh_primitive (coordinate creation) + 2. create_directed_mesh_graph (connectivity creation, pattern="8-star") + + Parameters + ---------- + xy : np.ndarray [N_grid_points, 2] + Grid point coordinates, with first column representing + x coordinates and second column y coordinates. N_grid_points is the + total number of grid points. + nx : int + Number of nodes in x direction + ny : int + Number of nodes in y direction + + Returns + ------- + networkx.DiGraph + Graph representing the 2D grid + """ + G_coords = create_single_level_2d_mesh_primitive(xy, nx, ny) + return create_directed_mesh_graph(G_coords, pattern="8-star") + + +def create_multirange_2d_mesh_primitives( + max_num_levels, xy, mesh_node_spacing=3, interlevel_refinement_factor=3 +): + """ + Create a list of undirected mesh primitive graphs (nx.Graph) representing + different levels of mesh resolution spanning the spatial domain of the + xy coordinates. + + This is the coordinate creation step for multi-level and hierarchical mesh + graphs. Each returned graph contains nodes with spatial positions and edges + annotated with adjacency type (``"cardinal"`` or ``"diagonal"``). + + The graphs can be consumed by connectivity creation functions to produce + directed mesh graphs for flat_multiscale or hierarchical architectures. + + Parameters + ---------- + max_num_levels : int + Number of edge-distance levels in mesh graph + xy : np.ndarray + Grid point coordinates, shaped [N_grid_points, 2] + mesh_node_spacing : float + Distance (in x- and y-direction) between created mesh nodes, + in coordinate system of xy + interlevel_refinement_factor : float + Refinement factor between grid points and bottom level of mesh hierarchy + + Returns + ------- + G_all_levels : list of networkx.Graph + List of undirected mesh primitive graphs for each level, each with + node positions and annotated spatial adjacency edges. + Each graph has ``"level"`` and ``"interlevel_refinement_factor"`` + graph attributes. + """ + # Compute the size along x and y direction of area to cover with graph + # This is measured in the Cartesian coordinates of xy + coord_extent = np.ptp(xy, axis=0) + # Number of nodes that would fit on bottom level of hierarchy, + # in both directions + max_nodes_bottom = (coord_extent / mesh_node_spacing).astype(int) + + # Find the number of mesh levels possible in x- and y-direction, + # and the number of leaf nodes that would correspond to + # max_nodes_bottom/(interlevel_refinement_factor^mesh_levels) = 1 + max_mesh_levels_float = np.log(max_nodes_bottom) / np.log( + interlevel_refinement_factor + ) + + max_mesh_levels = max_mesh_levels_float.astype(int) # (2,) + nleaf = interlevel_refinement_factor**max_mesh_levels + # leaves at the bottom in each direction, if using max_mesh_levels + + # As we can not instantiate different number of mesh levels in each + # direction, create mesh levels corresponding to the minimum of the two + mesh_levels_to_create = max_mesh_levels.min() + + if max_num_levels: + # Limit the levels in mesh graph + mesh_levels_to_create = min(mesh_levels_to_create, max_num_levels) + + logger.debug(f"mesh_levels: {mesh_levels_to_create}, nleaf: {nleaf}") + + # multi resolution tree levels + G_all_levels = [] + for lev in range(mesh_levels_to_create): # 0-index mesh levels + # Compute number of nodes on level separate for each direction + nodes_x, nodes_y = ( + nleaf / (interlevel_refinement_factor**lev) + ).astype(int) + g = create_single_level_2d_mesh_primitive(xy, nodes_x, nodes_y) + # Add level information to nodes, edges and full graph + for node in g.nodes: + g.nodes[node]["level"] = lev + for edge in g.edges: + g.edges[edge]["level"] = lev + g.graph["level"] = lev + # Store refinement factor so connectivity step can use it + g.graph["interlevel_refinement_factor"] = interlevel_refinement_factor + G_all_levels.append(g) + + return G_all_levels + + +def create_multirange_2d_mesh_graphs( + max_num_levels, xy, mesh_node_distance=3, level_refinement_factor=3 +): + """ + Create a list of 2D grid mesh graphs representing different levels of edge-length + scales spanning the spatial domain of the xy coordinates. + This list of graphs can then later be for example a) flattened into single graph + containing multiple ranges of connections or b) combined into a hierarchical graph. + + Each graph in the list contains a "level" attribute with the level index of the graph. + + Internally uses the two-step process: + 1. create_multirange_2d_mesh_primitives (coordinate creation) + 2. create_directed_mesh_graph (connectivity creation, pattern="8-star") + + Parameters + ---------- + max_num_levels : int + Number of edge-distance levels in mesh graph + xy : np.ndarray + Grid point coordinates, shaped [N_grid_points, 2] + mesh_node_distance: float + Distance (in x- and y-direction) between created mesh nodes, + in coordinate system of xy + level_refinement_factor: float + Refinement factor between grid points and bottom level of mesh hierarchy + + Returns + ------- + G_all_levels : list of networkx.Graph + List of networkx graphs for each level representing the connectivity + of the mesh within each level + """ + G_coords_list = create_multirange_2d_mesh_primitives( + max_num_levels=max_num_levels, + xy=xy, + mesh_node_spacing=mesh_node_distance, + interlevel_refinement_factor=level_refinement_factor, + ) + + G_all_levels = [] + for g_coords in G_coords_list: + g_directed = create_directed_mesh_graph(g_coords, pattern="8-star") + G_all_levels.append(g_directed) + + return G_all_levels diff --git a/src/weather_model_graphs/create/mesh/kinds/flat.py b/src/weather_model_graphs/create/mesh/kinds/flat.py index 92f47b0..4329389 100644 --- a/src/weather_model_graphs/create/mesh/kinds/flat.py +++ b/src/weather_model_graphs/create/mesh/kinds/flat.py @@ -1,57 +1,120 @@ +from typing import List + import networkx import numpy as np from ....networkx_utils import prepend_node_index -from .. import mesh as mesh_graph +from .. import coords as mesh_coords -def create_flat_multiscale_mesh_graph( - xy, mesh_node_distance: float, level_refinement_factor: int, max_num_levels: int +def _check_required_graph_attributes(G: networkx.Graph, context: str): + """Check that a coordinate graph has the required node and edge attributes. + + Parameters + ---------- + G : networkx.Graph + The coordinate graph to validate. + context : str + Description of where this check is being called, for error messages. + + Raises + ------ + ValueError + If required attributes are missing. + """ + # Check at least one node has required attributes + if len(G.nodes) > 0: + sample_node = next(iter(G.nodes)) + if "pos" not in G.nodes[sample_node]: + raise ValueError( + f"{context}: coordinate graph nodes must have a 'pos' attribute " + "(np.ndarray of shape [2,])." + ) + if "type" not in G.nodes[sample_node]: + raise ValueError( + f"{context}: coordinate graph nodes must have a 'type' attribute." + ) + # Check at least one edge has required attributes + if len(G.edges) > 0: + sample_edge = next(iter(G.edges)) + if "adjacency_type" not in G.edges[sample_edge]: + raise ValueError( + f"{context}: coordinate graph edges must have an 'adjacency_type' " + "attribute ('cardinal' or 'diagonal')." + ) + + +def create_flat_multiscale_from_coordinates( + G_coords_list: List[networkx.Graph], + pattern: str = "8-star", ): """ - Create flat mesh graph by merging the single-level mesh - graphs across all levels in `G_all_levels`. + Create flat multiscale mesh graph from a list of coordinate graphs. + + This is the connectivity creation step for flat multiscale meshes. + It takes undirected mesh primitive graphs (one per level) and produces a + single directed mesh graph where all levels are merged into one flat graph. + + In a flat multiscale graph, coarser levels are merged into the finer level + by coincident node positions (no separate inter-level connectivity needed). + + The ``pattern`` argument defines the spatial neighbourhood connectivity: + - ``"4-star"``: only cardinal directions (horizontal and vertical neighbours) + - ``"8-star"``: cardinal directions plus diagonals (all 8 surrounding neighbours) Parameters ---------- - xy : np.ndarray [N_grid_points, 2] - Grid point coordinates, with first column representing - x coordinates and second column y coordinates. N_grid_points is the - total number of grid points. - mesh_node_distance: float - Distance (in x- and y-direction) between created mesh nodes, - in coordinate system of xy - level_refinement_factor: int - Refinement factor between grid points and bottom level of mesh hierarchy - NOTE: Must be an odd integer >1 to create proper multiscale graph - max_num_levels : int - Maximum number of levels in the multi-scale graph + G_coords_list : list of networkx.Graph + List of undirected mesh primitive graphs, one per level. Each graph + must have: + - Node attributes: ``"pos"`` (np.ndarray of shape [2,]), ``"type"`` (str) + - Edge attributes: ``"adjacency_type"`` (str, ``"cardinal"`` or ``"diagonal"``) + - Graph attribute: ``"interlevel_refinement_factor"`` (int) + Created by ``create_multirange_2d_mesh_primitives``. + pattern : str + Connectivity pattern for intra-level edges: ``"4-star"`` or ``"8-star"`` + (default: ``"8-star"``) + Returns ------- - G_tot : networkx.Graph - The merged mesh graph + G_tot : networkx.DiGraph + The merged flat multiscale mesh graph """ - # Check that level_refinement_factor is an odd integer + # Validate required attributes on first graph + _check_required_graph_attributes( + G_coords_list[0], "create_flat_multiscale_from_coordinates" + ) + + # Assert interlevel_refinement_factor is set (no silent default) + if "interlevel_refinement_factor" not in G_coords_list[0].graph: + raise ValueError( + "The coordinate graphs must have an 'interlevel_refinement_factor' " + "graph attribute. This is set by create_multirange_2d_mesh_primitives." + ) + interlevel_refinement_factor = G_coords_list[0].graph[ + "interlevel_refinement_factor" + ] + + # Check that interlevel_refinement_factor is an odd integer if ( - int(level_refinement_factor) != level_refinement_factor - or level_refinement_factor % 2 != 1 + int(interlevel_refinement_factor) != interlevel_refinement_factor + or interlevel_refinement_factor % 2 != 1 ): raise ValueError( - "The `level_refinement_factor` must be an odd integer. " - f"Given value: {level_refinement_factor}." + "The `interlevel_refinement_factor` must be an odd integer. " + f"Given value: {interlevel_refinement_factor}." ) - G_all_levels: list[networkx.DiGraph] = mesh_graph.create_multirange_2d_mesh_graphs( - max_num_levels=max_num_levels, - xy=xy, - mesh_node_distance=mesh_node_distance, - level_refinement_factor=level_refinement_factor, - ) + # Convert each level's coordinate graph to directed graph with chosen pattern + G_all_levels = [ + mesh_coords.create_directed_mesh_graph(g_coords, pattern=pattern) + for g_coords in G_coords_list + ] # combine all levels to one graph G_tot = G_all_levels[0] # First node at level l+1 share position with node (offset, offset) at level l - level_offset = level_refinement_factor // 2 + level_offset = interlevel_refinement_factor // 2 first_level_nodes = list(G_all_levels[0].nodes) # Last nodes in first layer has pos (nx-1, ny-1) @@ -63,11 +126,18 @@ def create_flat_multiscale_mesh_graph( ij = ( np.array(nodes) .reshape((num_nodes_x, num_nodes_y, 2))[ - level_offset::level_refinement_factor, - level_offset::level_refinement_factor, + level_offset::interlevel_refinement_factor, + level_offset::interlevel_refinement_factor, :, ] - .reshape(int(num_nodes_x * num_nodes_y / (level_refinement_factor**2)), 2) + .reshape( + int( + num_nodes_x + * num_nodes_y + / (interlevel_refinement_factor**2) + ), + 2, + ) ) ij = [tuple(x) for x in ij] G_all_levels[lev] = networkx.relabel_nodes( @@ -75,9 +145,9 @@ def create_flat_multiscale_mesh_graph( ) G_tot = networkx.compose(G_tot, G_all_levels[lev]) - # Update number of nodes in x- and y-direction for next iteraion - num_nodes_x //= level_refinement_factor - num_nodes_y //= level_refinement_factor + # Update number of nodes in x- and y-direction for next iteration + num_nodes_x //= interlevel_refinement_factor + num_nodes_y //= interlevel_refinement_factor # Relabel mesh nodes to start with 0 G_tot = prepend_node_index(G_tot, 0) @@ -89,10 +159,92 @@ def create_flat_multiscale_mesh_graph( return G_tot +def create_flat_singlescale_from_coordinates( + G_coords: networkx.Graph, pattern: str = "8-star" +): + """ + Create a flat single-scale directed mesh graph from a mesh primitive graph. + + This is the connectivity creation step for flat single-scale meshes. + It converts an undirected mesh primitive graph to a directed mesh graph + using the specified connectivity pattern. + + The ``pattern`` argument defines the spatial neighbourhood connectivity: + - ``"4-star"``: only cardinal directions (horizontal and vertical neighbours) + - ``"8-star"``: cardinal directions plus diagonals (all 8 surrounding neighbours) + + Parameters + ---------- + G_coords : networkx.Graph + Undirected mesh primitive graph. Must have: + - Node attributes: ``"pos"`` (np.ndarray of shape [2,]), ``"type"`` (str) + - Edge attributes: ``"adjacency_type"`` (str, ``"cardinal"`` or ``"diagonal"``) + Created by ``create_single_level_2d_mesh_primitive``. + pattern : str + Connectivity pattern: ``"4-star"`` or ``"8-star"`` (default: ``"8-star"``) + + Returns + ------- + networkx.DiGraph + The flat single-scale directed mesh graph + """ + _check_required_graph_attributes( + G_coords, "create_flat_singlescale_from_coordinates" + ) + return mesh_coords.create_directed_mesh_graph(G_coords, pattern=pattern) + + +def create_flat_multiscale_mesh_graph( + xy, mesh_node_distance: float, level_refinement_factor: int, max_num_levels: int +): + """ + Create flat mesh graph by merging the single-level mesh + graphs across all levels in `G_all_levels`. + + Internally uses the two-step process: + 1. create_multirange_2d_mesh_primitives (coordinate creation) + 2. create_flat_multiscale_from_coordinates (connectivity creation) + + Parameters + ---------- + xy : np.ndarray [N_grid_points, 2] + Grid point coordinates, with first column representing + x coordinates and second column y coordinates. N_grid_points is the + total number of grid points. + mesh_node_distance: float + Distance (in x- and y-direction) between created mesh nodes, + in coordinate system of xy + level_refinement_factor: int + Refinement factor between grid points and bottom level of mesh hierarchy + NOTE: Must be an odd integer >1 to create proper multiscale graph + max_num_levels : int + Maximum number of levels in the multi-scale graph + Returns + ------- + G_tot : networkx.Graph + The merged mesh graph + """ + G_coords_list = mesh_coords.create_multirange_2d_mesh_primitives( + max_num_levels=max_num_levels, + xy=xy, + mesh_node_spacing=mesh_node_distance, + interlevel_refinement_factor=level_refinement_factor, + ) + + return create_flat_multiscale_from_coordinates( + G_coords_list, + pattern="8-star", + ) + + def create_flat_singlescale_mesh_graph(xy, mesh_node_distance: float): """ Create flat mesh graph of single level + Internally uses the two-step process: + 1. create_single_level_2d_mesh_primitive (coordinate creation) + 2. create_directed_mesh_graph (connectivity creation, pattern="8-star") + Parameters ---------- xy : np.ndarray [N_grid_points, 2] @@ -120,4 +272,4 @@ def create_flat_singlescale_mesh_graph(xy, mesh_node_distance: float): " so that the mesh nodes are spaced closer together?" ) - return mesh_graph.create_single_level_2d_mesh_graph(xy=xy, nx=nx, ny=ny) + return mesh_coords.create_single_level_2d_mesh_graph(xy, nx, ny) diff --git a/src/weather_model_graphs/create/mesh/kinds/hierarchical.py b/src/weather_model_graphs/create/mesh/kinds/hierarchical.py index b897693..d1a9ca3 100644 --- a/src/weather_model_graphs/create/mesh/kinds/hierarchical.py +++ b/src/weather_model_graphs/create/mesh/kinds/hierarchical.py @@ -1,67 +1,97 @@ +from typing import Dict, List, Optional + import networkx import numpy as np import scipy from ....networkx_utils import prepend_node_index -from .. import mesh as mesh_graph +from .. import coords as mesh_coords -def create_hierarchical_multiscale_mesh_graph( - xy, - mesh_node_distance: float, - level_refinement_factor: float, - max_num_levels: int, +def create_hierarchical_from_coordinates( + G_coords_list: List[networkx.Graph], + intra_level: Optional[Dict[str, object]] = None, + inter_level: Optional[Dict[str, object]] = None, ): """ - Create a hierarchical multiscale mesh graph with nearest neighbour - connections within each level (horizontally, vertically, and diagonally), and - connections between levels (coarse to fine and fine to coarse) using the - nearest neighbour connection. + Create a hierarchical multiscale mesh graph from a list of mesh primitive + graphs. + + This is the connectivity creation step for hierarchical meshes. + It takes undirected mesh primitive graphs (one per level) and produces a + directed mesh graph with intra-level connectivity and inter-level + up/down connections. + + The ``intra_level["pattern"]`` defines the spatial neighbourhood connectivity + within each mesh level: + - ``"4-star"``: only cardinal directions (horizontal and vertical neighbours) + - ``"8-star"``: cardinal directions plus diagonals (all 8 surrounding neighbours) Parameters ---------- - xy: np.ndarray - 2D array of mesh point positions. - Distance (in x- and y-direction) between created mesh nodes in bottom level, - in coordinate system of xy - mesh_node_distance: float - Distance (in x- and y-direction) between created mesh nodes in bottom level, - in coordinate system of xy - level_refinement_factor: float - Refinement factor between grid points and bottom level of mesh hierarchy - max_num_levels: int - The number of levels in the hierarchical mesh graph. + G_coords_list : list of networkx.Graph + List of undirected mesh primitive graphs, one per level. Each graph + must have: + - Node attributes: ``"pos"`` (np.ndarray of shape [2,]), ``"type"`` (str) + - Edge attributes: ``"adjacency_type"`` (str, ``"cardinal"`` or ``"diagonal"``) + Created by ``create_multirange_2d_mesh_primitives``. + intra_level : dict, optional + Configuration for intra-level connectivity. Keys: + - ``"pattern"`` (str): ``"4-star"`` or ``"8-star"``. + Default: ``{"pattern": "8-star"}`` + inter_level : dict, optional + Configuration for inter-level connectivity. Keys: + - ``"pattern"`` (str): Currently only ``"nearest"`` is supported. + - ``"k"`` (int): Number of nearest neighbours for inter-level connections. + Default: ``{"pattern": "nearest", "k": 1}`` Returns ------- - dict - A dictionary containing the hierarchical mesh graph, the mesh down graph, and - the mesh up graph, with keys "m2m", "mesh_down", and "mesh_up" respectively. + networkx.DiGraph + A directed graph containing the hierarchical mesh with intra-level + edges (direction="same"), inter-level down edges (direction="down"), + and inter-level up edges (direction="up"). """ - Gs_all_levels: list[networkx.DiGraph] = mesh_graph.create_multirange_2d_mesh_graphs( - max_num_levels=max_num_levels, - xy=xy, - mesh_node_distance=mesh_node_distance, - level_refinement_factor=level_refinement_factor, - ) + if intra_level is None: + intra_level = {"pattern": "8-star"} + if inter_level is None: + inter_level = {"pattern": "nearest", "k": 1} + + intra_level_pattern = intra_level.get("pattern", "8-star") + inter_level_pattern = inter_level.get("pattern", "nearest") + inter_level_k = inter_level.get("k", 1) + + if inter_level_pattern != "nearest": + raise NotImplementedError( + f"Inter-level pattern '{inter_level_pattern}' is not yet supported " + "for hierarchical graphs. Only 'nearest' is currently implemented." + ) + + # Convert each level's coordinate graph to directed graph with chosen pattern + Gs_all_levels = [ + mesh_coords.create_directed_mesh_graph( + g_coords, pattern=intra_level_pattern + ) + for g_coords in G_coords_list + ] + n_mesh_levels = len(Gs_all_levels) if n_mesh_levels < 2: raise ValueError( "At least two mesh levels are required for hierarchical mesh graph. " "You may need to reduce the level refinement factor " - f"or increase the max number of levels {max_num_levels} " - f"or number of grid points {xy.shape[0]}." + f"or increase the max number of levels " + f"or number of grid points." ) # Relabel nodes of each level with level index first - Gs_all_levels = [ prepend_node_index(graph, level_i) for level_i, graph in enumerate(Gs_all_levels) ] - # add `direction` attribute to all edges with value `same`` + # add `direction` attribute to all edges with value `same` for i, G in enumerate(Gs_all_levels): for u, v in G.edges: G.edges[u, v]["direction"] = "same" @@ -92,21 +122,29 @@ def create_hierarchical_multiscale_mesh_graph( v_from_xy = np.array([xy for _, xy in G_from.nodes.data("pos")]) kdt_m = scipy.spatial.KDTree(v_from_xy) - # add edges from mesh to grid + # add edges from coarser to finer level for v in v_to_list: - # find 1(?) nearest neighbours (index to vm_xy) - neigh_idx = kdt_m.query(G_down.nodes[v]["pos"], 1)[1] - u = v_from_list[neigh_idx] - - # add edge from mesh to grid - G_down.add_edge(u, v) - d = np.sqrt(np.sum((G_down.nodes[u]["pos"] - G_down.nodes[v]["pos"]) ** 2)) - G_down.edges[u, v]["len"] = d - G_down.edges[u, v]["vdiff"] = ( - G_down.nodes[u]["pos"] - G_down.nodes[v]["pos"] - ) - G_down.edges[u, v]["levels"] = f"{from_level}>{to_level}" - G_down.edges[u, v]["direction"] = "down" + # find k nearest neighbours (index to vm_xy) + neigh_idx = kdt_m.query(G_down.nodes[v]["pos"], inter_level_k)[1] + if inter_level_k == 1: + neigh_idx = [neigh_idx] + + for idx in neigh_idx: + u = v_from_list[idx] + + # add edge from coarser to finer + G_down.add_edge(u, v) + d = np.sqrt( + np.sum( + (G_down.nodes[u]["pos"] - G_down.nodes[v]["pos"]) ** 2 + ) + ) + G_down.edges[u, v]["len"] = d + G_down.edges[u, v]["vdiff"] = ( + G_down.nodes[u]["pos"] - G_down.nodes[v]["pos"] + ) + G_down.edges[u, v]["levels"] = f"{from_level}>{to_level}" + G_down.edges[u, v]["direction"] = "down" G_up = networkx.DiGraph() G_up.add_nodes_from(G_down.nodes(data=True)) @@ -130,3 +168,62 @@ def create_hierarchical_multiscale_mesh_graph( G_m2m.graph[prop] = {i: g.graph[prop] for i, g in enumerate(Gs_all_levels)} return G_m2m + + +def create_hierarchical_multiscale_mesh_graph( + xy: np.ndarray, + mesh_node_distance: float, + level_refinement_factor: float, + max_num_levels: int, + intra_level: Optional[Dict[str, object]] = None, + inter_level: Optional[Dict[str, object]] = None, +): + """ + Create a hierarchical multiscale mesh graph with nearest neighbour + connections within each level (horizontally, vertically, and diagonally), and + connections between levels (coarse to fine and fine to coarse) using the + nearest neighbour connection. + + Internally uses the two-step process: + 1. create_multirange_2d_mesh_primitives (coordinate creation) + 2. create_hierarchical_from_coordinates (connectivity creation) + + Parameters + ---------- + xy : np.ndarray + 2D array of mesh point positions, shaped [N_points, 2]. + mesh_node_distance : float + Distance (in x- and y-direction) between created mesh nodes in bottom level, + in coordinate system of xy + level_refinement_factor : float + Refinement factor between grid points and bottom level of mesh hierarchy + max_num_levels : int + The number of levels in the hierarchical mesh graph. + intra_level : dict, optional + Configuration for intra-level connectivity. Keys: + - ``"pattern"`` (str): ``"4-star"`` or ``"8-star"``. + Default: ``{"pattern": "8-star"}`` + inter_level : dict, optional + Configuration for inter-level connectivity. Keys: + - ``"pattern"`` (str): Currently only ``"nearest"`` is supported. + - ``"k"`` (int): Number of nearest neighbours. + Default: ``{"pattern": "nearest", "k": 1}`` + + Returns + ------- + networkx.DiGraph + A directed graph containing the hierarchical mesh with intra-level, + up, and down edges. + """ + G_coords_list = mesh_coords.create_multirange_2d_mesh_primitives( + max_num_levels=max_num_levels, + xy=xy, + mesh_node_spacing=mesh_node_distance, + interlevel_refinement_factor=level_refinement_factor, + ) + + return create_hierarchical_from_coordinates( + G_coords_list, + intra_level=intra_level, + inter_level=inter_level, + ) diff --git a/src/weather_model_graphs/create/mesh/kinds/triangular.py b/src/weather_model_graphs/create/mesh/kinds/triangular.py new file mode 100644 index 0000000..b6d2d8d --- /dev/null +++ b/src/weather_model_graphs/create/mesh/kinds/triangular.py @@ -0,0 +1,428 @@ +""" +Functions for creating regular triangular mesh graphs. + +Uses ``networkx.triangular_lattice_graph`` to produce an equilateral-triangle +lattice with 6-connectivity (each interior node has 6 neighbours). This +mirrors the rectilinear mesh functions (which use ``networkx.grid_2d_graph`` +with 8-connectivity) and plugs into the same two-step process: + +1. **Coordinate creation** (this module) → ``nx.Graph`` with ``pos``, ``type``, + and ``adjacency_type`` attributes. +2. **Connectivity creation** (``create_directed_mesh_graph``) → ``nx.DiGraph`` + with ``len`` and ``vdiff`` edge attributes. + +Supports flat, flat_multiscale, and hierarchical ``m2m_connectivity`` modes. +""" + +import networkx +import numpy as np +import scipy.spatial +from loguru import logger + +from ....networkx_utils import prepend_node_index +from .. import coords as mesh_coords + + +def create_single_level_2d_triangular_mesh_primitive( + xy: np.ndarray, nx: int, ny: int +): + """ + Create an undirected triangular mesh primitive graph (``nx.Graph``) with + node positions and spatial adjacency edges. + + This is analogous to ``create_single_level_2d_mesh_primitive`` but uses + ``networkx.triangular_lattice_graph`` instead of ``grid_2d_graph``. + + In a triangular lattice, each interior node has 6 neighbours (vs. 8 for + the rectilinear lattice with diagonals), providing more isotropic message + passing. + + The nodes form a grid of ``(ny + 1)`` rows and ``(nx + 1) // 2`` columns, + with odd-row nodes shifted horizontally. Positions are scaled and offset + so that the mesh spans the coordinate domain of *xy* (with nodes inset + from the border by half a cell width in each direction). + + Parameters + ---------- + xy : np.ndarray + Grid point coordinates, shaped ``[N_grid_points, 2]``. + nx : int + Number of triangle columns (passed as *n* to + ``triangular_lattice_graph``). + ny : int + Number of triangle rows (passed as *m* to + ``triangular_lattice_graph``). + + Returns + ------- + networkx.Graph + Undirected mesh primitive graph. Node attributes: ``pos`` + (np.ndarray[2,]), ``type`` (``"mesh"``). Edge attributes: + ``adjacency_type`` (always ``"cardinal"`` — triangular lattices have + only one class of edge). Graph attributes: ``dx``, ``dy``. + """ + xm, xM = np.amin(xy[:, 0]), np.amax(xy[:, 0]) + ym, yM = np.amin(xy[:, 1]), np.amax(xy[:, 1]) + + # Create the raw triangular lattice + g_raw = networkx.triangular_lattice_graph(ny, nx, with_positions=True) + + if g_raw.number_of_nodes() == 0: + raise ValueError( + f"triangular_lattice_graph({ny}, {nx}) produced 0 nodes. " + "Increase nx/ny or decrease mesh_node_spacing." + ) + + # Gather raw positions to compute extent + raw_positions = np.array([g_raw.nodes[n]["pos"] for n in g_raw.nodes()]) + raw_xmin, raw_ymin = raw_positions.min(axis=0) + raw_xmax, raw_ymax = raw_positions.max(axis=0) + raw_extent_x = raw_xmax - raw_xmin + raw_extent_y = raw_ymax - raw_ymin + + # Domain extent with half-cell inset + domain_x = xM - xm + domain_y = yM - ym + + # Scale factors — map raw lattice extent to domain extent (inset by half + # a cell in each direction, mirroring the rectilinear approach) + if raw_extent_x > 0: + scale_x = domain_x / (raw_extent_x + 1.0) # +1 for inset + else: + scale_x = domain_x # single column + if raw_extent_y > 0: + scale_y = domain_y / (raw_extent_y + np.sqrt(3) / 2) # +row_h for inset + else: + scale_y = domain_y # single row + + # Effective dx/dy for graph attributes + dx = scale_x + dy = scale_y * (np.sqrt(3) / 2) + + # Offset so mesh is centred within domain + offset_x = xm + (domain_x - raw_extent_x * scale_x) / 2 + offset_y = ym + (domain_y - raw_extent_y * scale_y) / 2 + + # Build output graph with scaled positions + g = networkx.Graph() + for node in g_raw.nodes(): + raw_pos = g_raw.nodes[node]["pos"] + pos = np.array([ + offset_x + (raw_pos[0] - raw_xmin) * scale_x, + offset_y + (raw_pos[1] - raw_ymin) * scale_y, + ]) + g.add_node(node, pos=pos, type="mesh") + + for u, v in g_raw.edges(): + g.add_edge(u, v, adjacency_type="cardinal") + + g.graph["dx"] = dx + g.graph["dy"] = dy + + return g + + +def create_multirange_2d_triangular_mesh_primitives( + max_num_levels, + xy, + mesh_node_spacing=3, + interlevel_refinement_factor=3, +): + """ + Create a list of undirected triangular mesh primitive graphs representing + different levels of mesh resolution. + + Mirrors ``create_multirange_2d_mesh_primitives`` but uses triangular + lattice topology at each level. + + Parameters + ---------- + max_num_levels : int + Maximum number of levels in the multi-scale graph. + xy : np.ndarray + Grid point coordinates, shaped ``[N_grid_points, 2]``. + mesh_node_spacing : float + Distance between mesh nodes at the finest level, in coordinate units. + interlevel_refinement_factor : float + Factor by which mesh node count decreases per level. + + Returns + ------- + list[networkx.Graph] + Triangular mesh primitive graphs, one per level. + """ + coord_extent = np.ptp(xy, axis=0) + # For triangular lattice, ny accounts for row spacing of sqrt(3)/2 + max_nx = int(coord_extent[0] / mesh_node_spacing) + max_ny = int(coord_extent[1] / (mesh_node_spacing * np.sqrt(3) / 2)) + + max_nodes_bottom = np.array([max_nx, max_ny]) + + max_mesh_levels_float = np.log(max_nodes_bottom) / np.log( + interlevel_refinement_factor + ) + max_mesh_levels = max_mesh_levels_float.astype(int) + nleaf = interlevel_refinement_factor ** max_mesh_levels + + mesh_levels_to_create = max_mesh_levels.min() + if max_num_levels: + mesh_levels_to_create = min(mesh_levels_to_create, max_num_levels) + + logger.debug( + f"triangular mesh_levels: {mesh_levels_to_create}, nleaf: {nleaf}" + ) + + G_all_levels = [] + for lev in range(mesh_levels_to_create): + nodes_x, nodes_y = ( + nleaf / (interlevel_refinement_factor ** lev) + ).astype(int) + g = create_single_level_2d_triangular_mesh_primitive( + xy, nodes_x, nodes_y + ) + for node in g.nodes: + g.nodes[node]["level"] = lev + for edge in g.edges: + g.edges[edge]["level"] = lev + g.graph["level"] = lev + g.graph["interlevel_refinement_factor"] = interlevel_refinement_factor + G_all_levels.append(g) + + return G_all_levels + + +# ── Convenience wrapper functions (mirror flat.py) ── + + +def create_single_level_2d_triangular_mesh_graph(xy, nx, ny): + """ + Create a directed triangular mesh graph from coordinates. + + Internally uses the two-step process: + 1. ``create_single_level_2d_triangular_mesh_primitive`` (coordinate creation) + 2. ``create_directed_mesh_graph`` (connectivity creation) + + For triangular lattices, *all* edges are ``"cardinal"`` so patterns + ``"4-star"`` and ``"8-star"`` produce the same result (6-connectivity). + + Parameters + ---------- + xy : np.ndarray + Grid point coordinates, shaped ``[N_grid_points, 2]``. + nx : int + Number of triangle columns. + ny : int + Number of triangle rows. + + Returns + ------- + networkx.DiGraph + Directed triangular mesh graph. + """ + G_coords = create_single_level_2d_triangular_mesh_primitive(xy, nx, ny) + return mesh_coords.create_directed_mesh_graph(G_coords, pattern="4-star") + + +def create_flat_singlescale_triangular_mesh_graph(xy, mesh_node_distance: float): + """ + Create a flat single-scale triangular mesh graph. + + Mirrors ``create_flat_singlescale_mesh_graph`` but with triangular + lattice topology (6-connectivity). + + Parameters + ---------- + xy : np.ndarray + Grid point coordinates, shaped ``[N_grid_points, 2]``. + mesh_node_distance : float + Approximate side length of equilateral triangles, in coordinate units. + + Returns + ------- + networkx.DiGraph + Flat single-scale directed triangular mesh graph. + """ + range_x, range_y = np.ptp(xy, axis=0) + nx = int(range_x / mesh_node_distance) + ny = int(range_y / (mesh_node_distance * np.sqrt(3) / 2)) + + if nx == 0 or ny == 0: + raise ValueError( + "The given `mesh_node_distance` is too large for the provided " + f"coordinates. Got mesh_node_distance={mesh_node_distance}, but " + f"the x-range is {range_x} and y-range is {range_y}. Maybe you " + "want to decrease the `mesh_node_distance` so that the mesh " + "nodes are spaced closer together?" + ) + + return create_single_level_2d_triangular_mesh_graph(xy, nx, ny) + + +def create_flat_multiscale_from_triangular_coordinates( + G_coords_list, + pattern="4-star", +): + """ + Create flat multiscale mesh graph from a list of triangular coordinate + graphs. + + Unlike the rectilinear variant (``create_flat_multiscale_from_coordinates``) + which relies on grid-index-based coincident-node detection, this function + uses position-based matching. For each coarser level, any node whose + position coincides (within floating-point tolerance) with an existing finer + level node is merged with it, so that multi-resolution edges share the + same node identity. + + Parameters + ---------- + G_coords_list : list[networkx.Graph] + One undirected triangular mesh primitive per level. + pattern : str + Connectivity pattern: ``"4-star"`` or ``"8-star"`` (default ``"4-star"``). + + Returns + ------- + networkx.DiGraph + Flat multiscale triangular mesh graph. + """ + # Convert each level to directed graph + G_directed = [ + mesh_coords.create_directed_mesh_graph(g, pattern=pattern) + for g in G_coords_list + ] + + # Prepend level index to make node labels unique across levels + G_directed = [ + prepend_node_index(g, level_i) + for level_i, g in enumerate(G_directed) + ] + + # Build merged graph, starting from finest level + G_tot = G_directed[0] + + for lev in range(1, len(G_directed)): + G_coarse = G_directed[lev] + + # KDTree of existing (finer) nodes for position matching + fine_nodes = list(G_tot.nodes()) + fine_positions = np.array( + [G_tot.nodes[n]["pos"] for n in fine_nodes] + ) + kdt = scipy.spatial.KDTree(fine_positions) + + # Find which coarse nodes coincide with existing fine nodes + relabel_map = {} + for node in G_coarse.nodes(): + pos = G_coarse.nodes[node]["pos"] + dist, idx = kdt.query(pos) + if dist < 1e-8: + relabel_map[node] = fine_nodes[idx] + + if relabel_map: + G_coarse = networkx.relabel_nodes(G_coarse, relabel_map) + + G_tot = networkx.compose(G_tot, G_coarse) + + # Re-index to sequential (0, i) labels + G_tot = prepend_node_index(G_tot, 0) + + # Preserve dx/dy as per-level dicts + G_tot.graph["dx"] = {i: g.graph["dx"] for i, g in enumerate(G_directed)} + G_tot.graph["dy"] = {i: g.graph["dy"] for i, g in enumerate(G_directed)} + + return G_tot + + +def create_flat_multiscale_triangular_mesh_graph( + xy, + mesh_node_distance: float, + level_refinement_factor: int, + max_num_levels: int, +): + """ + Create a flat multiscale triangular mesh graph. + + Mirrors ``create_flat_multiscale_mesh_graph`` but with triangular lattice + topology at each level. + + Parameters + ---------- + xy : np.ndarray + Grid point coordinates, shaped ``[N_grid_points, 2]``. + mesh_node_distance : float + Approximate side length of equilateral triangles at finest level. + level_refinement_factor : int + Refinement factor between levels (must be an odd integer > 1). + max_num_levels : int + Maximum number of levels in the multiscale graph. + + Returns + ------- + networkx.DiGraph + Flat multiscale directed triangular mesh graph. + """ + G_coords_list = create_multirange_2d_triangular_mesh_primitives( + max_num_levels=max_num_levels, + xy=xy, + mesh_node_spacing=mesh_node_distance, + interlevel_refinement_factor=level_refinement_factor, + ) + + return create_flat_multiscale_from_triangular_coordinates( + G_coords_list, pattern="4-star" + ) + + +def create_hierarchical_triangular_mesh_graph( + xy, + mesh_node_distance: float, + level_refinement_factor: float, + max_num_levels: int, + intra_level=None, + inter_level=None, +): + """ + Create a hierarchical multiscale triangular mesh graph. + + Mirrors ``create_hierarchical_multiscale_mesh_graph`` but with triangular + lattice topology at each level. + + Parameters + ---------- + xy : np.ndarray + Grid point coordinates, shaped ``[N_grid_points, 2]``. + mesh_node_distance : float + Approximate side length of equilateral triangles at finest level. + level_refinement_factor : float + Refinement factor between levels. + max_num_levels : int + Maximum number of levels. + intra_level : dict, optional + Intra-level connectivity config. Default: ``{"pattern": "4-star"}``. + inter_level : dict, optional + Inter-level connectivity config. Default: ``{"pattern": "nearest", "k": 1}``. + + Returns + ------- + networkx.DiGraph + Hierarchical directed triangular mesh graph. + """ + from .hierarchical import create_hierarchical_from_coordinates + + if intra_level is None: + intra_level = {"pattern": "4-star"} + if inter_level is None: + inter_level = {"pattern": "nearest", "k": 1} + + G_coords_list = create_multirange_2d_triangular_mesh_primitives( + max_num_levels=max_num_levels, + xy=xy, + mesh_node_spacing=mesh_node_distance, + interlevel_refinement_factor=level_refinement_factor, + ) + + return create_hierarchical_from_coordinates( + G_coords_list, + intra_level=intra_level, + inter_level=inter_level, + ) diff --git a/src/weather_model_graphs/create/mesh/mesh.py b/src/weather_model_graphs/create/mesh/mesh.py deleted file mode 100644 index 75c82d6..0000000 --- a/src/weather_model_graphs/create/mesh/mesh.py +++ /dev/null @@ -1,150 +0,0 @@ -import networkx -import numpy as np -from loguru import logger - - -def create_single_level_2d_mesh_graph(xy, nx, ny): - """ - Create directed graph with nx * ny nodes representing a 2D grid with - positions spanning the range of xy coordinate values (first dimension - is assumed to be x and y coordinate values respectively). Each nodes is - connected to its eight nearest neighbours, both horizontally, vertically - and diagonally as directed edges (which means that the graph contains two - edges between each pair of connected nodes). - - The nodes contain a "pos" attribute with the x and y - coordinates of the node, and an "type" attribute with the - type of the node (i.e. "mesh" for mesh nodes). - - The edges contain a "len" attribute with the length of the edge - and a "vdiff" attribute with the vector difference between the - nodes. - - Parameters - ---------- - xy : np.ndarray [N_grid_points, 2] - Grid point coordinates, with first column representing - x coordinates and second column y coordinates. N_grid_points is the - total number of grid points. - nx : int - Number of nodes in x direction - ny : int - Number of nodes in y direction - - Returns - ------- - networkx.DiGraph - Graph representing the 2D grid - """ - xm, xM = np.amin(xy[:, 0]), np.amax(xy[:, 0]) - ym, yM = np.amin(xy[:, 1]), np.amax(xy[:, 1]) - - # avoid nodes on border - dx = (xM - xm) / nx - dy = (yM - ym) / ny - lx = np.linspace(xm + dx / 2, xM - dx / 2, nx) - ly = np.linspace(ym + dy / 2, yM - dy / 2, ny) - - mg = np.meshgrid(lx, ly) - g = networkx.grid_2d_graph(len(lx), len(ly)) - - # Node name and `pos` attribute takes form (x, y) - for node in g.nodes: - node_xi, node_yi = node # Extract x and y index from node to index mx - g.nodes[node]["pos"] = np.array( - [mg[0][node_yi, node_xi], mg[1][node_yi, node_xi]] - ) - g.nodes[node]["type"] = "mesh" - - # add diagonal edges - g.add_edges_from( - [((x, y), (x + 1, y + 1)) for x in range(nx - 1) for y in range(ny - 1)] - + [((x + 1, y), (x, y + 1)) for x in range(nx - 1) for y in range(ny - 1)] - ) - - # turn into directed graph - dg = networkx.DiGraph(g) - for u, v in g.edges(): - d = np.sqrt(np.sum((g.nodes[u]["pos"] - g.nodes[v]["pos"]) ** 2)) - dg.edges[u, v]["len"] = d - dg.edges[u, v]["vdiff"] = g.nodes[u]["pos"] - g.nodes[v]["pos"] - dg.add_edge(v, u) - dg.edges[v, u]["len"] = d - dg.edges[v, u]["vdiff"] = g.nodes[v]["pos"] - g.nodes[u]["pos"] - - dg.graph["dx"] = dx - dg.graph["dy"] = dy - - return dg - - -def create_multirange_2d_mesh_graphs( - max_num_levels, xy, mesh_node_distance=3, level_refinement_factor=3 -): - """ - Create a list of 2D grid mesh graphs representing different levels of edge-length - scales spanning the spatial domain of the xy coordinates. - This list of graphs can then later be for example a) flattened into single graph - containing multiple ranges of connections or b) combined into a hierarchical graph. - - Each graph in the list contains a "level" attribute with the level index of the graph. - - Parameters - ---------- - max_num_levels : int - Number of edge-distance levels in mesh graph - xy : np.ndarray - Grid point coordinates, shaped [N_grid_points, 2] - mesh_node_distance: float - Distance (in x- and y-direction) between created mesh nodes, - in coordinate system of xy - level_refinement_factor: float - Refinement factor between grid points and bottom level of mesh hierarchy - - Returns - ------- - G_all_levels : list of networkx.Graph - List of networkx graphs for each level representing the connectivity - of the mesh within each level - """ - # Compute the size along x and y direction of area to cover with graph - # This is measured in the Cartesian coordiantes of xy - coord_extent = np.ptp(xy, axis=0) - # Number of nodes that would fit on bottom level of hierarchy, - # in both directions - max_nodes_bottom = (coord_extent / mesh_node_distance).astype(int) - - # Find the number of mesh levels possible in x- and y-direction, - # and the number of leaf nodes that would correspond to - # max_nodes_bottom/(level_refinement_factor^mesh_levels) = 1 - max_mesh_levels_float = np.log(max_nodes_bottom) / np.log(level_refinement_factor) - - max_mesh_levels = max_mesh_levels_float.astype(int) # (2,) - nleaf = level_refinement_factor**max_mesh_levels - # leaves at the bottom in each direction, if using max_mesh_levels - - # As we can not instantiate different number of mesh levels in each - # direction, create mesh levels corresponding to the minimum of the two - mesh_levels_to_create = max_mesh_levels.min() - - if max_num_levels: - # Limit the levels in mesh graph - mesh_levels_to_create = min(mesh_levels_to_create, max_num_levels) - - logger.debug(f"mesh_levels: {mesh_levels_to_create}, nleaf: {nleaf}") - - # multi resolution tree levels - G_all_levels = [] - for lev in range(mesh_levels_to_create): # 0-index mesh levels - # Compute number of nodes on level separate for each direction - nodes_x, nodes_y = (nleaf / (level_refinement_factor**lev)).astype(int) - g = create_single_level_2d_mesh_graph(xy, nodes_x, nodes_y) - # Add level information to nodes, edges and full graph - for node in g.nodes: - g.nodes[node]["level"] = lev - for edge in g.edges: - g.edges[edge]["level"] = lev - g.graph["level"] = lev - G_all_levels.append(g) - - return G_all_levels diff --git a/tests/test_graph_creation.py b/tests/test_graph_creation.py index b45d486..858299e 100644 --- a/tests/test_graph_creation.py +++ b/tests/test_graph_creation.py @@ -75,6 +75,7 @@ def test_create_graph_generic(m2g_connectivity, g2m_connectivity, m2m_connectivi graph = wmg.create.create_all_graph_components( coords=xy, m2m_connectivity=m2m_connectivity, + mesh_layout="rectilinear", m2m_connectivity_kwargs=m2m_kwargs, g2m_connectivity=g2m_connectivity, g2m_connectivity_kwargs=g2m_kwargs, diff --git a/tests/test_graph_plots.py b/tests/test_graph_plots.py index c5b1e10..eba23b5 100644 --- a/tests/test_graph_plots.py +++ b/tests/test_graph_plots.py @@ -14,6 +14,7 @@ def test_plot(): graph = wmg.create.create_all_graph_components( m2m_connectivity="flat_multiscale", coords=xy, + mesh_layout="rectilinear", m2m_connectivity_kwargs=dict( max_num_levels=3, mesh_node_distance=2, diff --git a/tests/test_mesh_layout.py b/tests/test_mesh_layout.py new file mode 100644 index 0000000..a37b341 --- /dev/null +++ b/tests/test_mesh_layout.py @@ -0,0 +1,1325 @@ +""" +Tests for the mesh_layout parameter and two-step coordinate/connectivity +architecture introduced in Issue #78. + +These tests verify: +1. The new API (mesh_layout, mesh_layout_kwargs with refinement_factor and + max_num_refinement_levels, m2m_connectivity_kwargs with pattern for flat/ + flat_multiscale and intra_level/inter_level sub-dicts for hierarchical) +2. The two-step process (coordinate creation → connectivity creation) +3. The 4-star vs 8-star pattern functionality +4. Backward compatibility with old-style kwargs +5. Edge annotations on coordinate graphs +6. Error handling for invalid inputs +""" + +import warnings + +import networkx as nx +import numpy as np +import pytest + +import tests.utils as test_utils +import weather_model_graphs as wmg +from weather_model_graphs.create.mesh.coords import ( + create_directed_mesh_graph, + create_multirange_2d_mesh_primitives, + create_single_level_2d_mesh_primitive, +) +from weather_model_graphs.create.mesh.kinds.flat import ( + create_flat_multiscale_from_coordinates, + create_flat_singlescale_from_coordinates, +) +from weather_model_graphs.create.mesh.kinds.hierarchical import ( + create_hierarchical_from_coordinates, +) + + +# ==================== +# Step 1: Coordinate creation tests +# ==================== + + +class TestSingleLevelCoordinateCreation: + """Tests for create_single_level_2d_mesh_primitive.""" + + def test_returns_undirected_graph(self): + xy = test_utils.create_fake_xy(N=10) + G = create_single_level_2d_mesh_primitive(xy, nx=5, ny=5) + assert isinstance(G, nx.Graph) + assert not isinstance(G, nx.DiGraph) + + def test_nodes_have_pos_and_type(self): + xy = test_utils.create_fake_xy(N=10) + G = create_single_level_2d_mesh_primitive(xy, nx=4, ny=4) + for node in G.nodes: + assert "pos" in G.nodes[node] + assert "type" in G.nodes[node] + assert G.nodes[node]["type"] == "mesh" + assert len(G.nodes[node]["pos"]) == 2 + + def test_correct_number_of_nodes(self): + xy = test_utils.create_fake_xy(N=10) + G = create_single_level_2d_mesh_primitive(xy, nx=5, ny=4) + assert len(G.nodes) == 5 * 4 + + def test_edges_have_adjacency_type(self): + xy = test_utils.create_fake_xy(N=10) + G = create_single_level_2d_mesh_primitive(xy, nx=5, ny=5) + cardinal_count = 0 + diagonal_count = 0 + for u, v, d in G.edges(data=True): + assert "adjacency_type" in d, f"Edge ({u}, {v}) missing adjacency_type" + assert d["adjacency_type"] in ("cardinal", "diagonal") + if d["adjacency_type"] == "cardinal": + cardinal_count += 1 + else: + diagonal_count += 1 + # For a 5x5 grid: cardinal = 2*(5*4) = 40, diagonal = 2*(4*4) = 32 + assert cardinal_count == 2 * (5 * 4) + assert diagonal_count == 2 * (4 * 4) + + def test_graph_has_dx_dy(self): + xy = test_utils.create_fake_xy(N=10) + G = create_single_level_2d_mesh_primitive(xy, nx=5, ny=5) + assert "dx" in G.graph + assert "dy" in G.graph + assert G.graph["dx"] > 0 + assert G.graph["dy"] > 0 + + +class TestMultirangeCoordinateCreation: + """Tests for create_multirange_2d_mesh_primitives.""" + + def test_returns_list_of_undirected_graphs(self): + xy = test_utils.create_fake_xy(N=30) + G_list = create_multirange_2d_mesh_primitives( + max_num_levels=3, xy=xy, mesh_node_spacing=3, interlevel_refinement_factor=3 + ) + assert isinstance(G_list, list) + assert len(G_list) > 0 + for G in G_list: + assert isinstance(G, nx.Graph) + assert not isinstance(G, nx.DiGraph) + + def test_each_level_has_level_attribute(self): + xy = test_utils.create_fake_xy(N=30) + G_list = create_multirange_2d_mesh_primitives( + max_num_levels=3, xy=xy, mesh_node_spacing=3, interlevel_refinement_factor=3 + ) + for i, G in enumerate(G_list): + assert G.graph["level"] == i + for node in G.nodes: + assert G.nodes[node]["level"] == i + + def test_interlevel_refinement_factor_stored(self): + xy = test_utils.create_fake_xy(N=30) + G_list = create_multirange_2d_mesh_primitives( + max_num_levels=3, xy=xy, mesh_node_spacing=3, interlevel_refinement_factor=3 + ) + for G in G_list: + assert G.graph["interlevel_refinement_factor"] == 3 + + def test_edges_have_adjacency_type(self): + xy = test_utils.create_fake_xy(N=30) + G_list = create_multirange_2d_mesh_primitives( + max_num_levels=2, xy=xy, mesh_node_spacing=3, interlevel_refinement_factor=3 + ) + for G in G_list: + for u, v, d in G.edges(data=True): + assert "adjacency_type" in d + + def test_coarser_levels_have_fewer_nodes(self): + xy = test_utils.create_fake_xy(N=30) + G_list = create_multirange_2d_mesh_primitives( + max_num_levels=3, xy=xy, mesh_node_spacing=3, interlevel_refinement_factor=3 + ) + if len(G_list) >= 2: + for i in range(len(G_list) - 1): + assert len(G_list[i].nodes) > len(G_list[i + 1].nodes) + + +# ==================== +# Step 2: Connectivity creation tests +# ==================== + + +class TestDirectedMeshGraph: + """Tests for create_directed_mesh_graph.""" + + def test_returns_directed_graph(self): + xy = test_utils.create_fake_xy(N=10) + G_coords = create_single_level_2d_mesh_primitive(xy, nx=4, ny=4) + G_directed = create_directed_mesh_graph(G_coords, pattern="8-star") + assert isinstance(G_directed, nx.DiGraph) + + def test_4star_has_fewer_edges_than_8star(self): + xy = test_utils.create_fake_xy(N=10) + G_coords = create_single_level_2d_mesh_primitive(xy, nx=5, ny=5) + G_4star = create_directed_mesh_graph(G_coords, pattern="4-star") + G_8star = create_directed_mesh_graph(G_coords, pattern="8-star") + assert len(G_4star.edges) < len(G_8star.edges) + + def test_4star_only_cardinal_edges(self): + """4-star should only include cardinal (horizontal/vertical) edges.""" + xy = test_utils.create_fake_xy(N=10) + G_coords = create_single_level_2d_mesh_primitive(xy, nx=4, ny=4) + G_4star = create_directed_mesh_graph(G_coords, pattern="4-star") + + # In a 4x4 grid, 4-star adjacency means each node connects only to + # horizontal/vertical neighbours + # For a 4x4 grid: 2 * (4*3 + 3*4) = 2 * 24 = 48 directed edges + expected_edges = 2 * (4 * 3 + 3 * 4) + assert len(G_4star.edges) == expected_edges + + def test_8star_includes_diagonal_edges(self): + """8-star should include both cardinal and diagonal edges.""" + xy = test_utils.create_fake_xy(N=10) + G_coords = create_single_level_2d_mesh_primitive(xy, nx=4, ny=4) + G_8star = create_directed_mesh_graph(G_coords, pattern="8-star") + + # Cardinal: 2 * (4*3 + 3*4) = 48 + # Diagonal: 2 * 2 * (3*3) = 36 + expected_edges = 48 + 36 + assert len(G_8star.edges) == expected_edges + + def test_edges_have_len_and_vdiff(self): + xy = test_utils.create_fake_xy(N=10) + G_coords = create_single_level_2d_mesh_primitive(xy, nx=4, ny=4) + G_directed = create_directed_mesh_graph(G_coords, pattern="8-star") + for u, v, d in G_directed.edges(data=True): + assert "len" in d + assert "vdiff" in d + assert d["len"] > 0 + + def test_invalid_pattern_raises_error(self): + xy = test_utils.create_fake_xy(N=10) + G_coords = create_single_level_2d_mesh_primitive(xy, nx=4, ny=4) + with pytest.raises(ValueError, match="Unknown connectivity pattern"): + create_directed_mesh_graph(G_coords, pattern="6-star") + + def test_preserves_graph_attributes(self): + xy = test_utils.create_fake_xy(N=10) + G_coords = create_single_level_2d_mesh_primitive(xy, nx=4, ny=4) + G_directed = create_directed_mesh_graph(G_coords, pattern="8-star") + assert "dx" in G_directed.graph + assert "dy" in G_directed.graph + + def test_bidirectional_edges(self): + """Each undirected edge should produce two directed edges.""" + xy = test_utils.create_fake_xy(N=10) + G_coords = create_single_level_2d_mesh_primitive(xy, nx=3, ny=3) + G_directed = create_directed_mesh_graph(G_coords, pattern="4-star") + for u, v in G_directed.edges(): + assert G_directed.has_edge(v, u), f"Missing reverse edge ({v}, {u})" + + +class TestFlatSinglescaleFromCoordinates: + """Tests for create_flat_singlescale_from_coordinates.""" + + def test_basic_creation(self): + xy = test_utils.create_fake_xy(N=10) + G_coords = create_single_level_2d_mesh_primitive(xy, nx=5, ny=5) + G = create_flat_singlescale_from_coordinates(G_coords, pattern="8-star") + assert isinstance(G, nx.DiGraph) + + def test_4star_pattern(self): + xy = test_utils.create_fake_xy(N=10) + G_coords = create_single_level_2d_mesh_primitive(xy, nx=5, ny=5) + G = create_flat_singlescale_from_coordinates(G_coords, pattern="4-star") + assert isinstance(G, nx.DiGraph) + # Fewer edges than 8-star + G_8 = create_flat_singlescale_from_coordinates(G_coords, pattern="8-star") + assert len(G.edges) < len(G_8.edges) + + +class TestFlatMultiscaleFromCoordinates: + """Tests for create_flat_multiscale_from_coordinates.""" + + def test_basic_creation(self): + xy = test_utils.create_fake_xy(N=30) + G_coords_list = create_multirange_2d_mesh_primitives( + max_num_levels=3, xy=xy, mesh_node_spacing=3, interlevel_refinement_factor=3 + ) + G = create_flat_multiscale_from_coordinates(G_coords_list) + assert isinstance(G, nx.DiGraph) + + def test_pattern_argument(self): + xy = test_utils.create_fake_xy(N=30) + G_coords_list = create_multirange_2d_mesh_primitives( + max_num_levels=3, xy=xy, mesh_node_spacing=3, interlevel_refinement_factor=3 + ) + G_4star = create_flat_multiscale_from_coordinates( + G_coords_list, + pattern="4-star", + ) + G_8star = create_flat_multiscale_from_coordinates( + G_coords_list, + pattern="8-star", + ) + assert len(G_4star.edges) < len(G_8star.edges) + + +class TestHierarchicalFromCoordinates: + """Tests for create_hierarchical_from_coordinates.""" + + def test_basic_creation(self): + xy = test_utils.create_fake_xy(N=30) + G_coords_list = create_multirange_2d_mesh_primitives( + max_num_levels=3, xy=xy, mesh_node_spacing=3, interlevel_refinement_factor=3 + ) + G = create_hierarchical_from_coordinates(G_coords_list) + assert isinstance(G, nx.DiGraph) + + def test_has_up_down_same_edges(self): + xy = test_utils.create_fake_xy(N=30) + G_coords_list = create_multirange_2d_mesh_primitives( + max_num_levels=3, xy=xy, mesh_node_spacing=3, interlevel_refinement_factor=3 + ) + G = create_hierarchical_from_coordinates(G_coords_list) + directions = set() + for u, v, d in G.edges(data=True): + if "direction" in d: + directions.add(d["direction"]) + assert "same" in directions + assert "up" in directions + assert "down" in directions + + def test_intra_level_pattern(self): + xy = test_utils.create_fake_xy(N=30) + G_coords_list = create_multirange_2d_mesh_primitives( + max_num_levels=3, xy=xy, mesh_node_spacing=3, interlevel_refinement_factor=3 + ) + G_4star = create_hierarchical_from_coordinates( + G_coords_list, + intra_level={"pattern": "4-star"}, + ) + G_8star = create_hierarchical_from_coordinates( + G_coords_list, + intra_level={"pattern": "8-star"}, + ) + assert len(G_4star.edges) < len(G_8star.edges) + + def test_inter_level_k_parameter(self): + xy = test_utils.create_fake_xy(N=30) + G_coords_list = create_multirange_2d_mesh_primitives( + max_num_levels=3, xy=xy, mesh_node_spacing=3, interlevel_refinement_factor=3 + ) + G_k1 = create_hierarchical_from_coordinates( + G_coords_list, + inter_level={"pattern": "nearest", "k": 1}, + ) + G_k3 = create_hierarchical_from_coordinates( + G_coords_list, + inter_level={"pattern": "nearest", "k": 3}, + ) + # More neighbours → more inter-level edges + assert len(G_k3.edges) > len(G_k1.edges) + + def test_invalid_inter_level_pattern_raises(self): + xy = test_utils.create_fake_xy(N=30) + G_coords_list = create_multirange_2d_mesh_primitives( + max_num_levels=3, xy=xy, mesh_node_spacing=3, interlevel_refinement_factor=3 + ) + with pytest.raises(NotImplementedError, match="Inter-level pattern"): + create_hierarchical_from_coordinates( + G_coords_list, + inter_level={"pattern": "some_unknown"}, + ) + + +# ==================== +# New API via create_all_graph_components tests +# ==================== + + +class TestNewAPIFlat: + """Tests for create_all_graph_components with new mesh_layout API (flat).""" + + def test_flat_with_new_api(self): + xy = test_utils.create_fake_xy(N=32) + graph = wmg.create.create_all_graph_components( + coords=xy, + m2m_connectivity="flat", + mesh_layout="rectilinear", + mesh_layout_kwargs=dict(mesh_node_spacing=3), + m2m_connectivity_kwargs=dict(pattern="8-star"), + g2m_connectivity="within_radius", + m2g_connectivity="nearest_neighbours", + g2m_connectivity_kwargs=dict(rel_max_dist=0.51), + m2g_connectivity_kwargs=dict(max_num_neighbours=4), + ) + assert isinstance(graph, nx.DiGraph) + + def test_flat_4star_pattern(self): + xy = test_utils.create_fake_xy(N=32) + graph_4 = wmg.create.create_all_graph_components( + coords=xy, + m2m_connectivity="flat", + mesh_layout="rectilinear", + mesh_layout_kwargs=dict(mesh_node_spacing=3), + m2m_connectivity_kwargs=dict(pattern="4-star"), + g2m_connectivity="nearest_neighbour", + m2g_connectivity="nearest_neighbour", + ) + graph_8 = wmg.create.create_all_graph_components( + coords=xy, + m2m_connectivity="flat", + mesh_layout="rectilinear", + mesh_layout_kwargs=dict(mesh_node_spacing=3), + m2m_connectivity_kwargs=dict(pattern="8-star"), + g2m_connectivity="nearest_neighbour", + m2g_connectivity="nearest_neighbour", + ) + assert isinstance(graph_4, nx.DiGraph) + assert isinstance(graph_8, nx.DiGraph) + assert len(graph_4.edges) < len(graph_8.edges) + + def test_missing_mesh_node_spacing_raises(self): + xy = test_utils.create_fake_xy(N=32) + with pytest.raises(ValueError, match="mesh_node_spacing"): + wmg.create.create_all_graph_components( + coords=xy, + m2m_connectivity="flat", + mesh_layout="rectilinear", + mesh_layout_kwargs={}, + g2m_connectivity="nearest_neighbour", + m2g_connectivity="nearest_neighbour", + ) + + +class TestNewAPIFlatMultiscale: + """Tests for create_all_graph_components with new API (flat_multiscale).""" + + def test_flat_multiscale_with_new_api(self): + xy = test_utils.create_fake_xy(N=32) + graph = wmg.create.create_all_graph_components( + coords=xy, + m2m_connectivity="flat_multiscale", + mesh_layout="rectilinear", + mesh_layout_kwargs=dict( + mesh_node_spacing=3, + refinement_factor=3, + max_num_refinement_levels=3, + ), + m2m_connectivity_kwargs=dict( + pattern="8-star", + ), + g2m_connectivity="within_radius", + m2g_connectivity="nearest_neighbours", + g2m_connectivity_kwargs=dict(rel_max_dist=0.51), + m2g_connectivity_kwargs=dict(max_num_neighbours=4), + ) + assert isinstance(graph, nx.DiGraph) + + def test_flat_multiscale_4star_pattern(self): + xy = test_utils.create_fake_xy(N=32) + graph = wmg.create.create_all_graph_components( + coords=xy, + m2m_connectivity="flat_multiscale", + mesh_layout="rectilinear", + mesh_layout_kwargs=dict( + mesh_node_spacing=3, + refinement_factor=3, + max_num_refinement_levels=3, + ), + m2m_connectivity_kwargs=dict( + pattern="4-star", + ), + g2m_connectivity="nearest_neighbour", + m2g_connectivity="nearest_neighbour", + ) + assert isinstance(graph, nx.DiGraph) + + +class TestNewAPIHierarchical: + """Tests for create_all_graph_components with new API (hierarchical).""" + + def test_hierarchical_with_new_api(self): + xy = test_utils.create_fake_xy(N=32) + graph = wmg.create.create_all_graph_components( + coords=xy, + m2m_connectivity="hierarchical", + mesh_layout="rectilinear", + mesh_layout_kwargs=dict( + mesh_node_spacing=3, + refinement_factor=3, + max_num_refinement_levels=3, + ), + m2m_connectivity_kwargs=dict( + intra_level=dict(pattern="8-star"), + inter_level=dict(pattern="nearest", k=1), + ), + g2m_connectivity="within_radius", + m2g_connectivity="nearest_neighbours", + g2m_connectivity_kwargs=dict(rel_max_dist=0.51), + m2g_connectivity_kwargs=dict(max_num_neighbours=4), + ) + assert isinstance(graph, nx.DiGraph) + + def test_hierarchical_4star_intra(self): + xy = test_utils.create_fake_xy(N=32) + graph = wmg.create.create_all_graph_components( + coords=xy, + m2m_connectivity="hierarchical", + mesh_layout="rectilinear", + mesh_layout_kwargs=dict( + mesh_node_spacing=3, + refinement_factor=3, + max_num_refinement_levels=3, + ), + m2m_connectivity_kwargs=dict( + intra_level=dict(pattern="4-star"), + inter_level=dict(pattern="nearest", k=1), + ), + g2m_connectivity="nearest_neighbour", + m2g_connectivity="nearest_neighbour", + ) + assert isinstance(graph, nx.DiGraph) + + def test_hierarchical_k3_nearest(self): + xy = test_utils.create_fake_xy(N=32) + graph = wmg.create.create_all_graph_components( + coords=xy, + m2m_connectivity="hierarchical", + mesh_layout="rectilinear", + mesh_layout_kwargs=dict( + mesh_node_spacing=3, + refinement_factor=3, + max_num_refinement_levels=3, + ), + m2m_connectivity_kwargs=dict( + intra_level=dict(pattern="8-star"), + inter_level=dict(pattern="nearest", k=3), + ), + g2m_connectivity="nearest_neighbour", + m2g_connectivity="nearest_neighbour", + ) + assert isinstance(graph, nx.DiGraph) + + +# ==================== +# Backward compatibility tests +# ==================== + + +class TestBackwardCompatibility: + """Tests that old-style kwargs still work with deprecation warnings.""" + + def test_old_style_flat_with_mesh_node_distance(self): + xy = test_utils.create_fake_xy(N=32) + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + graph = wmg.create.create_all_graph_components( + coords=xy, + m2m_connectivity="flat", + mesh_layout="rectilinear", + m2m_connectivity_kwargs=dict(mesh_node_distance=3), + g2m_connectivity="nearest_neighbour", + m2g_connectivity="nearest_neighbour", + ) + # Should have deprecation warning + deprecation_warnings = [ + x for x in w if issubclass(x.category, DeprecationWarning) + ] + assert len(deprecation_warnings) >= 1 + assert "mesh_node_distance" in str(deprecation_warnings[0].message) + assert isinstance(graph, nx.DiGraph) + + def test_old_style_flat_multiscale(self): + xy = test_utils.create_fake_xy(N=32) + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + graph = wmg.create.create_all_graph_components( + coords=xy, + m2m_connectivity="flat_multiscale", + mesh_layout="rectilinear", + m2m_connectivity_kwargs=dict( + mesh_node_distance=3, + level_refinement_factor=3, + max_num_levels=3, + ), + g2m_connectivity="nearest_neighbour", + m2g_connectivity="nearest_neighbour", + ) + deprecation_warnings = [ + x for x in w if issubclass(x.category, DeprecationWarning) + ] + assert len(deprecation_warnings) >= 3 # 3 migrated kwargs + assert isinstance(graph, nx.DiGraph) + + def test_old_style_hierarchical(self): + xy = test_utils.create_fake_xy(N=32) + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + graph = wmg.create.create_all_graph_components( + coords=xy, + m2m_connectivity="hierarchical", + mesh_layout="rectilinear", + m2m_connectivity_kwargs=dict( + mesh_node_distance=3, + level_refinement_factor=3, + max_num_levels=3, + ), + g2m_connectivity="nearest_neighbour", + m2g_connectivity="nearest_neighbour", + ) + deprecation_warnings = [ + x for x in w if issubclass(x.category, DeprecationWarning) + ] + assert len(deprecation_warnings) >= 3 + assert isinstance(graph, nx.DiGraph) + + def test_kwargs_dict_not_mutated(self): + """Verify that passing dict kwargs doesn't mutate the caller's dict.""" + xy = test_utils.create_fake_xy(N=32) + original_kwargs = dict( + mesh_node_distance=3, + level_refinement_factor=3, + max_num_levels=3, + ) + kwargs_copy = original_kwargs.copy() + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + wmg.create.create_all_graph_components( + coords=xy, + m2m_connectivity="hierarchical", + mesh_layout="rectilinear", + m2m_connectivity_kwargs=original_kwargs, + g2m_connectivity="nearest_neighbour", + m2g_connectivity="nearest_neighbour", + ) + + # The original dict should not be modified + assert original_kwargs == kwargs_copy + + +# ==================== +# Error handling tests +# ==================== + + +class TestErrorHandling: + """Tests for proper error handling.""" + + def test_unsupported_mesh_layout_raises(self): + xy = test_utils.create_fake_xy(N=32) + with pytest.raises(NotImplementedError, match="not yet supported"): + wmg.create.create_all_graph_components( + coords=xy, + m2m_connectivity="flat", + mesh_layout="hexagonal", + mesh_layout_kwargs=dict(mesh_node_spacing=3), + g2m_connectivity="nearest_neighbour", + m2g_connectivity="nearest_neighbour", + ) + + def test_unsupported_m2m_connectivity_raises(self): + xy = test_utils.create_fake_xy(N=32) + with pytest.raises(NotImplementedError, match="not implemented"): + wmg.create.create_all_graph_components( + coords=xy, + m2m_connectivity="some_unknown", + mesh_layout="rectilinear", + mesh_layout_kwargs=dict(mesh_node_spacing=3), + g2m_connectivity="nearest_neighbour", + m2g_connectivity="nearest_neighbour", + ) + + def test_mesh_node_spacing_too_large_raises(self): + xy = test_utils.create_fake_xy(N=10) + with pytest.raises(ValueError, match="too large"): + wmg.create.create_all_graph_components( + coords=xy, + m2m_connectivity="flat", + mesh_layout="rectilinear", + mesh_layout_kwargs=dict(mesh_node_spacing=100), + g2m_connectivity="nearest_neighbour", + m2g_connectivity="nearest_neighbour", + ) + + +# ==================== +# Equivalence tests: new API == old wrappers +# ==================== + + +class TestEquivalence: + """Verify that the new API produces equivalent results to the old wrappers.""" + + def test_keisler_archetype_matches_new_api(self): + """The keisler archetype function should produce the same result as + calling create_all_graph_components with the new API directly.""" + xy = test_utils.create_fake_xy(N=32) + + graph_archetype = wmg.create.archetype.create_keisler_graph( + coords=xy, mesh_node_distance=3 + ) + + graph_new_api = wmg.create.create_all_graph_components( + coords=xy, + m2m_connectivity="flat", + mesh_layout="rectilinear", + mesh_layout_kwargs=dict(mesh_node_spacing=3), + m2m_connectivity_kwargs=dict(pattern="8-star"), + g2m_connectivity="within_radius", + m2g_connectivity="nearest_neighbours", + g2m_connectivity_kwargs=dict(rel_max_dist=0.51), + m2g_connectivity_kwargs=dict(max_num_neighbours=4), + ) + + assert len(graph_archetype.nodes) == len(graph_new_api.nodes) + assert len(graph_archetype.edges) == len(graph_new_api.edges) + + def test_graphcast_archetype_matches_new_api(self): + xy = test_utils.create_fake_xy(N=32) + + graph_archetype = wmg.create.archetype.create_graphcast_graph( + coords=xy, + mesh_node_distance=3, + level_refinement_factor=3, + max_num_levels=3, + ) + + graph_new_api = wmg.create.create_all_graph_components( + coords=xy, + m2m_connectivity="flat_multiscale", + mesh_layout="rectilinear", + mesh_layout_kwargs=dict( + mesh_node_spacing=3, + refinement_factor=3, + max_num_refinement_levels=3, + ), + m2m_connectivity_kwargs=dict( + pattern="8-star", + ), + g2m_connectivity="within_radius", + m2g_connectivity="nearest_neighbours", + g2m_connectivity_kwargs=dict(rel_max_dist=0.51), + m2g_connectivity_kwargs=dict(max_num_neighbours=4), + ) + + assert len(graph_archetype.nodes) == len(graph_new_api.nodes) + assert len(graph_archetype.edges) == len(graph_new_api.edges) + + def test_oskarsson_archetype_matches_new_api(self): + xy = test_utils.create_fake_xy(N=32) + + graph_archetype = wmg.create.archetype.create_oskarsson_hierarchical_graph( + coords=xy, + mesh_node_distance=3, + level_refinement_factor=3, + max_num_levels=3, + ) + + graph_new_api = wmg.create.create_all_graph_components( + coords=xy, + m2m_connectivity="hierarchical", + mesh_layout="rectilinear", + mesh_layout_kwargs=dict( + mesh_node_spacing=3, + refinement_factor=3, + max_num_refinement_levels=3, + ), + m2m_connectivity_kwargs=dict( + intra_level=dict(pattern="8-star"), + inter_level=dict(pattern="nearest", k=1), + ), + g2m_connectivity="within_radius", + m2g_connectivity="nearest_neighbours", + g2m_connectivity_kwargs=dict(rel_max_dist=0.51), + m2g_connectivity_kwargs=dict(max_num_neighbours=4), + ) + + assert len(graph_archetype.nodes) == len(graph_new_api.nodes) + assert len(graph_archetype.edges) == len(graph_new_api.edges) + + +# ==================== +# Edge case tests +# ==================== + + +class TestCoordinateCreationEdgeCases: + """Edge cases for coordinate creation step.""" + + def test_minimum_grid_2x2(self): + """Smallest possible grid: 2x2 nodes.""" + xy = test_utils.create_fake_xy(N=10) + G = create_single_level_2d_mesh_primitive(xy, nx=2, ny=2) + assert len(G.nodes) == 4 + # 2x2 grid: cardinal edges = 2*(2*1) = 4, diagonal edges = 2*(1*1) = 2 + cardinal = sum( + 1 for _, _, d in G.edges(data=True) if d["adjacency_type"] == "cardinal" + ) + diagonal = sum( + 1 for _, _, d in G.edges(data=True) if d["adjacency_type"] == "diagonal" + ) + assert cardinal == 4 + assert diagonal == 2 + + def test_single_row_grid(self): + """Grid with only 1 row (nx=5, ny=1).""" + xy = test_utils.create_fake_xy(N=10) + G = create_single_level_2d_mesh_primitive(xy, nx=5, ny=1) + assert len(G.nodes) == 5 + # 5x1 grid: only horizontal cardinal edges, no diagonals + cardinal = sum( + 1 for _, _, d in G.edges(data=True) if d["adjacency_type"] == "cardinal" + ) + diagonal = sum( + 1 for _, _, d in G.edges(data=True) if d["adjacency_type"] == "diagonal" + ) + assert cardinal == 4 # 5-1 = 4 horizontal edges + assert diagonal == 0 + + def test_single_column_grid(self): + """Grid with only 1 column (nx=1, ny=5).""" + xy = test_utils.create_fake_xy(N=10) + G = create_single_level_2d_mesh_primitive(xy, nx=1, ny=5) + assert len(G.nodes) == 5 + cardinal = sum( + 1 for _, _, d in G.edges(data=True) if d["adjacency_type"] == "cardinal" + ) + diagonal = sum( + 1 for _, _, d in G.edges(data=True) if d["adjacency_type"] == "diagonal" + ) + assert cardinal == 4 # 5-1 = 4 vertical edges + assert diagonal == 0 + + def test_1x1_grid_no_edges(self): + """Grid with a single node (1x1): should have no edges.""" + xy = test_utils.create_fake_xy(N=10) + G = create_single_level_2d_mesh_primitive(xy, nx=1, ny=1) + assert len(G.nodes) == 1 + assert len(G.edges) == 0 + + def test_large_grid(self): + """Larger grid should still work correctly.""" + xy = test_utils.create_fake_xy(N=50) + G = create_single_level_2d_mesh_primitive(xy, nx=10, ny=10) + assert len(G.nodes) == 100 + expected_cardinal = 2 * (10 * 9) # 180 + expected_diagonal = 2 * (9 * 9) # 162 + cardinal = sum( + 1 for _, _, d in G.edges(data=True) if d["adjacency_type"] == "cardinal" + ) + diagonal = sum( + 1 for _, _, d in G.edges(data=True) if d["adjacency_type"] == "diagonal" + ) + assert cardinal == expected_cardinal + assert diagonal == expected_diagonal + + def test_node_positions_within_bounds(self): + """Node positions should be within the xy bounds.""" + xy = test_utils.create_fake_xy(N=20) + G = create_single_level_2d_mesh_primitive(xy, nx=5, ny=5) + x_min, y_min = np.amin(xy, axis=0) + x_max, y_max = np.amax(xy, axis=0) + for node in G.nodes: + pos = G.nodes[node]["pos"] + assert pos[0] >= x_min and pos[0] <= x_max + assert pos[1] >= y_min and pos[1] <= y_max + + def test_multirange_with_max_levels_1(self): + """Multi-range with max_num_levels=1 should return single-level list.""" + xy = test_utils.create_fake_xy(N=30) + G_list = create_multirange_2d_mesh_primitives( + max_num_levels=1, xy=xy, mesh_node_spacing=3, interlevel_refinement_factor=3 + ) + assert len(G_list) == 1 + assert G_list[0].graph["level"] == 0 + + def test_multirange_with_none_max_levels(self): + """max_num_levels=None should auto-compute levels.""" + xy = test_utils.create_fake_xy(N=30) + G_list = create_multirange_2d_mesh_primitives( + max_num_levels=None, xy=xy, mesh_node_spacing=3, interlevel_refinement_factor=3 + ) + assert isinstance(G_list, list) + assert len(G_list) >= 1 + + def test_multirange_refinement_factor_5(self): + """Test with a different refinement factor.""" + xy = test_utils.create_fake_xy(N=50) + G_list = create_multirange_2d_mesh_primitives( + max_num_levels=3, xy=xy, mesh_node_spacing=3, interlevel_refinement_factor=5 + ) + if len(G_list) >= 2: + for i in range(len(G_list) - 1): + assert len(G_list[i].nodes) > len(G_list[i + 1].nodes) + + +class TestConnectivityCreationEdgeCases: + """Edge cases for connectivity creation step.""" + + def test_directed_graph_from_1x1(self): + """Creating directed graph from a single-node coordinate graph.""" + xy = test_utils.create_fake_xy(N=10) + G_coords = create_single_level_2d_mesh_primitive(xy, nx=1, ny=1) + G_directed = create_directed_mesh_graph(G_coords, pattern="8-star") + assert isinstance(G_directed, nx.DiGraph) + assert len(G_directed.nodes) == 1 + assert len(G_directed.edges) == 0 + + def test_directed_graph_from_2x1(self): + """Creating directed graph from a 2x1 grid.""" + xy = test_utils.create_fake_xy(N=10) + G_coords = create_single_level_2d_mesh_primitive(xy, nx=2, ny=1) + G_4star = create_directed_mesh_graph(G_coords, pattern="4-star") + G_8star = create_directed_mesh_graph(G_coords, pattern="8-star") + # 2x1: 1 edge, both patterns should have same (no diagonals possible) + assert len(G_4star.edges) == 2 # bidirectional + assert len(G_8star.edges) == 2 + + def test_4star_is_subset_of_8star(self): + """All edges in 4-star should exist in 8-star.""" + xy = test_utils.create_fake_xy(N=10) + G_coords = create_single_level_2d_mesh_primitive(xy, nx=5, ny=5) + G_4star = create_directed_mesh_graph(G_coords, pattern="4-star") + G_8star = create_directed_mesh_graph(G_coords, pattern="8-star") + for u, v in G_4star.edges(): + assert G_8star.has_edge(u, v), f"4-star edge ({u},{v}) missing from 8-star" + + def test_edge_lengths_are_positive(self): + """All edge lengths should be positive.""" + xy = test_utils.create_fake_xy(N=10) + G_coords = create_single_level_2d_mesh_primitive(xy, nx=4, ny=4) + G = create_directed_mesh_graph(G_coords, pattern="8-star") + for u, v, d in G.edges(data=True): + assert d["len"] > 0, f"Edge ({u},{v}) has non-positive length" + + def test_vdiff_antisymmetric(self): + """vdiff(u,v) should be -vdiff(v,u).""" + xy = test_utils.create_fake_xy(N=10) + G_coords = create_single_level_2d_mesh_primitive(xy, nx=4, ny=4) + G = create_directed_mesh_graph(G_coords, pattern="8-star") + for u, v in G.edges(): + if G.has_edge(v, u): + np.testing.assert_allclose( + G.edges[u, v]["vdiff"], -G.edges[v, u]["vdiff"], + err_msg=f"vdiff not antisymmetric for ({u},{v})" + ) + + def test_edge_len_matches_vdiff_norm(self): + """Edge length should equal the norm of vdiff.""" + xy = test_utils.create_fake_xy(N=10) + G_coords = create_single_level_2d_mesh_primitive(xy, nx=4, ny=4) + G = create_directed_mesh_graph(G_coords, pattern="8-star") + for u, v, d in G.edges(data=True): + expected_len = np.sqrt(np.sum(d["vdiff"] ** 2)) + np.testing.assert_allclose( + d["len"], expected_len, + err_msg=f"Edge ({u},{v}) len doesn't match vdiff norm" + ) + + def test_flat_multiscale_single_level_input(self): + """Flat multiscale with a single-level list should work (degenerate case).""" + xy = test_utils.create_fake_xy(N=30) + G_coords_list = create_multirange_2d_mesh_primitives( + max_num_levels=1, xy=xy, mesh_node_spacing=3, interlevel_refinement_factor=3 + ) + G = create_flat_multiscale_from_coordinates(G_coords_list, pattern="8-star") + assert isinstance(G, nx.DiGraph) + assert len(G.nodes) > 0 + + def test_flat_multiscale_4star_vs_8star_edge_count(self): + """4-star flat_multiscale should have fewer edges than 8-star.""" + xy = test_utils.create_fake_xy(N=30) + G_coords_list = create_multirange_2d_mesh_primitives( + max_num_levels=3, xy=xy, mesh_node_spacing=3, interlevel_refinement_factor=3 + ) + G_4 = create_flat_multiscale_from_coordinates(G_coords_list, pattern="4-star") + G_8 = create_flat_multiscale_from_coordinates(G_coords_list, pattern="8-star") + assert len(G_4.edges) < len(G_8.edges) + + def test_hierarchical_single_level_raises(self): + """Hierarchical with only 1 level should raise ValueError.""" + xy = test_utils.create_fake_xy(N=30) + G_coords_list = create_multirange_2d_mesh_primitives( + max_num_levels=1, xy=xy, mesh_node_spacing=3, interlevel_refinement_factor=3 + ) + with pytest.raises(ValueError, match="At least two mesh levels"): + create_hierarchical_from_coordinates(G_coords_list) + + def test_hierarchical_edge_direction_attributes(self): + """Every edge in hierarchical graph must have a 'direction' attribute.""" + xy = test_utils.create_fake_xy(N=30) + G_coords_list = create_multirange_2d_mesh_primitives( + max_num_levels=3, xy=xy, mesh_node_spacing=3, interlevel_refinement_factor=3 + ) + G = create_hierarchical_from_coordinates(G_coords_list) + for u, v, d in G.edges(data=True): + assert "direction" in d, f"Edge ({u},{v}) missing 'direction'" + assert d["direction"] in ("same", "up", "down") + + def test_hierarchical_up_down_symmetry(self): + """For each 'down' edge (u,v), there should be an 'up' edge (v,u).""" + xy = test_utils.create_fake_xy(N=30) + G_coords_list = create_multirange_2d_mesh_primitives( + max_num_levels=3, xy=xy, mesh_node_spacing=3, interlevel_refinement_factor=3 + ) + G = create_hierarchical_from_coordinates(G_coords_list) + for u, v, d in G.edges(data=True): + if d.get("direction") == "down": + assert G.has_edge(v, u), f"Missing 'up' edge for 'down' ({u},{v})" + assert G.edges[v, u]["direction"] == "up" + + +class TestAPIEdgeCases: + """Edge cases for the public create_all_graph_components API.""" + + def test_flat_default_pattern_is_8star(self): + """When no m2m_connectivity_kwargs given, flat should default to 8-star.""" + xy = test_utils.create_fake_xy(N=32) + graph_default = wmg.create.create_all_graph_components( + coords=xy, + m2m_connectivity="flat", + mesh_layout="rectilinear", + mesh_layout_kwargs=dict(mesh_node_spacing=3), + g2m_connectivity="nearest_neighbour", + m2g_connectivity="nearest_neighbour", + ) + graph_8star = wmg.create.create_all_graph_components( + coords=xy, + m2m_connectivity="flat", + mesh_layout="rectilinear", + mesh_layout_kwargs=dict(mesh_node_spacing=3), + m2m_connectivity_kwargs=dict(pattern="8-star"), + g2m_connectivity="nearest_neighbour", + m2g_connectivity="nearest_neighbour", + ) + assert len(graph_default.edges) == len(graph_8star.edges) + + def test_flat_multiscale_default_pattern_is_8star(self): + """When no m2m_connectivity_kwargs given, flat_multiscale defaults to 8-star.""" + xy = test_utils.create_fake_xy(N=32) + graph_default = wmg.create.create_all_graph_components( + coords=xy, + m2m_connectivity="flat_multiscale", + mesh_layout="rectilinear", + mesh_layout_kwargs=dict( + mesh_node_spacing=3, + refinement_factor=3, + max_num_refinement_levels=3, + ), + g2m_connectivity="nearest_neighbour", + m2g_connectivity="nearest_neighbour", + ) + graph_8star = wmg.create.create_all_graph_components( + coords=xy, + m2m_connectivity="flat_multiscale", + mesh_layout="rectilinear", + mesh_layout_kwargs=dict( + mesh_node_spacing=3, + refinement_factor=3, + max_num_refinement_levels=3, + ), + m2m_connectivity_kwargs=dict(pattern="8-star"), + g2m_connectivity="nearest_neighbour", + m2g_connectivity="nearest_neighbour", + ) + assert len(graph_default.edges) == len(graph_8star.edges) + + def test_hierarchical_default_kwargs(self): + """When no m2m_connectivity_kwargs given, hierarchical has sensible defaults.""" + xy = test_utils.create_fake_xy(N=32) + graph_default = wmg.create.create_all_graph_components( + coords=xy, + m2m_connectivity="hierarchical", + mesh_layout="rectilinear", + mesh_layout_kwargs=dict( + mesh_node_spacing=3, + refinement_factor=3, + max_num_refinement_levels=3, + ), + g2m_connectivity="nearest_neighbour", + m2g_connectivity="nearest_neighbour", + ) + graph_explicit = wmg.create.create_all_graph_components( + coords=xy, + m2m_connectivity="hierarchical", + mesh_layout="rectilinear", + mesh_layout_kwargs=dict( + mesh_node_spacing=3, + refinement_factor=3, + max_num_refinement_levels=3, + ), + m2m_connectivity_kwargs=dict( + intra_level=dict(pattern="8-star"), + inter_level=dict(pattern="nearest", k=1), + ), + g2m_connectivity="nearest_neighbour", + m2g_connectivity="nearest_neighbour", + ) + assert len(graph_default.edges) == len(graph_explicit.edges) + + def test_mesh_layout_is_required(self): + """When mesh_layout not specified, it should raise TypeError.""" + xy = test_utils.create_fake_xy(N=32) + with pytest.raises(TypeError, match="mesh_layout"): + wmg.create.create_all_graph_components( + coords=xy, + m2m_connectivity="flat", + mesh_layout_kwargs=dict(mesh_node_spacing=3), + g2m_connectivity="nearest_neighbour", + m2g_connectivity="nearest_neighbour", + ) + + def test_return_components_flat(self): + """return_components=True should return dict with g2m, m2m, m2g.""" + xy = test_utils.create_fake_xy(N=32) + components = wmg.create.create_all_graph_components( + coords=xy, + m2m_connectivity="flat", + mesh_layout="rectilinear", + mesh_layout_kwargs=dict(mesh_node_spacing=3), + m2m_connectivity_kwargs=dict(pattern="8-star"), + g2m_connectivity="nearest_neighbour", + m2g_connectivity="nearest_neighbour", + return_components=True, + ) + assert isinstance(components, dict) + assert "g2m" in components + assert "m2m" in components + assert "m2g" in components + for name, g in components.items(): + assert isinstance(g, nx.DiGraph) + + def test_return_components_hierarchical(self): + """return_components=True for hierarchical should contain 3 components.""" + xy = test_utils.create_fake_xy(N=32) + components = wmg.create.create_all_graph_components( + coords=xy, + m2m_connectivity="hierarchical", + mesh_layout="rectilinear", + mesh_layout_kwargs=dict( + mesh_node_spacing=3, + refinement_factor=3, + max_num_refinement_levels=3, + ), + m2m_connectivity_kwargs=dict( + intra_level=dict(pattern="8-star"), + inter_level=dict(pattern="nearest", k=1), + ), + g2m_connectivity="nearest_neighbour", + m2g_connectivity="nearest_neighbour", + return_components=True, + ) + assert isinstance(components, dict) + assert "g2m" in components + assert "m2m" in components + assert "m2g" in components + + def test_return_components_flat_multiscale(self): + """return_components=True for flat_multiscale.""" + xy = test_utils.create_fake_xy(N=32) + components = wmg.create.create_all_graph_components( + coords=xy, + m2m_connectivity="flat_multiscale", + mesh_layout="rectilinear", + mesh_layout_kwargs=dict( + mesh_node_spacing=3, + refinement_factor=3, + max_num_refinement_levels=3, + ), + m2m_connectivity_kwargs=dict(pattern="8-star"), + g2m_connectivity="nearest_neighbour", + m2g_connectivity="nearest_neighbour", + return_components=True, + ) + assert isinstance(components, dict) + assert set(components.keys()) == {"g2m", "m2m", "m2g"} + + def test_flat_multiscale_no_sub_dicts_interface(self): + """Ensure flat_multiscale accepts a simple pattern arg, not sub-dicts.""" + xy = test_utils.create_fake_xy(N=32) + graph = wmg.create.create_all_graph_components( + coords=xy, + m2m_connectivity="flat_multiscale", + mesh_layout="rectilinear", + mesh_layout_kwargs=dict( + mesh_node_spacing=3, + refinement_factor=3, + max_num_refinement_levels=3, + ), + m2m_connectivity_kwargs=dict(pattern="8-star"), + g2m_connectivity="nearest_neighbour", + m2g_connectivity="nearest_neighbour", + ) + assert isinstance(graph, nx.DiGraph) + + def test_decode_mask_with_new_api(self): + """decode_mask should work correctly with the new API.""" + xy = test_utils.create_fake_xy(N=32) + n_points = len(xy) + mask = [True] * (n_points // 2) + [False] * (n_points - n_points // 2) + graph = wmg.create.create_all_graph_components( + coords=xy, + m2m_connectivity="flat", + mesh_layout="rectilinear", + mesh_layout_kwargs=dict(mesh_node_spacing=3), + m2m_connectivity_kwargs=dict(pattern="8-star"), + g2m_connectivity="nearest_neighbour", + m2g_connectivity="nearest_neighbour", + decode_mask=mask, + ) + assert isinstance(graph, nx.DiGraph) + + +class TestBackwardCompatEdgeCases: + """Advanced backward compatibility edge cases.""" + + def test_old_kwargs_with_flat_multiscale_compat(self): + """Old-style flat_multiscale kwargs should trigger deprecation warnings + and be migrated to the new names.""" + xy = test_utils.create_fake_xy(N=32) + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + graph = wmg.create.create_all_graph_components( + coords=xy, + m2m_connectivity="flat_multiscale", + mesh_layout="rectilinear", + m2m_connectivity_kwargs=dict( + mesh_node_distance=3, + level_refinement_factor=3, + max_num_levels=3, + ), + g2m_connectivity="nearest_neighbour", + m2g_connectivity="nearest_neighbour", + ) + deprecation_warnings = [ + x for x in w if issubclass(x.category, DeprecationWarning) + ] + # Should have 3 deprecation warnings + assert len(deprecation_warnings) >= 3 + # Check the new names appear in the messages + msgs = " ".join(str(x.message) for x in deprecation_warnings) + assert "mesh_node_spacing" in msgs + assert "refinement_factor" in msgs + assert "max_num_refinement_levels" in msgs + assert isinstance(graph, nx.DiGraph) + + +class TestGraphStructuralProperties: + """Tests verifying structural properties of generated graphs.""" + + def test_all_mesh_nodes_have_pos(self): + """Every node in the final graph should have 'pos' attribute.""" + xy = test_utils.create_fake_xy(N=32) + graph = wmg.create.create_all_graph_components( + coords=xy, + m2m_connectivity="flat", + mesh_layout="rectilinear", + mesh_layout_kwargs=dict(mesh_node_spacing=3), + g2m_connectivity="nearest_neighbour", + m2g_connectivity="nearest_neighbour", + ) + for node in graph.nodes: + assert "pos" in graph.nodes[node], f"Node {node} missing 'pos'" + + def test_all_edges_have_component(self): + """Every edge should have a 'component' attribute ('g2m', 'm2m', 'm2g').""" + xy = test_utils.create_fake_xy(N=32) + graph = wmg.create.create_all_graph_components( + coords=xy, + m2m_connectivity="flat", + mesh_layout="rectilinear", + mesh_layout_kwargs=dict(mesh_node_spacing=3), + g2m_connectivity="nearest_neighbour", + m2g_connectivity="nearest_neighbour", + ) + for u, v, d in graph.edges(data=True): + assert "component" in d, f"Edge ({u},{v}) missing 'component'" + assert d["component"] in ("g2m", "m2m", "m2g") + + def test_all_edges_have_len_and_vdiff(self): + """Every edge should have 'len' and 'vdiff' attributes.""" + xy = test_utils.create_fake_xy(N=32) + graph = wmg.create.create_all_graph_components( + coords=xy, + m2m_connectivity="flat", + mesh_layout="rectilinear", + mesh_layout_kwargs=dict(mesh_node_spacing=3), + g2m_connectivity="nearest_neighbour", + m2g_connectivity="nearest_neighbour", + ) + for u, v, d in graph.edges(data=True): + assert "len" in d, f"Edge ({u},{v}) missing 'len'" + assert "vdiff" in d, f"Edge ({u},{v}) missing 'vdiff'" + + def test_graph_is_directed(self): + """Final graph should always be a DiGraph.""" + xy = test_utils.create_fake_xy(N=32) + for connectivity in ["flat"]: + graph = wmg.create.create_all_graph_components( + coords=xy, + m2m_connectivity=connectivity, + mesh_layout="rectilinear", + mesh_layout_kwargs=dict(mesh_node_spacing=3), + g2m_connectivity="nearest_neighbour", + m2g_connectivity="nearest_neighbour", + ) + assert isinstance(graph, nx.DiGraph) + + def test_flat_4star_strictly_fewer_m2m_edges(self): + """4-star flat should have strictly fewer m2m edges than 8-star.""" + xy = test_utils.create_fake_xy(N=32) + components_4 = wmg.create.create_all_graph_components( + coords=xy, + m2m_connectivity="flat", + mesh_layout="rectilinear", + mesh_layout_kwargs=dict(mesh_node_spacing=3), + m2m_connectivity_kwargs=dict(pattern="4-star"), + g2m_connectivity="nearest_neighbour", + m2g_connectivity="nearest_neighbour", + return_components=True, + ) + components_8 = wmg.create.create_all_graph_components( + coords=xy, + m2m_connectivity="flat", + mesh_layout="rectilinear", + mesh_layout_kwargs=dict(mesh_node_spacing=3), + m2m_connectivity_kwargs=dict(pattern="8-star"), + g2m_connectivity="nearest_neighbour", + m2g_connectivity="nearest_neighbour", + return_components=True, + ) + # m2m component specifically should differ + assert len(components_4["m2m"].edges) < len(components_8["m2m"].edges) + # g2m and m2g should be the same (same grid spacing, same connectivity) + assert len(components_4["g2m"].edges) == len(components_8["g2m"].edges) + assert len(components_4["m2g"].edges) == len(components_8["m2g"].edges) + + def test_hierarchical_has_same_up_down_edge_count(self): + """Hierarchical graph should have equal number of up and down edges.""" + xy = test_utils.create_fake_xy(N=32) + graph = wmg.create.create_all_graph_components( + coords=xy, + m2m_connectivity="hierarchical", + mesh_layout="rectilinear", + mesh_layout_kwargs=dict( + mesh_node_spacing=3, + refinement_factor=3, + max_num_refinement_levels=3, + ), + m2m_connectivity_kwargs=dict( + intra_level=dict(pattern="8-star"), + inter_level=dict(pattern="nearest", k=1), + ), + g2m_connectivity="nearest_neighbour", + m2g_connectivity="nearest_neighbour", + return_components=True, + ) + m2m = graph["m2m"] + up_count = sum( + 1 for _, _, d in m2m.edges(data=True) if d.get("direction") == "up" + ) + down_count = sum( + 1 for _, _, d in m2m.edges(data=True) if d.get("direction") == "down" + ) + assert up_count == down_count, ( + f"Up edges ({up_count}) != Down edges ({down_count})" + ) + assert up_count > 0, "Should have at least some up/down edges" diff --git a/tests/test_triangular_mesh.py b/tests/test_triangular_mesh.py new file mode 100644 index 0000000..1ff70ec --- /dev/null +++ b/tests/test_triangular_mesh.py @@ -0,0 +1,1046 @@ +""" +Tests for mesh_layout="triangular" support (Issue #80). + +Tests verify: +1. Primitive creation (node count, positions, adjacency_type, type attrs) +2. Single-level directed graph (bidirectional edges, len/vdiff attrs, 6-connectivity) +3. Multirange primitive creation +4. Flat single-scale mesh graph via wrapper + integration +5. Flat multiscale mesh graph (position-based merging) +6. Hierarchical mesh graph +7. Integration through create_all_graph_components for all m2m_connectivity modes +8. Edge cases (spacing too large, single-level hierarchical) +9. Numerical correctness (len symmetry, vdiff reciprocity) +10. Pattern equivalence (4-star == 8-star for triangular) +""" + +import networkx as nx +import numpy as np +import pytest + +import tests.utils as test_utils +import weather_model_graphs as wmg +from weather_model_graphs.create.mesh.coords import create_directed_mesh_graph +from weather_model_graphs.create.mesh.kinds.triangular import ( + create_flat_multiscale_from_triangular_coordinates, + create_flat_multiscale_triangular_mesh_graph, + create_flat_singlescale_triangular_mesh_graph, + create_hierarchical_triangular_mesh_graph, + create_multirange_2d_triangular_mesh_primitives, + create_single_level_2d_triangular_mesh_graph, + create_single_level_2d_triangular_mesh_primitive, +) + + +# =========================== +# Fixtures +# =========================== + + +@pytest.fixture +def xy_small(): + """Small 10x10 domain with 4 corner grid points.""" + return np.array([[0, 0], [10, 0], [0, 10], [10, 10]], dtype=float) + + +@pytest.fixture +def xy_medium(): + """Medium domain with many grid points.""" + return test_utils.create_fake_xy(N=20) + + +@pytest.fixture +def xy_rectangular(): + """Non-square domain.""" + return test_utils.create_rectangular_fake_xy(Nx=15, Ny=10) + + +@pytest.fixture +def xy_offset(): + """Domain not starting at origin.""" + return np.array([[5, 3], [15, 3], [5, 13], [15, 13]], dtype=float) + + +@pytest.fixture +def xy_large(): + """Larger domain with many grid points.""" + return test_utils.create_fake_xy(N=50) + + +@pytest.fixture +def xy_wide(): + """Very wide, short domain.""" + return test_utils.create_rectangular_fake_xy(Nx=30, Ny=5) + + +# =========================== +# Step 1: Triangular Primitive (Coordinate Creation) +# =========================== + + +class TestTriangularPrimitive: + """Tests for create_single_level_2d_triangular_mesh_primitive.""" + + def test_returns_undirected_graph(self, xy_small): + G = create_single_level_2d_triangular_mesh_primitive(xy_small, nx=5, ny=5) + assert isinstance(G, nx.Graph) + assert not isinstance(G, nx.DiGraph) + + def test_nodes_have_pos_and_type(self, xy_small): + G = create_single_level_2d_triangular_mesh_primitive(xy_small, nx=4, ny=4) + for node in G.nodes: + assert "pos" in G.nodes[node] + assert "type" in G.nodes[node] + assert G.nodes[node]["type"] == "mesh" + pos = G.nodes[node]["pos"] + assert len(pos) == 2 + assert np.isfinite(pos).all() + + def test_nonzero_node_count(self, xy_small): + G = create_single_level_2d_triangular_mesh_primitive(xy_small, nx=6, ny=6) + assert G.number_of_nodes() > 0 + + def test_has_edges(self, xy_small): + G = create_single_level_2d_triangular_mesh_primitive(xy_small, nx=6, ny=6) + assert G.number_of_edges() > 0 + + def test_all_edges_are_cardinal(self, xy_small): + """Triangular lattice has only cardinal edges (no diagonal distinction).""" + G = create_single_level_2d_triangular_mesh_primitive(xy_small, nx=5, ny=5) + for u, v, d in G.edges(data=True): + assert "adjacency_type" in d, f"Edge ({u}, {v}) missing adjacency_type" + assert d["adjacency_type"] == "cardinal" + + def test_graph_has_dx_dy(self, xy_small): + G = create_single_level_2d_triangular_mesh_primitive(xy_small, nx=5, ny=5) + assert "dx" in G.graph + assert "dy" in G.graph + assert G.graph["dx"] > 0 + assert G.graph["dy"] > 0 + + def test_positions_within_domain(self, xy_small): + """Mesh node positions should lie within the coordinate domain.""" + xm, xM = xy_small[:, 0].min(), xy_small[:, 0].max() + ym, yM = xy_small[:, 1].min(), xy_small[:, 1].max() + G = create_single_level_2d_triangular_mesh_primitive(xy_small, nx=6, ny=6) + for node in G.nodes: + pos = G.nodes[node]["pos"] + assert xm <= pos[0] <= xM, f"x={pos[0]} out of [{xm}, {xM}]" + assert ym <= pos[1] <= yM, f"y={pos[1]} out of [{ym}, {yM}]" + + def test_raises_on_zero_nodes(self): + """nx=0 or ny=0 should produce 0 nodes and raise.""" + xy = np.array([[0, 0], [1, 0], [0, 1], [1, 1]], dtype=float) + with pytest.raises(ValueError, match="produced 0 nodes"): + create_single_level_2d_triangular_mesh_primitive(xy, nx=0, ny=0) + + def test_rectangular_domain(self, xy_rectangular): + """Works with non-square domains.""" + G = create_single_level_2d_triangular_mesh_primitive( + xy_rectangular, nx=8, ny=5 + ) + assert G.number_of_nodes() > 0 + assert G.number_of_edges() > 0 + + def test_minimal_lattice(self, xy_small): + """Smallest valid lattice (nx=1, ny=1) should produce nodes and edges.""" + G = create_single_level_2d_triangular_mesh_primitive(xy_small, nx=1, ny=1) + assert G.number_of_nodes() >= 2 + assert G.number_of_edges() >= 1 + + def test_large_lattice(self, xy_small): + """Large nx/ny values should produce many nodes.""" + G = create_single_level_2d_triangular_mesh_primitive(xy_small, nx=20, ny=20) + assert G.number_of_nodes() > 100 + + def test_asymmetric_nx_ny(self, xy_small): + """Very different nx and ny should still work.""" + G = create_single_level_2d_triangular_mesh_primitive(xy_small, nx=15, ny=3) + assert G.number_of_nodes() > 0 + assert G.number_of_edges() > 0 + + def test_offset_domain(self, xy_offset): + """Domain not starting at origin: positions should still be within bounds.""" + G = create_single_level_2d_triangular_mesh_primitive(xy_offset, nx=5, ny=5) + xm, xM = 5.0, 15.0 + ym, yM = 3.0, 13.0 + for node in G.nodes: + pos = G.nodes[node]["pos"] + assert xm <= pos[0] <= xM + assert ym <= pos[1] <= yM + + def test_positions_are_numpy_arrays(self, xy_small): + """Node positions should be numpy arrays.""" + G = create_single_level_2d_triangular_mesh_primitive(xy_small, nx=4, ny=4) + for node in G.nodes: + assert isinstance(G.nodes[node]["pos"], np.ndarray) + + def test_no_self_loops(self, xy_small): + """Primitive graph should have no self-loops.""" + G = create_single_level_2d_triangular_mesh_primitive(xy_small, nx=6, ny=6) + for u, v in G.edges(): + assert u != v + + def test_wide_domain(self, xy_wide): + """Very wide, short domain should still produce valid mesh.""" + G = create_single_level_2d_triangular_mesh_primitive(xy_wide, nx=10, ny=3) + assert G.number_of_nodes() > 0 + for node in G.nodes: + pos = G.nodes[node]["pos"] + assert np.isfinite(pos).all() + + +# =========================== +# Step 2: Directed Mesh Graph (Connectivity Creation) +# =========================== + + +class TestTriangularDirectedGraph: + """Tests for directed graph creation from triangular primitives.""" + + def test_returns_digraph(self, xy_small): + G_coords = create_single_level_2d_triangular_mesh_primitive( + xy_small, nx=5, ny=5 + ) + G = create_directed_mesh_graph(G_coords, pattern="4-star") + assert isinstance(G, nx.DiGraph) + + def test_edges_are_bidirectional(self, xy_small): + G_coords = create_single_level_2d_triangular_mesh_primitive( + xy_small, nx=5, ny=5 + ) + G = create_directed_mesh_graph(G_coords, pattern="4-star") + for u, v in G.edges(): + assert G.has_edge(v, u), f"Edge ({u}, {v}) missing reverse" + + def test_edges_have_len_and_vdiff(self, xy_small): + G_coords = create_single_level_2d_triangular_mesh_primitive( + xy_small, nx=5, ny=5 + ) + G = create_directed_mesh_graph(G_coords, pattern="4-star") + for u, v, d in G.edges(data=True): + assert "len" in d, f"Edge ({u}, {v}) missing 'len'" + assert "vdiff" in d, f"Edge ({u}, {v}) missing 'vdiff'" + assert d["len"] > 0 + assert len(d["vdiff"]) == 2 + + def test_len_symmetry(self, xy_small): + """Edge length should be the same in both directions.""" + G_coords = create_single_level_2d_triangular_mesh_primitive( + xy_small, nx=5, ny=5 + ) + G = create_directed_mesh_graph(G_coords, pattern="4-star") + for u, v in G.edges(): + if G.has_edge(v, u): + np.testing.assert_allclose( + G[u][v]["len"], G[v][u]["len"], atol=1e-10 + ) + + def test_vdiff_reciprocity(self, xy_small): + """vdiff(u→v) should equal -vdiff(v→u).""" + G_coords = create_single_level_2d_triangular_mesh_primitive( + xy_small, nx=5, ny=5 + ) + G = create_directed_mesh_graph(G_coords, pattern="4-star") + for u, v in G.edges(): + if G.has_edge(v, u): + np.testing.assert_allclose( + G[u][v]["vdiff"], -G[v][u]["vdiff"], atol=1e-10 + ) + + def test_pattern_4star_equals_8star(self, xy_small): + """For triangular lattice, 4-star and 8-star should produce identical + graphs since all edges are 'cardinal'.""" + G_coords = create_single_level_2d_triangular_mesh_primitive( + xy_small, nx=5, ny=5 + ) + G4 = create_directed_mesh_graph(G_coords, pattern="4-star") + G8 = create_directed_mesh_graph(G_coords, pattern="8-star") + assert G4.number_of_nodes() == G8.number_of_nodes() + assert G4.number_of_edges() == G8.number_of_edges() + + def test_node_count_preserved(self, xy_small): + """Directed graph should have same number of nodes as primitive.""" + G_coords = create_single_level_2d_triangular_mesh_primitive( + xy_small, nx=5, ny=5 + ) + G = create_directed_mesh_graph(G_coords, pattern="4-star") + assert G.number_of_nodes() == G_coords.number_of_nodes() + + def test_edge_count_is_twice_undirected(self, xy_small): + """Directed graph should have exactly 2x the undirected edge count.""" + G_coords = create_single_level_2d_triangular_mesh_primitive( + xy_small, nx=5, ny=5 + ) + G = create_directed_mesh_graph(G_coords, pattern="4-star") + assert G.number_of_edges() == 2 * G_coords.number_of_edges() + + def test_no_self_loops_directed(self, xy_small): + """Directed graph should have no self-loops.""" + G_coords = create_single_level_2d_triangular_mesh_primitive( + xy_small, nx=5, ny=5 + ) + G = create_directed_mesh_graph(G_coords, pattern="4-star") + for u, v in G.edges(): + assert u != v + + def test_pos_preserved_after_direction(self, xy_small): + """Node positions should be preserved after converting to directed.""" + G_coords = create_single_level_2d_triangular_mesh_primitive( + xy_small, nx=5, ny=5 + ) + G = create_directed_mesh_graph(G_coords, pattern="4-star") + for node in G.nodes: + assert "pos" in G.nodes[node] + assert len(G.nodes[node]["pos"]) == 2 + + def test_interior_node_degree_six(self, xy_small): + """Interior nodes of a triangular lattice should have degree 6 + (6 in-edges + 6 out-edges = 12 total in directed graph).""" + G_coords = create_single_level_2d_triangular_mesh_primitive( + xy_small, nx=8, ny=8 + ) + G = create_directed_mesh_graph(G_coords, pattern="4-star") + # At least one interior node should have degree 12 (6 in + 6 out) + max_deg = max(dict(G.degree()).values()) + assert max_deg == 12 + + def test_minimal_lattice_directed(self, xy_small): + """Minimal lattice (nx=1, ny=1) should still produce a valid directed graph.""" + G = create_single_level_2d_triangular_mesh_graph(xy_small, nx=1, ny=1) + assert isinstance(G, nx.DiGraph) + assert G.number_of_nodes() >= 2 + assert G.number_of_edges() >= 2 # at least one bidirectional edge + + +# =========================== +# Multirange primitives +# =========================== + + +class TestMultirangeTriangularPrimitives: + """Tests for create_multirange_2d_triangular_mesh_primitives.""" + + def test_returns_list(self, xy_medium): + G_list = create_multirange_2d_triangular_mesh_primitives( + max_num_levels=3, xy=xy_medium, mesh_node_spacing=2, + interlevel_refinement_factor=3, + ) + assert isinstance(G_list, list) + assert len(G_list) >= 1 + + def test_each_level_is_undirected(self, xy_medium): + G_list = create_multirange_2d_triangular_mesh_primitives( + max_num_levels=3, xy=xy_medium, mesh_node_spacing=2, + interlevel_refinement_factor=3, + ) + for G in G_list: + assert isinstance(G, nx.Graph) + assert not isinstance(G, nx.DiGraph) + + def test_level_attributes_set(self, xy_medium): + G_list = create_multirange_2d_triangular_mesh_primitives( + max_num_levels=3, xy=xy_medium, mesh_node_spacing=2, + interlevel_refinement_factor=3, + ) + for lev, G in enumerate(G_list): + assert G.graph["level"] == lev + for node in G.nodes: + assert G.nodes[node]["level"] == lev + for u, v in G.edges(): + assert G.edges[u, v]["level"] == lev + + def test_finer_level_has_more_nodes(self, xy_medium): + G_list = create_multirange_2d_triangular_mesh_primitives( + max_num_levels=3, xy=xy_medium, mesh_node_spacing=2, + interlevel_refinement_factor=3, + ) + if len(G_list) > 1: + assert G_list[0].number_of_nodes() > G_list[1].number_of_nodes() + + def test_max_num_levels_respected(self, xy_medium): + G_list = create_multirange_2d_triangular_mesh_primitives( + max_num_levels=2, xy=xy_medium, mesh_node_spacing=2, + interlevel_refinement_factor=3, + ) + assert len(G_list) <= 2 + + def test_single_level(self, xy_medium): + """max_num_levels=1 should produce exactly 1 level.""" + G_list = create_multirange_2d_triangular_mesh_primitives( + max_num_levels=1, xy=xy_medium, mesh_node_spacing=2, + interlevel_refinement_factor=3, + ) + assert len(G_list) == 1 + assert G_list[0].graph["level"] == 0 + + def test_refinement_factor_2(self, xy_medium): + """Different refinement factor should still produce valid graphs.""" + G_list = create_multirange_2d_triangular_mesh_primitives( + max_num_levels=3, xy=xy_medium, mesh_node_spacing=2, + interlevel_refinement_factor=2, + ) + assert len(G_list) >= 1 + for G in G_list: + assert G.number_of_nodes() > 0 + + def test_all_levels_cover_same_domain(self, xy_medium): + """All levels should span approximately the same coordinate domain.""" + G_list = create_multirange_2d_triangular_mesh_primitives( + max_num_levels=3, xy=xy_medium, mesh_node_spacing=2, + interlevel_refinement_factor=3, + ) + if len(G_list) < 2: + pytest.skip("Only one level created") + # Check centers are roughly the same across levels + centers = [] + for G in G_list: + positions = np.array([G.nodes[n]["pos"] for n in G.nodes]) + centers.append(positions.mean(axis=0)) + for c in centers[1:]: + np.testing.assert_allclose(c, centers[0], atol=2.0) + + def test_interlevel_refinement_factor_preserved(self, xy_medium): + """Each level should have the refinement factor as a graph attribute.""" + G_list = create_multirange_2d_triangular_mesh_primitives( + max_num_levels=3, xy=xy_medium, mesh_node_spacing=2, + interlevel_refinement_factor=3, + ) + for G in G_list: + assert G.graph["interlevel_refinement_factor"] == 3 + + def test_all_levels_have_edges(self, xy_medium): + """Every level should have at least some edges.""" + G_list = create_multirange_2d_triangular_mesh_primitives( + max_num_levels=3, xy=xy_medium, mesh_node_spacing=2, + interlevel_refinement_factor=3, + ) + for G in G_list: + assert G.number_of_edges() > 0 + + +# =========================== +# Convenience wrapper: single-level +# =========================== + + +class TestSingleLevelTriangularGraph: + """Tests for create_single_level_2d_triangular_mesh_graph.""" + + def test_returns_digraph(self, xy_small): + G = create_single_level_2d_triangular_mesh_graph(xy_small, nx=5, ny=5) + assert isinstance(G, nx.DiGraph) + + def test_has_bidirectional_edges(self, xy_small): + G = create_single_level_2d_triangular_mesh_graph(xy_small, nx=5, ny=5) + for u, v in G.edges(): + assert G.has_edge(v, u) + + def test_edges_have_attributes(self, xy_small): + """Directed graph edges should have len and vdiff.""" + G = create_single_level_2d_triangular_mesh_graph(xy_small, nx=5, ny=5) + for u, v, d in G.edges(data=True): + assert "len" in d + assert "vdiff" in d + + def test_with_rectangular_domain(self, xy_rectangular): + """Should work correctly on non-square domains.""" + G = create_single_level_2d_triangular_mesh_graph(xy_rectangular, nx=8, ny=5) + assert isinstance(G, nx.DiGraph) + assert G.number_of_edges() > 0 + + def test_minimal_grid(self, xy_small): + """Minimal grid (nx=1, ny=1) should produce a valid graph.""" + G = create_single_level_2d_triangular_mesh_graph(xy_small, nx=1, ny=1) + assert G.number_of_nodes() >= 2 + + +# =========================== +# Flat single-scale +# =========================== + + +class TestFlatSinglescaleTriangular: + """Tests for create_flat_singlescale_triangular_mesh_graph.""" + + def test_returns_digraph(self, xy_small): + G = create_flat_singlescale_triangular_mesh_graph(xy_small, mesh_node_distance=2.0) + assert isinstance(G, nx.DiGraph) + + def test_nodes_have_pos(self, xy_small): + G = create_flat_singlescale_triangular_mesh_graph(xy_small, mesh_node_distance=2.0) + for node in G.nodes: + assert "pos" in G.nodes[node] + + def test_raises_on_large_spacing(self, xy_small): + """Spacing larger than domain should raise.""" + with pytest.raises(ValueError, match="too large"): + create_flat_singlescale_triangular_mesh_graph( + xy_small, mesh_node_distance=100.0 + ) + + def test_edges_are_bidirectional(self, xy_small): + """All edges should have a reverse.""" + G = create_flat_singlescale_triangular_mesh_graph(xy_small, mesh_node_distance=2.0) + for u, v in G.edges(): + assert G.has_edge(v, u) + + def test_smaller_spacing_more_nodes(self, xy_small): + """Smaller mesh_node_distance should produce more nodes.""" + G_coarse = create_flat_singlescale_triangular_mesh_graph( + xy_small, mesh_node_distance=3.0 + ) + G_fine = create_flat_singlescale_triangular_mesh_graph( + xy_small, mesh_node_distance=1.5 + ) + assert G_fine.number_of_nodes() > G_coarse.number_of_nodes() + + def test_rectangular_domain(self, xy_rectangular): + """Should work with non-square domains.""" + G = create_flat_singlescale_triangular_mesh_graph( + xy_rectangular, mesh_node_distance=2.0 + ) + assert isinstance(G, nx.DiGraph) + assert G.number_of_nodes() > 0 + + def test_no_nan_positions(self, xy_small): + """No node should have NaN or Inf positions.""" + G = create_flat_singlescale_triangular_mesh_graph(xy_small, mesh_node_distance=2.0) + for node in G.nodes: + pos = G.nodes[node]["pos"] + assert np.isfinite(pos).all() + + def test_spacing_just_fits(self): + """Spacing that just fits one cell should work.""" + xy = np.array([[0, 0], [5, 0], [0, 5], [5, 5]], dtype=float) + G = create_flat_singlescale_triangular_mesh_graph(xy, mesh_node_distance=4.0) + assert G.number_of_nodes() >= 2 + + +# =========================== +# Flat multiscale (triangular-specific merging) +# =========================== + + +class TestFlatMultiscaleTriangular: + """Tests for the triangular flat multiscale graph and position-based merging.""" + + def test_returns_digraph(self, xy_medium): + G = create_flat_multiscale_triangular_mesh_graph( + xy_medium, mesh_node_distance=2.0, + level_refinement_factor=3, max_num_levels=3, + ) + assert isinstance(G, nx.DiGraph) + + def test_has_edges(self, xy_medium): + G = create_flat_multiscale_triangular_mesh_graph( + xy_medium, mesh_node_distance=2.0, + level_refinement_factor=3, max_num_levels=3, + ) + assert G.number_of_edges() > 0 + + def test_edges_have_len_and_vdiff(self, xy_medium): + G = create_flat_multiscale_triangular_mesh_graph( + xy_medium, mesh_node_distance=2.0, + level_refinement_factor=3, max_num_levels=3, + ) + for u, v, d in G.edges(data=True): + assert "len" in d + assert "vdiff" in d + + def test_fewer_nodes_than_sum_of_levels(self, xy_medium): + """Position-based merging should produce fewer nodes than the raw + sum of all levels (coincident nodes get merged).""" + G_coords_list = create_multirange_2d_triangular_mesh_primitives( + max_num_levels=3, xy=xy_medium, mesh_node_spacing=2, + interlevel_refinement_factor=3, + ) + total_raw = sum(g.number_of_nodes() for g in G_coords_list) + G = create_flat_multiscale_from_triangular_coordinates(G_coords_list) + # Merged graph has at most as many nodes (usually fewer) + assert G.number_of_nodes() <= total_raw + + def test_graph_has_dx_dy_dicts(self, xy_medium): + G = create_flat_multiscale_triangular_mesh_graph( + xy_medium, mesh_node_distance=2.0, + level_refinement_factor=3, max_num_levels=3, + ) + assert isinstance(G.graph.get("dx"), dict) + assert isinstance(G.graph.get("dy"), dict) + + def test_bidirectional_edges(self, xy_medium): + """All edges should have a reverse.""" + G = create_flat_multiscale_triangular_mesh_graph( + xy_medium, mesh_node_distance=2.0, + level_refinement_factor=3, max_num_levels=3, + ) + for u, v in G.edges(): + assert G.has_edge(v, u), f"Edge ({u},{v}) no reverse" + + def test_nodes_have_pos(self, xy_medium): + """All nodes should have pos attribute.""" + G = create_flat_multiscale_triangular_mesh_graph( + xy_medium, mesh_node_distance=2.0, + level_refinement_factor=3, max_num_levels=3, + ) + for node in G.nodes: + assert "pos" in G.nodes[node] + assert np.isfinite(G.nodes[node]["pos"]).all() + + def test_single_level_multiscale(self): + """When domain only supports 1 level, flat_multiscale should still work.""" + xy = np.array([[0, 0], [3, 0], [0, 3], [3, 3]], dtype=float) + G = create_flat_multiscale_triangular_mesh_graph( + xy, mesh_node_distance=1.0, + level_refinement_factor=3, max_num_levels=3, + ) + assert isinstance(G, nx.DiGraph) + assert G.number_of_nodes() > 0 + + def test_refinement_factor_2(self, xy_medium): + """Refinement factor of 2 should work.""" + G = create_flat_multiscale_triangular_mesh_graph( + xy_medium, mesh_node_distance=2.0, + level_refinement_factor=2, max_num_levels=3, + ) + assert isinstance(G, nx.DiGraph) + assert G.number_of_edges() > 0 + + def test_no_self_loops(self, xy_medium): + """No self-loops in flat multiscale graph.""" + G = create_flat_multiscale_triangular_mesh_graph( + xy_medium, mesh_node_distance=2.0, + level_refinement_factor=3, max_num_levels=3, + ) + for u, v in G.edges(): + assert u != v + + def test_more_nodes_than_coarsest_level(self, xy_large): + """Multiscale should have more nodes than the coarsest single level.""" + G_coords_list = create_multirange_2d_triangular_mesh_primitives( + max_num_levels=3, xy=xy_large, mesh_node_spacing=2, + interlevel_refinement_factor=3, + ) + if len(G_coords_list) < 2: + pytest.skip("Only one level created") + coarsest_nodes = G_coords_list[-1].number_of_nodes() + G_multi = create_flat_multiscale_from_triangular_coordinates(G_coords_list) + assert G_multi.number_of_nodes() > coarsest_nodes + + +# =========================== +# Hierarchical +# =========================== + + +class TestHierarchicalTriangular: + """Tests for create_hierarchical_triangular_mesh_graph.""" + + def test_returns_digraph(self, xy_medium): + G = create_hierarchical_triangular_mesh_graph( + xy_medium, mesh_node_distance=2.0, + level_refinement_factor=3, max_num_levels=3, + ) + assert isinstance(G, nx.DiGraph) + + def test_has_edges(self, xy_medium): + G = create_hierarchical_triangular_mesh_graph( + xy_medium, mesh_node_distance=2.0, + level_refinement_factor=3, max_num_levels=3, + ) + assert G.number_of_edges() > 0 + + def test_edges_have_level_attribute(self, xy_medium): + G = create_hierarchical_triangular_mesh_graph( + xy_medium, mesh_node_distance=2.0, + level_refinement_factor=3, max_num_levels=3, + ) + for u, v, d in G.edges(data=True): + # Intra-level edges have 'level', inter-level have 'levels' + assert "level" in d or "levels" in d + + def test_multiple_levels_present(self, xy_medium): + G = create_hierarchical_triangular_mesh_graph( + xy_medium, mesh_node_distance=2.0, + level_refinement_factor=3, max_num_levels=3, + ) + levels = set() + for u, v, d in G.edges(data=True): + if "level" in d: + levels.add(d["level"]) + elif "levels" in d: + # Inter-level edges like '0>1' + parts = d["levels"].split(">") + levels.update(int(p) for p in parts) + assert len(levels) >= 2, "Expected multiple levels in hierarchical graph" + + def test_single_level_raises(self, xy_small): + """Hierarchical requires ≥2 levels; too-coarse spacing should raise.""" + with pytest.raises(ValueError): + create_hierarchical_triangular_mesh_graph( + xy_small, mesh_node_distance=20.0, + level_refinement_factor=3, max_num_levels=3, + ) + + def test_nodes_have_pos(self, xy_medium): + """All nodes should have pos attribute.""" + G = create_hierarchical_triangular_mesh_graph( + xy_medium, mesh_node_distance=2.0, + level_refinement_factor=3, max_num_levels=3, + ) + for node in G.nodes: + assert "pos" in G.nodes[node] + assert np.isfinite(G.nodes[node]["pos"]).all() + + def test_custom_intra_level(self, xy_medium): + """Custom intra_level pattern should be accepted.""" + G = create_hierarchical_triangular_mesh_graph( + xy_medium, mesh_node_distance=2.0, + level_refinement_factor=3, max_num_levels=3, + intra_level={"pattern": "8-star"}, + ) + assert isinstance(G, nx.DiGraph) + assert G.number_of_edges() > 0 + + def test_custom_inter_level(self, xy_medium): + """Custom inter_level config should be accepted.""" + G = create_hierarchical_triangular_mesh_graph( + xy_medium, mesh_node_distance=2.0, + level_refinement_factor=3, max_num_levels=3, + inter_level={"pattern": "nearest", "k": 3}, + ) + assert isinstance(G, nx.DiGraph) + assert G.number_of_edges() > 0 + + def test_no_self_loops(self, xy_medium): + """Hierarchical graph should have no self-loops.""" + G = create_hierarchical_triangular_mesh_graph( + xy_medium, mesh_node_distance=2.0, + level_refinement_factor=3, max_num_levels=3, + ) + for u, v in G.edges(): + assert u != v + + def test_has_inter_level_edges(self, xy_medium): + """Should have inter-level edges connecting different levels.""" + G = create_hierarchical_triangular_mesh_graph( + xy_medium, mesh_node_distance=2.0, + level_refinement_factor=3, max_num_levels=3, + ) + inter_level_count = sum( + 1 for _, _, d in G.edges(data=True) if "levels" in d + ) + assert inter_level_count > 0 + + def test_inter_level_edges_have_direction(self, xy_medium): + """Inter-level edges should have 'direction' attribute (up/down).""" + G = create_hierarchical_triangular_mesh_graph( + xy_medium, mesh_node_distance=2.0, + level_refinement_factor=3, max_num_levels=3, + ) + for u, v, d in G.edges(data=True): + if "levels" in d: + assert "direction" in d + assert d["direction"] in ("up", "down") + + +# =========================== +# Integration: create_all_graph_components +# =========================== + + +class TestIntegrationTriangular: + """Full integration tests through create_all_graph_components.""" + + COMMON_KW = dict( + m2g_connectivity="nearest_neighbours", + g2m_connectivity="nearest_neighbours", + m2g_connectivity_kwargs=dict(max_num_neighbours=4), + g2m_connectivity_kwargs=dict(max_num_neighbours=4), + return_components=True, + ) + + def test_flat_triangular(self, xy_medium): + comps = wmg.create.create_all_graph_components( + coords=xy_medium, + m2m_connectivity="flat", + mesh_layout="triangular", + mesh_layout_kwargs=dict(mesh_node_spacing=2.0), + **self.COMMON_KW, + ) + m2m = comps["m2m"] + assert isinstance(m2m, nx.DiGraph) + assert m2m.number_of_nodes() > 0 + assert m2m.number_of_edges() > 0 + # Should also have g2m and m2g + assert comps["g2m"].number_of_edges() > 0 + assert comps["m2g"].number_of_edges() > 0 + + def test_hierarchical_triangular(self, xy_medium): + comps = wmg.create.create_all_graph_components( + coords=xy_medium, + m2m_connectivity="hierarchical", + mesh_layout="triangular", + mesh_layout_kwargs=dict( + mesh_node_spacing=2.0, max_num_refinement_levels=3 + ), + **self.COMMON_KW, + ) + m2m = comps["m2m"] + assert isinstance(m2m, nx.DiGraph) + assert m2m.number_of_nodes() > 0 + + def test_flat_multiscale_triangular(self, xy_medium): + comps = wmg.create.create_all_graph_components( + coords=xy_medium, + m2m_connectivity="flat_multiscale", + mesh_layout="triangular", + mesh_layout_kwargs=dict( + mesh_node_spacing=2.0, max_num_refinement_levels=3 + ), + **self.COMMON_KW, + ) + m2m = comps["m2m"] + assert isinstance(m2m, nx.DiGraph) + assert m2m.number_of_nodes() > 0 + + def test_unsupported_layout_raises(self, xy_small): + with pytest.raises(NotImplementedError, match="not yet supported"): + wmg.create.create_all_graph_components( + coords=xy_small, + m2m_connectivity="flat", + mesh_layout="hexagonal", + mesh_layout_kwargs=dict(mesh_node_spacing=1.0), + **self.COMMON_KW, + ) + + def test_flat_triangular_return_combined(self, xy_medium): + """With return_components=False, returns a single composed graph.""" + kw = dict(self.COMMON_KW) + kw["return_components"] = False + G = wmg.create.create_all_graph_components( + coords=xy_medium, + m2m_connectivity="flat", + mesh_layout="triangular", + mesh_layout_kwargs=dict(mesh_node_spacing=2.0), + **kw, + ) + assert isinstance(G, nx.DiGraph) + assert G.number_of_nodes() > 0 + + def test_flat_pattern_kwarg_forwarded(self, xy_medium): + """m2m_connectivity_kwargs={'pattern': ...} should be forwarded.""" + comps = wmg.create.create_all_graph_components( + coords=xy_medium, + m2m_connectivity="flat", + mesh_layout="triangular", + mesh_layout_kwargs=dict(mesh_node_spacing=2.0), + m2m_connectivity_kwargs=dict(pattern="8-star"), + **self.COMMON_KW, + ) + assert comps["m2m"].number_of_edges() > 0 + + def test_rectilinear_still_works(self, xy_medium): + """Regression: rectilinear layout should not be broken.""" + comps = wmg.create.create_all_graph_components( + coords=xy_medium, + m2m_connectivity="flat", + mesh_layout="rectilinear", + mesh_layout_kwargs=dict(mesh_node_spacing=2.0), + **self.COMMON_KW, + ) + assert comps["m2m"].number_of_nodes() > 0 + + def test_rectilinear_flat_multiscale_still_works(self, xy_medium): + """Regression: rectilinear flat_multiscale should not be broken.""" + comps = wmg.create.create_all_graph_components( + coords=xy_medium, + m2m_connectivity="flat_multiscale", + mesh_layout="rectilinear", + mesh_layout_kwargs=dict( + mesh_node_spacing=2.0, max_num_refinement_levels=3 + ), + **self.COMMON_KW, + ) + assert comps["m2m"].number_of_nodes() > 0 + + def test_rectilinear_hierarchical_still_works(self, xy_medium): + """Regression: rectilinear hierarchical should not be broken.""" + comps = wmg.create.create_all_graph_components( + coords=xy_medium, + m2m_connectivity="hierarchical", + mesh_layout="rectilinear", + mesh_layout_kwargs=dict( + mesh_node_spacing=2.0, max_num_refinement_levels=3 + ), + **self.COMMON_KW, + ) + assert comps["m2m"].number_of_nodes() > 0 + + def test_flat_triangular_with_within_radius(self, xy_medium): + """Triangular flat with within_radius g2m/m2g connectivity.""" + comps = wmg.create.create_all_graph_components( + coords=xy_medium, + m2m_connectivity="flat", + mesh_layout="triangular", + mesh_layout_kwargs=dict(mesh_node_spacing=2.0), + m2g_connectivity="within_radius", + g2m_connectivity="within_radius", + m2g_connectivity_kwargs=dict(max_dist=5.0), + g2m_connectivity_kwargs=dict(max_dist=5.0), + return_components=True, + ) + assert comps["m2m"].number_of_nodes() > 0 + assert comps["g2m"].number_of_edges() > 0 + assert comps["m2g"].number_of_edges() > 0 + + def test_flat_triangular_with_nearest_neighbour(self, xy_medium): + """Triangular flat with nearest_neighbour (singular) connectivity.""" + comps = wmg.create.create_all_graph_components( + coords=xy_medium, + m2m_connectivity="flat", + mesh_layout="triangular", + mesh_layout_kwargs=dict(mesh_node_spacing=2.0), + m2g_connectivity="nearest_neighbour", + g2m_connectivity="nearest_neighbour", + return_components=True, + ) + assert comps["m2m"].number_of_nodes() > 0 + assert comps["g2m"].number_of_edges() > 0 + + def test_flat_no_mesh_node_spacing_raises(self, xy_small): + """Missing mesh_node_spacing should raise ValueError.""" + with pytest.raises(ValueError, match="mesh_node_spacing"): + wmg.create.create_all_graph_components( + coords=xy_small, + m2m_connectivity="flat", + mesh_layout="triangular", + mesh_layout_kwargs=dict(), + **self.COMMON_KW, + ) + + def test_flat_multiscale_no_mesh_node_spacing_raises(self, xy_small): + """Missing mesh_node_spacing in flat_multiscale should raise ValueError.""" + with pytest.raises(ValueError, match="mesh_node_spacing"): + wmg.create.create_all_graph_components( + coords=xy_small, + m2m_connectivity="flat_multiscale", + mesh_layout="triangular", + mesh_layout_kwargs=dict(max_num_refinement_levels=3), + **self.COMMON_KW, + ) + + def test_hierarchical_no_mesh_node_spacing_raises(self, xy_small): + """Missing mesh_node_spacing in hierarchical should raise ValueError.""" + with pytest.raises(ValueError, match="mesh_node_spacing"): + wmg.create.create_all_graph_components( + coords=xy_small, + m2m_connectivity="hierarchical", + mesh_layout="triangular", + mesh_layout_kwargs=dict(max_num_refinement_levels=3), + **self.COMMON_KW, + ) + + def test_all_components_have_nodes(self, xy_medium): + """All three components (g2m, m2m, m2g) should have nodes.""" + comps = wmg.create.create_all_graph_components( + coords=xy_medium, + m2m_connectivity="flat", + mesh_layout="triangular", + mesh_layout_kwargs=dict(mesh_node_spacing=2.0), + **self.COMMON_KW, + ) + for key in ("g2m", "m2m", "m2g"): + assert comps[key].number_of_nodes() > 0 + assert comps[key].number_of_edges() > 0 + + def test_large_domain_triangular(self, xy_large): + """Large domain with small spacing should produce a big graph.""" + comps = wmg.create.create_all_graph_components( + coords=xy_large, + m2m_connectivity="flat", + mesh_layout="triangular", + mesh_layout_kwargs=dict(mesh_node_spacing=3.0), + **self.COMMON_KW, + ) + assert comps["m2m"].number_of_nodes() > 50 + + +# =========================== +# Numerical correctness +# =========================== + + +class TestNumericalCorrectness: + """Test numerical properties of the triangular mesh graph.""" + + def test_edge_lengths_positive(self, xy_medium): + G = create_flat_singlescale_triangular_mesh_graph( + xy_medium, mesh_node_distance=2.0 + ) + for u, v, d in G.edges(data=True): + assert d["len"] > 0 + + def test_vdiff_consistent_with_pos(self, xy_small): + """vdiff should equal pos(u) - pos(v).""" + G = create_single_level_2d_triangular_mesh_graph(xy_small, nx=5, ny=5) + for u, v, d in G.edges(data=True): + pos_u = G.nodes[u]["pos"] + pos_v = G.nodes[v]["pos"] + expected_vdiff = pos_u - pos_v + np.testing.assert_allclose(d["vdiff"], expected_vdiff, atol=1e-10) + + def test_len_consistent_with_vdiff(self, xy_small): + """len should equal the L2 norm of vdiff.""" + G = create_single_level_2d_triangular_mesh_graph(xy_small, nx=5, ny=5) + for u, v, d in G.edges(data=True): + expected_len = np.linalg.norm(d["vdiff"]) + np.testing.assert_allclose(d["len"], expected_len, atol=1e-10) + + def test_no_nan_in_edge_attrs(self, xy_small): + """Edge attributes should contain no NaN or Inf.""" + G = create_single_level_2d_triangular_mesh_graph(xy_small, nx=6, ny=6) + for u, v, d in G.edges(data=True): + assert np.isfinite(d["len"]) + assert np.isfinite(d["vdiff"]).all() + + def test_no_zero_length_edges(self, xy_small): + """All edges should have strictly positive length.""" + G = create_single_level_2d_triangular_mesh_graph(xy_small, nx=6, ny=6) + for u, v, d in G.edges(data=True): + assert d["len"] > 1e-12 + + def test_edge_lengths_roughly_uniform_for_interior(self, xy_small): + """For a uniform triangular lattice, all edges should have similar + length (within a narrow tolerance, accounting for scaling).""" + G = create_single_level_2d_triangular_mesh_graph(xy_small, nx=8, ny=8) + lengths = [d["len"] for _, _, d in G.edges(data=True)] + # In a uniformly scaled equilateral mesh, all edges should be + # within ~50% of each other (accounting for aspect ratio scaling) + max_len = max(lengths) + min_len = min(lengths) + assert min_len > 0 + ratio = max_len / min_len + # For equilateral triangles with potentially different x/y scaling, + # the ratio should still be reasonable + assert ratio < 3.0, f"Edge length ratio {ratio} too large" + + def test_scaled_domain_produces_scaled_lengths(self): + """Doubling the domain should roughly double edge lengths.""" + xy1 = np.array([[0, 0], [10, 0], [0, 10], [10, 10]], dtype=float) + xy2 = np.array([[0, 0], [20, 0], [0, 20], [20, 20]], dtype=float) + G1 = create_single_level_2d_triangular_mesh_graph(xy1, nx=5, ny=5) + G2 = create_single_level_2d_triangular_mesh_graph(xy2, nx=5, ny=5) + avg_len1 = np.mean([d["len"] for _, _, d in G1.edges(data=True)]) + avg_len2 = np.mean([d["len"] for _, _, d in G2.edges(data=True)]) + np.testing.assert_allclose(avg_len2 / avg_len1, 2.0, rtol=0.1) + + def test_no_nan_in_positions(self, xy_medium): + """No node should have NaN in positions.""" + G = create_flat_singlescale_triangular_mesh_graph( + xy_medium, mesh_node_distance=2.0 + ) + for node in G.nodes: + pos = G.nodes[node]["pos"] + assert isinstance(pos, np.ndarray) + assert np.isfinite(pos).all()