Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
f0c19a4
feat: auto-detect distance metric from CRS via SpatialCoordinateValue…
FAbdullah17 Mar 3, 2026
14f719f
fix: address review feedback for CRS-aware distance metric integratio…
FAbdullah17 Mar 5, 2026
7ca16dd
refactor(hierarchical): rename selector variable to spatial_coord_sel…
FAbdullah17 Mar 7, 2026
d8fbae4
feat(base): warn when graph_crs is None and euclidean is assumed
FAbdullah17 Mar 7, 2026
718afe5
refactor(base): use loguru warning for rectilinear geographic notice
FAbdullah17 Mar 7, 2026
f5bb64a
chore(base): trim rectilinear geographic warning message
FAbdullah17 Mar 7, 2026
e050204
refactor(base): require explicit distance_metric in connect_nodes_acr…
FAbdullah17 Mar 7, 2026
a4c5a50
test(spatial): capture loguru warnings instead of pytest.warns
FAbdullah17 Mar 7, 2026
8da3663
test(spatial): reword haversine unit wording to radians
FAbdullah17 Mar 7, 2026
bf76435
test: add haversine longitude wrap-around regression tests
FAbdullah17 Mar 8, 2026
96f2a99
refactor: require explicit distance_metric in hierarchical mesh builder
FAbdullah17 Mar 22, 2026
854e062
refactor: rename with_radius to within_radius
FAbdullah17 Mar 22, 2026
8dbfee1
feat: accept haversine within_radius radius in degrees
FAbdullah17 Mar 22, 2026
bce4d8c
feat: return haversine within_radius distances in degrees
FAbdullah17 Mar 22, 2026
8c1a610
feat: return haversine k_nearest_to distances in degrees
FAbdullah17 Mar 24, 2026
0d75615
test: assert non-zero haversine k_nearest_to distance in degrees
FAbdullah17 Mar 24, 2026
cbd09e9
test: align haversine k_nearest_to expectations with degree outputs
FAbdullah17 Mar 24, 2026
97f0e94
docs: add changelog entry for CRS-aware degree-based metric handling
FAbdullah17 Mar 24, 2026
b1cf19d
Merge branch 'main' into feature/distance-engine
leifdenby Mar 24, 2026
ee70506
test: make within_radius inclusive distance assertion deterministic
FAbdullah17 Mar 24, 2026
44b2cd7
chore: apply CI lint and notebook API fixes
FAbdullah17 Mar 25, 2026
949a71d
fix: add mesh_node_distance kwargs in notebook flat mesh examples
FAbdullah17 Mar 25, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
335 changes: 335 additions & 0 deletions docs/distance_metric_auto_detection.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,335 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "419f7f83",
"metadata": {},
"source": [
"# Auto-Detection of Distance Metric from CRS\n",
"\n",
"This notebook demonstrates the `SpatialCoordinateValuesSelector` class introduced\n",
"in [issue #75](https://github.com/mllam/weather-model-graphs/issues/75).\n",
"\n",
"The class wraps a ball-tree to provide efficient nearest-neighbour and radius\n",
"queries using the correct distance metric for the coordinate reference system (CRS):\n",
"\n",
"| CRS type | Distance metric | When to use |\n",
"|---|---|---|\n",
"| Geographic (lat/lon) — e.g. `ccrs.PlateCarree()` | **Haversine** | Global / sparse grids |\n",
"| Projected — e.g. `ccrs.LambertConformal()` | **Euclidean** | Regional / LAM grids |\n",
"\n",
"Distances returned by the haversine path are in **metres**; Euclidean distances are\n",
"in the same units as the input coordinates."
]
},
{
"cell_type": "markdown",
"id": "fd3d7fce",
"metadata": {},
"source": [
"## 1. Setup and Imports"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fe32be32",
"metadata": {},
"outputs": [],
"source": [
"import warnings\n",
"\n",
"import cartopy.crs as ccrs\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"import weather_model_graphs as wmg\n",
"from weather_model_graphs.spatial import SpatialCoordinateValuesSelector\n",
"\n",
"print(\"weather_model_graphs imported OK\")\n",
"print(f\"SpatialCoordinateValuesSelector: {SpatialCoordinateValuesSelector}\")"
]
},
{
"cell_type": "markdown",
"id": "b0d01140",
"metadata": {},
"source": [
"## 2. Euclidean Distance Queries (Projected CRS)\n",
"\n",
"For a projected coordinate system the coordinates are Cartesian (e.g. metres).\n",
"We use the `\"euclidean\"` metric."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6c276b4e",
"metadata": {},
"outputs": [],
"source": [
"rng = np.random.default_rng(42)\n",
"\n",
"# 50 random Cartesian points in a 100 km × 100 km domain (units: metres)\n",
"n_pts = 50\n",
"projected_coords = rng.random((n_pts, 2)) * 1e5 # shape (50, 2)\n",
"\n",
"sel_euc = SpatialCoordinateValuesSelector(\"euclidean\", projected_coords)\n",
"print(f\"metric: {sel_euc.distance_metric}\")\n",
"\n",
"# ---- k-nearest-to ----\n",
"query_pt = np.array([5e4, 5e4]) # centre of the domain\n",
"idxs, dists = sel_euc.k_nearest_to(query_pt, k=5)\n",
"print(\"\\n5 nearest neighbours (Euclidean):\")\n",
"for rank, (i, d) in enumerate(zip(idxs, dists), 1):\n",
" print(f\" rank {rank}: index={i:3d} dist={d/1e3:.2f} km coord={projected_coords[i]}\")\n",
"\n",
"# ---- with_radius ----\n",
"radius_m = 2e4 # 20 km\n",
"idxs_r, dists_r = sel_euc.with_radius(query_pt, radius=radius_m)\n",
"print(f\"\\nPoints within {radius_m/1e3:.0f} km of centre: {len(idxs_r)}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "79300d69",
"metadata": {},
"outputs": [],
"source": [
"fig, axes = plt.subplots(1, 2, figsize=(12, 5))\n",
"\n",
"for ax, (highlight_idxs, title, radius) in zip(\n",
" axes,\n",
" [\n",
" (idxs, \"k=5 nearest neighbours\", None),\n",
" (idxs_r, f\"within_radius = {radius_m/1e3:.0f} km\", radius_m),\n",
" ],\n",
"):\n",
" ax.scatter(\n",
" projected_coords[:, 0] / 1e3,\n",
" projected_coords[:, 1] / 1e3,\n",
" c=\"lightgrey\", s=30, zorder=2, label=\"all points\",\n",
" )\n",
" ax.scatter(\n",
" projected_coords[highlight_idxs, 0] / 1e3,\n",
" projected_coords[highlight_idxs, 1] / 1e3,\n",
" c=\"steelblue\", s=60, zorder=3, label=\"selected\",\n",
" )\n",
" ax.scatter(*query_pt / 1e3, c=\"red\", s=100, marker=\"*\", zorder=4, label=\"query\")\n",
" if radius is not None:\n",
" circle = plt.Circle(query_pt / 1e3, radius / 1e3, fill=False, color=\"red\", lw=1.5, ls=\"--\")\n",
" ax.add_patch(circle)\n",
" ax.set_aspect(\"equal\")\n",
" ax.set_xlabel(\"x (km)\")\n",
" ax.set_ylabel(\"y (km)\")\n",
" ax.set_title(title)\n",
" ax.legend(loc=\"upper right\")\n",
"\n",
"fig.suptitle(\"Euclidean (projected CRS) distance queries\", fontsize=13)\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "9ebfe0fb",
"metadata": {},
"source": [
"## 3. Haversine Distance Queries (Geographic CRS)\n",
"\n",
"For a geographic CRS the coordinates are longitude / latitude in degrees.\n",
"We use the `\"haversine\"` metric, and **all distances are returned in metres**."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "871d59b9",
"metadata": {},
"outputs": [],
"source": [
"# 50 random lon/lat points over Europe (lon: -10 to 30, lat: 45 to 70)\n",
"geo_coords = np.column_stack([\n",
" rng.uniform(-10, 30, n_pts), # longitude\n",
" rng.uniform(45, 70, n_pts), # latitude\n",
"])\n",
"\n",
"sel_hav = SpatialCoordinateValuesSelector(\"haversine\", geo_coords)\n",
"print(f\"metric: {sel_hav.distance_metric}\")\n",
"\n",
"# Known result: 1 degree of longitude at latitude 55° ≈ 63,800 m\n",
"# (cos(55°) * 111,195 m/deg ≈ 63,800 m)\n",
"query_geo = np.array([10.0, 55.0]) # lon=10°, lat=55°\n",
"idxs_h, dists_h = sel_hav.k_nearest_to(query_geo, k=5)\n",
"print(\"\\n5 nearest neighbours (Haversine, distances in metres):\")\n",
"for rank, (i, d) in enumerate(zip(idxs_h, dists_h), 1):\n",
" print(f\" rank {rank}: index={i:3d} dist={d/1e3:7.1f} km lon={geo_coords[i,0]:.2f}° lat={geo_coords[i,1]:.2f}°\")\n",
"\n",
"# Radius query: 500 km around the query point\n",
"radius_500km = 500_000\n",
"idxs_hr, dists_hr = sel_hav.with_radius(query_geo, radius=radius_500km)\n",
"print(f\"\\nPoints within {radius_500km/1e3:.0f} km: {len(idxs_hr)}\")"
]
},
{
"cell_type": "markdown",
"id": "bcd4e330",
"metadata": {},
"source": [
"## 4. CRS-Based Automatic Metric Selection (`for_crs`)\n",
"\n",
"The class-method `SpatialCoordinateValuesSelector.for_crs(crs, coords)` inspects\n",
"`crs.is_geographic` and picks the correct metric automatically — no manual choice needed."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "65ec5fc4",
"metadata": {},
"outputs": [],
"source": [
"cases = {\n",
" \"PlateCarree (geographic)\": ccrs.PlateCarree(),\n",
" \"Mercator (geographic)\": ccrs.Mercator(),\n",
" \"LambertConformal (projected)\": ccrs.LambertConformal(),\n",
" \"Stereographic (projected)\": ccrs.Stereographic(),\n",
" \"Mollweide (projected)\": ccrs.Mollweide(),\n",
"}\n",
"\n",
"print(f\"{'CRS':<40} {'is_geographic':<15} {'auto-selected metric'}\")\n",
"print(\"-\" * 75)\n",
"for name, crs in cases.items():\n",
" sel = SpatialCoordinateValuesSelector.for_crs(crs, projected_coords)\n",
" print(f\"{name:<40} {str(crs.is_geographic):<15} {sel.distance_metric}\")"
]
},
{
"cell_type": "markdown",
"id": "efd3e377",
"metadata": {},
"source": [
"## 5. Integration with Graph Connectivity\n",
"\n",
"`SpatialCoordinateValuesSelector` is now used automatically inside\n",
"`wmg.create.create_all_graph_components()` for all nearest-neighbour and\n",
"radius queries. When you supply `graph_crs`, the correct metric is chosen\n",
"for you — no other changes are required at the call site."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "46b71d7f",
"metadata": {},
"outputs": [],
"source": [
"# --- Projected CRS (euclidean) ---\n",
"lon = np.linspace(0.0, 9.0, 8)\n",
"lat = np.linspace(50.0, 58.0, 8)\n",
"lo, la = np.meshgrid(lon, lat)\n",
"lonlat_coords = np.column_stack([lo.ravel(), la.ravel()])\n",
"\n",
"# Projected: convert to Oblique Mercator metres for demonstration\n",
"from pyproj import Transformer\n",
"transformer = Transformer.from_crs(\"EPSG:4326\", \"EPSG:3857\", always_xy=True)\n",
"x, y = transformer.transform(lonlat_coords[:, 0], lonlat_coords[:, 1])\n",
"proj_coords = np.column_stack([x, y])\n",
"\n",
"# Build graph with projected CRS — euclidean metric is used automatically\n",
"import pyproj\n",
"web_mercator = pyproj.CRS(\"EPSG:3857\")\n",
"G_proj = wmg.create.create_all_graph_components(\n",
" coords=proj_coords,\n",
" m2m_connectivity=\"flat\",\n",
" g2m_connectivity=\"nearest_neighbour\",\n",
" m2g_connectivity=\"nearest_neighbour\",\n",
" graph_crs=web_mercator,\n",
")\n",
"lens_proj = [d[\"len\"] for _, _, d in G_proj.edges(data=True) if \"len\" in d]\n",
"print(f\"Projected graph: {G_proj.number_of_nodes()} nodes, {G_proj.number_of_edges()} edges\")\n",
"print(f\"Edge 'len' range: {min(lens_proj)/1e3:.1f} – {max(lens_proj)/1e3:.1f} km (euclidean, metres)\")"
]
},
{
"cell_type": "markdown",
"id": "1f8660d6",
"metadata": {},
"source": [
"## 6. Rectilinear + Geographic CRS Warning\n",
"\n",
"When a rectilinear mesh is overlaid on geographic (lat/lon) coordinates a\n",
"`UserWarning` is raised automatically, because equally-spaced lon/lat values are\n",
"**not** equally spaced on a sphere."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "45835a0a",
"metadata": {},
"outputs": [],
"source": [
"# Trigger the warning intentionally by using a geographic CRS with a\n",
"# rectilinear (flat) mesh. The haversine metric is still used for all\n",
"# neighbour queries; the warning just flags the mesh *layout*.\n",
"\n",
"with warnings.catch_warnings(record=True) as caught:\n",
" warnings.simplefilter(\"always\")\n",
" G_geo = wmg.create.create_all_graph_components(\n",
" coords=lonlat_coords,\n",
" m2m_connectivity=\"flat\",\n",
" g2m_connectivity=\"nearest_neighbour\",\n",
" m2g_connectivity=\"nearest_neighbour\",\n",
" graph_crs=ccrs.PlateCarree(),\n",
" )\n",
"\n",
"print(f\"Warnings raised: {len(caught)}\")\n",
"for w in caught:\n",
" print(f\"\\n[{w.category.__name__}] {w.message}\")\n",
"\n",
"lens_geo = [d[\"len\"] for _, _, d in G_geo.edges(data=True) if \"len\" in d]\n",
"print(f\"\\nGeographic graph: {G_geo.number_of_nodes()} nodes, {G_geo.number_of_edges()} edges\")\n",
"print(f\"Edge 'len' range: {min(lens_geo)/1e3:.1f} – {max(lens_geo)/1e3:.1f} km (haversine, metres)\")"
]
},
{
"cell_type": "markdown",
"id": "2b15f521",
"metadata": {},
"source": [
"## 7. Running the Test Suite\n",
"\n",
"Run the new tests from the notebook. All tests should pass."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "16fdd506",
"metadata": {},
"outputs": [],
"source": [
"import subprocess, sys\n",
"\n",
"result = subprocess.run(\n",
" [sys.executable, \"-m\", \"pytest\", \"tests/test_spatial_index.py\", \"-v\", \"--tb=short\"],\n",
" capture_output=True,\n",
" text=True,\n",
" cwd=\"..\", # run from repo root (parent of docs/)\n",
")\n",
"print(result.stdout)\n",
"if result.returncode != 0:\n",
" print(result.stderr)"
]
}
],
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ dependencies = [
"networkx>=3.3",
"scipy>=1.13.0",
"pyproj>=3.7.0",
"scikit-learn>=1.3.0",
]
requires-python = ">=3.10"
readme = "README.md"
Expand Down
1 change: 1 addition & 0 deletions src/weather_model_graphs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
replace_node_labels_with_unique_ids,
split_graph_by_edge_attribute,
)
from .spatial import SpatialCoordinateValuesSelector
Loading