diff --git a/CHANGELOG.md b/CHANGELOG.md index a729d95..4ad92df 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Changed + +- Auto-detect spatial metric from CRS, with degree-based external haversine inputs/outputs and internal radian conversion. [\#86](https://github.com/mllam/weather-model-graphs/pull/86), @FAbdullah17 + ### Added - Add Django-style graph filtering via `filter_graph`, for example to select diff --git a/docs/distance_metric_auto_detection.ipynb b/docs/distance_metric_auto_detection.ipynb new file mode 100644 index 0000000..d218950 --- /dev/null +++ b/docs/distance_metric_auto_detection.ipynb @@ -0,0 +1,361 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "419f7f83", + "metadata": {}, + "source": [ + "# Auto-Detection of Distance Metric from CRS\n", + "\n", + "This notebook demonstrates the `SpatialCoordinateValuesSelector` class introduced\n", + "in [issue #75](https://github.com/mllam/weather-model-graphs/issues/75).\n", + "\n", + "The class wraps a ball-tree to provide efficient nearest-neighbour and radius\n", + "queries using the correct distance metric for the coordinate reference system (CRS):\n", + "\n", + "| CRS type | Distance metric | When to use |\n", + "|---|---|---|\n", + "| Geographic (lat/lon) — e.g. `ccrs.PlateCarree()` | **Haversine** | Global / sparse grids |\n", + "| Projected — e.g. `ccrs.LambertConformal()` | **Euclidean** | Regional / LAM grids |\n", + "\n", + "Distances returned by the haversine path are in **degrees**; Euclidean distances are\n", + "in the same units as the input coordinates." + ] + }, + { + "cell_type": "markdown", + "id": "fd3d7fce", + "metadata": {}, + "source": [ + "## 1. Setup and Imports" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fe32be32", + "metadata": {}, + "outputs": [], + "source": [ + "import warnings\n", + "\n", + "import cartopy.crs as ccrs\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "\n", + "import weather_model_graphs as wmg\n", + "from weather_model_graphs.spatial import SpatialCoordinateValuesSelector\n", + "\n", + "print(\"weather_model_graphs imported OK\")\n", + "print(f\"SpatialCoordinateValuesSelector: {SpatialCoordinateValuesSelector}\")" + ] + }, + { + "cell_type": "markdown", + "id": "b0d01140", + "metadata": {}, + "source": [ + "## 2. Euclidean Distance Queries (Projected CRS)\n", + "\n", + "For a projected coordinate system the coordinates are Cartesian (e.g. metres).\n", + "We use the `\"euclidean\"` metric." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6c276b4e", + "metadata": {}, + "outputs": [], + "source": [ + "rng = np.random.default_rng(42)\n", + "\n", + "# 50 random Cartesian points in a 100 km × 100 km domain (units: metres)\n", + "n_pts = 50\n", + "projected_coords = rng.random((n_pts, 2)) * 1e5 # shape (50, 2)\n", + "\n", + "sel_euc = SpatialCoordinateValuesSelector(\"euclidean\", projected_coords)\n", + "print(f\"metric: {sel_euc.distance_metric}\")\n", + "\n", + "# ---- k-nearest-to ----\n", + "query_pt = np.array([5e4, 5e4]) # centre of the domain\n", + "idxs, dists = sel_euc.k_nearest_to(query_pt, k=5)\n", + "print(\"\\n5 nearest neighbours (Euclidean):\")\n", + "for rank, (i, d) in enumerate(zip(idxs, dists), 1):\n", + " print(\n", + " f\" rank {rank}: index={i:3d} dist={d/1e3:.2f} km coord={projected_coords[i]}\"\n", + " )\n", + "\n", + "# ---- within_radius ----\n", + "radius_m = 2e4 # 20 km\n", + "idxs_r, dists_r = sel_euc.within_radius(query_pt, radius=radius_m)\n", + "print(f\"\\nPoints within {radius_m/1e3:.0f} km of centre: {len(idxs_r)}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "79300d69", + "metadata": {}, + "outputs": [], + "source": [ + "fig, axes = plt.subplots(1, 2, figsize=(12, 5))\n", + "\n", + "for ax, (highlight_idxs, title, radius) in zip(\n", + " axes,\n", + " [\n", + " (idxs, \"k=5 nearest neighbours\", None),\n", + " (idxs_r, f\"within_radius = {radius_m/1e3:.0f} km\", radius_m),\n", + " ],\n", + "):\n", + " ax.scatter(\n", + " projected_coords[:, 0] / 1e3,\n", + " projected_coords[:, 1] / 1e3,\n", + " c=\"lightgrey\",\n", + " s=30,\n", + " zorder=2,\n", + " label=\"all points\",\n", + " )\n", + " ax.scatter(\n", + " projected_coords[highlight_idxs, 0] / 1e3,\n", + " projected_coords[highlight_idxs, 1] / 1e3,\n", + " c=\"steelblue\",\n", + " s=60,\n", + " zorder=3,\n", + " label=\"selected\",\n", + " )\n", + " ax.scatter(*query_pt / 1e3, c=\"red\", s=100, marker=\"*\", zorder=4, label=\"query\")\n", + " if radius is not None:\n", + " circle = plt.Circle(\n", + " query_pt / 1e3, radius / 1e3, fill=False, color=\"red\", lw=1.5, ls=\"--\"\n", + " )\n", + " ax.add_patch(circle)\n", + " ax.set_aspect(\"equal\")\n", + " ax.set_xlabel(\"x (km)\")\n", + " ax.set_ylabel(\"y (km)\")\n", + " ax.set_title(title)\n", + " ax.legend(loc=\"upper right\")\n", + "\n", + "fig.suptitle(\"Euclidean (projected CRS) distance queries\", fontsize=13)\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "9ebfe0fb", + "metadata": {}, + "source": [ + "## 3. Haversine Distance Queries (Geographic CRS)\n", + "\n", + "For a geographic CRS the coordinates are longitude / latitude in degrees.\n", + "We use the `\"haversine\"` metric, and **all distances are returned in metres**." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "871d59b9", + "metadata": {}, + "outputs": [], + "source": [ + "# 50 random lon/lat points over Europe (lon: -10 to 30, lat: 45 to 70)\n", + "geo_coords = np.column_stack(\n", + " [\n", + " rng.uniform(-10, 30, n_pts), # longitude\n", + " rng.uniform(45, 70, n_pts), # latitude\n", + " ]\n", + ")\n", + "\n", + "sel_hav = SpatialCoordinateValuesSelector(\"haversine\", geo_coords)\n", + "print(f\"metric: {sel_hav.distance_metric}\")\n", + "\n", + "# Known result: 1 degree of longitude at latitude 55° ≈ 63,800 m\n", + "# (cos(55°) * 111,195 m/deg ≈ 63,800 m)\n", + "query_geo = np.array([10.0, 55.0]) # lon=10°, lat=55°\n", + "idxs_h, dists_h = sel_hav.k_nearest_to(query_geo, k=5)\n", + "print(\"\\n5 nearest neighbours (Haversine, distances in degrees):\")\n", + "for rank, (i, d) in enumerate(zip(idxs_h, dists_h), 1):\n", + " print(\n", + " f\" rank {rank}: index={i:3d} dist={d:6.2f}° lon={geo_coords[i,0]:.2f}° lat={geo_coords[i,1]:.2f}°\"\n", + " )\n", + "\n", + "# Radius query: 5° around the query point\n", + "radius_deg = 5.0\n", + "idxs_hr, dists_hr = sel_hav.within_radius(query_geo, radius=radius_deg)\n", + "print(f\"\\nPoints within {radius_deg:.1f}°: {len(idxs_hr)}\")" + ] + }, + { + "cell_type": "markdown", + "id": "bcd4e330", + "metadata": {}, + "source": [ + "## 4. CRS-Based Automatic Metric Selection (`for_crs`)\n", + "\n", + "The class-method `SpatialCoordinateValuesSelector.for_crs(crs, coords)` inspects\n", + "`crs.is_geographic` and picks the correct metric automatically — no manual choice needed." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "65ec5fc4", + "metadata": {}, + "outputs": [], + "source": [ + "cases = {\n", + " \"PlateCarree (geographic)\": ccrs.PlateCarree(),\n", + " \"Mercator (geographic)\": ccrs.Mercator(),\n", + " \"LambertConformal (projected)\": ccrs.LambertConformal(),\n", + " \"Stereographic (projected)\": ccrs.Stereographic(),\n", + " \"Mollweide (projected)\": ccrs.Mollweide(),\n", + "}\n", + "\n", + "print(f\"{'CRS':<40} {'is_geographic':<15} {'auto-selected metric'}\")\n", + "print(\"-\" * 75)\n", + "for name, crs in cases.items():\n", + " sel = SpatialCoordinateValuesSelector.for_crs(crs, projected_coords)\n", + " print(f\"{name:<40} {str(crs.is_geographic):<15} {sel.distance_metric}\")" + ] + }, + { + "cell_type": "markdown", + "id": "efd3e377", + "metadata": {}, + "source": [ + "## 5. Integration with Graph Connectivity\n", + "\n", + "`SpatialCoordinateValuesSelector` is now used automatically inside\n", + "`wmg.create.create_all_graph_components()` for all nearest-neighbour and\n", + "radius queries. When you supply `graph_crs`, the correct metric is chosen\n", + "for you — no other changes are required at the call site." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "46b71d7f", + "metadata": {}, + "outputs": [], + "source": [ + "# --- Projected CRS (euclidean) ---\n", + "lon = np.linspace(0.0, 9.0, 8)\n", + "lat = np.linspace(50.0, 58.0, 8)\n", + "lo, la = np.meshgrid(lon, lat)\n", + "lonlat_coords = np.column_stack([lo.ravel(), la.ravel()])\n", + "\n", + "# Projected: convert to Oblique Mercator metres for demonstration\n", + "from pyproj import Transformer\n", + "\n", + "transformer = Transformer.from_crs(\"EPSG:4326\", \"EPSG:3857\", always_xy=True)\n", + "x, y = transformer.transform(lonlat_coords[:, 0], lonlat_coords[:, 1])\n", + "proj_coords = np.column_stack([x, y])\n", + "\n", + "# Build graph with projected CRS — euclidean metric is used automatically\n", + "import pyproj\n", + "\n", + "web_mercator = pyproj.CRS(\"EPSG:3857\")\n", + "G_proj = wmg.create.create_all_graph_components(\n", + " coords=proj_coords,\n", + " m2m_connectivity=\"flat\",\n", + " m2m_connectivity_kwargs=dict(mesh_node_distance=300000.0),\n", + " g2m_connectivity=\"nearest_neighbour\",\n", + " m2g_connectivity=\"nearest_neighbour\",\n", + " graph_crs=web_mercator,\n", + ")\n", + "lens_proj = [d[\"len\"] for _, _, d in G_proj.edges(data=True) if \"len\" in d]\n", + "print(\n", + " f\"Projected graph: {G_proj.number_of_nodes()} nodes, {G_proj.number_of_edges()} edges\"\n", + ")\n", + "print(\n", + " f\"Edge 'len' range: {min(lens_proj)/1e3:.1f} – {max(lens_proj)/1e3:.1f} km (euclidean, metres)\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "1f8660d6", + "metadata": {}, + "source": [ + "## 6. Rectilinear + Geographic CRS Warning\n", + "\n", + "When a rectilinear mesh is overlaid on geographic (lat/lon) coordinates a\n", + "`UserWarning` is raised automatically, because equally-spaced lon/lat values are\n", + "**not** equally spaced on a sphere." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "45835a0a", + "metadata": {}, + "outputs": [], + "source": [ + "# Trigger the warning intentionally by using a geographic CRS with a\n", + "# rectilinear (flat) mesh. The haversine metric is still used for all\n", + "# neighbour queries; the warning just flags the mesh *layout*.\n", + "\n", + "with warnings.catch_warnings(record=True) as caught:\n", + " warnings.simplefilter(\"always\")\n", + " G_geo = wmg.create.create_all_graph_components(\n", + " coords=lonlat_coords,\n", + " m2m_connectivity=\"flat\",\n", + " m2m_connectivity_kwargs=dict(mesh_node_distance=3.0),\n", + " g2m_connectivity=\"nearest_neighbour\",\n", + " m2g_connectivity=\"nearest_neighbour\",\n", + " graph_crs=ccrs.PlateCarree(),\n", + " )\n", + "\n", + "print(f\"Warnings raised: {len(caught)}\")\n", + "for w in caught:\n", + " print(f\"\\n[{w.category.__name__}] {w.message}\")\n", + "\n", + "lens_geo = [d[\"len\"] for _, _, d in G_geo.edges(data=True) if \"len\" in d]\n", + "print(\n", + " f\"\\nGeographic graph: {G_geo.number_of_nodes()} nodes, {G_geo.number_of_edges()} edges\"\n", + ")\n", + "print(\n", + " f\"Edge 'len' range: {min(lens_geo)/1e3:.1f} – {max(lens_geo)/1e3:.1f} km (haversine, metres)\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "2b15f521", + "metadata": {}, + "source": [ + "## 7. Running the Test Suite\n", + "\n", + "Run the new tests from the notebook. All tests should pass." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "16fdd506", + "metadata": {}, + "outputs": [], + "source": [ + "import subprocess, sys\n", + "\n", + "result = subprocess.run(\n", + " [sys.executable, \"-m\", \"pytest\", \"tests/test_spatial_index.py\", \"-v\", \"--tb=short\"],\n", + " capture_output=True,\n", + " text=True,\n", + " cwd=\"..\", # run from repo root (parent of docs/)\n", + ")\n", + "print(result.stdout)\n", + "if result.returncode != 0:\n", + " print(result.stderr)" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/pyproject.toml b/pyproject.toml index 01eb298..25d2788 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,6 +11,7 @@ dependencies = [ "networkx>=3.3", "scipy>=1.13.0", "pyproj>=3.7.0", + "scikit-learn>=1.3.0", ] requires-python = ">=3.10" readme = "README.md" diff --git a/src/weather_model_graphs/__init__.py b/src/weather_model_graphs/__init__.py index 273d4db..b1c92a8 100644 --- a/src/weather_model_graphs/__init__.py +++ b/src/weather_model_graphs/__init__.py @@ -11,3 +11,4 @@ replace_node_labels_with_unique_ids, split_graph_by_edge_attribute, ) +from .spatial import SpatialCoordinateValuesSelector diff --git a/src/weather_model_graphs/create/base.py b/src/weather_model_graphs/create/base.py index f922b2d..55badf1 100644 --- a/src/weather_model_graphs/create/base.py +++ b/src/weather_model_graphs/create/base.py @@ -8,14 +8,12 @@ function uses `connect_nodes_across_graphs` to connect nodes across the component graphs. """ - from typing import Iterable import networkx import networkx as nx import numpy as np import pyproj -import scipy.spatial from loguru import logger from ..networkx_utils import ( @@ -23,6 +21,7 @@ split_graph_by_edge_attribute, split_on_edge_attribute_existance, ) +from ..spatial import SpatialCoordinateValuesSelector from .grid import create_grid_graph_nodes from .mesh.kinds.flat import ( create_flat_multiscale_mesh_graph, @@ -125,6 +124,34 @@ def create_all_graph_components( xy_tuple = coord_transformer.transform(xx=coords[:, 0], yy=coords[:, 1]) xy = np.stack(xy_tuple, axis=1) + # Build a spatial index for the graph coordinates so that all downstream + # neighbour queries use the correct distance metric for the CRS. + if graph_crs is not None: + spatial_coord_selector = SpatialCoordinateValuesSelector.for_crs(graph_crs, xy) + else: + logger.warning( + "No `graph_crs` provided: using Euclidean distance metric for spatial neighbour queries." + ) + # No graph_crs provided: assume projected (Cartesian) coordinates, + # so Euclidean distance is used as the default metric. + spatial_coord_selector = SpatialCoordinateValuesSelector("euclidean", xy) + + # Warn when a rectilinear mesh is being built on top of geographic (lat/lon) + # coordinates. Equally-spaced lon/lat values are *not* equally spaced on a + # sphere, so the mesh node density will vary strongly with latitude. + _is_geographic = getattr(graph_crs, "is_geographic", False) + if _is_geographic and m2m_connectivity in ( + "flat", + "flat_multiscale", + "hierarchical", + ): + logger.warning( + f"m2m_connectivity='{m2m_connectivity}' places mesh nodes on a " + "rectilinear (equally-spaced lon/lat) grid, but the graph CRS is " + "geographic. Equally-spaced longitude/latitude values are NOT equally " + "spaced on a sphere — mesh node density will vary with latitude." + ) + if m2m_connectivity == "flat": graph_components["m2m"] = create_flat_singlescale_mesh_graph( xy, @@ -136,6 +163,7 @@ def create_all_graph_components( # `m2m` (mesh-to-mesh), `mesh_up` (up edge connections) and `mesh_down` (down edge connections) graph_components["m2m"] = create_hierarchical_multiscale_mesh_graph( xy=xy, + distance_metric=spatial_coord_selector.distance_metric, **m2m_connectivity_kwargs, ) # Only connect grid to bottom level of hierarchy @@ -157,6 +185,7 @@ def create_all_graph_components( G_source=G_grid, G_target=grid_connect_graph, method=g2m_connectivity, + distance_metric=spatial_coord_selector.distance_metric, **g2m_connectivity_kwargs, ) graph_components["g2m"] = G_g2m @@ -175,6 +204,7 @@ def create_all_graph_components( G_source=grid_connect_graph, G_target=decode_grid, method=m2g_connectivity, + distance_metric=spatial_coord_selector.distance_metric, **m2g_connectivity_kwargs, ) graph_components["m2g"] = G_m2g @@ -212,10 +242,12 @@ def create_all_graph_components( def connect_nodes_across_graphs( G_source, G_target, + *, method="nearest_neighbour", max_dist=None, rel_max_dist=None, max_num_neighbours=None, + distance_metric: str, ): """ Create a new graph containing the nodes in `G_source` and `G_target` and add @@ -249,6 +281,9 @@ def connect_nodes_across_graphs( relative to longest edge in (bottom level of) `G_source` and `G_target`. max_num_neighbours : int Maximum number of neighbours to search for in `G_target` for each node in `G_source` + distance_metric : str + Distance metric used for neighbour search. Supported values are + ``"euclidean"`` and ``"haversine"``. Returns ------- @@ -259,13 +294,14 @@ def connect_nodes_across_graphs( source_nodes_list = list(G_source.nodes) target_nodes_list = list(G_target.nodes) - # build kd tree for source nodes (e.g. the mesh nodes when constructing m2g) + # Build spatial selector for source nodes (e.g. mesh nodes when constructing m2g). xy_source = np.array([G_source.nodes[node]["pos"] for node in G_source.nodes]) - kdt_s = scipy.spatial.KDTree(xy_source) + spatial_coord_selector = SpatialCoordinateValuesSelector(distance_metric, xy_source) - # Determine method and perform checks once - # Conditionally define _find_neighbour_node_idxs_in_source_mesh for use in - # loop later + # Determine method and perform checks once. + # Conditionally define _find_neighbour_node_idxs_in_source_mesh for use in loop later. + # Each helper returns (indices_array, distances_array) so that edge lengths + # can be taken directly from the tree without recomputing. if method == "containing_rectangle": if ( max_dist is not None @@ -283,7 +319,11 @@ def connect_nodes_across_graphs( # which is at a relative distance of 1. This relative distance is equal # to the diagonal of one rectangle. rad_graph = connect_nodes_across_graphs( - G_source, G_target, method="within_radius", rel_max_dist=1.0 + G_source, + G_target, + method="within_radius", + rel_max_dist=1.0, + distance_metric=distance_metric, ) # Filter edges to those that fit within a rectangle of measurements dx,dy @@ -326,8 +366,8 @@ def _edge_filter(edge_prop): ) def _find_neighbour_node_idxs_in_source_mesh(xy_target): - neigh_idx = kdt_s.query(xy_target, 1)[1] - return [neigh_idx] + idxs, dists = spatial_coord_selector.k_nearest_to(xy_target, k=1) + return idxs, dists elif method == "nearest_neighbours": if max_num_neighbours is None: @@ -340,15 +380,17 @@ def _find_neighbour_node_idxs_in_source_mesh(xy_target): ) def _find_neighbour_node_idxs_in_source_mesh(xy_target): - neigh_idxs = kdt_s.query(xy_target, max_num_neighbours)[1] - return neigh_idxs + idxs, dists = spatial_coord_selector.k_nearest_to( + xy_target, k=max_num_neighbours + ) + return idxs, dists elif method == "within_radius": if max_num_neighbours is not None: raise Exception( "to use `within_radius` method you should not set `max_num_neighbours`" ) - # Determine actual query length to use + # Determine actual query radius to use if max_dist is not None: if rel_max_dist is not None: raise Exception( @@ -391,8 +433,10 @@ def _find_neighbour_node_idxs_in_source_mesh(xy_target): ) def _find_neighbour_node_idxs_in_source_mesh(xy_target): - neigh_idxs = kdt_s.query_ball_point(xy_target, query_dist) - return neigh_idxs + idxs, dists = spatial_coord_selector.within_radius( + xy_target, radius=query_dist + ) + return idxs, dists else: raise NotImplementedError(method) @@ -407,21 +451,12 @@ def _find_neighbour_node_idxs_in_source_mesh(xy_target): # add edges for target_node in target_nodes_list: xy_target = G_target.nodes[target_node]["pos"] - neigh_idxs = _find_neighbour_node_idxs_in_source_mesh(xy_target) - for i in neigh_idxs: + neigh_idxs, neigh_dists = _find_neighbour_node_idxs_in_source_mesh(xy_target) + for i, d in zip(neigh_idxs, neigh_dists): source_node = source_nodes_list[i] # add edge from source to target G_connect.add_edge(source_node, target_node) - d = np.sqrt( - np.sum( - ( - G_connect.nodes[source_node]["pos"] - - G_connect.nodes[target_node]["pos"] - ) - ** 2 - ) - ) - G_connect.edges[source_node, target_node]["len"] = d + G_connect.edges[source_node, target_node]["len"] = float(d) G_connect.edges[source_node, target_node]["vdiff"] = ( G_connect.nodes[source_node]["pos"] - G_connect.nodes[target_node]["pos"] diff --git a/src/weather_model_graphs/create/mesh/kinds/hierarchical.py b/src/weather_model_graphs/create/mesh/kinds/hierarchical.py index b897693..39d9524 100644 --- a/src/weather_model_graphs/create/mesh/kinds/hierarchical.py +++ b/src/weather_model_graphs/create/mesh/kinds/hierarchical.py @@ -1,8 +1,8 @@ import networkx import numpy as np -import scipy from ....networkx_utils import prepend_node_index +from ....spatial import SpatialCoordinateValuesSelector from .. import mesh as mesh_graph @@ -11,6 +11,7 @@ def create_hierarchical_multiscale_mesh_graph( mesh_node_distance: float, level_refinement_factor: float, max_num_levels: int, + distance_metric: str, ): """ Create a hierarchical multiscale mesh graph with nearest neighbour @@ -31,6 +32,10 @@ def create_hierarchical_multiscale_mesh_graph( Refinement factor between grid points and bottom level of mesh hierarchy max_num_levels: int The number of levels in the hierarchical mesh graph. + distance_metric : {'euclidean', 'haversine'} + Distance metric used when computing inter-level nearest-neighbour edges + and storing edge ``"len"`` attributes. Pass ``'haversine'`` when *xy* + contains longitude/latitude coordinates (geographic CRS). Returns ------- @@ -85,22 +90,27 @@ def create_hierarchical_multiscale_mesh_graph( # Add nodes of to level G_down.add_nodes_from(G_to.nodes(data=True)) - # build kd tree for mesh point pos + # build spatial coordinate selector for source (coarser) mesh node positions # order in vm should be same as in vm_xy v_to_list = list(G_to.nodes) v_from_list = list(G_from.nodes) v_from_xy = np.array([xy for _, xy in G_from.nodes.data("pos")]) - kdt_m = scipy.spatial.KDTree(v_from_xy) + spatial_coord_selector = SpatialCoordinateValuesSelector( + distance_metric, v_from_xy + ) # add edges from mesh to grid for v in v_to_list: - # find 1(?) nearest neighbours (index to vm_xy) - neigh_idx = kdt_m.query(G_down.nodes[v]["pos"], 1)[1] + # find nearest neighbour in coarser level + neigh_idxs, neigh_dists = spatial_coord_selector.k_nearest_to( + G_down.nodes[v]["pos"], k=1 + ) + neigh_idx = int(neigh_idxs[0]) + d = float(neigh_dists[0]) u = v_from_list[neigh_idx] # add edge from mesh to grid G_down.add_edge(u, v) - d = np.sqrt(np.sum((G_down.nodes[u]["pos"] - G_down.nodes[v]["pos"]) ** 2)) G_down.edges[u, v]["len"] = d G_down.edges[u, v]["vdiff"] = ( G_down.nodes[u]["pos"] - G_down.nodes[v]["pos"] diff --git a/src/weather_model_graphs/spatial.py b/src/weather_model_graphs/spatial.py new file mode 100644 index 0000000..cb9088c --- /dev/null +++ b/src/weather_model_graphs/spatial.py @@ -0,0 +1,226 @@ +""" +Spatial indexing for metric-aware nearest-neighbour and radius queries. + +Provides :class:`SpatialCoordinateValuesSelector`, which wraps a ball-tree to +deliver k-nearest-neighbour and radius lookups using either the Euclidean or the +Haversine distance metric. The correct metric is chosen automatically when the +object is created via the :meth:`SpatialCoordinateValuesSelector.for_crs` class +method. +""" + +from __future__ import annotations + +from typing import Tuple + +import numpy as np + +try: + from sklearn.neighbors import BallTree as _BallTree + + _HAS_SKLEARN = True +except ImportError: # pragma: no cover + _HAS_SKLEARN = False + + +class SpatialCoordinateValuesSelector: + """ + Metric-aware spatial index for selecting coordinate values by proximity. + + Wraps a ball-tree to provide fast k-nearest-neighbour and radius queries. + The tree is built once at construction time; subsequent queries are cheap. + + Two distance metrics are supported: + + * ``"euclidean"`` – standard Cartesian distance, appropriate for projected + coordinate systems (e.g. Lambert Conformal, UTM). + * ``"haversine"`` – great-circle distance on a sphere, appropriate for + geographic coordinate systems expressed as longitude/latitude in degrees + (e.g. PlateCarree). + + Parameters + ---------- + distance_metric : {'euclidean', 'haversine'} + Distance metric to use for the underlying ball-tree. + coords : np.ndarray, shape (N, 2) + Coordinate array. For ``"euclidean"`` these are arbitrary Cartesian + (x, y) values. For ``"haversine"`` these must be **longitude/latitude + in degrees** (first column longitude, second column latitude). + + Raises + ------ + ValueError + If *distance_metric* is not ``"euclidean"`` or ``"haversine"``. + ImportError + If *distance_metric* is ``"haversine"`` and ``scikit-learn`` is not + installed. + + Examples + -------- + Euclidean (projected CRS): + + >>> import numpy as np + >>> coords = np.array([[0.0, 0.0], [1.0, 0.0], [2.0, 0.0]]) + >>> sel = SpatialCoordinateValuesSelector("euclidean", coords) + >>> idxs, dists = sel.k_nearest_to([1.5, 0.0], k=2) + >>> idxs.tolist() + [2, 1] + + Haversine (geographic CRS, lon/lat degrees): + + >>> coords_geo = np.array([[0.0, 0.0], [10.0, 0.0], [20.0, 0.0]]) + >>> sel_geo = SpatialCoordinateValuesSelector("haversine", coords_geo) + >>> idxs, dists = sel_geo.k_nearest_to([5.0, 0.0], k=2) + """ + + def __init__(self, distance_metric: str, coords: np.ndarray) -> None: + _VALID_METRICS = ("euclidean", "haversine") + if distance_metric not in _VALID_METRICS: + raise ValueError( + f"distance_metric must be one of {_VALID_METRICS!r}, " + f"got {distance_metric!r}." + ) + + if distance_metric == "haversine" and not _HAS_SKLEARN: + raise ImportError( + "scikit-learn is required for the 'haversine' distance metric. " + "Install it with: pip install scikit-learn" + ) + + self.distance_metric: str = distance_metric + self._coords: np.ndarray = np.asarray(coords, dtype=float) + + if distance_metric == "haversine": + # BallTree with haversine expects [latitude, longitude] in **radians**. + # coords are stored as [longitude, latitude] in degrees throughout the + # rest of the codebase, so we swap columns and convert. + tree_coords = np.deg2rad(self._coords[:, ::-1]) # (N, 2) [lat_rad, lon_rad] + self._tree = _BallTree(tree_coords, metric="haversine") + else: + self._tree = _BallTree(self._coords, metric="euclidean") + + # Public query methods + def k_nearest_to(self, point: np.ndarray, k: int) -> Tuple[np.ndarray, np.ndarray]: + """ + Return the *k* nearest coordinate values to *point*. + + Parameters + ---------- + point : array-like, shape (2,) + Query point. For ``"euclidean"`` this is a Cartesian (x, y) + coordinate; for ``"haversine"`` this is a (longitude, latitude) + pair in degrees. + k : int + Number of nearest neighbours to return. + + Returns + ------- + indices : np.ndarray, shape (k,) + Indices into the original *coords* array (passed to ``__init__``) + of the *k* nearest neighbours, ordered by increasing distance. + distances : np.ndarray, shape (k,) + Corresponding distances. For ``"euclidean"`` these are in the same + units as *coords*; for ``"haversine"`` these are in **degrees**. + """ + tree_point = self._prepare_query_point(point) + raw_dists, raw_idxs = self._tree.query(tree_point, k=k) + indices = raw_idxs.flatten() + distances = ( + np.rad2deg(raw_dists.flatten()) + if self.distance_metric == "haversine" + else raw_dists.flatten() + ) + return indices, distances + + def within_radius( + self, point: np.ndarray, radius: float + ) -> Tuple[np.ndarray, np.ndarray]: + """ + Return all coordinate values within *radius* of *point*. + + Parameters + ---------- + point : array-like, shape (2,) + Query point (same coordinate convention as for :meth:`k_nearest_to`). + radius : float + Search radius. For ``"euclidean"`` this is in the same units as + *coords*; for ``"haversine"`` this is in **degrees**. + + Returns + ------- + indices : np.ndarray + Indices into the original *coords* array of all neighbours within + *radius*, **unsorted**. + distances : np.ndarray + Distances to each returned neighbour. For ``"euclidean"`` these + are in the same units as *coords*; for ``"haversine"`` these are + in **degrees**. + """ + tree_point = self._prepare_query_point(point) + raw_idxs, raw_dists = self._tree.query_radius( + tree_point, + r=np.deg2rad(radius) if self.distance_metric == "haversine" else radius, + return_distance=True, + ) + indices = raw_idxs[0] + distances = ( + np.rad2deg(raw_dists[0]) + if self.distance_metric == "haversine" + else raw_dists[0] + ) + return indices, distances + + # Factory class-method + @classmethod + def for_crs( + cls, + crs, + coords: np.ndarray, + ) -> "SpatialCoordinateValuesSelector": + """ + Create a :class:`SpatialCoordinateValuesSelector` appropriate for *crs*. + + Inspects the ``is_geographic`` property of *crs* to choose the metric: + + * Geographic CRS (``crs.is_geographic is True``) → ``"haversine"`` + * Projected CRS → ``"euclidean"`` + + Parameters + ---------- + crs : cartopy.crs.CRS or pyproj.CRS + Coordinate reference system of *coords*. Must expose an + ``is_geographic`` attribute (both *cartopy* and *pyproj* CRS + objects do). + coords : np.ndarray, shape (N, 2) + Coordinate array in the given *crs*. + + Returns + ------- + SpatialCoordinateValuesSelector + Configured with the appropriate distance metric. + + Examples + -------- + >>> import cartopy.crs as ccrs + >>> import numpy as np + >>> coords = np.column_stack([np.linspace(-10, 10, 50), + ... np.linspace(50, 60, 50)]) + >>> sel = SpatialCoordinateValuesSelector.for_crs(ccrs.PlateCarree(), coords) + >>> sel.distance_metric + 'haversine' + >>> sel2 = SpatialCoordinateValuesSelector.for_crs(ccrs.LambertConformal(), coords) + >>> sel2.distance_metric + 'euclidean' + """ + is_geographic = getattr(crs, "is_geographic", False) + # pyproj CRS exposes is_geographic as a bool property; cartopy CRS does too. + metric = "haversine" if is_geographic else "euclidean" + return cls(metric, coords) + + # Private helpers + def _prepare_query_point(self, point: np.ndarray) -> np.ndarray: + """Convert a query point to the internal representation used by the tree.""" + pt = np.asarray(point, dtype=float).reshape(1, 2) + if self.distance_metric == "haversine": + # swap [lon, lat] → [lat, lon] and convert to radians + return np.deg2rad(pt[:, ::-1]) + return pt diff --git a/tests/test_graph_creation.py b/tests/test_graph_creation.py index b45d486..3cd1615 100644 --- a/tests/test_graph_creation.py +++ b/tests/test_graph_creation.py @@ -263,6 +263,10 @@ def test_edgeless_nodes_preservation_in_different_graphs( ) graph_target = wmg.create.grid.create_grid_graph_nodes(coordinates_grid) graph = wmg.create.base.connect_nodes_across_graphs( - G_source=graph_source, G_target=graph_target, method=method, **method_kwargs + G_source=graph_source, + G_target=graph_target, + method=method, + distance_metric="euclidean", + **method_kwargs, ) assert set(graph.nodes) == set(graph_source.nodes) | set(graph_target.nodes) diff --git a/tests/test_spatial_index.py b/tests/test_spatial_index.py new file mode 100644 index 0000000..5744079 --- /dev/null +++ b/tests/test_spatial_index.py @@ -0,0 +1,357 @@ +""" +Tests for :class:`weather_model_graphs.spatial.SpatialCoordinateValuesSelector`. + +Covers: +- Initialisation (valid / invalid metric) +- Euclidean k-nearest-to and within_radius queries +- Haversine k-nearest-to and within_radius queries (distances in degrees) +- Factory method SpatialCoordinateValuesSelector.for_crs() +- Warning emitted for rectilinear mesh + geographic CRS in create_all_graph_components +""" + +import warnings + +import cartopy.crs as ccrs +import numpy as np +import pyproj +import pytest +from loguru import logger + +import weather_model_graphs as wmg +from weather_model_graphs.spatial import SpatialCoordinateValuesSelector + + +# Fixtures +@pytest.fixture() +def simple_euclidean_coords(): + """Five points on a horizontal line: x = 0, 1, 2, 3, 4; y = 0.""" + return np.array([[0.0, 0.0], [1.0, 0.0], [2.0, 0.0], [3.0, 0.0], [4.0, 0.0]]) + + +@pytest.fixture() +def simple_geo_coords(): + """Five lon/lat points along the equator (y = 0 °), 0–40 ° longitude.""" + return np.array([[0.0, 0.0], [10.0, 0.0], [20.0, 0.0], [30.0, 0.0], [40.0, 0.0]]) + + +# Initialisation +class TestInit: + def test_euclidean_metric_stored(self, simple_euclidean_coords): + sel = SpatialCoordinateValuesSelector("euclidean", simple_euclidean_coords) + assert sel.distance_metric == "euclidean" + + def test_haversine_metric_stored(self, simple_geo_coords): + sel = SpatialCoordinateValuesSelector("haversine", simple_geo_coords) + assert sel.distance_metric == "haversine" + + def test_invalid_metric_raises(self, simple_euclidean_coords): + with pytest.raises(ValueError, match="distance_metric must be one of"): + SpatialCoordinateValuesSelector("manhattan", simple_euclidean_coords) + + def test_coords_stored_as_float_array(self, simple_euclidean_coords): + sel = SpatialCoordinateValuesSelector("euclidean", simple_euclidean_coords) + assert sel._coords.dtype == np.float64 + + +# Euclidean – k_nearest_to +class TestEuclideanKNearest: + def test_self_distance_is_zero(self, simple_euclidean_coords): + sel = SpatialCoordinateValuesSelector("euclidean", simple_euclidean_coords) + idxs, dists = sel.k_nearest_to([0.0, 0.0], k=1) + assert idxs[0] == 0 + assert dists[0] == pytest.approx(0.0) + + def test_nearest_of_two(self): + coords = np.array([[0.0, 0.0], [10.0, 0.0]]) + sel = SpatialCoordinateValuesSelector("euclidean", coords) + idxs, dists = sel.k_nearest_to([3.0, 0.0], k=1) + assert idxs[0] == 0 # [0,0] is closer than [10,0] + + def test_k_neighbours_returned(self, simple_euclidean_coords): + sel = SpatialCoordinateValuesSelector("euclidean", simple_euclidean_coords) + idxs, dists = sel.k_nearest_to([2.0, 0.0], k=3) + assert len(idxs) == 3 + assert len(dists) == 3 + + def test_distances_sorted_ascending(self, simple_euclidean_coords): + sel = SpatialCoordinateValuesSelector("euclidean", simple_euclidean_coords) + idxs, dists = sel.k_nearest_to([0.5, 0.0], k=3) + assert list(dists) == sorted(dists) + + def test_known_euclidean_distance(self): + # Two points: (0,0) and (3,4) – distance = 5 + coords = np.array([[0.0, 0.0], [3.0, 4.0]]) + sel = SpatialCoordinateValuesSelector("euclidean", coords) + idxs, dists = sel.k_nearest_to([0.0, 0.0], k=2) + assert dists[1] == pytest.approx(5.0) + + +# Euclidean – within_radius +class TestEuclideanWithRadius: + def test_returns_points_within_radius(self, simple_euclidean_coords): + sel = SpatialCoordinateValuesSelector("euclidean", simple_euclidean_coords) + idxs, dists = sel.within_radius([2.0, 0.0], radius=1.5) + # should include indices 1, 2, 3 (x=1, 2, 3) + assert set(idxs) == {1, 2, 3} + + def test_excludes_points_beyond_radius(self, simple_euclidean_coords): + sel = SpatialCoordinateValuesSelector("euclidean", simple_euclidean_coords) + idxs, _ = sel.within_radius([2.0, 0.0], radius=0.5) + assert set(idxs) == {2} + + def test_distances_within_radius(self, simple_euclidean_coords): + sel = SpatialCoordinateValuesSelector("euclidean", simple_euclidean_coords) + idxs, dists = sel.within_radius([2.0, 0.0], radius=1.5) + # All returned distances must be ≤ radius + assert all(d <= 1.5 + 1e-9 for d in dists) + + def test_zero_radius_returns_only_self(self, simple_euclidean_coords): + sel = SpatialCoordinateValuesSelector("euclidean", simple_euclidean_coords) + idxs, dists = sel.within_radius([1.0, 0.0], radius=0.0) + assert set(idxs) == {1} + assert dists[0] == pytest.approx(0.0) + + +# Haversine – k_nearest_to + + +class TestHaversineKNearest: + def test_distance_is_returned_in_degrees(self, simple_geo_coords): + sel = SpatialCoordinateValuesSelector("haversine", simple_geo_coords) + idxs, dists = sel.k_nearest_to([-10.0, 0.0], k=1) + assert idxs[0] == 0 + assert dists[0] == pytest.approx(10.0, rel=1e-4) + + def test_distances_in_degrees(self, simple_geo_coords): + """10° longitude at equator should be returned as 10.0 degrees.""" + sel = SpatialCoordinateValuesSelector("haversine", simple_geo_coords) + idxs, dists = sel.k_nearest_to([0.0, 0.0], k=2) + # nearest is self (0 deg), second is [10, 0] = 10 deg + assert dists[1] == pytest.approx(10.0, rel=1e-4) + + def test_distances_are_haversine_degrees(self): + """For geographic coords, haversine distances are returned in degrees.""" + coords = np.array([[0.0, 0.0], [1.0, 0.0]]) + sel = SpatialCoordinateValuesSelector("haversine", coords) + _, d_hav = sel.k_nearest_to([0.0, 0.0], k=2) + assert d_hav[1] == pytest.approx(1.0, rel=1e-4) + + +# Haversine – within_radius +class TestHaversineWithRadius: + def test_radius_in_degrees_inclusive(self, simple_geo_coords): + """A 12° radius from origin includes points at 0° and 10° lon.""" + sel = SpatialCoordinateValuesSelector("haversine", simple_geo_coords) + radius_deg = 12.0 + idxs, dists = sel.within_radius([0.0, 0.0], radius=radius_deg) + order = np.argsort(dists) + sorted_idxs = idxs[order] + sorted_dists = dists[order] + np.testing.assert_array_equal(sorted_idxs, np.array([0, 1])) + np.testing.assert_allclose(sorted_dists, np.array([0.0, 10.0]), rtol=1e-4) + + def test_radius_in_degrees_exclusive(self, simple_geo_coords): + """A 5° radius from origin excludes the 10° point.""" + sel = SpatialCoordinateValuesSelector("haversine", simple_geo_coords) + radius_deg = 5.0 + idxs, _ = sel.within_radius([0.0, 0.0], radius=radius_deg) + assert set(idxs) == {0} + + +class TestHaversineLongitudeWrapAround: + """Regression tests for periodic longitude behaviour around +/-180 deg.""" + + @pytest.fixture() + def equator_periodic_coords(self): + # Equally spaced points on the equator, matching the periodic-domain notebook. + lons = np.linspace(0.0, 360.0, 40, endpoint=False) + lats = np.zeros_like(lons) + return np.column_stack([lons, lats]) + + def test_nearest_neighbour_wraps_across_dateline(self, equator_periodic_coords): + """A query near 360 deg should see 0 deg as nearest via spherical wrap-around.""" + sel = SpatialCoordinateValuesSelector("haversine", equator_periodic_coords) + idxs, dists = sel.k_nearest_to([359.0, 0.0], k=2) + nearest_lons = equator_periodic_coords[idxs, 0] + + # 0 deg is 1 degree away on the sphere, while 351 deg is 8 degrees away. + assert 0.0 in nearest_lons + assert dists.min() == pytest.approx(1.0, rel=1e-4) + + def test_radius_query_crosses_longitude_seam(self, equator_periodic_coords): + """Radius query near 360 deg should include 0 deg neighbour across seam.""" + sel = SpatialCoordinateValuesSelector("haversine", equator_periodic_coords) + idxs, _ = sel.within_radius([359.0, 0.0], radius=5.0) + lons_in_radius = set(equator_periodic_coords[idxs, 0]) + + # 0 deg is within 1 degree across seam; 351 deg is 8 degrees away and excluded. + assert 0.0 in lons_in_radius + assert 351.0 not in lons_in_radius + + +# Factory: for_crs + + +class TestForCrs: + def test_geographic_crs_gives_haversine(self): + coords = np.random.default_rng(0).random((10, 2)) + # pyproj.CRS('EPSG:4326') is a true geographic CRS: is_geographic=True + sel = SpatialCoordinateValuesSelector.for_crs(pyproj.CRS("EPSG:4326"), coords) + assert sel.distance_metric == "haversine" + + def test_projected_crs_gives_euclidean(self): + coords = np.random.default_rng(0).random((10, 2)) * 1e6 + sel = SpatialCoordinateValuesSelector.for_crs(ccrs.LambertConformal(), coords) + assert sel.distance_metric == "euclidean" + + def test_mollweide_projected_gives_euclidean(self): + coords = np.random.default_rng(0).random((10, 2)) * 1e6 + sel = SpatialCoordinateValuesSelector.for_crs(ccrs.Mollweide(), coords) + assert sel.distance_metric == "euclidean" + + def test_equivalent_to_manual_construction_euclidean(self): + """for_crs on a projected CRS should produce the same results as the + manually constructed euclidean selector.""" + rng = np.random.default_rng(1) + coords = rng.random((20, 2)) * 1e5 + query = [5e4, 5e4] + sel_factory = SpatialCoordinateValuesSelector.for_crs( + ccrs.LambertConformal(), coords + ) + sel_manual = SpatialCoordinateValuesSelector("euclidean", coords) + idxs_f, dists_f = sel_factory.k_nearest_to(query, k=3) + idxs_m, dists_m = sel_manual.k_nearest_to(query, k=3) + np.testing.assert_array_equal(idxs_f, idxs_m) + np.testing.assert_allclose(dists_f, dists_m) + + +# Integration: rectilinear + geographic warning + + +class TestRectilinearGeographicWarning: + """ + When create_all_graph_components is called with a geographic graph_crs and + a rectilinear m2m_connectivity, a UserWarning should be raised. + """ + + def _make_lonlat_coords(self, n=10): + lon = np.linspace(-10.0, 10.0, n) + lat = np.linspace(50.0, 60.0, n) + lo, la = np.meshgrid(lon, lat) + return np.column_stack([lo.ravel(), la.ravel()]) + + @pytest.mark.parametrize("m2m", ["flat", "flat_multiscale", "hierarchical"]) + def test_warning_raised_for_geographic_crs(self, m2m): + # Use a 30x30 grid over a ~29-degree domain so hierarchical can build >=2 levels + lon = np.linspace(0.0, 29.0, 30) + lat = np.linspace(45.0, 74.0, 30) + lo, la = np.meshgrid(lon, lat) + large_coords = np.column_stack([lo.ravel(), la.ravel()]) + + # pyproj EPSG:4326 has is_geographic=True → triggers warning + geo_crs = pyproj.CRS("EPSG:4326") + + kwargs = dict( + coords=large_coords, + m2m_connectivity=m2m, + g2m_connectivity="nearest_neighbour", + m2g_connectivity="nearest_neighbour", + graph_crs=geo_crs, + ) + if m2m == "flat": + kwargs["m2m_connectivity_kwargs"] = dict(mesh_node_distance=3) + elif m2m == "flat_multiscale": + kwargs["m2m_connectivity_kwargs"] = dict( + max_num_levels=2, mesh_node_distance=3, level_refinement_factor=3 + ) + elif m2m == "hierarchical": + kwargs["m2m_connectivity_kwargs"] = dict( + max_num_levels=2, mesh_node_distance=3, level_refinement_factor=3 + ) + warning_messages = [] + sink_id = logger.add( + lambda msg: warning_messages.append(msg.record["message"]), + level="WARNING", + ) + try: + wmg.create.create_all_graph_components(**kwargs) + finally: + logger.remove(sink_id) + + assert any("rectilinear" in message for message in warning_messages) + + def test_no_warning_for_projected_crs(self): + """No UserWarning for a projected CRS.""" + # Use a small Cartesian grid (pretend it's in some projected CRS) + coords = np.column_stack([np.linspace(0, 1e5, 20), np.linspace(0, 1e5, 20)]) + + class _FakeProjectedCRS: + """Minimal projected CRS stub.""" + + is_geographic = False + + with warnings.catch_warnings(): + warnings.simplefilter("error", UserWarning) + # Should *not* raise + wmg.create.create_all_graph_components( + coords=coords, + m2m_connectivity="flat", + m2m_connectivity_kwargs=dict(mesh_node_distance=3000), + g2m_connectivity="nearest_neighbour", + m2g_connectivity="nearest_neighbour", + graph_crs=_FakeProjectedCRS(), + ) + + +# Integration: graph creation uses correct metric end-to-end +class TestIntegrationGraphCreation: + """ + Smoke-test that graph creation completes without error when a geographic + CRS is supplied, and that the haversine-based edge lengths are physically + reasonable in degrees for a ~10° domain. + """ + + def _make_lonlat_coords(self, n=8): + lon = np.linspace(0.0, 9.0, n) + lat = np.linspace(50.0, 59.0, n) + lo, la = np.meshgrid(lon, lat) + return np.column_stack([lo.ravel(), la.ravel()]) + + def test_graph_created_with_geographic_crs(self): + coords = self._make_lonlat_coords() + # pyproj EPSG:4326 has is_geographic=True → haversine metric used, + # and the rectilinear/geographic warning is logged. + warning_messages = [] + sink_id = logger.add( + lambda msg: warning_messages.append(msg.record["message"]), + level="WARNING", + ) + try: + G = wmg.create.create_all_graph_components( + coords=coords, + m2m_connectivity="flat", + m2m_connectivity_kwargs=dict(mesh_node_distance=3), + g2m_connectivity="nearest_neighbour", + m2g_connectivity="nearest_neighbour", + graph_crs=pyproj.CRS("EPSG:4326"), + return_components=False, + ) + finally: + logger.remove(sink_id) + + assert any("rectilinear" in message for message in warning_messages) + # The g2m / m2g edges use haversine (distances in degrees). + # The m2m internal mesh edges still use Euclidean (degrees) because + # create_single_level_2d_mesh_graph does not receive the CRS. + g2m_m2g_lens = [ + d["len"] + for _, _, d in G.edges(data=True) + if d.get("component") in ("g2m", "m2g") and "len" in d + ] + assert len(g2m_m2g_lens) > 0, "Expected g2m/m2g edges with 'len' attribute" + # For a ~9 degree domain, haversine edge lengths should stay in a + # plausible degree range and remain well below half the globe. + assert all(1e-3 < length < 20.0 for length in g2m_m2g_lens), ( + f"g2m/m2g edge lengths out of expected haversine range: " + f"min={min(g2m_m2g_lens):.6f} deg, max={max(g2m_m2g_lens):.6f} deg" + )