diff --git a/CHANGELOG.md b/CHANGELOG.md index e036489..ef84eba 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,11 +5,15 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). -## [Unreleased] +## [Unreleased](https://github.com/mllam/weather-model-graphs/compare/v0.3.0...HEAD) ### Added -- Added a standalone graph consistency checking tool (`wmg.diagnostics.check_graph_consistency`) to ensure structural health, such as verifying all grid nodes successfully connect to the mesh (#42). +- Add `mesh_layout` argument to mesh graph creation functions, with `rectilinear` + as the first supported layout. Uses a two-step architecture separating coordinate + creation from connectivity creation, enabling future alternative layouts (e.g. triangular). + [\#78](https://github.com/mllam/weather-model-graphs/issues/78), @prajwal-tech07 +- Add a standalone graph consistency checking tool (`wmg.diagnostics.check_graph_consistency`) to ensure structural health, such as verifying all grid nodes successfully connect to the mesh (#42). - Add Django-style graph filtering via `filter_graph`, for example to select nodes by type (`node__type="mesh"`), edges by component (`edge__component="g2m"`), long edges (`edge__len__gt=...`), and spatial diff --git a/docs/creating_the_graph.ipynb b/docs/creating_the_graph.ipynb index 8cbc913..76bcbcd 100644 --- a/docs/creating_the_graph.ipynb +++ b/docs/creating_the_graph.ipynb @@ -411,8 +411,11 @@ "graph = wmg.create.create_all_graph_components(\n", " m2m_connectivity=\"flat_multiscale\",\n", " coords=xy,\n", - " m2m_connectivity_kwargs=dict(\n", - " mesh_node_distance=2, level_refinement_factor=3, max_num_levels=None\n", + " mesh_layout=\"rectilinear\",\n", + " mesh_layout_kwargs=dict(\n", + " mesh_node_spacing=2,\n", + " refinement_factor=3,\n", + " max_num_refinement_levels=3,\n", " ),\n", " g2m_connectivity=\"nearest_neighbour\",\n", " m2g_connectivity=\"nearest_neighbour\",\n", diff --git a/src/weather_model_graphs/create/archetype.py b/src/weather_model_graphs/create/archetype.py index b4d717e..6c94166 100644 --- a/src/weather_model_graphs/create/archetype.py +++ b/src/weather_model_graphs/create/archetype.py @@ -57,18 +57,15 @@ def create_keisler_graph( """ return create_all_graph_components( coords=coords, + coords_crs=coords_crs, + graph_crs=graph_crs, + mesh_layout="rectilinear", + mesh_layout_kwargs=dict(mesh_node_spacing=mesh_node_distance), m2m_connectivity="flat", - m2m_connectivity_kwargs=dict(mesh_node_distance=mesh_node_distance), g2m_connectivity="within_radius", + g2m_connectivity_kwargs=dict(rel_max_dist=0.51), m2g_connectivity="nearest_neighbours", - g2m_connectivity_kwargs=dict( - rel_max_dist=0.51, - ), - m2g_connectivity_kwargs=dict( - max_num_neighbours=4, - ), - coords_crs=coords_crs, - graph_crs=graph_crs, + m2g_connectivity_kwargs=dict(max_num_neighbours=4), decode_mask=decode_mask, return_components=return_components, ) @@ -132,22 +129,19 @@ def create_graphcast_graph( """ return create_all_graph_components( coords=coords, - m2m_connectivity="flat_multiscale", - m2m_connectivity_kwargs=dict( - mesh_node_distance=mesh_node_distance, - level_refinement_factor=level_refinement_factor, - max_num_levels=max_num_levels, + coords_crs=coords_crs, + graph_crs=graph_crs, + mesh_layout="rectilinear", + mesh_layout_kwargs=dict( + mesh_node_spacing=mesh_node_distance, + refinement_factor=level_refinement_factor, + max_num_refinement_levels=max_num_levels, ), + m2m_connectivity="flat_multiscale", g2m_connectivity="within_radius", + g2m_connectivity_kwargs=dict(rel_max_dist=0.51), m2g_connectivity="nearest_neighbours", - g2m_connectivity_kwargs=dict( - rel_max_dist=0.51, - ), - m2g_connectivity_kwargs=dict( - max_num_neighbours=4, - ), - coords_crs=coords_crs, - graph_crs=graph_crs, + m2g_connectivity_kwargs=dict(max_num_neighbours=4), decode_mask=decode_mask, return_components=return_components, ) @@ -216,22 +210,19 @@ def create_oskarsson_hierarchical_graph( """ return create_all_graph_components( coords=coords, - m2m_connectivity="hierarchical", - m2m_connectivity_kwargs=dict( - mesh_node_distance=mesh_node_distance, - level_refinement_factor=level_refinement_factor, - max_num_levels=max_num_levels, + coords_crs=coords_crs, + graph_crs=graph_crs, + mesh_layout="rectilinear", + mesh_layout_kwargs=dict( + mesh_node_spacing=mesh_node_distance, + refinement_factor=level_refinement_factor, + max_num_refinement_levels=max_num_levels, ), + m2m_connectivity="hierarchical", g2m_connectivity="within_radius", + g2m_connectivity_kwargs=dict(rel_max_dist=0.51), m2g_connectivity="nearest_neighbours", - g2m_connectivity_kwargs=dict( - rel_max_dist=0.51, - ), - m2g_connectivity_kwargs=dict( - max_num_neighbours=4, - ), - coords_crs=coords_crs, - graph_crs=graph_crs, + m2g_connectivity_kwargs=dict(max_num_neighbours=4), decode_mask=decode_mask, return_components=return_components, ) diff --git a/src/weather_model_graphs/create/base.py b/src/weather_model_graphs/create/base.py index fbfa93b..27be089 100644 --- a/src/weather_model_graphs/create/base.py +++ b/src/weather_model_graphs/create/base.py @@ -8,7 +8,7 @@ function uses `connect_nodes_across_graphs` to connect nodes across the component graphs. """ -from typing import Iterable +from typing import Iterable, List, Union import networkx import networkx as nx @@ -23,11 +23,75 @@ split_on_edge_attribute_existance, ) from .grid import create_grid_graph_nodes -from .mesh.kinds.flat import ( - create_flat_multiscale_mesh_graph, - create_flat_singlescale_mesh_graph, +from .mesh.connectivity.flat import ( + create_flat_multiscale_from_coordinates, + create_flat_singlescale_from_coordinates, ) -from .mesh.kinds.hierarchical import create_hierarchical_multiscale_mesh_graph +from .mesh.connectivity.hierarchical import create_hierarchical_from_coordinates +from .mesh.coords import ( + create_multirange_2d_mesh_primitives, + create_single_level_2d_mesh_primitive, +) + + +def _migrate_deprecated_kwargs(mesh_layout_kwargs, m2m_connectivity_kwargs): + """Migrate old-style kwargs to the new mesh_layout_kwargs structure. + + In the old API, ``mesh_node_distance``, ``level_refinement_factor``, and + ``max_num_levels`` were passed via ``m2m_connectivity_kwargs``. In the new + design these belong in ``mesh_layout_kwargs`` (as ``mesh_node_spacing``, + ``refinement_factor``, and ``max_num_refinement_levels`` respectively). + + This helper emits deprecation warnings for each migrated key and moves + the value into *mesh_layout_kwargs*. It is intended to be removed once the + old API is no longer supported. + + Parameters + ---------- + mesh_layout_kwargs : dict + Mutable dict of mesh layout keyword arguments. + m2m_connectivity_kwargs : dict + Mutable dict of m2m connectivity keyword arguments. + + Returns + ------- + tuple[dict, dict] + Updated (mesh_layout_kwargs, m2m_connectivity_kwargs). + """ + if ( + "mesh_node_distance" in m2m_connectivity_kwargs + and "mesh_node_spacing" not in mesh_layout_kwargs + ): + logger.warning( + "Passing 'mesh_node_distance' in m2m_connectivity_kwargs is deprecated. " + "Use mesh_layout_kwargs=dict(mesh_node_spacing=...) instead." + ) + mesh_layout_kwargs["mesh_node_spacing"] = m2m_connectivity_kwargs.pop( + "mesh_node_distance" + ) + if ( + "level_refinement_factor" in m2m_connectivity_kwargs + and "refinement_factor" not in mesh_layout_kwargs + ): + logger.warning( + "Passing 'level_refinement_factor' in m2m_connectivity_kwargs is deprecated. " + "Use mesh_layout_kwargs=dict(refinement_factor=...) instead." + ) + mesh_layout_kwargs["refinement_factor"] = m2m_connectivity_kwargs.pop( + "level_refinement_factor" + ) + if ( + "max_num_levels" in m2m_connectivity_kwargs + and "max_num_refinement_levels" not in mesh_layout_kwargs + ): + logger.warning( + "Passing 'max_num_levels' in m2m_connectivity_kwargs is deprecated. " + "Use mesh_layout_kwargs=dict(max_num_refinement_levels=...) instead." + ) + mesh_layout_kwargs["max_num_refinement_levels"] = m2m_connectivity_kwargs.pop( + "max_num_levels" + ) + return mesh_layout_kwargs, m2m_connectivity_kwargs def create_all_graph_components( @@ -35,9 +99,11 @@ def create_all_graph_components( m2m_connectivity: str, m2g_connectivity: str, g2m_connectivity: str, - m2m_connectivity_kwargs={}, - m2g_connectivity_kwargs={}, - g2m_connectivity_kwargs={}, + mesh_layout: str, + mesh_layout_kwargs: dict = None, + m2m_connectivity_kwargs: dict = None, + m2g_connectivity_kwargs: dict = None, + g2m_connectivity_kwargs: dict = None, coords_crs: pyproj.crs.CRS | None = None, graph_crs: pyproj.crs.CRS | None = None, decode_mask: Iterable[bool] | None = None, @@ -48,6 +114,14 @@ def create_all_graph_components( grid-to-mesh (g2m), mesh-to-mesh (m2m) and mesh-to-grid (m2g), representing the encode-process-decode respectively. + The mesh graph creation follows a two-step process: + 1. **Coordinate creation** (controlled by `mesh_layout` + `mesh_layout_kwargs`): + Creates an undirected graph (nx.Graph) with node positions and spatial + adjacency edges annotated with adjacency types. + 2. **Connectivity creation** (controlled by `m2m_connectivity` + `m2m_connectivity_kwargs`): + Converts the coordinate graph to directed connectivity (nx.DiGraph) + based on the specified pattern and connectivity method. + For each graph component, the method for connecting nodes across graphs should be specified (with the `*_connectivity` arguments, e.g. `m2g_connectivity`). And the method-specific arguments should be passed as keyword arguments using @@ -61,15 +135,30 @@ def create_all_graph_components( - "within_radius": Find all neighbours in grid within an absolute distance of `max_dist` or relative distance of `rel_max_dist` from each node in mesh + mesh_layout: + - "rectilinear": Uniform regular grid with ``mesh_node_spacing`` resolution. + Produces an undirected mesh primitive with 4-star (cardinal) and + 8-star (cardinal + diagonal) spatial adjacency edges. + + mesh_layout_kwargs (for mesh_layout="rectilinear"): + - mesh_node_spacing: float, distance between mesh nodes in coordinate units. + - refinement_factor: int, refinement factor between levels + (for multi-level and hierarchical mesh graphs, default: 3) + - max_num_refinement_levels: int, maximum number of mesh levels + (for multi-level and hierarchical mesh graphs) + + Wherever the ``pattern`` argument appears below it defines the spatial + neighbourhood connectivity: + - ``"4-star"``: only cardinal directions (horizontal and vertical neighbours) + - ``"8-star"``: cardinal plus diagonal neighbours (all 8 surrounding nodes) + m2m_connectivity: - - "flat": Create a single-level 2D mesh graph with `mesh_node_distance`, - similar to Keisler et al. (2022) - - "flat_multiscale": Create a flat multiscale mesh graph with `max_num_levels`, - `mesh_node_distance` and `level_refinement_factor`, - similar to GraphCast, Lam et al. (2023) - - "hierarchical": Create a hierarchical mesh graph with `max_num_levels`, - `mesh_node_distance` and `level_refinement_factor`, - similar to Oskarsson et al. (2023) + - "flat": Create a single-level directed mesh graph. + m2m_connectivity_kwargs: pattern (default: "8-star") + - "flat_multiscale": Create a flat multiscale mesh graph. + m2m_connectivity_kwargs: pattern (default: "8-star") + - "hierarchical": Create a hierarchical mesh graph with up/down connections. + m2m_connectivity_kwargs: intra_level=dict(pattern=...), inter_level=dict(pattern=..., k=...) m2g_connectivity: - "nearest_neighbour": Find the nearest neighbour in mesh for each node in grid @@ -96,6 +185,29 @@ def create_all_graph_components( """ graph_components: dict[networkx.DiGraph] = {} + # Initialize mutable default arguments (and copy to avoid mutating caller's dicts) + if mesh_layout_kwargs is None: + mesh_layout_kwargs = {} + else: + mesh_layout_kwargs = dict(mesh_layout_kwargs) + if m2m_connectivity_kwargs is None: + m2m_connectivity_kwargs = {} + else: + m2m_connectivity_kwargs = dict(m2m_connectivity_kwargs) + if m2g_connectivity_kwargs is None: + m2g_connectivity_kwargs = {} + else: + m2g_connectivity_kwargs = dict(m2g_connectivity_kwargs) + if g2m_connectivity_kwargs is None: + g2m_connectivity_kwargs = {} + else: + g2m_connectivity_kwargs = dict(g2m_connectivity_kwargs) + + # Migrate deprecated kwargs (to be removed in a future version) + mesh_layout_kwargs, m2m_connectivity_kwargs = _migrate_deprecated_kwargs( + mesh_layout_kwargs, m2m_connectivity_kwargs + ) + assert ( len(coords.shape) == 2 and coords.shape[1] == 2 ), "Grid node coordinates should be given as an array of shape [num_grid_nodes, 2]." @@ -124,29 +236,93 @@ def create_all_graph_components( xy_tuple = coord_transformer.transform(xx=coords[:, 0], yy=coords[:, 1]) xy = np.stack(xy_tuple, axis=1) + # Validate m2m_connectivity early so that we raise a clear NotImplementedError + # before any coordinate creation is attempted + _supported_m2m_connectivity = {"flat", "hierarchical", "flat_multiscale"} + if m2m_connectivity not in _supported_m2m_connectivity: + raise NotImplementedError( + f"Kind {m2m_connectivity} not implemented. " + f"Supported: {sorted(_supported_m2m_connectivity)}" + ) + + # ----------------------------------------------------------------------- + # Step 1: Coordinate creation — produces the mesh primitive graph(s) + # Result type depends on m2m_connectivity: + # - flat: G_mesh_coords: nx.Graph + # - hierarchical/flat_multiscale: G_mesh_coords: List[nx.Graph] + # ----------------------------------------------------------------------- + G_mesh_coords: Union[networkx.Graph, List[networkx.Graph]] + + if mesh_layout == "rectilinear": + mesh_node_spacing = mesh_layout_kwargs.get( + "mesh_node_spacing" + ) or mesh_layout_kwargs.get("grid_spacing") + if mesh_node_spacing is None: + raise ValueError( + "mesh_layout='rectilinear' requires 'mesh_node_spacing' in " + "mesh_layout_kwargs (or 'mesh_node_distance' in " + "m2m_connectivity_kwargs for backward compatibility)." + ) + + if m2m_connectivity == "flat": + # Single-level mesh: compute nx/ny from spacing + range_x, range_y = np.ptp(xy, axis=0) + nx_mesh = int(range_x / mesh_node_spacing) + ny_mesh = int(range_y / mesh_node_spacing) + if nx_mesh == 0 or ny_mesh == 0: + raise ValueError( + "The given `mesh_node_spacing` is too large for the provided " + f"coordinates. Got mesh_node_spacing={mesh_node_spacing}, but the " + f"x-range is {range_x} and y-range is {range_y}. Maybe you " + "want to decrease the `mesh_node_spacing` so that the mesh nodes " + "are spaced closer together?" + ) + G_mesh_coords = create_single_level_2d_mesh_primitive(xy, nx_mesh, ny_mesh) + else: + # Multi-level mesh: build kwargs for create_multirange_2d_mesh_primitives + primitives_kwargs = dict(xy=xy, mesh_node_spacing=mesh_node_spacing) + if "refinement_factor" in mesh_layout_kwargs: + primitives_kwargs["interlevel_refinement_factor"] = mesh_layout_kwargs[ + "refinement_factor" + ] + if "max_num_refinement_levels" in mesh_layout_kwargs: + primitives_kwargs["max_num_levels"] = mesh_layout_kwargs[ + "max_num_refinement_levels" + ] + G_mesh_coords = create_multirange_2d_mesh_primitives(**primitives_kwargs) + else: + raise NotImplementedError( + f"mesh_layout='{mesh_layout}' is not yet supported. " + "Currently only 'rectilinear' is implemented." + ) + + # ----------------------------------------------------------------------- + # Step 2: Connectivity creation — converts mesh primitives to directed graph + # ----------------------------------------------------------------------- if m2m_connectivity == "flat": - graph_components["m2m"] = create_flat_singlescale_mesh_graph( - xy, - **m2m_connectivity_kwargs, + graph_components["m2m"] = create_flat_singlescale_from_coordinates( + G_mesh_coords, **m2m_connectivity_kwargs ) grid_connect_graph = graph_components["m2m"] + elif m2m_connectivity == "hierarchical": - # hierarchical mesh graph have three sub-graphs: - # `m2m` (mesh-to-mesh), `mesh_up` (up edge connections) and `mesh_down` (down edge connections) - graph_components["m2m"] = create_hierarchical_multiscale_mesh_graph( - xy=xy, - **m2m_connectivity_kwargs, + # hierarchical mesh graph has three sub-graphs: + # `m2m` (mesh-to-mesh), `mesh_up` (up edge connections) and + # `mesh_down` (down edge connections) + graph_components["m2m"] = create_hierarchical_from_coordinates( + G_mesh_coords, **m2m_connectivity_kwargs ) # Only connect grid to bottom level of hierarchy grid_connect_graph = split_graph_by_edge_attribute( graph_components["m2m"], "level" )[0] + elif m2m_connectivity == "flat_multiscale": - graph_components["m2m"] = create_flat_multiscale_mesh_graph( - xy=xy, - **m2m_connectivity_kwargs, + graph_components["m2m"] = create_flat_multiscale_from_coordinates( + G_mesh_coords, **m2m_connectivity_kwargs ) grid_connect_graph = graph_components["m2m"] + else: raise NotImplementedError(f"Kind {m2m_connectivity} not implemented") diff --git a/src/weather_model_graphs/create/mesh/__init__.py b/src/weather_model_graphs/create/mesh/__init__.py index 573d9f9..2d3d7cc 100644 --- a/src/weather_model_graphs/create/mesh/__init__.py +++ b/src/weather_model_graphs/create/mesh/__init__.py @@ -1 +1,6 @@ -from .mesh import create_single_level_2d_mesh_graph +from .coords import ( + create_directed_mesh_graph, + create_multirange_2d_mesh_primitives, + create_single_level_2d_mesh_graph, + create_single_level_2d_mesh_primitive, +) diff --git a/src/weather_model_graphs/create/mesh/kinds/__init__.py b/src/weather_model_graphs/create/mesh/connectivity/__init__.py similarity index 100% rename from src/weather_model_graphs/create/mesh/kinds/__init__.py rename to src/weather_model_graphs/create/mesh/connectivity/__init__.py diff --git a/src/weather_model_graphs/create/mesh/connectivity/flat.py b/src/weather_model_graphs/create/mesh/connectivity/flat.py new file mode 100644 index 0000000..5c192fc --- /dev/null +++ b/src/weather_model_graphs/create/mesh/connectivity/flat.py @@ -0,0 +1,262 @@ +from typing import List + +import networkx +import numpy as np + +from ....networkx_utils import prepend_node_index +from .. import coords as mesh_coords + + +def _check_required_graph_attributes(G: networkx.Graph, context: str): + """Check that a coordinate graph has the required node and edge attributes. + + Parameters + ---------- + G : networkx.Graph + The coordinate graph to validate. + context : str + Description of where this check is being called, for error messages. + + Raises + ------ + ValueError + If required attributes are missing. + """ + # Check at least one node has required attributes + if len(G.nodes) > 0: + sample_node = next(iter(G.nodes)) + if "pos" not in G.nodes[sample_node]: + raise ValueError( + f"{context}: coordinate graph nodes must have a 'pos' attribute " + "(np.ndarray of shape [2,])." + ) + if "type" not in G.nodes[sample_node]: + raise ValueError( + f"{context}: coordinate graph nodes must have a 'type' attribute." + ) + # Check at least one edge has required attributes + if len(G.edges) > 0: + sample_edge = next(iter(G.edges)) + if "adjacency_type" not in G.edges[sample_edge]: + raise ValueError( + f"{context}: coordinate graph edges must have an 'adjacency_type' " + "attribute ('cardinal' or 'diagonal')." + ) + + +def create_flat_multiscale_from_coordinates( + G_coords_list: List[networkx.Graph], + **kwargs, +): + """ + Create flat multiscale mesh graph from a list of coordinate graphs. + + This is the connectivity creation step for flat multiscale meshes. + It takes undirected mesh primitive graphs (one per level) and produces a + single directed mesh graph where all levels are merged into one flat graph. + + In a flat multiscale graph, coarser levels are merged into the finer level + by coincident node positions (no separate inter-level connectivity needed). + + Parameters + ---------- + G_coords_list : list of networkx.Graph + List of undirected mesh primitive graphs, one per level. Each graph + must have: + - Node attributes: ``"pos"`` (np.ndarray of shape [2,]), ``"type"`` (str) + - Edge attributes: ``"adjacency_type"`` (str, ``"cardinal"`` or ``"diagonal"``) + - Graph attribute: ``"interlevel_refinement_factor"`` (int) + Created by ``create_multirange_2d_mesh_primitives``. + **kwargs + Additional keyword arguments passed to ``create_directed_mesh_graph`` + (e.g. ``pattern="8-star"``). + + Returns + ------- + G_tot : networkx.DiGraph + The merged flat multiscale mesh graph + """ + # Validate required attributes on first graph + _check_required_graph_attributes( + G_coords_list[0], "create_flat_multiscale_from_coordinates" + ) + + # Assert interlevel_refinement_factor is set (no silent default) + if "interlevel_refinement_factor" not in G_coords_list[0].graph: + raise ValueError( + "The coordinate graphs must have an 'interlevel_refinement_factor' " + "graph attribute. This is set by create_multirange_2d_mesh_primitives." + ) + interlevel_refinement_factor = G_coords_list[0].graph[ + "interlevel_refinement_factor" + ] + + # Check that interlevel_refinement_factor is an odd integer + if ( + int(interlevel_refinement_factor) != interlevel_refinement_factor + or interlevel_refinement_factor % 2 != 1 + ): + raise ValueError( + "The `interlevel_refinement_factor` must be an odd integer. " + f"Given value: {interlevel_refinement_factor}." + ) + + # Convert each level's coordinate graph to directed graph with chosen pattern + G_all_levels = [ + mesh_coords.create_directed_mesh_graph(g_coords, **kwargs) + for g_coords in G_coords_list + ] + + # combine all levels to one graph + G_tot = G_all_levels[0] + # First node at level l+1 share position with node (offset, offset) at level l + level_offset = interlevel_refinement_factor // 2 + + first_level_nodes = list(G_all_levels[0].nodes) + # Last nodes in first layer has pos (nx-1, ny-1) + num_nodes_x = first_level_nodes[-1][0] + 1 + num_nodes_y = first_level_nodes[-1][1] + 1 + + for lev in range(1, len(G_all_levels)): + nodes = list(G_all_levels[lev - 1].nodes) + ij = ( + np.array(nodes) + .reshape((num_nodes_x, num_nodes_y, 2))[ + level_offset::interlevel_refinement_factor, + level_offset::interlevel_refinement_factor, + :, + ] + .reshape( + int(num_nodes_x * num_nodes_y / (interlevel_refinement_factor**2)), + 2, + ) + ) + ij = [tuple(x) for x in ij] + G_all_levels[lev] = networkx.relabel_nodes( + G_all_levels[lev], dict(zip(G_all_levels[lev].nodes, ij)) + ) + G_tot = networkx.compose(G_tot, G_all_levels[lev]) + + # Update number of nodes in x- and y-direction for next iteration + num_nodes_x //= interlevel_refinement_factor + num_nodes_y //= interlevel_refinement_factor + + # Relabel mesh nodes to start with 0 + G_tot = prepend_node_index(G_tot, 0) + + # add dx and dy to graph + G_tot.graph["dx"] = {i: g.graph["dx"] for i, g in enumerate(G_all_levels)} + G_tot.graph["dy"] = {i: g.graph["dy"] for i, g in enumerate(G_all_levels)} + + return G_tot + + +def create_flat_singlescale_from_coordinates(G_coords: networkx.Graph, **kwargs): + """ + Create a flat single-scale directed mesh graph from a mesh primitive graph. + + This is the connectivity creation step for flat single-scale meshes. + It converts an undirected mesh primitive graph to a directed mesh graph + using the specified connectivity pattern. + + Parameters + ---------- + G_coords : networkx.Graph + Undirected mesh primitive graph. Must have: + - Node attributes: ``"pos"`` (np.ndarray of shape [2,]), ``"type"`` (str) + - Edge attributes: ``"adjacency_type"`` (str, ``"cardinal"`` or ``"diagonal"``) + Created by ``create_single_level_2d_mesh_primitive``. + **kwargs + Additional keyword arguments passed to ``create_directed_mesh_graph`` + (e.g. ``pattern="8-star"``). + + Returns + ------- + networkx.DiGraph + The flat single-scale directed mesh graph + """ + _check_required_graph_attributes( + G_coords, "create_flat_singlescale_from_coordinates" + ) + return mesh_coords.create_directed_mesh_graph(G_coords, **kwargs) + + +def create_flat_multiscale_mesh_graph( + xy, mesh_node_distance: float, level_refinement_factor: int, max_num_levels: int +): + """ + Create flat mesh graph by merging the single-level mesh + graphs across all levels in `G_all_levels`. + + Internally uses the two-step process: + 1. create_multirange_2d_mesh_primitives (coordinate creation) + 2. create_flat_multiscale_from_coordinates (connectivity creation) + + Parameters + ---------- + xy : np.ndarray [N_grid_points, 2] + Grid point coordinates, with first column representing + x coordinates and second column y coordinates. N_grid_points is the + total number of grid points. + mesh_node_distance: float + Distance (in x- and y-direction) between created mesh nodes, + in coordinate system of xy + level_refinement_factor: int + Refinement factor between grid points and bottom level of mesh hierarchy + NOTE: Must be an odd integer >1 to create proper multiscale graph + max_num_levels : int + Maximum number of levels in the multi-scale graph + Returns + ------- + G_tot : networkx.Graph + The merged mesh graph + """ + G_coords_list = mesh_coords.create_multirange_2d_mesh_primitives( + max_num_levels=max_num_levels, + xy=xy, + mesh_node_spacing=mesh_node_distance, + interlevel_refinement_factor=level_refinement_factor, + ) + + return create_flat_multiscale_from_coordinates( + G_coords_list, + pattern="8-star", + ) + + +def create_flat_singlescale_mesh_graph(xy, mesh_node_distance: float): + """ + Create flat mesh graph of single level + + Internally uses the two-step process: + 1. create_single_level_2d_mesh_primitive (coordinate creation) + 2. create_directed_mesh_graph (connectivity creation, pattern="8-star") + + Parameters + ---------- + xy : np.ndarray [N_grid_points, 2] + Grid point coordinates, with first column representing + x coordinates and second column y coordinates. N_grid_points is the + total number of grid points. + mesh_node_distance: float + Distance (in x- and y-direction) between created mesh nodes, + in coordinate system of xy + Returns + ------- + G_flat : networkx.Graph + The flat mesh graph + """ + # Compute number of mesh nodes in x and y dimensions + range_x, range_y = np.ptp(xy, axis=0) + nx = int(range_x / mesh_node_distance) + ny = int(range_y / mesh_node_distance) + + if nx == 0 or ny == 0: + raise ValueError( + "The given `mesh_node_distance` is too large for the provided coordinates. " + f"Got mesh_node_distance={mesh_node_distance}, but the x-range is {range_x} " + f"and y-range is {range_y}. Maybe you want to decrease the `mesh_node_distance`" + " so that the mesh nodes are spaced closer together?" + ) + + return mesh_coords.create_single_level_2d_mesh_graph(xy=xy, nx=nx, ny=ny) diff --git a/src/weather_model_graphs/create/mesh/connectivity/hierarchical.py b/src/weather_model_graphs/create/mesh/connectivity/hierarchical.py new file mode 100644 index 0000000..e870f4f --- /dev/null +++ b/src/weather_model_graphs/create/mesh/connectivity/hierarchical.py @@ -0,0 +1,225 @@ +from typing import Dict, List, Optional + +import networkx +import numpy as np +import scipy + +from ....networkx_utils import prepend_node_index +from .. import coords as mesh_coords + + +def create_hierarchical_from_coordinates( + G_coords_list: List[networkx.Graph], + intra_level: Dict[str, object] = {"pattern": "8-star"}, + inter_level: Dict[str, object] = {"pattern": "nearest", "k": 1}, +): + """ + Create a hierarchical multiscale mesh graph from a list of mesh primitive + graphs. + + This is the connectivity creation step for hierarchical meshes. + It takes undirected mesh primitive graphs (one per level) and produces a + directed mesh graph with intra-level connectivity and inter-level + up/down connections. + + The ``intra_level["pattern"]`` defines the spatial neighbourhood connectivity + within each mesh level: + - ``"4-star"``: only cardinal directions (horizontal and vertical neighbours) + - ``"8-star"``: cardinal directions plus diagonals (all 8 surrounding neighbours) + + Parameters + ---------- + G_coords_list : list of networkx.Graph + List of undirected mesh primitive graphs, one per level. Each graph + must have: + - Node attributes: ``"pos"`` (np.ndarray of shape [2,]), ``"type"`` (str) + - Edge attributes: ``"adjacency_type"`` (str, ``"cardinal"`` or ``"diagonal"``) + Created by ``create_multirange_2d_mesh_primitives``. + intra_level : dict + Configuration for intra-level connectivity. Keys: + - ``"pattern"`` (str): ``"4-star"`` or ``"8-star"``. + Default: ``{"pattern": "8-star"}`` + inter_level : dict + Configuration for inter-level connectivity. Keys: + - ``"pattern"`` (str): Currently only ``"nearest"`` is supported. + - ``"k"`` (int): Number of nearest neighbours for inter-level connections. + Default: ``{"pattern": "nearest", "k": 1}`` + + Returns + ------- + networkx.DiGraph + A directed graph containing the hierarchical mesh with intra-level + edges (direction="same"), inter-level down edges (direction="down"), + and inter-level up edges (direction="up"). + """ + intra_level_pattern = intra_level.get("pattern", "8-star") + inter_level_pattern = inter_level.get("pattern", "nearest") + inter_level_k = inter_level.get("k", 1) + + if inter_level_pattern != "nearest": + raise NotImplementedError( + f"Inter-level pattern '{inter_level_pattern}' is not yet supported " + "for hierarchical graphs. Only 'nearest' is currently implemented." + ) + + # Convert each level's coordinate graph to directed graph with chosen pattern + Gs_all_levels = [ + mesh_coords.create_directed_mesh_graph(g_coords, pattern=intra_level_pattern) + for g_coords in G_coords_list + ] + + n_mesh_levels = len(Gs_all_levels) + + if n_mesh_levels < 2: + raise ValueError( + "At least two mesh levels are required for hierarchical mesh graph. " + "You may need to reduce the level refinement factor " + "or increase the max number of levels " + "or number of grid points." + ) + + # Relabel nodes of each level with level index first + Gs_all_levels = [ + prepend_node_index(graph, level_i) + for level_i, graph in enumerate(Gs_all_levels) + ] + + # add `direction` attribute to all edges with value `same` + for i, G in enumerate(Gs_all_levels): + for u, v in G.edges: + G.edges[u, v]["direction"] = "same" + G.edges[u, v]["level"] = i + + # Create inter-level mesh edges + up_graphs = [] + down_graphs = [] + for G_from, G_to in zip( + Gs_all_levels[1:], + Gs_all_levels[:-1], + ): + from_level = G_from.graph["level"] + to_level = G_to.graph["level"] + + # start out from graph at from level + G_down = G_from.copy() + G_down.clear_edges() + G_down = networkx.DiGraph(G_down) + + # Add nodes of to level + G_down.add_nodes_from(G_to.nodes(data=True)) + + # build kd tree for mesh point pos + # order in vm should be same as in vm_xy + v_to_list = list(G_to.nodes) + v_from_list = list(G_from.nodes) + v_from_xy = np.array([xy for _, xy in G_from.nodes.data("pos")]) + kdt_m = scipy.spatial.KDTree(v_from_xy) + + # add edges from coarser to finer level + for v in v_to_list: + # find k nearest neighbours (index to vm_xy) + neigh_idx = kdt_m.query(G_down.nodes[v]["pos"], inter_level_k)[1] + if inter_level_k == 1: + neigh_idx = [neigh_idx] + + for idx in neigh_idx: + u = v_from_list[idx] + + # add edge from coarser to finer + G_down.add_edge(u, v) + d = np.sqrt( + np.sum((G_down.nodes[u]["pos"] - G_down.nodes[v]["pos"]) ** 2) + ) + G_down.edges[u, v]["len"] = d + G_down.edges[u, v]["vdiff"] = ( + G_down.nodes[u]["pos"] - G_down.nodes[v]["pos"] + ) + G_down.edges[u, v]["levels"] = f"{from_level}>{to_level}" + G_down.edges[u, v]["direction"] = "down" + + G_up = networkx.DiGraph() + G_up.add_nodes_from(G_down.nodes(data=True)) + for u, v, data in G_down.edges(data=True): + data = data.copy() + data["levels"] = f"{to_level}>{from_level}" + data["direction"] = "up" + G_up.add_edge(v, u, **data) + + up_graphs.append(G_up) + down_graphs.append(G_down) + + G_up_all = networkx.compose_all(up_graphs) + G_down_all = networkx.compose_all(down_graphs) + G_all_levels = networkx.compose_all(Gs_all_levels) + + G_m2m = networkx.compose_all([G_all_levels, G_up_all, G_down_all]) + + # add dx and dy to graph + for prop in ("dx", "dy"): + G_m2m.graph[prop] = {i: g.graph[prop] for i, g in enumerate(Gs_all_levels)} + + return G_m2m + + +def create_hierarchical_multiscale_mesh_graph( + xy: np.ndarray, + mesh_node_distance: float, + level_refinement_factor: float, + max_num_levels: int, + intra_level: Optional[Dict[str, object]] = None, + inter_level: Optional[Dict[str, object]] = None, +): + """ + Create a hierarchical multiscale mesh graph with nearest neighbour + connections within each level (horizontally, vertically, and diagonally), and + connections between levels (coarse to fine and fine to coarse) using the + nearest neighbour connection. + + Internally uses the two-step process: + 1. create_multirange_2d_mesh_primitives (coordinate creation) + 2. create_hierarchical_from_coordinates (connectivity creation) + + Parameters + ---------- + xy : np.ndarray + 2D array of mesh point positions, shaped [N_points, 2]. + mesh_node_distance : float + Distance (in x- and y-direction) between created mesh nodes in bottom level, + in coordinate system of xy + level_refinement_factor : float + Refinement factor between grid points and bottom level of mesh hierarchy + max_num_levels : int + The number of levels in the hierarchical mesh graph. + intra_level : dict, optional + Configuration for intra-level connectivity. Keys: + - ``"pattern"`` (str): ``"4-star"`` or ``"8-star"``. + Default: ``{"pattern": "8-star"}`` + inter_level : dict, optional + Configuration for inter-level connectivity. Keys: + - ``"pattern"`` (str): Currently only ``"nearest"`` is supported. + - ``"k"`` (int): Number of nearest neighbours. + Default: ``{"pattern": "nearest", "k": 1}`` + + Returns + ------- + networkx.DiGraph + A directed graph containing the hierarchical mesh with intra-level, + up, and down edges. + """ + G_coords_list = mesh_coords.create_multirange_2d_mesh_primitives( + max_num_levels=max_num_levels, + xy=xy, + mesh_node_spacing=mesh_node_distance, + interlevel_refinement_factor=level_refinement_factor, + ) + + kwargs = {} + if intra_level is not None: + kwargs["intra_level"] = intra_level + if inter_level is not None: + kwargs["inter_level"] = inter_level + + return create_hierarchical_from_coordinates( + G_coords_list, + **kwargs, + ) diff --git a/src/weather_model_graphs/create/mesh/coords.py b/src/weather_model_graphs/create/mesh/coords.py new file mode 100644 index 0000000..acb4898 --- /dev/null +++ b/src/weather_model_graphs/create/mesh/coords.py @@ -0,0 +1,336 @@ +import networkx +import numpy as np +from loguru import logger + + +def create_single_level_2d_mesh_primitive(xy: np.ndarray, nx: int, ny: int): + """ + Create an undirected mesh primitive graph (nx.Graph) with node positions + and spatial adjacency edges, representing the coordinate creation step. + + A mesh primitive is an undirected graph that encodes all potential + neighbourhood connectivity edges. It serves as a blueprint from which + directed connectivity graphs can later be built by selecting a subset + of edges (e.g. 4-star or 8-star pattern). + + This produces a graph where: + - Nodes have a ``"pos"`` attribute (np.ndarray of shape [2,] with x and y + coordinates) and a ``"type"`` attribute (str, always ``"mesh"``). + - Edges have an ``"adjacency_type"`` attribute (str): ``"cardinal"`` for + horizontal/vertical neighbours (4-star connectivity) or ``"diagonal"`` + for diagonal neighbours (additional edges in 8-star connectivity). + + This is the first step in the two-step mesh creation process: + 1. Coordinate creation (this function) -> nx.Graph with spatial adjacency + 2. Connectivity creation (create_directed_mesh_graph) -> nx.DiGraph + + Parameters + ---------- + xy : np.ndarray + Grid point coordinates, shaped [N_grid_points, 2], with first column + representing x coordinates and second column y coordinates. + nx : int + Number of nodes in x direction + ny : int + Number of nodes in y direction + + Returns + ------- + networkx.Graph + Undirected mesh primitive graph with node positions and annotated + spatial adjacency edges. + """ + xm, xM = np.amin(xy[:, 0]), np.amax(xy[:, 0]) + ym, yM = np.amin(xy[:, 1]), np.amax(xy[:, 1]) + + # avoid nodes on border + dx = (xM - xm) / nx + dy = (yM - ym) / ny + lx = np.linspace(xm + dx / 2, xM - dx / 2, nx) + ly = np.linspace(ym + dy / 2, yM - dy / 2, ny) + + mg = np.meshgrid(lx, ly) + g = networkx.grid_2d_graph(len(lx), len(ly)) + + # Node name and `pos` attribute takes form (x, y) + for node in g.nodes: + node_xi, node_yi = node # Extract x and y index from node to index mx + g.nodes[node]["pos"] = np.array( + [mg[0][node_yi, node_xi], mg[1][node_yi, node_xi]] + ) + g.nodes[node]["type"] = "mesh" + + # Mark existing grid_2d_graph edges as cardinal (4-star adjacency) + for u, v in g.edges(): + g.edges[u, v]["adjacency_type"] = "cardinal" + + # Add diagonal edges (8-star adjacency) + diagonal_edges = [ + ((x, y), (x + 1, y + 1)) for x in range(nx - 1) for y in range(ny - 1) + ] + [((x + 1, y), (x, y + 1)) for x in range(nx - 1) for y in range(ny - 1)] + g.add_edges_from(diagonal_edges) + for u, v in diagonal_edges: + g.edges[u, v]["adjacency_type"] = "diagonal" + + g.graph["dx"] = dx + g.graph["dy"] = dy + + return g + + +def create_directed_mesh_graph(G_undirected: networkx.Graph, pattern: str = "8-star"): + """ + Convert an undirected mesh primitive graph with spatial adjacency edges to a + directed mesh graph (nx.DiGraph) based on the specified connectivity pattern. + + This is the second step in the two-step mesh creation process: + 1. Coordinate creation (create_single_level_2d_mesh_primitive) -> nx.Graph + 2. Connectivity creation (this function) -> nx.DiGraph + + The ``pattern`` argument defines the spatial neighbourhood connectivity: + - ``"4-star"``: only cardinal directions (horizontal and vertical neighbours) + - ``"8-star"``: cardinal directions plus diagonals (all 8 surrounding neighbours) + + Parameters + ---------- + G_undirected : networkx.Graph + Undirected mesh primitive graph. Expected node attributes: + - ``"pos"``: np.ndarray of shape [2,], spatial coordinates. + Expected edge attributes: + - ``"adjacency_type"``: str, either ``"cardinal"`` or ``"diagonal"``. + Additional edge attributes (e.g. ``"level"``) are preserved in the + output directed graph. + pattern : str + Connectivity pattern. Options: + - ``"4-star"``: only cardinal edges (horizontal/vertical neighbours) + - ``"8-star"``: all edges (cardinal + diagonal neighbours) + + Returns + ------- + networkx.DiGraph + Directed graph with bidirectional edges, each having ``"len"`` and + ``"vdiff"`` attributes. All original edge attributes from the + primitive graph are preserved. + """ + if pattern == "4-star": + # Filter to only cardinal edges, preserving edge data + edges_to_use = [ + (u, v, d) + for u, v, d in G_undirected.edges(data=True) + if d.get("adjacency_type") == "cardinal" + ] + elif pattern == "8-star": + # Use all edges with their data + edges_to_use = list(G_undirected.edges(data=True)) + else: + raise ValueError( + f"Unknown connectivity pattern: '{pattern}'. " + "Choose '4-star' or '8-star'." + ) + + # Create filtered undirected graph with only selected edges (preserving attrs) + g_filtered = networkx.Graph() + g_filtered.add_nodes_from(G_undirected.nodes(data=True)) + g_filtered.add_edges_from(edges_to_use) + + # Convert to directed graph (creates edges in both directions) + dg = networkx.DiGraph(g_filtered) + for u, v in g_filtered.edges(): + d = np.sqrt( + np.sum((G_undirected.nodes[u]["pos"] - G_undirected.nodes[v]["pos"]) ** 2) + ) + dg.edges[u, v]["len"] = d + dg.edges[u, v]["vdiff"] = ( + G_undirected.nodes[u]["pos"] - G_undirected.nodes[v]["pos"] + ) + # Ensure reverse edge exists and has attributes + dg.edges[v, u]["len"] = d + dg.edges[v, u]["vdiff"] = ( + G_undirected.nodes[v]["pos"] - G_undirected.nodes[u]["pos"] + ) + + # Preserve graph-level attributes (dx, dy, level, etc.) + dg.graph.update(G_undirected.graph) + + return dg + + +def create_single_level_2d_mesh_graph(xy: np.ndarray, nx: int, ny: int): + """ + Create directed graph with nx * ny nodes representing a 2D grid with + positions spanning the range of xy coordinate values (first dimension + is assumed to be x and y coordinate values respectively). Each nodes is + connected to its eight nearest neighbours, both horizontally, vertically + and diagonally as directed edges (which means that the graph contains two + edges between each pair of connected nodes). + + The nodes contain a "pos" attribute with the x and y + coordinates of the node, and an "type" attribute with the + type of the node (i.e. "mesh" for mesh nodes). + + The edges contain a "len" attribute with the length of the edge + and a "vdiff" attribute with the vector difference between the + nodes. + + Internally, this uses the two-step process: + 1. create_single_level_2d_mesh_primitive (coordinate creation) + 2. create_directed_mesh_graph (connectivity creation, pattern="8-star") + + Parameters + ---------- + xy : np.ndarray [N_grid_points, 2] + Grid point coordinates, with first column representing + x coordinates and second column y coordinates. N_grid_points is the + total number of grid points. + nx : int + Number of nodes in x direction + ny : int + Number of nodes in y direction + + Returns + ------- + networkx.DiGraph + Graph representing the 2D grid + """ + G_coords = create_single_level_2d_mesh_primitive(xy, nx, ny) + return create_directed_mesh_graph(G_coords, pattern="8-star") + + +def create_multirange_2d_mesh_primitives( + max_num_levels: int, + xy: np.ndarray, + mesh_node_spacing: float = 3, + interlevel_refinement_factor: float = 3, +): + """ + Create a list of undirected mesh primitive graphs (nx.Graph) representing + different levels of mesh resolution spanning the spatial domain of the + xy coordinates. + + This is the coordinate creation step for multi-level and hierarchical mesh + graphs. Each returned graph contains nodes with spatial positions and edges + annotated with adjacency type (``"cardinal"`` or ``"diagonal"``). + + The graphs can be consumed by connectivity creation functions to produce + directed mesh graphs for flat_multiscale or hierarchical architectures. + + Parameters + ---------- + max_num_levels : int + Number of edge-distance levels in mesh graph + xy : np.ndarray + Grid point coordinates, shaped [N_grid_points, 2] + mesh_node_spacing : float + Distance (in x- and y-direction) between created mesh nodes, + in coordinate system of xy + interlevel_refinement_factor : float + Refinement factor between grid points and bottom level of mesh hierarchy + + Returns + ------- + G_all_levels : list of networkx.Graph + List of undirected mesh primitive graphs for each level, each with + node positions and annotated spatial adjacency edges. + Each graph has ``"level"`` and ``"interlevel_refinement_factor"`` + graph attributes. + """ + # Compute the size along x and y direction of area to cover with graph + # This is measured in the Cartesian coordinates of xy + coord_extent = np.ptp(xy, axis=0) + # Number of nodes that would fit on bottom level of hierarchy, + # in both directions + max_nodes_bottom = (coord_extent / mesh_node_spacing).astype(int) + + # Find the number of mesh levels possible in x- and y-direction, + # and the number of leaf nodes that would correspond to + # max_nodes_bottom/(interlevel_refinement_factor^mesh_levels) = 1 + max_mesh_levels_float = np.log(max_nodes_bottom) / np.log( + interlevel_refinement_factor + ) + + max_mesh_levels = max_mesh_levels_float.astype(int) # (2,) + nleaf = interlevel_refinement_factor**max_mesh_levels + # leaves at the bottom in each direction, if using max_mesh_levels + + # As we can not instantiate different number of mesh levels in each + # direction, create mesh levels corresponding to the minimum of the two + mesh_levels_to_create = max_mesh_levels.min() + + if max_num_levels: + # Limit the levels in mesh graph + mesh_levels_to_create = min(mesh_levels_to_create, max_num_levels) + + logger.debug(f"mesh_levels: {mesh_levels_to_create}, nleaf: {nleaf}") + + # multi resolution tree levels + G_all_levels = [] + for lev in range(mesh_levels_to_create): # 0-index mesh levels + # Compute number of nodes on level separate for each direction + nodes_x, nodes_y = (nleaf / (interlevel_refinement_factor**lev)).astype(int) + g = create_single_level_2d_mesh_primitive(xy, nodes_x, nodes_y) + # Add level information to nodes, edges and full graph + for node in g.nodes: + g.nodes[node]["level"] = lev + for edge in g.edges: + g.edges[edge]["level"] = lev + g.graph["level"] = lev + # Store refinement factor so connectivity step can use it + g.graph["interlevel_refinement_factor"] = interlevel_refinement_factor + G_all_levels.append(g) + + return G_all_levels + + +def create_multirange_2d_mesh_graphs( + max_num_levels: int, + xy: np.ndarray, + mesh_node_distance: float = 3, + level_refinement_factor: float = 3, + pattern: str = "8-star", +): + """ + Create a list of 2D grid mesh graphs representing different levels of edge-length + scales spanning the spatial domain of the xy coordinates. + This list of graphs can then later be for example a) flattened into single graph + containing multiple ranges of connections or b) combined into a hierarchical graph. + + Each graph in the list contains a "level" attribute with the level index of the graph. + + Internally uses the two-step process: + 1. create_multirange_2d_mesh_primitives (coordinate creation) + 2. create_directed_mesh_graph (connectivity creation) + + Parameters + ---------- + max_num_levels : int + Number of edge-distance levels in mesh graph + xy : np.ndarray + Grid point coordinates, shaped [N_grid_points, 2] + mesh_node_distance : float + Distance (in x- and y-direction) between created mesh nodes, + in coordinate system of xy + level_refinement_factor : float + Refinement factor between grid points and bottom level of mesh hierarchy + pattern : str + Connectivity pattern for directed graph creation: ``"4-star"`` or + ``"8-star"`` (default: ``"8-star"``) + + Returns + ------- + G_all_levels : list of networkx.DiGraph + List of networkx graphs for each level representing the connectivity + of the mesh within each level + """ + G_coords_list = create_multirange_2d_mesh_primitives( + max_num_levels=max_num_levels, + xy=xy, + mesh_node_spacing=mesh_node_distance, + interlevel_refinement_factor=level_refinement_factor, + ) + + G_all_levels = [] + for g_coords in G_coords_list: + g_directed = create_directed_mesh_graph(g_coords, pattern=pattern) + G_all_levels.append(g_directed) + + return G_all_levels diff --git a/src/weather_model_graphs/create/mesh/kinds/flat.py b/src/weather_model_graphs/create/mesh/kinds/flat.py deleted file mode 100644 index 92f47b0..0000000 --- a/src/weather_model_graphs/create/mesh/kinds/flat.py +++ /dev/null @@ -1,123 +0,0 @@ -import networkx -import numpy as np - -from ....networkx_utils import prepend_node_index -from .. import mesh as mesh_graph - - -def create_flat_multiscale_mesh_graph( - xy, mesh_node_distance: float, level_refinement_factor: int, max_num_levels: int -): - """ - Create flat mesh graph by merging the single-level mesh - graphs across all levels in `G_all_levels`. - - Parameters - ---------- - xy : np.ndarray [N_grid_points, 2] - Grid point coordinates, with first column representing - x coordinates and second column y coordinates. N_grid_points is the - total number of grid points. - mesh_node_distance: float - Distance (in x- and y-direction) between created mesh nodes, - in coordinate system of xy - level_refinement_factor: int - Refinement factor between grid points and bottom level of mesh hierarchy - NOTE: Must be an odd integer >1 to create proper multiscale graph - max_num_levels : int - Maximum number of levels in the multi-scale graph - Returns - ------- - G_tot : networkx.Graph - The merged mesh graph - """ - # Check that level_refinement_factor is an odd integer - if ( - int(level_refinement_factor) != level_refinement_factor - or level_refinement_factor % 2 != 1 - ): - raise ValueError( - "The `level_refinement_factor` must be an odd integer. " - f"Given value: {level_refinement_factor}." - ) - - G_all_levels: list[networkx.DiGraph] = mesh_graph.create_multirange_2d_mesh_graphs( - max_num_levels=max_num_levels, - xy=xy, - mesh_node_distance=mesh_node_distance, - level_refinement_factor=level_refinement_factor, - ) - - # combine all levels to one graph - G_tot = G_all_levels[0] - # First node at level l+1 share position with node (offset, offset) at level l - level_offset = level_refinement_factor // 2 - - first_level_nodes = list(G_all_levels[0].nodes) - # Last nodes in first layer has pos (nx-1, ny-1) - num_nodes_x = first_level_nodes[-1][0] + 1 - num_nodes_y = first_level_nodes[-1][1] + 1 - - for lev in range(1, len(G_all_levels)): - nodes = list(G_all_levels[lev - 1].nodes) - ij = ( - np.array(nodes) - .reshape((num_nodes_x, num_nodes_y, 2))[ - level_offset::level_refinement_factor, - level_offset::level_refinement_factor, - :, - ] - .reshape(int(num_nodes_x * num_nodes_y / (level_refinement_factor**2)), 2) - ) - ij = [tuple(x) for x in ij] - G_all_levels[lev] = networkx.relabel_nodes( - G_all_levels[lev], dict(zip(G_all_levels[lev].nodes, ij)) - ) - G_tot = networkx.compose(G_tot, G_all_levels[lev]) - - # Update number of nodes in x- and y-direction for next iteraion - num_nodes_x //= level_refinement_factor - num_nodes_y //= level_refinement_factor - - # Relabel mesh nodes to start with 0 - G_tot = prepend_node_index(G_tot, 0) - - # add dx and dy to graph - G_tot.graph["dx"] = {i: g.graph["dx"] for i, g in enumerate(G_all_levels)} - G_tot.graph["dy"] = {i: g.graph["dy"] for i, g in enumerate(G_all_levels)} - - return G_tot - - -def create_flat_singlescale_mesh_graph(xy, mesh_node_distance: float): - """ - Create flat mesh graph of single level - - Parameters - ---------- - xy : np.ndarray [N_grid_points, 2] - Grid point coordinates, with first column representing - x coordinates and second column y coordinates. N_grid_points is the - total number of grid points. - mesh_node_distance: float - Distance (in x- and y-direction) between created mesh nodes, - in coordinate system of xy - Returns - ------- - G_flat : networkx.Graph - The flat mesh graph - """ - # Compute number of mesh nodes in x and y dimensions - range_x, range_y = np.ptp(xy, axis=0) - nx = int(range_x / mesh_node_distance) - ny = int(range_y / mesh_node_distance) - - if nx == 0 or ny == 0: - raise ValueError( - "The given `mesh_node_distance` is too large for the provided coordinates. " - f"Got mesh_node_distance={mesh_node_distance}, but the x-range is {range_x} " - f"and y-range is {range_y}. Maybe you want to decrease the `mesh_node_distance`" - " so that the mesh nodes are spaced closer together?" - ) - - return mesh_graph.create_single_level_2d_mesh_graph(xy=xy, nx=nx, ny=ny) diff --git a/src/weather_model_graphs/create/mesh/kinds/hierarchical.py b/src/weather_model_graphs/create/mesh/kinds/hierarchical.py deleted file mode 100644 index b897693..0000000 --- a/src/weather_model_graphs/create/mesh/kinds/hierarchical.py +++ /dev/null @@ -1,132 +0,0 @@ -import networkx -import numpy as np -import scipy - -from ....networkx_utils import prepend_node_index -from .. import mesh as mesh_graph - - -def create_hierarchical_multiscale_mesh_graph( - xy, - mesh_node_distance: float, - level_refinement_factor: float, - max_num_levels: int, -): - """ - Create a hierarchical multiscale mesh graph with nearest neighbour - connections within each level (horizontally, vertically, and diagonally), and - connections between levels (coarse to fine and fine to coarse) using the - nearest neighbour connection. - - Parameters - ---------- - xy: np.ndarray - 2D array of mesh point positions. - Distance (in x- and y-direction) between created mesh nodes in bottom level, - in coordinate system of xy - mesh_node_distance: float - Distance (in x- and y-direction) between created mesh nodes in bottom level, - in coordinate system of xy - level_refinement_factor: float - Refinement factor between grid points and bottom level of mesh hierarchy - max_num_levels: int - The number of levels in the hierarchical mesh graph. - - Returns - ------- - dict - A dictionary containing the hierarchical mesh graph, the mesh down graph, and - the mesh up graph, with keys "m2m", "mesh_down", and "mesh_up" respectively. - """ - Gs_all_levels: list[networkx.DiGraph] = mesh_graph.create_multirange_2d_mesh_graphs( - max_num_levels=max_num_levels, - xy=xy, - mesh_node_distance=mesh_node_distance, - level_refinement_factor=level_refinement_factor, - ) - n_mesh_levels = len(Gs_all_levels) - - if n_mesh_levels < 2: - raise ValueError( - "At least two mesh levels are required for hierarchical mesh graph. " - "You may need to reduce the level refinement factor " - f"or increase the max number of levels {max_num_levels} " - f"or number of grid points {xy.shape[0]}." - ) - - # Relabel nodes of each level with level index first - - Gs_all_levels = [ - prepend_node_index(graph, level_i) - for level_i, graph in enumerate(Gs_all_levels) - ] - - # add `direction` attribute to all edges with value `same`` - for i, G in enumerate(Gs_all_levels): - for u, v in G.edges: - G.edges[u, v]["direction"] = "same" - G.edges[u, v]["level"] = i - - # Create inter-level mesh edges - up_graphs = [] - down_graphs = [] - for G_from, G_to in zip( - Gs_all_levels[1:], - Gs_all_levels[:-1], - ): - from_level = G_from.graph["level"] - to_level = G_to.graph["level"] - - # start out from graph at from level - G_down = G_from.copy() - G_down.clear_edges() - G_down = networkx.DiGraph(G_down) - - # Add nodes of to level - G_down.add_nodes_from(G_to.nodes(data=True)) - - # build kd tree for mesh point pos - # order in vm should be same as in vm_xy - v_to_list = list(G_to.nodes) - v_from_list = list(G_from.nodes) - v_from_xy = np.array([xy for _, xy in G_from.nodes.data("pos")]) - kdt_m = scipy.spatial.KDTree(v_from_xy) - - # add edges from mesh to grid - for v in v_to_list: - # find 1(?) nearest neighbours (index to vm_xy) - neigh_idx = kdt_m.query(G_down.nodes[v]["pos"], 1)[1] - u = v_from_list[neigh_idx] - - # add edge from mesh to grid - G_down.add_edge(u, v) - d = np.sqrt(np.sum((G_down.nodes[u]["pos"] - G_down.nodes[v]["pos"]) ** 2)) - G_down.edges[u, v]["len"] = d - G_down.edges[u, v]["vdiff"] = ( - G_down.nodes[u]["pos"] - G_down.nodes[v]["pos"] - ) - G_down.edges[u, v]["levels"] = f"{from_level}>{to_level}" - G_down.edges[u, v]["direction"] = "down" - - G_up = networkx.DiGraph() - G_up.add_nodes_from(G_down.nodes(data=True)) - for u, v, data in G_down.edges(data=True): - data = data.copy() - data["levels"] = f"{to_level}>{from_level}" - data["direction"] = "up" - G_up.add_edge(v, u, **data) - - up_graphs.append(G_up) - down_graphs.append(G_down) - - G_up_all = networkx.compose_all(up_graphs) - G_down_all = networkx.compose_all(down_graphs) - G_all_levels = networkx.compose_all(Gs_all_levels) - - G_m2m = networkx.compose_all([G_all_levels, G_up_all, G_down_all]) - - # add dx and dy to graph - for prop in ("dx", "dy"): - G_m2m.graph[prop] = {i: g.graph[prop] for i, g in enumerate(Gs_all_levels)} - - return G_m2m diff --git a/src/weather_model_graphs/create/mesh/mesh.py b/src/weather_model_graphs/create/mesh/mesh.py deleted file mode 100644 index 75c82d6..0000000 --- a/src/weather_model_graphs/create/mesh/mesh.py +++ /dev/null @@ -1,150 +0,0 @@ -import networkx -import numpy as np -from loguru import logger - - -def create_single_level_2d_mesh_graph(xy, nx, ny): - """ - Create directed graph with nx * ny nodes representing a 2D grid with - positions spanning the range of xy coordinate values (first dimension - is assumed to be x and y coordinate values respectively). Each nodes is - connected to its eight nearest neighbours, both horizontally, vertically - and diagonally as directed edges (which means that the graph contains two - edges between each pair of connected nodes). - - The nodes contain a "pos" attribute with the x and y - coordinates of the node, and an "type" attribute with the - type of the node (i.e. "mesh" for mesh nodes). - - The edges contain a "len" attribute with the length of the edge - and a "vdiff" attribute with the vector difference between the - nodes. - - Parameters - ---------- - xy : np.ndarray [N_grid_points, 2] - Grid point coordinates, with first column representing - x coordinates and second column y coordinates. N_grid_points is the - total number of grid points. - nx : int - Number of nodes in x direction - ny : int - Number of nodes in y direction - - Returns - ------- - networkx.DiGraph - Graph representing the 2D grid - """ - xm, xM = np.amin(xy[:, 0]), np.amax(xy[:, 0]) - ym, yM = np.amin(xy[:, 1]), np.amax(xy[:, 1]) - - # avoid nodes on border - dx = (xM - xm) / nx - dy = (yM - ym) / ny - lx = np.linspace(xm + dx / 2, xM - dx / 2, nx) - ly = np.linspace(ym + dy / 2, yM - dy / 2, ny) - - mg = np.meshgrid(lx, ly) - g = networkx.grid_2d_graph(len(lx), len(ly)) - - # Node name and `pos` attribute takes form (x, y) - for node in g.nodes: - node_xi, node_yi = node # Extract x and y index from node to index mx - g.nodes[node]["pos"] = np.array( - [mg[0][node_yi, node_xi], mg[1][node_yi, node_xi]] - ) - g.nodes[node]["type"] = "mesh" - - # add diagonal edges - g.add_edges_from( - [((x, y), (x + 1, y + 1)) for x in range(nx - 1) for y in range(ny - 1)] - + [((x + 1, y), (x, y + 1)) for x in range(nx - 1) for y in range(ny - 1)] - ) - - # turn into directed graph - dg = networkx.DiGraph(g) - for u, v in g.edges(): - d = np.sqrt(np.sum((g.nodes[u]["pos"] - g.nodes[v]["pos"]) ** 2)) - dg.edges[u, v]["len"] = d - dg.edges[u, v]["vdiff"] = g.nodes[u]["pos"] - g.nodes[v]["pos"] - dg.add_edge(v, u) - dg.edges[v, u]["len"] = d - dg.edges[v, u]["vdiff"] = g.nodes[v]["pos"] - g.nodes[u]["pos"] - - dg.graph["dx"] = dx - dg.graph["dy"] = dy - - return dg - - -def create_multirange_2d_mesh_graphs( - max_num_levels, xy, mesh_node_distance=3, level_refinement_factor=3 -): - """ - Create a list of 2D grid mesh graphs representing different levels of edge-length - scales spanning the spatial domain of the xy coordinates. - This list of graphs can then later be for example a) flattened into single graph - containing multiple ranges of connections or b) combined into a hierarchical graph. - - Each graph in the list contains a "level" attribute with the level index of the graph. - - Parameters - ---------- - max_num_levels : int - Number of edge-distance levels in mesh graph - xy : np.ndarray - Grid point coordinates, shaped [N_grid_points, 2] - mesh_node_distance: float - Distance (in x- and y-direction) between created mesh nodes, - in coordinate system of xy - level_refinement_factor: float - Refinement factor between grid points and bottom level of mesh hierarchy - - Returns - ------- - G_all_levels : list of networkx.Graph - List of networkx graphs for each level representing the connectivity - of the mesh within each level - """ - # Compute the size along x and y direction of area to cover with graph - # This is measured in the Cartesian coordiantes of xy - coord_extent = np.ptp(xy, axis=0) - # Number of nodes that would fit on bottom level of hierarchy, - # in both directions - max_nodes_bottom = (coord_extent / mesh_node_distance).astype(int) - - # Find the number of mesh levels possible in x- and y-direction, - # and the number of leaf nodes that would correspond to - # max_nodes_bottom/(level_refinement_factor^mesh_levels) = 1 - max_mesh_levels_float = np.log(max_nodes_bottom) / np.log(level_refinement_factor) - - max_mesh_levels = max_mesh_levels_float.astype(int) # (2,) - nleaf = level_refinement_factor**max_mesh_levels - # leaves at the bottom in each direction, if using max_mesh_levels - - # As we can not instantiate different number of mesh levels in each - # direction, create mesh levels corresponding to the minimum of the two - mesh_levels_to_create = max_mesh_levels.min() - - if max_num_levels: - # Limit the levels in mesh graph - mesh_levels_to_create = min(mesh_levels_to_create, max_num_levels) - - logger.debug(f"mesh_levels: {mesh_levels_to_create}, nleaf: {nleaf}") - - # multi resolution tree levels - G_all_levels = [] - for lev in range(mesh_levels_to_create): # 0-index mesh levels - # Compute number of nodes on level separate for each direction - nodes_x, nodes_y = (nleaf / (level_refinement_factor**lev)).astype(int) - g = create_single_level_2d_mesh_graph(xy, nodes_x, nodes_y) - # Add level information to nodes, edges and full graph - for node in g.nodes: - g.nodes[node]["level"] = lev - for edge in g.edges: - g.edges[edge]["level"] = lev - g.graph["level"] = lev - G_all_levels.append(g) - - return G_all_levels diff --git a/tests/test_graph_creation.py b/tests/test_graph_creation.py index b45d486..858299e 100644 --- a/tests/test_graph_creation.py +++ b/tests/test_graph_creation.py @@ -75,6 +75,7 @@ def test_create_graph_generic(m2g_connectivity, g2m_connectivity, m2m_connectivi graph = wmg.create.create_all_graph_components( coords=xy, m2m_connectivity=m2m_connectivity, + mesh_layout="rectilinear", m2m_connectivity_kwargs=m2m_kwargs, g2m_connectivity=g2m_connectivity, g2m_connectivity_kwargs=g2m_kwargs, diff --git a/tests/test_graph_plots.py b/tests/test_graph_plots.py index c5b1e10..eba23b5 100644 --- a/tests/test_graph_plots.py +++ b/tests/test_graph_plots.py @@ -14,6 +14,7 @@ def test_plot(): graph = wmg.create.create_all_graph_components( m2m_connectivity="flat_multiscale", coords=xy, + mesh_layout="rectilinear", m2m_connectivity_kwargs=dict( max_num_levels=3, mesh_node_distance=2, diff --git a/tests/test_mesh_layout.py b/tests/test_mesh_layout.py new file mode 100644 index 0000000..845d31a --- /dev/null +++ b/tests/test_mesh_layout.py @@ -0,0 +1,1332 @@ +""" +Tests for the mesh_layout parameter and two-step coordinate/connectivity +architecture introduced in Issue #78. + +These tests verify: +1. The new API (mesh_layout, mesh_layout_kwargs with refinement_factor and + max_num_refinement_levels, m2m_connectivity_kwargs with pattern for flat/ + flat_multiscale and intra_level/inter_level sub-dicts for hierarchical) +2. The two-step process (coordinate creation → connectivity creation) +3. The 4-star vs 8-star pattern functionality +4. Backward compatibility with old-style kwargs +5. Edge annotations on coordinate graphs +6. Error handling for invalid inputs +""" + +import io +import warnings + +import networkx as nx +import numpy as np +import pytest +from loguru import logger + +import tests.utils as test_utils +import weather_model_graphs as wmg +from weather_model_graphs.create.mesh.connectivity.flat import ( + create_flat_multiscale_from_coordinates, + create_flat_singlescale_from_coordinates, +) +from weather_model_graphs.create.mesh.connectivity.hierarchical import ( + create_hierarchical_from_coordinates, +) +from weather_model_graphs.create.mesh.coords import ( + create_directed_mesh_graph, + create_multirange_2d_mesh_primitives, + create_single_level_2d_mesh_primitive, +) + +# ==================== +# Step 1: Coordinate creation tests +# ==================== + + +class TestSingleLevelCoordinateCreation: + """Tests for create_single_level_2d_mesh_primitive.""" + + def test_returns_undirected_graph(self): + xy = test_utils.create_fake_xy(N=10) + G = create_single_level_2d_mesh_primitive(xy, nx=5, ny=5) + assert isinstance(G, nx.Graph) + assert not isinstance(G, nx.DiGraph) + + def test_nodes_have_pos_and_type(self): + xy = test_utils.create_fake_xy(N=10) + G = create_single_level_2d_mesh_primitive(xy, nx=4, ny=4) + for node in G.nodes: + assert "pos" in G.nodes[node] + assert "type" in G.nodes[node] + assert G.nodes[node]["type"] == "mesh" + assert len(G.nodes[node]["pos"]) == 2 + + def test_correct_number_of_nodes(self): + xy = test_utils.create_fake_xy(N=10) + G = create_single_level_2d_mesh_primitive(xy, nx=5, ny=4) + assert len(G.nodes) == 5 * 4 + + def test_edges_have_adjacency_type(self): + xy = test_utils.create_fake_xy(N=10) + G = create_single_level_2d_mesh_primitive(xy, nx=5, ny=5) + cardinal_count = 0 + diagonal_count = 0 + for u, v, d in G.edges(data=True): + assert "adjacency_type" in d, f"Edge ({u}, {v}) missing adjacency_type" + assert d["adjacency_type"] in ("cardinal", "diagonal") + if d["adjacency_type"] == "cardinal": + cardinal_count += 1 + else: + diagonal_count += 1 + # For a 5x5 grid: cardinal = 2*(5*4) = 40, diagonal = 2*(4*4) = 32 + assert cardinal_count == 2 * (5 * 4) + assert diagonal_count == 2 * (4 * 4) + + def test_graph_has_dx_dy(self): + xy = test_utils.create_fake_xy(N=10) + G = create_single_level_2d_mesh_primitive(xy, nx=5, ny=5) + assert "dx" in G.graph + assert "dy" in G.graph + assert G.graph["dx"] > 0 + assert G.graph["dy"] > 0 + + +class TestMultirangeCoordinateCreation: + """Tests for create_multirange_2d_mesh_primitives.""" + + def test_returns_list_of_undirected_graphs(self): + xy = test_utils.create_fake_xy(N=30) + G_list = create_multirange_2d_mesh_primitives( + max_num_levels=3, xy=xy, mesh_node_spacing=3, interlevel_refinement_factor=3 + ) + assert isinstance(G_list, list) + assert len(G_list) > 0 + for G in G_list: + assert isinstance(G, nx.Graph) + assert not isinstance(G, nx.DiGraph) + + def test_each_level_has_level_attribute(self): + xy = test_utils.create_fake_xy(N=30) + G_list = create_multirange_2d_mesh_primitives( + max_num_levels=3, xy=xy, mesh_node_spacing=3, interlevel_refinement_factor=3 + ) + for i, G in enumerate(G_list): + assert G.graph["level"] == i + for node in G.nodes: + assert G.nodes[node]["level"] == i + + def test_interlevel_refinement_factor_stored(self): + xy = test_utils.create_fake_xy(N=30) + G_list = create_multirange_2d_mesh_primitives( + max_num_levels=3, xy=xy, mesh_node_spacing=3, interlevel_refinement_factor=3 + ) + for G in G_list: + assert G.graph["interlevel_refinement_factor"] == 3 + + def test_edges_have_adjacency_type(self): + xy = test_utils.create_fake_xy(N=30) + G_list = create_multirange_2d_mesh_primitives( + max_num_levels=2, xy=xy, mesh_node_spacing=3, interlevel_refinement_factor=3 + ) + for G in G_list: + for u, v, d in G.edges(data=True): + assert "adjacency_type" in d + + def test_coarser_levels_have_fewer_nodes(self): + xy = test_utils.create_fake_xy(N=30) + G_list = create_multirange_2d_mesh_primitives( + max_num_levels=3, xy=xy, mesh_node_spacing=3, interlevel_refinement_factor=3 + ) + if len(G_list) >= 2: + for i in range(len(G_list) - 1): + assert len(G_list[i].nodes) > len(G_list[i + 1].nodes) + + +# ==================== +# Step 2: Connectivity creation tests +# ==================== + + +class TestDirectedMeshGraph: + """Tests for create_directed_mesh_graph.""" + + def test_returns_directed_graph(self): + xy = test_utils.create_fake_xy(N=10) + G_coords = create_single_level_2d_mesh_primitive(xy, nx=4, ny=4) + G_directed = create_directed_mesh_graph(G_coords, pattern="8-star") + assert isinstance(G_directed, nx.DiGraph) + + def test_4star_has_fewer_edges_than_8star(self): + xy = test_utils.create_fake_xy(N=10) + G_coords = create_single_level_2d_mesh_primitive(xy, nx=5, ny=5) + G_4star = create_directed_mesh_graph(G_coords, pattern="4-star") + G_8star = create_directed_mesh_graph(G_coords, pattern="8-star") + assert len(G_4star.edges) < len(G_8star.edges) + + def test_4star_only_cardinal_edges(self): + """4-star should only include cardinal (horizontal/vertical) edges.""" + xy = test_utils.create_fake_xy(N=10) + G_coords = create_single_level_2d_mesh_primitive(xy, nx=4, ny=4) + G_4star = create_directed_mesh_graph(G_coords, pattern="4-star") + + # In a 4x4 grid, 4-star adjacency means each node connects only to + # horizontal/vertical neighbours + # For a 4x4 grid: 2 * (4*3 + 3*4) = 2 * 24 = 48 directed edges + expected_edges = 2 * (4 * 3 + 3 * 4) + assert len(G_4star.edges) == expected_edges + + def test_8star_includes_diagonal_edges(self): + """8-star should include both cardinal and diagonal edges.""" + xy = test_utils.create_fake_xy(N=10) + G_coords = create_single_level_2d_mesh_primitive(xy, nx=4, ny=4) + G_8star = create_directed_mesh_graph(G_coords, pattern="8-star") + + # Cardinal: 2 * (4*3 + 3*4) = 48 + # Diagonal: 2 * 2 * (3*3) = 36 + expected_edges = 48 + 36 + assert len(G_8star.edges) == expected_edges + + def test_edges_have_len_and_vdiff(self): + xy = test_utils.create_fake_xy(N=10) + G_coords = create_single_level_2d_mesh_primitive(xy, nx=4, ny=4) + G_directed = create_directed_mesh_graph(G_coords, pattern="8-star") + for u, v, d in G_directed.edges(data=True): + assert "len" in d + assert "vdiff" in d + assert d["len"] > 0 + + def test_invalid_pattern_raises_error(self): + xy = test_utils.create_fake_xy(N=10) + G_coords = create_single_level_2d_mesh_primitive(xy, nx=4, ny=4) + with pytest.raises(ValueError, match="Unknown connectivity pattern"): + create_directed_mesh_graph(G_coords, pattern="6-star") + + def test_preserves_graph_attributes(self): + xy = test_utils.create_fake_xy(N=10) + G_coords = create_single_level_2d_mesh_primitive(xy, nx=4, ny=4) + G_directed = create_directed_mesh_graph(G_coords, pattern="8-star") + assert "dx" in G_directed.graph + assert "dy" in G_directed.graph + + def test_bidirectional_edges(self): + """Each undirected edge should produce two directed edges.""" + xy = test_utils.create_fake_xy(N=10) + G_coords = create_single_level_2d_mesh_primitive(xy, nx=3, ny=3) + G_directed = create_directed_mesh_graph(G_coords, pattern="4-star") + for u, v in G_directed.edges(): + assert G_directed.has_edge(v, u), f"Missing reverse edge ({v}, {u})" + + +class TestFlatSinglescaleFromCoordinates: + """Tests for create_flat_singlescale_from_coordinates.""" + + def test_basic_creation(self): + xy = test_utils.create_fake_xy(N=10) + G_coords = create_single_level_2d_mesh_primitive(xy, nx=5, ny=5) + G = create_flat_singlescale_from_coordinates(G_coords, pattern="8-star") + assert isinstance(G, nx.DiGraph) + + def test_4star_pattern(self): + xy = test_utils.create_fake_xy(N=10) + G_coords = create_single_level_2d_mesh_primitive(xy, nx=5, ny=5) + G = create_flat_singlescale_from_coordinates(G_coords, pattern="4-star") + assert isinstance(G, nx.DiGraph) + # Fewer edges than 8-star + G_8 = create_flat_singlescale_from_coordinates(G_coords, pattern="8-star") + assert len(G.edges) < len(G_8.edges) + + +class TestFlatMultiscaleFromCoordinates: + """Tests for create_flat_multiscale_from_coordinates.""" + + def test_basic_creation(self): + xy = test_utils.create_fake_xy(N=30) + G_coords_list = create_multirange_2d_mesh_primitives( + max_num_levels=3, xy=xy, mesh_node_spacing=3, interlevel_refinement_factor=3 + ) + G = create_flat_multiscale_from_coordinates(G_coords_list) + assert isinstance(G, nx.DiGraph) + + def test_pattern_argument(self): + xy = test_utils.create_fake_xy(N=30) + G_coords_list = create_multirange_2d_mesh_primitives( + max_num_levels=3, xy=xy, mesh_node_spacing=3, interlevel_refinement_factor=3 + ) + G_4star = create_flat_multiscale_from_coordinates( + G_coords_list, + pattern="4-star", + ) + G_8star = create_flat_multiscale_from_coordinates( + G_coords_list, + pattern="8-star", + ) + assert len(G_4star.edges) < len(G_8star.edges) + + +class TestHierarchicalFromCoordinates: + """Tests for create_hierarchical_from_coordinates.""" + + def test_basic_creation(self): + xy = test_utils.create_fake_xy(N=30) + G_coords_list = create_multirange_2d_mesh_primitives( + max_num_levels=3, xy=xy, mesh_node_spacing=3, interlevel_refinement_factor=3 + ) + G = create_hierarchical_from_coordinates(G_coords_list) + assert isinstance(G, nx.DiGraph) + + def test_has_up_down_same_edges(self): + xy = test_utils.create_fake_xy(N=30) + G_coords_list = create_multirange_2d_mesh_primitives( + max_num_levels=3, xy=xy, mesh_node_spacing=3, interlevel_refinement_factor=3 + ) + G = create_hierarchical_from_coordinates(G_coords_list) + directions = set() + for u, v, d in G.edges(data=True): + if "direction" in d: + directions.add(d["direction"]) + assert "same" in directions + assert "up" in directions + assert "down" in directions + + def test_intra_level_pattern(self): + xy = test_utils.create_fake_xy(N=30) + G_coords_list = create_multirange_2d_mesh_primitives( + max_num_levels=3, xy=xy, mesh_node_spacing=3, interlevel_refinement_factor=3 + ) + G_4star = create_hierarchical_from_coordinates( + G_coords_list, + intra_level={"pattern": "4-star"}, + ) + G_8star = create_hierarchical_from_coordinates( + G_coords_list, + intra_level={"pattern": "8-star"}, + ) + assert len(G_4star.edges) < len(G_8star.edges) + + def test_inter_level_k_parameter(self): + xy = test_utils.create_fake_xy(N=30) + G_coords_list = create_multirange_2d_mesh_primitives( + max_num_levels=3, xy=xy, mesh_node_spacing=3, interlevel_refinement_factor=3 + ) + G_k1 = create_hierarchical_from_coordinates( + G_coords_list, + inter_level={"pattern": "nearest", "k": 1}, + ) + G_k3 = create_hierarchical_from_coordinates( + G_coords_list, + inter_level={"pattern": "nearest", "k": 3}, + ) + # More neighbours → more inter-level edges + assert len(G_k3.edges) > len(G_k1.edges) + + def test_invalid_inter_level_pattern_raises(self): + xy = test_utils.create_fake_xy(N=30) + G_coords_list = create_multirange_2d_mesh_primitives( + max_num_levels=3, xy=xy, mesh_node_spacing=3, interlevel_refinement_factor=3 + ) + with pytest.raises(NotImplementedError, match="Inter-level pattern"): + create_hierarchical_from_coordinates( + G_coords_list, + inter_level={"pattern": "some_unknown"}, + ) + + +# ==================== +# New API via create_all_graph_components tests +# ==================== + + +class TestNewAPIFlat: + """Tests for create_all_graph_components with new mesh_layout API (flat).""" + + def test_flat_with_new_api(self): + xy = test_utils.create_fake_xy(N=32) + graph = wmg.create.create_all_graph_components( + coords=xy, + m2m_connectivity="flat", + mesh_layout="rectilinear", + mesh_layout_kwargs=dict(mesh_node_spacing=3), + m2m_connectivity_kwargs=dict(pattern="8-star"), + g2m_connectivity="within_radius", + m2g_connectivity="nearest_neighbours", + g2m_connectivity_kwargs=dict(rel_max_dist=0.51), + m2g_connectivity_kwargs=dict(max_num_neighbours=4), + ) + assert isinstance(graph, nx.DiGraph) + + def test_flat_4star_pattern(self): + xy = test_utils.create_fake_xy(N=32) + graph_4 = wmg.create.create_all_graph_components( + coords=xy, + m2m_connectivity="flat", + mesh_layout="rectilinear", + mesh_layout_kwargs=dict(mesh_node_spacing=3), + m2m_connectivity_kwargs=dict(pattern="4-star"), + g2m_connectivity="nearest_neighbour", + m2g_connectivity="nearest_neighbour", + ) + graph_8 = wmg.create.create_all_graph_components( + coords=xy, + m2m_connectivity="flat", + mesh_layout="rectilinear", + mesh_layout_kwargs=dict(mesh_node_spacing=3), + m2m_connectivity_kwargs=dict(pattern="8-star"), + g2m_connectivity="nearest_neighbour", + m2g_connectivity="nearest_neighbour", + ) + assert isinstance(graph_4, nx.DiGraph) + assert isinstance(graph_8, nx.DiGraph) + assert len(graph_4.edges) < len(graph_8.edges) + + def test_missing_mesh_node_spacing_raises(self): + xy = test_utils.create_fake_xy(N=32) + with pytest.raises(ValueError, match="mesh_node_spacing"): + wmg.create.create_all_graph_components( + coords=xy, + m2m_connectivity="flat", + mesh_layout="rectilinear", + mesh_layout_kwargs={}, + g2m_connectivity="nearest_neighbour", + m2g_connectivity="nearest_neighbour", + ) + + +class TestNewAPIFlatMultiscale: + """Tests for create_all_graph_components with new API (flat_multiscale).""" + + def test_flat_multiscale_with_new_api(self): + xy = test_utils.create_fake_xy(N=32) + graph = wmg.create.create_all_graph_components( + coords=xy, + m2m_connectivity="flat_multiscale", + mesh_layout="rectilinear", + mesh_layout_kwargs=dict( + mesh_node_spacing=3, + refinement_factor=3, + max_num_refinement_levels=3, + ), + m2m_connectivity_kwargs=dict( + pattern="8-star", + ), + g2m_connectivity="within_radius", + m2g_connectivity="nearest_neighbours", + g2m_connectivity_kwargs=dict(rel_max_dist=0.51), + m2g_connectivity_kwargs=dict(max_num_neighbours=4), + ) + assert isinstance(graph, nx.DiGraph) + + def test_flat_multiscale_4star_pattern(self): + xy = test_utils.create_fake_xy(N=32) + graph = wmg.create.create_all_graph_components( + coords=xy, + m2m_connectivity="flat_multiscale", + mesh_layout="rectilinear", + mesh_layout_kwargs=dict( + mesh_node_spacing=3, + refinement_factor=3, + max_num_refinement_levels=3, + ), + m2m_connectivity_kwargs=dict( + pattern="4-star", + ), + g2m_connectivity="nearest_neighbour", + m2g_connectivity="nearest_neighbour", + ) + assert isinstance(graph, nx.DiGraph) + + +class TestNewAPIHierarchical: + """Tests for create_all_graph_components with new API (hierarchical).""" + + def test_hierarchical_with_new_api(self): + xy = test_utils.create_fake_xy(N=32) + graph = wmg.create.create_all_graph_components( + coords=xy, + m2m_connectivity="hierarchical", + mesh_layout="rectilinear", + mesh_layout_kwargs=dict( + mesh_node_spacing=3, + refinement_factor=3, + max_num_refinement_levels=3, + ), + m2m_connectivity_kwargs=dict( + intra_level=dict(pattern="8-star"), + inter_level=dict(pattern="nearest", k=1), + ), + g2m_connectivity="within_radius", + m2g_connectivity="nearest_neighbours", + g2m_connectivity_kwargs=dict(rel_max_dist=0.51), + m2g_connectivity_kwargs=dict(max_num_neighbours=4), + ) + assert isinstance(graph, nx.DiGraph) + + def test_hierarchical_4star_intra(self): + xy = test_utils.create_fake_xy(N=32) + graph = wmg.create.create_all_graph_components( + coords=xy, + m2m_connectivity="hierarchical", + mesh_layout="rectilinear", + mesh_layout_kwargs=dict( + mesh_node_spacing=3, + refinement_factor=3, + max_num_refinement_levels=3, + ), + m2m_connectivity_kwargs=dict( + intra_level=dict(pattern="4-star"), + inter_level=dict(pattern="nearest", k=1), + ), + g2m_connectivity="nearest_neighbour", + m2g_connectivity="nearest_neighbour", + ) + assert isinstance(graph, nx.DiGraph) + + def test_hierarchical_k3_nearest(self): + xy = test_utils.create_fake_xy(N=32) + graph = wmg.create.create_all_graph_components( + coords=xy, + m2m_connectivity="hierarchical", + mesh_layout="rectilinear", + mesh_layout_kwargs=dict( + mesh_node_spacing=3, + refinement_factor=3, + max_num_refinement_levels=3, + ), + m2m_connectivity_kwargs=dict( + intra_level=dict(pattern="8-star"), + inter_level=dict(pattern="nearest", k=3), + ), + g2m_connectivity="nearest_neighbour", + m2g_connectivity="nearest_neighbour", + ) + assert isinstance(graph, nx.DiGraph) + + +# ==================== +# Backward compatibility tests +# ==================== + + +class TestBackwardCompatibility: + """Tests that old-style kwargs still work with deprecation warnings.""" + + def test_old_style_flat_with_mesh_node_distance(self): + xy = test_utils.create_fake_xy(N=32) + log_output = io.StringIO() + handler_id = logger.add(log_output, format="{message}", level="WARNING") + try: + graph = wmg.create.create_all_graph_components( + coords=xy, + m2m_connectivity="flat", + mesh_layout="rectilinear", + m2m_connectivity_kwargs=dict(mesh_node_distance=3), + g2m_connectivity="nearest_neighbour", + m2g_connectivity="nearest_neighbour", + ) + finally: + logger.remove(handler_id) + assert "mesh_node_distance" in log_output.getvalue() + assert isinstance(graph, nx.DiGraph) + + def test_old_style_flat_multiscale(self): + xy = test_utils.create_fake_xy(N=32) + log_output = io.StringIO() + handler_id = logger.add(log_output, format="{message}", level="WARNING") + try: + graph = wmg.create.create_all_graph_components( + coords=xy, + m2m_connectivity="flat_multiscale", + mesh_layout="rectilinear", + m2m_connectivity_kwargs=dict( + mesh_node_distance=3, + level_refinement_factor=3, + max_num_levels=3, + ), + g2m_connectivity="nearest_neighbour", + m2g_connectivity="nearest_neighbour", + ) + finally: + logger.remove(handler_id) + log_text = log_output.getvalue() + assert "mesh_node_distance" in log_text + assert "level_refinement_factor" in log_text + assert "max_num_levels" in log_text + assert isinstance(graph, nx.DiGraph) + + def test_old_style_hierarchical(self): + xy = test_utils.create_fake_xy(N=32) + log_output = io.StringIO() + handler_id = logger.add(log_output, format="{message}", level="WARNING") + try: + graph = wmg.create.create_all_graph_components( + coords=xy, + m2m_connectivity="hierarchical", + mesh_layout="rectilinear", + m2m_connectivity_kwargs=dict( + mesh_node_distance=3, + level_refinement_factor=3, + max_num_levels=3, + ), + g2m_connectivity="nearest_neighbour", + m2g_connectivity="nearest_neighbour", + ) + finally: + logger.remove(handler_id) + log_text = log_output.getvalue() + assert "mesh_node_distance" in log_text + assert "level_refinement_factor" in log_text + assert "max_num_levels" in log_text + assert isinstance(graph, nx.DiGraph) + + def test_kwargs_dict_not_mutated(self): + """Verify that passing dict kwargs doesn't mutate the caller's dict.""" + xy = test_utils.create_fake_xy(N=32) + original_kwargs = dict( + mesh_node_distance=3, + level_refinement_factor=3, + max_num_levels=3, + ) + kwargs_copy = original_kwargs.copy() + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + wmg.create.create_all_graph_components( + coords=xy, + m2m_connectivity="hierarchical", + mesh_layout="rectilinear", + m2m_connectivity_kwargs=original_kwargs, + g2m_connectivity="nearest_neighbour", + m2g_connectivity="nearest_neighbour", + ) + + # The original dict should not be modified + assert original_kwargs == kwargs_copy + + +# ==================== +# Error handling tests +# ==================== + + +class TestErrorHandling: + """Tests for proper error handling.""" + + def test_unsupported_mesh_layout_raises(self): + xy = test_utils.create_fake_xy(N=32) + with pytest.raises(NotImplementedError, match="not yet supported"): + wmg.create.create_all_graph_components( + coords=xy, + m2m_connectivity="flat", + mesh_layout="triangular", + mesh_layout_kwargs=dict(mesh_node_spacing=3), + g2m_connectivity="nearest_neighbour", + m2g_connectivity="nearest_neighbour", + ) + + def test_unsupported_m2m_connectivity_raises(self): + xy = test_utils.create_fake_xy(N=32) + with pytest.raises(NotImplementedError, match="not implemented"): + wmg.create.create_all_graph_components( + coords=xy, + m2m_connectivity="some_unknown", + mesh_layout="rectilinear", + mesh_layout_kwargs=dict(mesh_node_spacing=3), + g2m_connectivity="nearest_neighbour", + m2g_connectivity="nearest_neighbour", + ) + + def test_mesh_node_spacing_too_large_raises(self): + xy = test_utils.create_fake_xy(N=10) + with pytest.raises(ValueError, match="too large"): + wmg.create.create_all_graph_components( + coords=xy, + m2m_connectivity="flat", + mesh_layout="rectilinear", + mesh_layout_kwargs=dict(mesh_node_spacing=100), + g2m_connectivity="nearest_neighbour", + m2g_connectivity="nearest_neighbour", + ) + + +# ==================== +# Equivalence tests: new API == old wrappers +# ==================== + + +class TestEquivalence: + """Verify that the new API produces equivalent results to the old wrappers.""" + + def test_keisler_archetype_matches_new_api(self): + """The keisler archetype function should produce the same result as + calling create_all_graph_components with the new API directly.""" + xy = test_utils.create_fake_xy(N=32) + + graph_archetype = wmg.create.archetype.create_keisler_graph( + coords=xy, mesh_node_distance=3 + ) + + graph_new_api = wmg.create.create_all_graph_components( + coords=xy, + m2m_connectivity="flat", + mesh_layout="rectilinear", + mesh_layout_kwargs=dict(mesh_node_spacing=3), + m2m_connectivity_kwargs=dict(pattern="8-star"), + g2m_connectivity="within_radius", + m2g_connectivity="nearest_neighbours", + g2m_connectivity_kwargs=dict(rel_max_dist=0.51), + m2g_connectivity_kwargs=dict(max_num_neighbours=4), + ) + + assert len(graph_archetype.nodes) == len(graph_new_api.nodes) + assert len(graph_archetype.edges) == len(graph_new_api.edges) + + def test_graphcast_archetype_matches_new_api(self): + xy = test_utils.create_fake_xy(N=32) + + graph_archetype = wmg.create.archetype.create_graphcast_graph( + coords=xy, + mesh_node_distance=3, + level_refinement_factor=3, + max_num_levels=3, + ) + + graph_new_api = wmg.create.create_all_graph_components( + coords=xy, + m2m_connectivity="flat_multiscale", + mesh_layout="rectilinear", + mesh_layout_kwargs=dict( + mesh_node_spacing=3, + refinement_factor=3, + max_num_refinement_levels=3, + ), + m2m_connectivity_kwargs=dict( + pattern="8-star", + ), + g2m_connectivity="within_radius", + m2g_connectivity="nearest_neighbours", + g2m_connectivity_kwargs=dict(rel_max_dist=0.51), + m2g_connectivity_kwargs=dict(max_num_neighbours=4), + ) + + assert len(graph_archetype.nodes) == len(graph_new_api.nodes) + assert len(graph_archetype.edges) == len(graph_new_api.edges) + + def test_oskarsson_archetype_matches_new_api(self): + xy = test_utils.create_fake_xy(N=32) + + graph_archetype = wmg.create.archetype.create_oskarsson_hierarchical_graph( + coords=xy, + mesh_node_distance=3, + level_refinement_factor=3, + max_num_levels=3, + ) + + graph_new_api = wmg.create.create_all_graph_components( + coords=xy, + m2m_connectivity="hierarchical", + mesh_layout="rectilinear", + mesh_layout_kwargs=dict( + mesh_node_spacing=3, + refinement_factor=3, + max_num_refinement_levels=3, + ), + m2m_connectivity_kwargs=dict( + intra_level=dict(pattern="8-star"), + inter_level=dict(pattern="nearest", k=1), + ), + g2m_connectivity="within_radius", + m2g_connectivity="nearest_neighbours", + g2m_connectivity_kwargs=dict(rel_max_dist=0.51), + m2g_connectivity_kwargs=dict(max_num_neighbours=4), + ) + + assert len(graph_archetype.nodes) == len(graph_new_api.nodes) + assert len(graph_archetype.edges) == len(graph_new_api.edges) + + +# ==================== +# Edge case tests +# ==================== + + +class TestCoordinateCreationEdgeCases: + """Edge cases for coordinate creation step.""" + + def test_minimum_grid_2x2(self): + """Smallest possible grid: 2x2 nodes.""" + xy = test_utils.create_fake_xy(N=10) + G = create_single_level_2d_mesh_primitive(xy, nx=2, ny=2) + assert len(G.nodes) == 4 + # 2x2 grid: cardinal edges = 2*(2*1) = 4, diagonal edges = 2*(1*1) = 2 + cardinal = sum( + 1 for _, _, d in G.edges(data=True) if d["adjacency_type"] == "cardinal" + ) + diagonal = sum( + 1 for _, _, d in G.edges(data=True) if d["adjacency_type"] == "diagonal" + ) + assert cardinal == 4 + assert diagonal == 2 + + def test_single_row_grid(self): + """Grid with only 1 row (nx=5, ny=1).""" + xy = test_utils.create_fake_xy(N=10) + G = create_single_level_2d_mesh_primitive(xy, nx=5, ny=1) + assert len(G.nodes) == 5 + # 5x1 grid: only horizontal cardinal edges, no diagonals + cardinal = sum( + 1 for _, _, d in G.edges(data=True) if d["adjacency_type"] == "cardinal" + ) + diagonal = sum( + 1 for _, _, d in G.edges(data=True) if d["adjacency_type"] == "diagonal" + ) + assert cardinal == 4 # 5-1 = 4 horizontal edges + assert diagonal == 0 + + def test_single_column_grid(self): + """Grid with only 1 column (nx=1, ny=5).""" + xy = test_utils.create_fake_xy(N=10) + G = create_single_level_2d_mesh_primitive(xy, nx=1, ny=5) + assert len(G.nodes) == 5 + cardinal = sum( + 1 for _, _, d in G.edges(data=True) if d["adjacency_type"] == "cardinal" + ) + diagonal = sum( + 1 for _, _, d in G.edges(data=True) if d["adjacency_type"] == "diagonal" + ) + assert cardinal == 4 # 5-1 = 4 vertical edges + assert diagonal == 0 + + def test_1x1_grid_no_edges(self): + """Grid with a single node (1x1): should have no edges.""" + xy = test_utils.create_fake_xy(N=10) + G = create_single_level_2d_mesh_primitive(xy, nx=1, ny=1) + assert len(G.nodes) == 1 + assert len(G.edges) == 0 + + def test_large_grid(self): + """Larger grid should still work correctly.""" + xy = test_utils.create_fake_xy(N=50) + G = create_single_level_2d_mesh_primitive(xy, nx=10, ny=10) + assert len(G.nodes) == 100 + expected_cardinal = 2 * (10 * 9) # 180 + expected_diagonal = 2 * (9 * 9) # 162 + cardinal = sum( + 1 for _, _, d in G.edges(data=True) if d["adjacency_type"] == "cardinal" + ) + diagonal = sum( + 1 for _, _, d in G.edges(data=True) if d["adjacency_type"] == "diagonal" + ) + assert cardinal == expected_cardinal + assert diagonal == expected_diagonal + + def test_node_positions_within_bounds(self): + """Node positions should be within the xy bounds.""" + xy = test_utils.create_fake_xy(N=20) + G = create_single_level_2d_mesh_primitive(xy, nx=5, ny=5) + x_min, y_min = np.amin(xy, axis=0) + x_max, y_max = np.amax(xy, axis=0) + for node in G.nodes: + pos = G.nodes[node]["pos"] + assert pos[0] >= x_min and pos[0] <= x_max + assert pos[1] >= y_min and pos[1] <= y_max + + def test_multirange_with_max_levels_1(self): + """Multi-range with max_num_levels=1 should return single-level list.""" + xy = test_utils.create_fake_xy(N=30) + G_list = create_multirange_2d_mesh_primitives( + max_num_levels=1, xy=xy, mesh_node_spacing=3, interlevel_refinement_factor=3 + ) + assert len(G_list) == 1 + assert G_list[0].graph["level"] == 0 + + def test_multirange_with_none_max_levels(self): + """max_num_levels=None should auto-compute levels.""" + xy = test_utils.create_fake_xy(N=30) + G_list = create_multirange_2d_mesh_primitives( + max_num_levels=None, + xy=xy, + mesh_node_spacing=3, + interlevel_refinement_factor=3, + ) + assert isinstance(G_list, list) + assert len(G_list) >= 1 + + def test_multirange_refinement_factor_5(self): + """Test with a different refinement factor.""" + xy = test_utils.create_fake_xy(N=50) + G_list = create_multirange_2d_mesh_primitives( + max_num_levels=3, xy=xy, mesh_node_spacing=3, interlevel_refinement_factor=5 + ) + if len(G_list) >= 2: + for i in range(len(G_list) - 1): + assert len(G_list[i].nodes) > len(G_list[i + 1].nodes) + + +class TestConnectivityCreationEdgeCases: + """Edge cases for connectivity creation step.""" + + def test_directed_graph_from_1x1(self): + """Creating directed graph from a single-node coordinate graph.""" + xy = test_utils.create_fake_xy(N=10) + G_coords = create_single_level_2d_mesh_primitive(xy, nx=1, ny=1) + G_directed = create_directed_mesh_graph(G_coords, pattern="8-star") + assert isinstance(G_directed, nx.DiGraph) + assert len(G_directed.nodes) == 1 + assert len(G_directed.edges) == 0 + + def test_directed_graph_from_2x1(self): + """Creating directed graph from a 2x1 grid.""" + xy = test_utils.create_fake_xy(N=10) + G_coords = create_single_level_2d_mesh_primitive(xy, nx=2, ny=1) + G_4star = create_directed_mesh_graph(G_coords, pattern="4-star") + G_8star = create_directed_mesh_graph(G_coords, pattern="8-star") + # 2x1: 1 edge, both patterns should have same (no diagonals possible) + assert len(G_4star.edges) == 2 # bidirectional + assert len(G_8star.edges) == 2 + + def test_4star_is_subset_of_8star(self): + """All edges in 4-star should exist in 8-star.""" + xy = test_utils.create_fake_xy(N=10) + G_coords = create_single_level_2d_mesh_primitive(xy, nx=5, ny=5) + G_4star = create_directed_mesh_graph(G_coords, pattern="4-star") + G_8star = create_directed_mesh_graph(G_coords, pattern="8-star") + for u, v in G_4star.edges(): + assert G_8star.has_edge(u, v), f"4-star edge ({u},{v}) missing from 8-star" + + def test_edge_lengths_are_positive(self): + """All edge lengths should be positive.""" + xy = test_utils.create_fake_xy(N=10) + G_coords = create_single_level_2d_mesh_primitive(xy, nx=4, ny=4) + G = create_directed_mesh_graph(G_coords, pattern="8-star") + for u, v, d in G.edges(data=True): + assert d["len"] > 0, f"Edge ({u},{v}) has non-positive length" + + def test_vdiff_antisymmetric(self): + """vdiff(u,v) should be -vdiff(v,u).""" + xy = test_utils.create_fake_xy(N=10) + G_coords = create_single_level_2d_mesh_primitive(xy, nx=4, ny=4) + G = create_directed_mesh_graph(G_coords, pattern="8-star") + for u, v in G.edges(): + if G.has_edge(v, u): + np.testing.assert_allclose( + G.edges[u, v]["vdiff"], + -G.edges[v, u]["vdiff"], + err_msg=f"vdiff not antisymmetric for ({u},{v})", + ) + + def test_edge_len_matches_vdiff_norm(self): + """Edge length should equal the norm of vdiff.""" + xy = test_utils.create_fake_xy(N=10) + G_coords = create_single_level_2d_mesh_primitive(xy, nx=4, ny=4) + G = create_directed_mesh_graph(G_coords, pattern="8-star") + for u, v, d in G.edges(data=True): + expected_len = np.sqrt(np.sum(d["vdiff"] ** 2)) + np.testing.assert_allclose( + d["len"], + expected_len, + err_msg=f"Edge ({u},{v}) len doesn't match vdiff norm", + ) + + def test_flat_multiscale_single_level_input(self): + """Flat multiscale with a single-level list should work (degenerate case).""" + xy = test_utils.create_fake_xy(N=30) + G_coords_list = create_multirange_2d_mesh_primitives( + max_num_levels=1, xy=xy, mesh_node_spacing=3, interlevel_refinement_factor=3 + ) + G = create_flat_multiscale_from_coordinates(G_coords_list, pattern="8-star") + assert isinstance(G, nx.DiGraph) + assert len(G.nodes) > 0 + + def test_flat_multiscale_4star_vs_8star_edge_count(self): + """4-star flat_multiscale should have fewer edges than 8-star.""" + xy = test_utils.create_fake_xy(N=30) + G_coords_list = create_multirange_2d_mesh_primitives( + max_num_levels=3, xy=xy, mesh_node_spacing=3, interlevel_refinement_factor=3 + ) + G_4 = create_flat_multiscale_from_coordinates(G_coords_list, pattern="4-star") + G_8 = create_flat_multiscale_from_coordinates(G_coords_list, pattern="8-star") + assert len(G_4.edges) < len(G_8.edges) + + def test_hierarchical_single_level_raises(self): + """Hierarchical with only 1 level should raise ValueError.""" + xy = test_utils.create_fake_xy(N=30) + G_coords_list = create_multirange_2d_mesh_primitives( + max_num_levels=1, xy=xy, mesh_node_spacing=3, interlevel_refinement_factor=3 + ) + with pytest.raises(ValueError, match="At least two mesh levels"): + create_hierarchical_from_coordinates(G_coords_list) + + def test_hierarchical_edge_direction_attributes(self): + """Every edge in hierarchical graph must have a 'direction' attribute.""" + xy = test_utils.create_fake_xy(N=30) + G_coords_list = create_multirange_2d_mesh_primitives( + max_num_levels=3, xy=xy, mesh_node_spacing=3, interlevel_refinement_factor=3 + ) + G = create_hierarchical_from_coordinates(G_coords_list) + for u, v, d in G.edges(data=True): + assert "direction" in d, f"Edge ({u},{v}) missing 'direction'" + assert d["direction"] in ("same", "up", "down") + + def test_hierarchical_up_down_symmetry(self): + """For each 'down' edge (u,v), there should be an 'up' edge (v,u).""" + xy = test_utils.create_fake_xy(N=30) + G_coords_list = create_multirange_2d_mesh_primitives( + max_num_levels=3, xy=xy, mesh_node_spacing=3, interlevel_refinement_factor=3 + ) + G = create_hierarchical_from_coordinates(G_coords_list) + for u, v, d in G.edges(data=True): + if d.get("direction") == "down": + assert G.has_edge(v, u), f"Missing 'up' edge for 'down' ({u},{v})" + assert G.edges[v, u]["direction"] == "up" + + +class TestAPIEdgeCases: + """Edge cases for the public create_all_graph_components API.""" + + def test_flat_default_pattern_is_8star(self): + """When no m2m_connectivity_kwargs given, flat should default to 8-star.""" + xy = test_utils.create_fake_xy(N=32) + graph_default = wmg.create.create_all_graph_components( + coords=xy, + m2m_connectivity="flat", + mesh_layout="rectilinear", + mesh_layout_kwargs=dict(mesh_node_spacing=3), + g2m_connectivity="nearest_neighbour", + m2g_connectivity="nearest_neighbour", + ) + graph_8star = wmg.create.create_all_graph_components( + coords=xy, + m2m_connectivity="flat", + mesh_layout="rectilinear", + mesh_layout_kwargs=dict(mesh_node_spacing=3), + m2m_connectivity_kwargs=dict(pattern="8-star"), + g2m_connectivity="nearest_neighbour", + m2g_connectivity="nearest_neighbour", + ) + assert len(graph_default.edges) == len(graph_8star.edges) + + def test_flat_multiscale_default_pattern_is_8star(self): + """When no m2m_connectivity_kwargs given, flat_multiscale defaults to 8-star.""" + xy = test_utils.create_fake_xy(N=32) + graph_default = wmg.create.create_all_graph_components( + coords=xy, + m2m_connectivity="flat_multiscale", + mesh_layout="rectilinear", + mesh_layout_kwargs=dict( + mesh_node_spacing=3, + refinement_factor=3, + max_num_refinement_levels=3, + ), + g2m_connectivity="nearest_neighbour", + m2g_connectivity="nearest_neighbour", + ) + graph_8star = wmg.create.create_all_graph_components( + coords=xy, + m2m_connectivity="flat_multiscale", + mesh_layout="rectilinear", + mesh_layout_kwargs=dict( + mesh_node_spacing=3, + refinement_factor=3, + max_num_refinement_levels=3, + ), + m2m_connectivity_kwargs=dict(pattern="8-star"), + g2m_connectivity="nearest_neighbour", + m2g_connectivity="nearest_neighbour", + ) + assert len(graph_default.edges) == len(graph_8star.edges) + + def test_hierarchical_default_kwargs(self): + """When no m2m_connectivity_kwargs given, hierarchical has sensible defaults.""" + xy = test_utils.create_fake_xy(N=32) + graph_default = wmg.create.create_all_graph_components( + coords=xy, + m2m_connectivity="hierarchical", + mesh_layout="rectilinear", + mesh_layout_kwargs=dict( + mesh_node_spacing=3, + refinement_factor=3, + max_num_refinement_levels=3, + ), + g2m_connectivity="nearest_neighbour", + m2g_connectivity="nearest_neighbour", + ) + graph_explicit = wmg.create.create_all_graph_components( + coords=xy, + m2m_connectivity="hierarchical", + mesh_layout="rectilinear", + mesh_layout_kwargs=dict( + mesh_node_spacing=3, + refinement_factor=3, + max_num_refinement_levels=3, + ), + m2m_connectivity_kwargs=dict( + intra_level=dict(pattern="8-star"), + inter_level=dict(pattern="nearest", k=1), + ), + g2m_connectivity="nearest_neighbour", + m2g_connectivity="nearest_neighbour", + ) + assert len(graph_default.edges) == len(graph_explicit.edges) + + def test_mesh_layout_is_required(self): + """When mesh_layout not specified, it should raise TypeError.""" + xy = test_utils.create_fake_xy(N=32) + with pytest.raises(TypeError, match="mesh_layout"): + wmg.create.create_all_graph_components( + coords=xy, + m2m_connectivity="flat", + mesh_layout_kwargs=dict(mesh_node_spacing=3), + g2m_connectivity="nearest_neighbour", + m2g_connectivity="nearest_neighbour", + ) + + def test_return_components_flat(self): + """return_components=True should return dict with g2m, m2m, m2g.""" + xy = test_utils.create_fake_xy(N=32) + components = wmg.create.create_all_graph_components( + coords=xy, + m2m_connectivity="flat", + mesh_layout="rectilinear", + mesh_layout_kwargs=dict(mesh_node_spacing=3), + m2m_connectivity_kwargs=dict(pattern="8-star"), + g2m_connectivity="nearest_neighbour", + m2g_connectivity="nearest_neighbour", + return_components=True, + ) + assert isinstance(components, dict) + assert "g2m" in components + assert "m2m" in components + assert "m2g" in components + for name, g in components.items(): + assert isinstance(g, nx.DiGraph) + + def test_return_components_hierarchical(self): + """return_components=True for hierarchical should contain 3 components.""" + xy = test_utils.create_fake_xy(N=32) + components = wmg.create.create_all_graph_components( + coords=xy, + m2m_connectivity="hierarchical", + mesh_layout="rectilinear", + mesh_layout_kwargs=dict( + mesh_node_spacing=3, + refinement_factor=3, + max_num_refinement_levels=3, + ), + m2m_connectivity_kwargs=dict( + intra_level=dict(pattern="8-star"), + inter_level=dict(pattern="nearest", k=1), + ), + g2m_connectivity="nearest_neighbour", + m2g_connectivity="nearest_neighbour", + return_components=True, + ) + assert isinstance(components, dict) + assert "g2m" in components + assert "m2m" in components + assert "m2g" in components + + def test_return_components_flat_multiscale(self): + """return_components=True for flat_multiscale.""" + xy = test_utils.create_fake_xy(N=32) + components = wmg.create.create_all_graph_components( + coords=xy, + m2m_connectivity="flat_multiscale", + mesh_layout="rectilinear", + mesh_layout_kwargs=dict( + mesh_node_spacing=3, + refinement_factor=3, + max_num_refinement_levels=3, + ), + m2m_connectivity_kwargs=dict(pattern="8-star"), + g2m_connectivity="nearest_neighbour", + m2g_connectivity="nearest_neighbour", + return_components=True, + ) + assert isinstance(components, dict) + assert set(components.keys()) == {"g2m", "m2m", "m2g"} + + def test_flat_multiscale_no_sub_dicts_interface(self): + """Ensure flat_multiscale accepts a simple pattern arg, not sub-dicts.""" + xy = test_utils.create_fake_xy(N=32) + graph = wmg.create.create_all_graph_components( + coords=xy, + m2m_connectivity="flat_multiscale", + mesh_layout="rectilinear", + mesh_layout_kwargs=dict( + mesh_node_spacing=3, + refinement_factor=3, + max_num_refinement_levels=3, + ), + m2m_connectivity_kwargs=dict(pattern="8-star"), + g2m_connectivity="nearest_neighbour", + m2g_connectivity="nearest_neighbour", + ) + assert isinstance(graph, nx.DiGraph) + + def test_decode_mask_with_new_api(self): + """decode_mask should work correctly with the new API.""" + xy = test_utils.create_fake_xy(N=32) + n_points = len(xy) + mask = [True] * (n_points // 2) + [False] * (n_points - n_points // 2) + graph = wmg.create.create_all_graph_components( + coords=xy, + m2m_connectivity="flat", + mesh_layout="rectilinear", + mesh_layout_kwargs=dict(mesh_node_spacing=3), + m2m_connectivity_kwargs=dict(pattern="8-star"), + g2m_connectivity="nearest_neighbour", + m2g_connectivity="nearest_neighbour", + decode_mask=mask, + ) + assert isinstance(graph, nx.DiGraph) + + +class TestBackwardCompatEdgeCases: + """Advanced backward compatibility edge cases.""" + + def test_old_kwargs_with_flat_multiscale_compat(self): + """Old-style flat_multiscale kwargs should trigger deprecation warnings + (via loguru) and be migrated to the new names.""" + xy = test_utils.create_fake_xy(N=32) + log_output = io.StringIO() + handler_id = logger.add(log_output, format="{message}", level="WARNING") + try: + graph = wmg.create.create_all_graph_components( + coords=xy, + m2m_connectivity="flat_multiscale", + mesh_layout="rectilinear", + m2m_connectivity_kwargs=dict( + mesh_node_distance=3, + level_refinement_factor=3, + max_num_levels=3, + ), + g2m_connectivity="nearest_neighbour", + m2g_connectivity="nearest_neighbour", + ) + finally: + logger.remove(handler_id) + log_text = log_output.getvalue() + assert "mesh_node_spacing" in log_text + assert "refinement_factor" in log_text + assert "max_num_refinement_levels" in log_text + assert isinstance(graph, nx.DiGraph) + + +class TestGraphStructuralProperties: + """Tests verifying structural properties of generated graphs.""" + + def test_all_mesh_nodes_have_pos(self): + """Every node in the final graph should have 'pos' attribute.""" + xy = test_utils.create_fake_xy(N=32) + graph = wmg.create.create_all_graph_components( + coords=xy, + m2m_connectivity="flat", + mesh_layout="rectilinear", + mesh_layout_kwargs=dict(mesh_node_spacing=3), + g2m_connectivity="nearest_neighbour", + m2g_connectivity="nearest_neighbour", + ) + for node in graph.nodes: + assert "pos" in graph.nodes[node], f"Node {node} missing 'pos'" + + def test_all_edges_have_component(self): + """Every edge should have a 'component' attribute ('g2m', 'm2m', 'm2g').""" + xy = test_utils.create_fake_xy(N=32) + graph = wmg.create.create_all_graph_components( + coords=xy, + m2m_connectivity="flat", + mesh_layout="rectilinear", + mesh_layout_kwargs=dict(mesh_node_spacing=3), + g2m_connectivity="nearest_neighbour", + m2g_connectivity="nearest_neighbour", + ) + for u, v, d in graph.edges(data=True): + assert "component" in d, f"Edge ({u},{v}) missing 'component'" + assert d["component"] in ("g2m", "m2m", "m2g") + + def test_all_edges_have_len_and_vdiff(self): + """Every edge should have 'len' and 'vdiff' attributes.""" + xy = test_utils.create_fake_xy(N=32) + graph = wmg.create.create_all_graph_components( + coords=xy, + m2m_connectivity="flat", + mesh_layout="rectilinear", + mesh_layout_kwargs=dict(mesh_node_spacing=3), + g2m_connectivity="nearest_neighbour", + m2g_connectivity="nearest_neighbour", + ) + for u, v, d in graph.edges(data=True): + assert "len" in d, f"Edge ({u},{v}) missing 'len'" + assert "vdiff" in d, f"Edge ({u},{v}) missing 'vdiff'" + + def test_graph_is_directed(self): + """Final graph should always be a DiGraph.""" + xy = test_utils.create_fake_xy(N=32) + for connectivity in ["flat"]: + graph = wmg.create.create_all_graph_components( + coords=xy, + m2m_connectivity=connectivity, + mesh_layout="rectilinear", + mesh_layout_kwargs=dict(mesh_node_spacing=3), + g2m_connectivity="nearest_neighbour", + m2g_connectivity="nearest_neighbour", + ) + assert isinstance(graph, nx.DiGraph) + + def test_flat_4star_strictly_fewer_m2m_edges(self): + """4-star flat should have strictly fewer m2m edges than 8-star.""" + xy = test_utils.create_fake_xy(N=32) + components_4 = wmg.create.create_all_graph_components( + coords=xy, + m2m_connectivity="flat", + mesh_layout="rectilinear", + mesh_layout_kwargs=dict(mesh_node_spacing=3), + m2m_connectivity_kwargs=dict(pattern="4-star"), + g2m_connectivity="nearest_neighbour", + m2g_connectivity="nearest_neighbour", + return_components=True, + ) + components_8 = wmg.create.create_all_graph_components( + coords=xy, + m2m_connectivity="flat", + mesh_layout="rectilinear", + mesh_layout_kwargs=dict(mesh_node_spacing=3), + m2m_connectivity_kwargs=dict(pattern="8-star"), + g2m_connectivity="nearest_neighbour", + m2g_connectivity="nearest_neighbour", + return_components=True, + ) + # m2m component specifically should differ + assert len(components_4["m2m"].edges) < len(components_8["m2m"].edges) + # g2m and m2g should be the same (same grid spacing, same connectivity) + assert len(components_4["g2m"].edges) == len(components_8["g2m"].edges) + assert len(components_4["m2g"].edges) == len(components_8["m2g"].edges) + + def test_hierarchical_has_same_up_down_edge_count(self): + """Hierarchical graph should have equal number of up and down edges.""" + xy = test_utils.create_fake_xy(N=32) + graph = wmg.create.create_all_graph_components( + coords=xy, + m2m_connectivity="hierarchical", + mesh_layout="rectilinear", + mesh_layout_kwargs=dict( + mesh_node_spacing=3, + refinement_factor=3, + max_num_refinement_levels=3, + ), + m2m_connectivity_kwargs=dict( + intra_level=dict(pattern="8-star"), + inter_level=dict(pattern="nearest", k=1), + ), + g2m_connectivity="nearest_neighbour", + m2g_connectivity="nearest_neighbour", + return_components=True, + ) + m2m = graph["m2m"] + up_count = sum( + 1 for _, _, d in m2m.edges(data=True) if d.get("direction") == "up" + ) + down_count = sum( + 1 for _, _, d in m2m.edges(data=True) if d.get("direction") == "down" + ) + assert ( + up_count == down_count + ), f"Up edges ({up_count}) != Down edges ({down_count})" + assert up_count > 0, "Should have at least some up/down edges"