-
Notifications
You must be signed in to change notification settings - Fork 48
feat: auto-detect distance metric from CRS via SpatialCoordinateValue… #86
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
FAbdullah17
wants to merge
22
commits into
mllam:main
Choose a base branch
from
FAbdullah17:feature/distance-engine
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 1 commit
Commits
Show all changes
22 commits
Select commit
Hold shift + click to select a range
f0c19a4
feat: auto-detect distance metric from CRS via SpatialCoordinateValue…
FAbdullah17 14f719f
fix: address review feedback for CRS-aware distance metric integratio…
FAbdullah17 7ca16dd
refactor(hierarchical): rename selector variable to spatial_coord_sel…
FAbdullah17 d8fbae4
feat(base): warn when graph_crs is None and euclidean is assumed
FAbdullah17 718afe5
refactor(base): use loguru warning for rectilinear geographic notice
FAbdullah17 f5bb64a
chore(base): trim rectilinear geographic warning message
FAbdullah17 e050204
refactor(base): require explicit distance_metric in connect_nodes_acr…
FAbdullah17 a4c5a50
test(spatial): capture loguru warnings instead of pytest.warns
FAbdullah17 8da3663
test(spatial): reword haversine unit wording to radians
FAbdullah17 bf76435
test: add haversine longitude wrap-around regression tests
FAbdullah17 96f2a99
refactor: require explicit distance_metric in hierarchical mesh builder
FAbdullah17 854e062
refactor: rename with_radius to within_radius
FAbdullah17 8dbfee1
feat: accept haversine within_radius radius in degrees
FAbdullah17 bce4d8c
feat: return haversine within_radius distances in degrees
FAbdullah17 8c1a610
feat: return haversine k_nearest_to distances in degrees
FAbdullah17 0d75615
test: assert non-zero haversine k_nearest_to distance in degrees
FAbdullah17 cbd09e9
test: align haversine k_nearest_to expectations with degree outputs
FAbdullah17 97f0e94
docs: add changelog entry for CRS-aware degree-based metric handling
FAbdullah17 b1cf19d
Merge branch 'main' into feature/distance-engine
leifdenby ee70506
test: make within_radius inclusive distance assertion deterministic
FAbdullah17 44b2cd7
chore: apply CI lint and notebook API fixes
FAbdullah17 949a71d
fix: add mesh_node_distance kwargs in notebook flat mesh examples
FAbdullah17 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,335 @@ | ||
| { | ||
| "cells": [ | ||
| { | ||
| "cell_type": "markdown", | ||
| "id": "419f7f83", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "# Auto-Detection of Distance Metric from CRS\n", | ||
| "\n", | ||
| "This notebook demonstrates the `SpatialCoordinateValuesSelector` class introduced\n", | ||
| "in [issue #75](https://github.com/mllam/weather-model-graphs/issues/75).\n", | ||
| "\n", | ||
| "The class wraps a ball-tree to provide efficient nearest-neighbour and radius\n", | ||
| "queries using the correct distance metric for the coordinate reference system (CRS):\n", | ||
| "\n", | ||
| "| CRS type | Distance metric | When to use |\n", | ||
| "|---|---|---|\n", | ||
| "| Geographic (lat/lon) — e.g. `ccrs.PlateCarree()` | **Haversine** | Global / sparse grids |\n", | ||
| "| Projected — e.g. `ccrs.LambertConformal()` | **Euclidean** | Regional / LAM grids |\n", | ||
| "\n", | ||
| "Distances returned by the haversine path are in **metres**; Euclidean distances are\n", | ||
| "in the same units as the input coordinates." | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "id": "fd3d7fce", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "## 1. Setup and Imports" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "fe32be32", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "import warnings\n", | ||
| "\n", | ||
| "import cartopy.crs as ccrs\n", | ||
| "import matplotlib.pyplot as plt\n", | ||
| "import numpy as np\n", | ||
| "\n", | ||
| "import weather_model_graphs as wmg\n", | ||
| "from weather_model_graphs.spatial import SpatialCoordinateValuesSelector\n", | ||
| "\n", | ||
| "print(\"weather_model_graphs imported OK\")\n", | ||
| "print(f\"SpatialCoordinateValuesSelector: {SpatialCoordinateValuesSelector}\")" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "id": "b0d01140", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "## 2. Euclidean Distance Queries (Projected CRS)\n", | ||
| "\n", | ||
| "For a projected coordinate system the coordinates are Cartesian (e.g. metres).\n", | ||
| "We use the `\"euclidean\"` metric." | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "6c276b4e", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "rng = np.random.default_rng(42)\n", | ||
| "\n", | ||
| "# 50 random Cartesian points in a 100 km × 100 km domain (units: metres)\n", | ||
| "n_pts = 50\n", | ||
| "projected_coords = rng.random((n_pts, 2)) * 1e5 # shape (50, 2)\n", | ||
| "\n", | ||
| "sel_euc = SpatialCoordinateValuesSelector(\"euclidean\", projected_coords)\n", | ||
| "print(f\"metric: {sel_euc.distance_metric}\")\n", | ||
| "\n", | ||
| "# ---- k-nearest-to ----\n", | ||
| "query_pt = np.array([5e4, 5e4]) # centre of the domain\n", | ||
| "idxs, dists = sel_euc.k_nearest_to(query_pt, k=5)\n", | ||
| "print(\"\\n5 nearest neighbours (Euclidean):\")\n", | ||
| "for rank, (i, d) in enumerate(zip(idxs, dists), 1):\n", | ||
| " print(f\" rank {rank}: index={i:3d} dist={d/1e3:.2f} km coord={projected_coords[i]}\")\n", | ||
| "\n", | ||
| "# ---- with_radius ----\n", | ||
| "radius_m = 2e4 # 20 km\n", | ||
| "idxs_r, dists_r = sel_euc.with_radius(query_pt, radius=radius_m)\n", | ||
| "print(f\"\\nPoints within {radius_m/1e3:.0f} km of centre: {len(idxs_r)}\")" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "79300d69", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "fig, axes = plt.subplots(1, 2, figsize=(12, 5))\n", | ||
| "\n", | ||
| "for ax, (highlight_idxs, title, radius) in zip(\n", | ||
| " axes,\n", | ||
| " [\n", | ||
| " (idxs, \"k=5 nearest neighbours\", None),\n", | ||
| " (idxs_r, f\"within_radius = {radius_m/1e3:.0f} km\", radius_m),\n", | ||
| " ],\n", | ||
| "):\n", | ||
| " ax.scatter(\n", | ||
| " projected_coords[:, 0] / 1e3,\n", | ||
| " projected_coords[:, 1] / 1e3,\n", | ||
| " c=\"lightgrey\", s=30, zorder=2, label=\"all points\",\n", | ||
| " )\n", | ||
| " ax.scatter(\n", | ||
| " projected_coords[highlight_idxs, 0] / 1e3,\n", | ||
| " projected_coords[highlight_idxs, 1] / 1e3,\n", | ||
| " c=\"steelblue\", s=60, zorder=3, label=\"selected\",\n", | ||
| " )\n", | ||
| " ax.scatter(*query_pt / 1e3, c=\"red\", s=100, marker=\"*\", zorder=4, label=\"query\")\n", | ||
| " if radius is not None:\n", | ||
| " circle = plt.Circle(query_pt / 1e3, radius / 1e3, fill=False, color=\"red\", lw=1.5, ls=\"--\")\n", | ||
| " ax.add_patch(circle)\n", | ||
| " ax.set_aspect(\"equal\")\n", | ||
| " ax.set_xlabel(\"x (km)\")\n", | ||
| " ax.set_ylabel(\"y (km)\")\n", | ||
| " ax.set_title(title)\n", | ||
| " ax.legend(loc=\"upper right\")\n", | ||
| "\n", | ||
| "fig.suptitle(\"Euclidean (projected CRS) distance queries\", fontsize=13)\n", | ||
| "plt.tight_layout()\n", | ||
| "plt.show()" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "id": "9ebfe0fb", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "## 3. Haversine Distance Queries (Geographic CRS)\n", | ||
| "\n", | ||
| "For a geographic CRS the coordinates are longitude / latitude in degrees.\n", | ||
| "We use the `\"haversine\"` metric, and **all distances are returned in metres**." | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "871d59b9", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "# 50 random lon/lat points over Europe (lon: -10 to 30, lat: 45 to 70)\n", | ||
| "geo_coords = np.column_stack([\n", | ||
| " rng.uniform(-10, 30, n_pts), # longitude\n", | ||
| " rng.uniform(45, 70, n_pts), # latitude\n", | ||
| "])\n", | ||
| "\n", | ||
| "sel_hav = SpatialCoordinateValuesSelector(\"haversine\", geo_coords)\n", | ||
| "print(f\"metric: {sel_hav.distance_metric}\")\n", | ||
| "\n", | ||
| "# Known result: 1 degree of longitude at latitude 55° ≈ 63,800 m\n", | ||
| "# (cos(55°) * 111,195 m/deg ≈ 63,800 m)\n", | ||
| "query_geo = np.array([10.0, 55.0]) # lon=10°, lat=55°\n", | ||
| "idxs_h, dists_h = sel_hav.k_nearest_to(query_geo, k=5)\n", | ||
| "print(\"\\n5 nearest neighbours (Haversine, distances in metres):\")\n", | ||
| "for rank, (i, d) in enumerate(zip(idxs_h, dists_h), 1):\n", | ||
| " print(f\" rank {rank}: index={i:3d} dist={d/1e3:7.1f} km lon={geo_coords[i,0]:.2f}° lat={geo_coords[i,1]:.2f}°\")\n", | ||
| "\n", | ||
| "# Radius query: 500 km around the query point\n", | ||
| "radius_500km = 500_000\n", | ||
| "idxs_hr, dists_hr = sel_hav.with_radius(query_geo, radius=radius_500km)\n", | ||
| "print(f\"\\nPoints within {radius_500km/1e3:.0f} km: {len(idxs_hr)}\")" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "id": "bcd4e330", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "## 4. CRS-Based Automatic Metric Selection (`for_crs`)\n", | ||
| "\n", | ||
| "The class-method `SpatialCoordinateValuesSelector.for_crs(crs, coords)` inspects\n", | ||
| "`crs.is_geographic` and picks the correct metric automatically — no manual choice needed." | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "65ec5fc4", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "cases = {\n", | ||
| " \"PlateCarree (geographic)\": ccrs.PlateCarree(),\n", | ||
| " \"Mercator (geographic)\": ccrs.Mercator(),\n", | ||
| " \"LambertConformal (projected)\": ccrs.LambertConformal(),\n", | ||
| " \"Stereographic (projected)\": ccrs.Stereographic(),\n", | ||
| " \"Mollweide (projected)\": ccrs.Mollweide(),\n", | ||
| "}\n", | ||
| "\n", | ||
| "print(f\"{'CRS':<40} {'is_geographic':<15} {'auto-selected metric'}\")\n", | ||
| "print(\"-\" * 75)\n", | ||
| "for name, crs in cases.items():\n", | ||
| " sel = SpatialCoordinateValuesSelector.for_crs(crs, projected_coords)\n", | ||
| " print(f\"{name:<40} {str(crs.is_geographic):<15} {sel.distance_metric}\")" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "id": "efd3e377", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "## 5. Integration with Graph Connectivity\n", | ||
| "\n", | ||
| "`SpatialCoordinateValuesSelector` is now used automatically inside\n", | ||
| "`wmg.create.create_all_graph_components()` for all nearest-neighbour and\n", | ||
| "radius queries. When you supply `graph_crs`, the correct metric is chosen\n", | ||
| "for you — no other changes are required at the call site." | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "46b71d7f", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "# --- Projected CRS (euclidean) ---\n", | ||
| "lon = np.linspace(0.0, 9.0, 8)\n", | ||
| "lat = np.linspace(50.0, 58.0, 8)\n", | ||
| "lo, la = np.meshgrid(lon, lat)\n", | ||
| "lonlat_coords = np.column_stack([lo.ravel(), la.ravel()])\n", | ||
| "\n", | ||
| "# Projected: convert to Oblique Mercator metres for demonstration\n", | ||
| "from pyproj import Transformer\n", | ||
| "transformer = Transformer.from_crs(\"EPSG:4326\", \"EPSG:3857\", always_xy=True)\n", | ||
| "x, y = transformer.transform(lonlat_coords[:, 0], lonlat_coords[:, 1])\n", | ||
| "proj_coords = np.column_stack([x, y])\n", | ||
| "\n", | ||
| "# Build graph with projected CRS — euclidean metric is used automatically\n", | ||
| "import pyproj\n", | ||
| "web_mercator = pyproj.CRS(\"EPSG:3857\")\n", | ||
| "G_proj = wmg.create.create_all_graph_components(\n", | ||
| " coords=proj_coords,\n", | ||
| " m2m_connectivity=\"flat\",\n", | ||
| " g2m_connectivity=\"nearest_neighbour\",\n", | ||
| " m2g_connectivity=\"nearest_neighbour\",\n", | ||
| " graph_crs=web_mercator,\n", | ||
| ")\n", | ||
| "lens_proj = [d[\"len\"] for _, _, d in G_proj.edges(data=True) if \"len\" in d]\n", | ||
| "print(f\"Projected graph: {G_proj.number_of_nodes()} nodes, {G_proj.number_of_edges()} edges\")\n", | ||
| "print(f\"Edge 'len' range: {min(lens_proj)/1e3:.1f} – {max(lens_proj)/1e3:.1f} km (euclidean, metres)\")" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "id": "1f8660d6", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "## 6. Rectilinear + Geographic CRS Warning\n", | ||
| "\n", | ||
| "When a rectilinear mesh is overlaid on geographic (lat/lon) coordinates a\n", | ||
| "`UserWarning` is raised automatically, because equally-spaced lon/lat values are\n", | ||
| "**not** equally spaced on a sphere." | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "45835a0a", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "# Trigger the warning intentionally by using a geographic CRS with a\n", | ||
| "# rectilinear (flat) mesh. The haversine metric is still used for all\n", | ||
| "# neighbour queries; the warning just flags the mesh *layout*.\n", | ||
| "\n", | ||
| "with warnings.catch_warnings(record=True) as caught:\n", | ||
| " warnings.simplefilter(\"always\")\n", | ||
| " G_geo = wmg.create.create_all_graph_components(\n", | ||
| " coords=lonlat_coords,\n", | ||
| " m2m_connectivity=\"flat\",\n", | ||
| " g2m_connectivity=\"nearest_neighbour\",\n", | ||
| " m2g_connectivity=\"nearest_neighbour\",\n", | ||
| " graph_crs=ccrs.PlateCarree(),\n", | ||
| " )\n", | ||
| "\n", | ||
| "print(f\"Warnings raised: {len(caught)}\")\n", | ||
| "for w in caught:\n", | ||
| " print(f\"\\n[{w.category.__name__}] {w.message}\")\n", | ||
| "\n", | ||
| "lens_geo = [d[\"len\"] for _, _, d in G_geo.edges(data=True) if \"len\" in d]\n", | ||
| "print(f\"\\nGeographic graph: {G_geo.number_of_nodes()} nodes, {G_geo.number_of_edges()} edges\")\n", | ||
| "print(f\"Edge 'len' range: {min(lens_geo)/1e3:.1f} – {max(lens_geo)/1e3:.1f} km (haversine, metres)\")" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "markdown", | ||
| "id": "2b15f521", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "## 7. Running the Test Suite\n", | ||
| "\n", | ||
| "Run the new tests from the notebook. All tests should pass." | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "id": "16fdd506", | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "import subprocess, sys\n", | ||
| "\n", | ||
| "result = subprocess.run(\n", | ||
| " [sys.executable, \"-m\", \"pytest\", \"tests/test_spatial_index.py\", \"-v\", \"--tb=short\"],\n", | ||
| " capture_output=True,\n", | ||
| " text=True,\n", | ||
| " cwd=\"..\", # run from repo root (parent of docs/)\n", | ||
| ")\n", | ||
| "print(result.stdout)\n", | ||
| "if result.returncode != 0:\n", | ||
| " print(result.stderr)" | ||
| ] | ||
| } | ||
| ], | ||
| "metadata": { | ||
| "language_info": { | ||
| "name": "python" | ||
| } | ||
| }, | ||
| "nbformat": 4, | ||
| "nbformat_minor": 5 | ||
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.