Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 14 additions & 13 deletions src/weather_model_graphs/create/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
split_graph_by_edge_attribute,
split_on_edge_attribute_existance,
)
from ..spatial_index import create_spatial_index
from .grid import create_grid_graph_nodes
from .mesh.kinds.flat import (
create_flat_multiscale_mesh_graph,
Expand Down Expand Up @@ -255,12 +256,12 @@ def connect_nodes_across_graphs(
Graph containing the nodes in `G_source` and `G_target` and directed edges
from nodes in `G_source` to nodes in `G_target`
"""
source_nodes_list = list(G_source.nodes)
source_nodes_list = sorted(G_source.nodes) # Sort nodes for consistent indexing
target_nodes_list = list(G_target.nodes)

# build kd tree for source nodes (e.g. the mesh nodes when constructing m2g)
xy_source = np.array([G_source.nodes[node]["pos"] for node in G_source.nodes])
kdt_s = scipy.spatial.KDTree(xy_source)
# build spatial index for source nodes (e.g. the mesh nodes when constructing m2g)
xy_source = np.array([G_source.nodes[node]["pos"] for node in source_nodes_list])
spatial_index = create_spatial_index(xy_source, method="kdtree")

# Determine method and perform checks once
# Conditionally define _find_neighbour_node_idxs_in_source_mesh for use in
Expand Down Expand Up @@ -325,8 +326,8 @@ def _edge_filter(edge_prop):
)

def _find_neighbour_node_idxs_in_source_mesh(xy_target):
neigh_idx = kdt_s.query(xy_target, 1)[1]
return [neigh_idx]
distances, indices = spatial_index.query(np.array([xy_target]), k=1)
return indices[0].tolist() # Convert to list to match original

elif method == "nearest_neighbours":
if max_num_neighbours is None:
Expand All @@ -339,8 +340,8 @@ def _find_neighbour_node_idxs_in_source_mesh(xy_target):
)

def _find_neighbour_node_idxs_in_source_mesh(xy_target):
neigh_idxs = kdt_s.query(xy_target, max_num_neighbours)[1]
return neigh_idxs
distances, indices = spatial_index.query(np.array([xy_target]), k=max_num_neighbours)
return indices[0].tolist() # Convert to list

elif method == "within_radius":
if max_num_neighbours is not None:
Expand All @@ -351,13 +352,13 @@ def _find_neighbour_node_idxs_in_source_mesh(xy_target):
if max_dist is not None:
if rel_max_dist is not None:
raise Exception(
"to use `witin_radius` method you should only set one of `max_dist` or `rel_max_dist"
"to use `within_radius` method you should only set one of `max_dist` or `rel_max_dist"
)
query_dist = max_dist
elif rel_max_dist is not None:
if max_dist is not None:
raise Exception(
"to use `witin_radius` method you should only set one of `max_dist` or `rel_max_dist"
"to use `within_radius` method you should only set one of `max_dist` or `rel_max_dist"
)
# Figure out longest edge in (lowest level) mesh graph
longest_edge = 0.0
Expand Down Expand Up @@ -386,12 +387,12 @@ def _find_neighbour_node_idxs_in_source_mesh(xy_target):
query_dist = longest_edge * rel_max_dist
else:
raise Exception(
"to use `witin_radius` method you shold set `max_dist` or `rel_max_dist"
"to use `within_radius` method you should set `max_dist` or `rel_max_dist"
)

def _find_neighbour_node_idxs_in_source_mesh(xy_target):
neigh_idxs = kdt_s.query_ball_point(xy_target, query_dist)
return neigh_idxs
distances, indices = spatial_index.query(np.array([xy_target]), radius=query_dist)
return indices.tolist() # Convert to list

else:
raise NotImplementedError(method)
Expand Down
213 changes: 213 additions & 0 deletions src/weather_model_graphs/visualise/spatial_index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
"""
Spatial indexing utilities for efficient neighbor queries in weather model graphs.

This module provides optimized spatial indexing using KD-Trees and Ball Trees
for fast neighbor searches, reducing complexity from O(N²) to O(N log N).
"""

from abc import ABC, abstractmethod
from typing import List, Tuple, Union

import numpy as np
from scipy.spatial import KDTree, cKDTree

# Optional BallTree
try:
from sklearn.neighbors import BallTree
HAS_SKLEARN = True
except ImportError:
HAS_SKLEARN = False


class SpatialIndex(ABC):
"""Abstract base class for spatial indexing."""

def __init__(self, points: np.ndarray):
"""
Initialize spatial index with points.

Parameters
----------
points : np.ndarray
Array of shape (N, D) where N is number of points, D is dimension.
"""
self.points = points
self._build_index()

@abstractmethod
def _build_index(self):
"""Build the spatial index."""
pass

@abstractmethod
def query(self, query_points: np.ndarray, k: int = 1, radius: float = None) -> Tuple[np.ndarray, np.ndarray]:
"""
Query the index for nearest neighbors.

Parameters
----------
query_points : np.ndarray
Points to query, shape (M, D)
k : int
Number of nearest neighbors
radius : float
Search radius (for radius queries)

Returns
-------
Tuple[np.ndarray, np.ndarray]
Distances and indices of neighbors
"""
pass


class KDTreeIndex(SpatialIndex):
"""KD-Tree based spatial index using scipy."""

def _build_index(self):
"""Build KD-Tree index."""
self.index = cKDTree(self.points) # Use cKDTree for better performance

def query(self, query_points: np.ndarray, k: int = 1, radius: float = None) -> Tuple[np.ndarray, np.ndarray]:
"""
Query KD-Tree for nearest neighbors.

Parameters
----------
query_points : np.ndarray
Points to query, shape (M, D)
k : int
Number of nearest neighbors
radius : float
Search radius (ignored for k-nearest, used for radius queries)

Returns
-------
Tuple[np.ndarray, np.ndarray]
Distances and indices of neighbors
"""
if radius is not None:
# Radius query
indices = self.index.query_ball_point(query_points, radius)
# For consistency, return distances and indices
distances = []
flat_indices = []
for i, idx_list in enumerate(indices):
if idx_list:
dists = np.linalg.norm(self.points[idx_list] - query_points[i], axis=1)
distances.extend(dists)
flat_indices.extend(idx_list)
else:
distances.append(np.inf)
flat_indices.append(-1)
return np.array(distances), np.array(flat_indices)
else:
# k-nearest neighbors
distances, indices = self.index.query(query_points, k=k)
# Handle scalar case when k=1
if k == 1:
distances = distances.reshape(-1, 1)
indices = indices.reshape(-1, 1)
return distances, indices


class BallTreeIndex(SpatialIndex):
"""Ball Tree based spatial index using sklearn."""

def __init__(self, points: np.ndarray):
if not HAS_SKLEARN:
raise ImportError("BallTree requires scikit-learn. Install with: pip install scikit-learn")
super().__init__(points)

def _build_index(self):
"""Build Ball Tree index."""
self.index = BallTree(self.points)

def query(self, query_points: np.ndarray, k: int = 1, radius: float = None) -> Tuple[np.ndarray, np.ndarray]:
"""
Query Ball Tree for nearest neighbors.

Parameters
----------
query_points : np.ndarray
Points to query, shape (M, D)
k : int
Number of nearest neighbors
radius : float
Search radius (ignored for k-nearest, used for radius queries)

Returns
-------
Tuple[np.ndarray, np.ndarray]
Distances and indices of neighbors
"""
if radius is not None:
# Radius query
indices = self.index.query_radius(query_points, radius)
distances = []
flat_indices = []
for i, idx_list in enumerate(indices):
if len(idx_list) > 0:
dists = np.linalg.norm(self.points[idx_list] - query_points[i], axis=1)
distances.extend(dists)
flat_indices.extend(idx_list)
else:
distances.append(np.inf)
flat_indices.append(-1)
return np.array(distances), np.array(flat_indices)
else:
# k-nearest neighbors
distances, indices = self.index.query(query_points, k=k)
return distances, indices


def create_spatial_index(points: np.ndarray, method: str = "kdtree") -> SpatialIndex:
"""
Create a spatial index for efficient neighbor queries.

Parameters
----------
points : np.ndarray
Array of shape (N, D) where N is number of points, D is dimension
method : str
Indexing method: "kdtree" or "balltree"

Returns
-------
SpatialIndex
Configured spatial index
"""
if method.lower() == "kdtree":
return KDTreeIndex(points)
elif method.lower() == "balltree":
return BallTreeIndex(points)
else:
raise ValueError(f"Unknown method: {method}. Use 'kdtree' or 'balltree'")


def find_neighbors_vectorized(
query_points: np.ndarray,
index: SpatialIndex,
k: int = 1,
radius: float = None
) -> Tuple[np.ndarray, np.ndarray]:
"""
Vectorized neighbor search using spatial index.

Parameters
----------
query_points : np.ndarray
Points to query, shape (M, D)
index : SpatialIndex
Pre-built spatial index
k : int
Number of nearest neighbors
radius : float
Search radius

Returns
-------
Tuple[np.ndarray, np.ndarray]
Distances and indices of neighbors
"""
return index.query(query_points, k=k, radius=radius)