Skip to content

Commit 379e3d3

Browse files
committed
Kinda use map_partitions in indexed sjoin
xref geopandas#114 (comment). This is an alternative way of writing the sjoin for the case where both sides have spatial partitions, using just high-level DataFrame APIs instead of generating the low-level dask. There isn't a ton of advantage to this, because selecting the partitions generates a low-level graph, so you lose Blockwise fusion regardless.
1 parent 1073e40 commit 379e3d3

File tree

2 files changed

+57
-26
lines changed

2 files changed

+57
-26
lines changed

dask_geopandas/sjoin.py

Lines changed: 47 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@
77
from .core import from_geopandas, GeoDataFrame
88

99

10+
def partitions_are_unchanged(part_idxs: np.ndarray, npartitions: int) -> bool:
11+
"Whether selecting these partition indices would result in an identical DataFrame."
12+
return len(part_idxs) == npartitions and (part_idxs[:-1] < part_idxs[1:]).all()
13+
14+
1015
def sjoin(left, right, how="inner", op="intersects"):
1116
"""
1217
Spatial join of two GeoDataFrames.
@@ -58,33 +63,52 @@ def sjoin(left, right, how="inner", op="intersects"):
5863
how="inner",
5964
op="intersects",
6065
)
61-
parts_left = np.asarray(parts.index)
62-
parts_right = np.asarray(parts["index_right"].values)
63-
using_spatial_partitions = True
64-
else:
65-
# Unknown spatial partitions -> full cartesian (cross) product of all
66-
# combinations of the partitions of the left and right dataframe
67-
n_left = left.npartitions
68-
n_right = right.npartitions
69-
parts_left = np.repeat(np.arange(n_left), n_right)
70-
parts_right = np.tile(np.arange(n_right), n_left)
71-
using_spatial_partitions = False
66+
parts_left = parts.index.values
67+
parts_right = parts["index_right"].values
68+
# Sub-select just the partitions from each input we need---unless we need all of them.
69+
left_sub = (
70+
left
71+
if partitions_are_unchanged(parts_left, left.npartitions)
72+
else left.partitions[parts_left]
73+
)
74+
right_sub = (
75+
right
76+
if partitions_are_unchanged(parts_right, right.npartitions)
77+
else right.partitions[parts_right]
78+
)
79+
80+
joined = left_sub.map_partitions(
81+
geopandas.sjoin,
82+
right_sub,
83+
how,
84+
op,
85+
enforce_metadata=False,
86+
transform_divisions=False,
87+
align_dataframes=False,
88+
meta=meta,
89+
)
90+
91+
# TODO preserve spatial partitions of the output if only left has spatial
92+
# partitions
93+
joined.spatial_partitions = [
94+
left.spatial_partitions.iloc[l].intersection(
95+
right.spatial_partitions.iloc[r]
96+
)
97+
for l, r in zip(parts_left, parts_right)
98+
]
99+
return joined
100+
101+
# Unknown spatial partitions -> full cartesian (cross) product of all
102+
# combinations of the partitions of the left and right dataframe
103+
n_left = left.npartitions
104+
n_right = right.npartitions
105+
parts_left = np.repeat(np.arange(n_left), n_right)
106+
parts_right = np.tile(np.arange(n_right), n_left)
72107

73108
dsk = {}
74-
new_spatial_partitions = []
75109
for i, (l, r) in enumerate(zip(parts_left, parts_right)):
76110
dsk[(name, i)] = (geopandas.sjoin, (left._name, l), (right._name, r), how, op)
77-
# TODO preserve spatial partitions of the output if only left has spatial
78-
# partitions
79-
if using_spatial_partitions:
80-
lr = left.spatial_partitions.iloc[l]
81-
rr = right.spatial_partitions.iloc[r]
82-
# extent = lr.intersection(rr).buffer(buffer).intersection(lr.union(rr))
83-
extent = lr.intersection(rr)
84-
new_spatial_partitions.append(extent)
85111

86112
divisions = [None] * (len(dsk) + 1)
87113
graph = HighLevelGraph.from_collections(name, dsk, dependencies=[left, right])
88-
if not using_spatial_partitions:
89-
new_spatial_partitions = None
90-
return GeoDataFrame(graph, name, meta, divisions, new_spatial_partitions)
114+
return GeoDataFrame(graph, name, meta, divisions, None)

tests/test_sjoin.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,24 @@
1+
import pytest
2+
13
import geopandas
24
from geopandas.testing import assert_geodataframe_equal
35

46
import dask_geopandas
57

68

7-
def test_sjoin_dask_geopandas():
9+
@pytest.mark.parametrize(
10+
"npartitions_left, npartitions_right", [(4, 4), (1, 3), (3, 1), (3, 4)]
11+
)
12+
def test_sjoin_dask_geopandas(npartitions_left, npartitions_right):
813
df_points = geopandas.read_file(geopandas.datasets.get_path("naturalearth_cities"))
9-
ddf_points = dask_geopandas.from_geopandas(df_points, npartitions=4)
14+
ddf_points = dask_geopandas.from_geopandas(df_points, npartitions=npartitions_left)
1015

1116
df_polygons = geopandas.read_file(
1217
geopandas.datasets.get_path("naturalearth_lowres")
1318
)
14-
ddf_polygons = dask_geopandas.from_geopandas(df_polygons, npartitions=4)
19+
ddf_polygons = dask_geopandas.from_geopandas(
20+
df_polygons, npartitions=npartitions_right
21+
)
1522

1623
expected = geopandas.sjoin(df_points, df_polygons, op="within", how="inner")
1724
expected = expected.sort_index()

0 commit comments

Comments
 (0)