-
Notifications
You must be signed in to change notification settings - Fork 94
Expand file tree
/
Copy path__init__.py
More file actions
2948 lines (2613 loc) · 117 KB
/
Copy path__init__.py
File metadata and controls
2948 lines (2613 loc) · 117 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""Out-of-core CRS reprojection and multi-raster merge.
Public API
----------
reproject(raster, target_crs, ...)
Reproject a DataArray to a new coordinate reference system.
merge(rasters, ...)
Merge multiple DataArrays into a single mosaic.
"""
from __future__ import annotations
import math
import numpy as np
import xarray as xr
from xrspatial.utils import _dask_task_name_kwargs, _validate_raster
from ._crs_utils import _detect_band_nodata, _detect_nodata, _detect_source_crs, _resolve_crs
from ._grid import (_MAX_OUTPUT_PIXELS, _chunk_bounds, _compute_chunk_layout, _compute_output_grid,
_make_output_coords, _validate_grid_params)
from ._interpolate import _resample_cupy_native, _resample_numpy, _validate_resampling
from ._itrf import itrf_transform
from ._itrf import list_frames as itrf_frames
from ._merge import _merge_arrays_numpy, _validate_strategy
from ._transform import ApproximateTransform
from ._vertical import (depth_to_ellipsoidal, ellipsoidal_to_depth, ellipsoidal_to_orthometric,
geoid_height, geoid_height_raster, orthometric_to_ellipsoidal)
__all__ = [
'reproject', 'merge',
'geoid_height', 'geoid_height_raster',
'ellipsoidal_to_orthometric', 'orthometric_to_ellipsoidal',
'depth_to_ellipsoidal', 'ellipsoidal_to_depth',
'itrf_transform', 'itrf_frames',
]
# ---------------------------------------------------------------------------
# Source geometry helpers
# ---------------------------------------------------------------------------
_Y_NAMES = {'y', 'lat', 'latitude', 'Y', 'Lat', 'Latitude'}
_X_NAMES = {'x', 'lon', 'longitude', 'X', 'Lon', 'Longitude'}
# Output byte budget above which merge() auto-promotes an in-memory mosaic
# to the lazy dask path instead of allocating the whole array.
_MERGE_OOM_THRESHOLD = 512 * 1024 * 1024 # 512 MB
# Byte budget above which reproject() auto-promotes an in-memory raster
# (numpy or cupy) to the chunked dask path. Compared against the input
# array size and one float64 output array independently, not against
# total memory: the eager numpy path holds ~7 output-sized float64
# temporaries (coordinate grids, pixel-index grids, result), so
# eager-path peak RSS can reach ~7x this budget. That multiplier is why a
# small input upsampled to a large output exhausted memory long before
# the _MAX_OUTPUT_PIXELS guard tripped (#3267). The same budget gates the
# cupy promotion added in #3281.
_REPROJECT_OOM_THRESHOLD = 512 * 1024 * 1024 # 512 MB
# Map friendly vertical datum tokens to EPSG codes so attrs['vertical_crs']
# from reproject output matches the convention used by xrspatial.geotiff,
# which also writes EPSG ints to attrs['vertical_crs'].
_VERTICAL_DATUM_EPSG = {
'EGM96': 5773, # EGM96 height
'EGM2008': 3855, # EGM2008 height
'ellipsoidal': 4979, # WGS 84 (3D, ellipsoidal height)
}
# Sentinel marking the deprecated ``src_vertical_crs`` / ``tgt_vertical_crs``
# kwargs as "not passed". Distinct from None so we can tell an explicit
# ``src_vertical_crs=None`` apart from the default and only warn when the
# caller actually used the old name.
_DEPRECATED = object()
def _resolve_deprecated_vertical_kwarg(old_name, old_val, new_name, new_val):
"""Map a deprecated vertical-CRS kwarg onto its renamed replacement.
Emits a ``DeprecationWarning`` when the old name is used and rejects
passing both the old and new spellings at once.
"""
if old_val is _DEPRECATED:
return new_val
import warnings
warnings.warn(
f"reproject(): {old_name!r} is deprecated, use {new_name!r} instead.",
DeprecationWarning,
stacklevel=3,
)
if new_val is not None:
raise TypeError(
f"reproject(): pass either {new_name!r} or the deprecated "
f"{old_name!r}, not both."
)
return old_val
def _find_spatial_dims(raster):
"""Find the y and x dimension names, handling multi-band rasters.
Returns (ydim, xdim). Checks dim names first, falls back to
assuming the last two non-band dims are spatial.
"""
dims = raster.dims
ydim = xdim = None
for d in dims:
if d in _Y_NAMES:
ydim = d
elif d in _X_NAMES:
xdim = d
if ydim is not None and xdim is not None:
return ydim, xdim
# Fallback: last two dims
return dims[-2], dims[-1]
# Default tolerance for the regular-spacing check. Coordinates loaded from
# real GeoTIFFs can drift a few ULPs from perfectly uniform after pixel-to-
# world transforms, so 1e-6 (relative) is loose enough to accept those while
# still catching the single-pixel perturbation case in #2184.
_REGULAR_COORD_RTOL = 1e-6
def _validate_regular_axis(coords, axis_name, func_name, rtol=_REGULAR_COORD_RTOL):
"""Validate that 1-D coordinate array is strictly monotonic and regular.
Pixel-resolution math in `_source_bounds` and the chunk workers assumes
a uniform grid. Without this check, irregular or non-monotonic coords
silently produce wrong georeferencing (see #2184).
Parameters
----------
coords : array-like
1-D coordinate values along one axis.
axis_name : str
Name of the axis ("x" or "y") for the error message.
func_name : str
Calling function name for the error prefix.
rtol : float
Relative tolerance for spacing regularity.
Raises
------
ValueError
If coords contain non-finite values, are not strictly monotonic,
or have spacing that varies by more than ``rtol`` relative to the
median step.
"""
arr = np.asarray(coords)
if arr.ndim != 1:
raise ValueError(
f"{func_name}(): coordinate '{axis_name}' must be 1-D, "
f"got shape {arr.shape}."
)
if arr.size < 2:
# A single-pixel raster has no spacing to validate; the caller
# will fall back to res=1.0 in _source_bounds, which is fine.
return
if not np.all(np.isfinite(arr)):
raise ValueError(
f"{func_name}(): coordinate '{axis_name}' contains non-finite "
f"values (NaN or inf)."
)
# np.asarray skips the copy when arr is already float64; np.diff promotes
# ints to int64, which is fine but we want float steps for the median /
# tolerance math below.
diffs = np.diff(np.asarray(arr, dtype=np.float64))
# Strict monotonicity: every step has the same sign and is non-zero.
# `diffs > 0` AND `diffs < 0` are both False for zero steps (repeated
# coords), so the combined check rejects them. Do NOT replace this with
# a sign-only test like `np.all(np.sign(diffs) == np.sign(diffs[0]))` --
# that variant accepts zero steps and lets a repeated coord through.
if not (np.all(diffs > 0) or np.all(diffs < 0)):
raise ValueError(
f"{func_name}(): coordinate '{axis_name}' must be strictly "
f"monotonic (all ascending or all descending). The reproject "
f"pipeline assumes a uniformly-spaced grid; see #2184."
)
median_step = float(np.median(diffs))
abs_med = abs(median_step)
deviation = np.abs(diffs - median_step)
worst = float(np.max(deviation))
if worst > rtol * abs_med:
# Report the index of the worst step in the original coords so the
# caller can locate the offending sample without re-running diff.
worst_idx = int(np.argmax(deviation))
raise ValueError(
f"{func_name}(): coordinate '{axis_name}' is not regularly "
f"spaced. Median step is {median_step!r}; worst deviation is "
f"{worst!r} at index {worst_idx} (between {axis_name}[{worst_idx}]"
f"={float(arr[worst_idx])!r} and {axis_name}[{worst_idx + 1}]"
f"={float(arr[worst_idx + 1])!r}). The reproject pipeline "
f"assumes a uniformly-spaced grid; see #2184."
)
def _validate_source_coords(raster, func_name):
"""Validate both spatial axes of a raster before any reproject work."""
ydim, xdim = _find_spatial_dims(raster)
_validate_regular_axis(raster.coords[ydim].values, 'y', func_name)
_validate_regular_axis(raster.coords[xdim].values, 'x', func_name)
def _source_bounds(raster):
"""Extract (left, bottom, right, top) from a DataArray's coordinates."""
ydim, xdim = _find_spatial_dims(raster)
y = raster.coords[ydim].values
x = raster.coords[xdim].values
# Compute pixel-edge bounds from pixel-center coords
if len(y) > 1:
res_y = abs(float(y[1] - y[0]))
else:
res_y = 1.0
if len(x) > 1:
res_x = abs(float(x[1] - x[0]))
else:
res_x = 1.0
x_min, x_max = float(np.min(x)), float(np.max(x))
y_min, y_max = float(np.min(y)), float(np.max(y))
left = x_min - res_x / 2
right = x_max + res_x / 2
bottom = y_min - res_y / 2
top = y_max + res_y / 2
return (left, bottom, right, top)
def _is_y_descending(raster):
"""Check if Y axis goes from top (large) to bottom (small)."""
ydim, _ = _find_spatial_dims(raster)
y = raster.coords[ydim].values
if len(y) < 2:
return True
return float(y[0]) > float(y[-1])
def _is_x_descending(raster):
"""Check if X axis goes from right (large) to left (small).
Mirrors :func:`_is_y_descending` for the horizontal axis. The default
convention for a single-column raster is ascending x (matching
:func:`_make_output_coords` which always emits ascending x).
"""
_, xdim = _find_spatial_dims(raster)
x = raster.coords[xdim].values
if len(x) < 2:
return False
return float(x[0]) > float(x[-1])
# ---------------------------------------------------------------------------
# Per-chunk coordinate transform
# ---------------------------------------------------------------------------
def _transform_coords(transformer, chunk_bounds, chunk_shape,
transform_precision, src_crs=None, tgt_crs=None):
"""Compute source CRS coordinates for every output pixel.
When *transform_precision* is 0, every pixel is transformed through
pyproj exactly (same strategy as GDAL/rasterio). Otherwise an
approximate bilinear control-grid interpolation is used.
For common CRS pairs (WGS84/NAD83 <-> UTM, WGS84 <-> Web Mercator),
a Numba JIT fast path bypasses pyproj entirely for ~30x speedup.
Returns
-------
src_y, src_x : ndarray (height, width)
"""
# Try Numba fast path for common projections.
# transform_precision == 0 is the documented escape hatch for exact
# per-pixel pyproj transforms, so skip the approximate fast path then.
if (transform_precision != 0
and src_crs is not None and tgt_crs is not None):
try:
from ._projections import try_numba_transform
result = try_numba_transform(
src_crs, tgt_crs, chunk_bounds, chunk_shape,
)
if result is not None:
return result
except (ImportError, ModuleNotFoundError):
pass # fall through to pyproj
height, width = chunk_shape
left, bottom, right, top = chunk_bounds
res_x = (right - left) / width
res_y = (top - bottom) / height
if transform_precision == 0:
# Exact per-pixel transform via pyproj bulk API.
# Process in row strips to keep memory bounded and improve
# cache locality for large rasters.
out_x_1d = left + (np.arange(width, dtype=np.float64) + 0.5) * res_x
src_x_out = np.empty((height, width), dtype=np.float64)
src_y_out = np.empty((height, width), dtype=np.float64)
strip = 256
for r0 in range(0, height, strip):
r1 = min(r0 + strip, height)
n_rows = r1 - r0
out_y_strip = top - (np.arange(r0, r1, dtype=np.float64) + 0.5) * res_y
# Broadcast to (n_rows, width) without allocating a full copy
sx, sy = transformer.transform(
np.tile(out_x_1d, n_rows),
np.repeat(out_y_strip, width),
)
src_x_out[r0:r1] = np.asarray(sx, dtype=np.float64).reshape(n_rows, width)
src_y_out[r0:r1] = np.asarray(sy, dtype=np.float64).reshape(n_rows, width)
return src_y_out, src_x_out
# Approximate: bilinear interpolation on a coarse control grid.
approx = ApproximateTransform(
transformer, chunk_bounds, chunk_shape,
precision=transform_precision,
)
row_grid = np.arange(height, dtype=np.float64)[:, np.newaxis]
col_grid = np.arange(width, dtype=np.float64)[np.newaxis, :]
row_grid = np.broadcast_to(row_grid, (height, width))
col_grid = np.broadcast_to(col_grid, (height, width))
return approx(row_grid, col_grid)
# ---------------------------------------------------------------------------
# Per-chunk worker functions
# ---------------------------------------------------------------------------
def _reproject_chunk_numpy(
source_data, source_bounds_tuple, source_shape, source_y_desc,
src_wkt, tgt_wkt,
chunk_bounds_tuple, chunk_shape,
resampling, nodata, transform_precision,
source_x_desc=False,
band_nodata=None,
):
"""Reproject a single output chunk (numpy backend).
Called inside ``dask.delayed`` for the dask path, or directly for numpy.
CRS objects are passed as WKT strings for pickle safety.
``source_x_desc`` mirrors ``source_y_desc`` for the horizontal axis:
when True, source column 0 is at the maximum x and column ``src_w-1``
is at the minimum x. Defaults to False so older callers keep working.
"""
from ._crs_utils import _crs_from_wkt
src_crs = _crs_from_wkt(src_wkt)
tgt_crs = _crs_from_wkt(tgt_wkt)
# Try Numba fast path first (avoids creating pyproj Transformer).
# transform_precision == 0 forces the exact pyproj path, so skip Numba.
numba_result = None
if transform_precision != 0:
try:
from ._projections import try_numba_transform
numba_result = try_numba_transform(
src_crs, tgt_crs, chunk_bounds_tuple, chunk_shape,
)
except (ImportError, ModuleNotFoundError):
pass
if numba_result is not None:
src_y, src_x = numba_result
else:
# Fallback: create pyproj Transformer (expensive)
from ._crs_utils import _require_pyproj
pyproj = _require_pyproj()
transformer = pyproj.Transformer.from_crs(
tgt_crs, src_crs, always_xy=True
)
# Pass src_crs/tgt_crs as None: the numba fast path was already
# tried above and returned None, and _transform_coords gates its
# own try_numba_transform retry on both CRSes being non-None.
# Re-trying would repeat the CRS param parsing and chunk-sized
# coordinate allocations for nothing (#3106).
src_y, src_x = _transform_coords(
transformer, chunk_bounds_tuple, chunk_shape, transform_precision,
)
# Convert source CRS coordinates to source pixel coordinates
src_left, src_bottom, src_right, src_top = source_bounds_tuple
src_h, src_w = source_shape
src_res_x = (src_right - src_left) / src_w
src_res_y = (src_top - src_bottom) / src_h
if source_x_desc:
src_col_px = (src_right - src_x) / src_res_x - 0.5
else:
src_col_px = (src_x - src_left) / src_res_x - 0.5
if source_y_desc:
src_row_px = (src_top - src_y) / src_res_y - 0.5
else:
src_row_px = (src_y - src_bottom) / src_res_y - 0.5
# Determine source window needed
r_min = np.nanmin(src_row_px)
r_max = np.nanmax(src_row_px)
c_min = np.nanmin(src_col_px)
c_max = np.nanmax(src_col_px)
# 3-D source: empty-chunk returns must carry the band axis or the
# dask map_blocks template (which is 3-D for 3-D sources) sees a
# shape mismatch (#2027).
if source_data.ndim == 3:
_empty_shape = (*chunk_shape, source_data.shape[2])
else:
_empty_shape = chunk_shape
# Empty-chunk fills must match the dtype the data path returns and the
# dask template advertises (#3096): integer sources round-trip back to
# their dtype, floats stay float64. Without this, a single no-overlap
# chunk promoted the whole assembled dask output to float64. The
# resolved nodata is guaranteed representable for integer rasters
# (#2185/#2572).
if np.issubdtype(source_data.dtype, np.integer):
_empty_dtype = source_data.dtype
else:
_empty_dtype = np.float64
if not np.isfinite(r_min) or not np.isfinite(r_max):
return np.full(_empty_shape, nodata, dtype=_empty_dtype)
if not np.isfinite(c_min) or not np.isfinite(c_max):
return np.full(_empty_shape, nodata, dtype=_empty_dtype)
r_min = int(np.floor(r_min)) - 2
r_max = int(np.ceil(r_max)) + 3
c_min = int(np.floor(c_min)) - 2
c_max = int(np.ceil(c_max)) + 3
# Check overlap
if r_min >= src_h or r_max <= 0 or c_min >= src_w or c_max <= 0:
return np.full(_empty_shape, nodata, dtype=_empty_dtype)
# Clip to source bounds
r_min_clip = max(0, r_min)
r_max_clip = min(src_h, r_max)
c_min_clip = max(0, c_min)
c_max_clip = min(src_w, c_max)
# Guard: cap source window to prevent OOM if projection maps a small
# output chunk to a huge source area (e.g. polar stereographic edges).
_MAX_WINDOW_PIXELS = 64 * 1024 * 1024 # 64 Mpix (~512 MB for float64)
win_pixels = (r_max_clip - r_min_clip) * (c_max_clip - c_min_clip)
if win_pixels > _MAX_WINDOW_PIXELS:
return np.full(_empty_shape, nodata, dtype=_empty_dtype)
# Extract source window
window = source_data[r_min_clip:r_max_clip, c_min_clip:c_max_clip]
if hasattr(window, 'compute'):
window = window.compute()
window = np.asarray(window)
orig_dtype = window.dtype
# Adjust coordinates relative to window
local_row = src_row_px - r_min_clip
local_col = src_col_px - c_min_clip
# Multi-band: reproject each band separately, share coordinates
if window.ndim == 3:
n_bands = window.shape[2]
bands = []
for b in range(n_bands):
band_data = window[:, :, b].astype(np.float64)
# Mask this band with its own source sentinel when the raster
# declares per-band nodata; otherwise fall back to the single
# resolved sentinel (#2647).
src_nd = band_nodata[b] if band_nodata is not None else nodata
if not np.isnan(src_nd):
band_data[band_data == src_nd] = np.nan
band_result = _resample_numpy(band_data, local_row, local_col,
resampling=resampling, nodata=nodata)
if np.issubdtype(orig_dtype, np.integer):
info = np.iinfo(orig_dtype)
band_result = np.clip(np.round(band_result), info.min, info.max).astype(orig_dtype)
bands.append(band_result)
return np.stack(bands, axis=-1)
# Single-band path
window = window.astype(np.float64)
# Convert sentinel nodata to NaN so numba kernels can detect it
if not np.isnan(nodata):
window[window == nodata] = np.nan
result = _resample_numpy(window, local_row, local_col,
resampling=resampling, nodata=nodata)
# Clamp and cast back for integer source dtypes
if np.issubdtype(orig_dtype, np.integer):
info = np.iinfo(orig_dtype)
result = np.clip(np.round(result), info.min, info.max).astype(orig_dtype)
return result
def _reproject_chunk_cupy(
source_data, source_bounds_tuple, source_shape, source_y_desc,
src_wkt, tgt_wkt,
chunk_bounds_tuple, chunk_shape,
resampling, nodata, transform_precision,
source_x_desc=False,
band_nodata=None,
):
"""CuPy variant of ``_reproject_chunk_numpy``.
``source_x_desc`` carries the horizontal direction flag (same meaning
as in :func:`_reproject_chunk_numpy`).
"""
import cupy as cp
from ._crs_utils import _crs_from_wkt
src_crs = _crs_from_wkt(src_wkt)
tgt_crs = _crs_from_wkt(tgt_wkt)
# 3-D source: empty-chunk returns must carry the band axis to match
# the dask+cupy map_blocks template (#2027).
if source_data.ndim == 3:
_empty_shape = (*chunk_shape, source_data.shape[2])
else:
_empty_shape = chunk_shape
# Empty-chunk fills must match the dtype the data path returns (#3096);
# see the matching block in _reproject_chunk_numpy.
if np.issubdtype(source_data.dtype, np.integer):
_empty_dtype = source_data.dtype
else:
_empty_dtype = np.float64
# Try CUDA transform first (keeps coordinates on-device).
# transform_precision == 0 forces the exact pyproj path, so skip CUDA.
cuda_result = None
if (transform_precision != 0
and src_crs is not None and tgt_crs is not None):
try:
from ._projections_cuda import try_cuda_transform
cuda_result = try_cuda_transform(
src_crs, tgt_crs, chunk_bounds_tuple, chunk_shape,
)
except (ImportError, ModuleNotFoundError):
pass
if cuda_result is not None:
src_y, src_x = cuda_result # cupy arrays
src_left, src_bottom, src_right, src_top = source_bounds_tuple
src_h, src_w = source_shape
src_res_x = (src_right - src_left) / src_w
src_res_y = (src_top - src_bottom) / src_h
# Pixel coordinate math stays on GPU via cupy operators
if source_x_desc:
src_col_px = (src_right - src_x) / src_res_x - 0.5
else:
src_col_px = (src_x - src_left) / src_res_x - 0.5
if source_y_desc:
src_row_px = (src_top - src_y) / src_res_y - 0.5
else:
src_row_px = (src_y - src_bottom) / src_res_y - 0.5
# Need min/max on CPU for window selection.
# Stack the four reductions and pull them across in one device-to-host
# transfer to avoid four separate synchronous syncs.
mins_maxes = cp.stack([
cp.nanmin(src_row_px), cp.nanmax(src_row_px),
cp.nanmin(src_col_px), cp.nanmax(src_col_px),
])
r_min_val, r_max_val, c_min_val, c_max_val = (
float(v) for v in mins_maxes.get()
)
if not (np.isfinite(r_min_val) and np.isfinite(r_max_val)
and np.isfinite(c_min_val) and np.isfinite(c_max_val)):
return cp.full(_empty_shape, nodata, dtype=_empty_dtype)
r_min = int(np.floor(r_min_val)) - 2
r_max = int(np.ceil(r_max_val)) + 3
c_min = int(np.floor(c_min_val)) - 2
c_max = int(np.ceil(c_max_val)) + 3
# Coordinates stay as CuPy arrays for native CUDA resampling
else:
# CPU fallback (Numba JIT or pyproj)
from ._crs_utils import _require_pyproj
pyproj = _require_pyproj()
transformer = pyproj.Transformer.from_crs(
tgt_crs, src_crs, always_xy=True
)
src_y, src_x = _transform_coords(
transformer, chunk_bounds_tuple, chunk_shape, transform_precision,
src_crs=src_crs, tgt_crs=tgt_crs,
)
src_left, src_bottom, src_right, src_top = source_bounds_tuple
src_h, src_w = source_shape
src_res_x = (src_right - src_left) / src_w
src_res_y = (src_top - src_bottom) / src_h
if source_x_desc:
src_col_px = (src_right - src_x) / src_res_x - 0.5
else:
src_col_px = (src_x - src_left) / src_res_x - 0.5
if source_y_desc:
src_row_px = (src_top - src_y) / src_res_y - 0.5
else:
src_row_px = (src_y - src_bottom) / src_res_y - 0.5
r_min = np.nanmin(src_row_px)
r_max = np.nanmax(src_row_px)
c_min = np.nanmin(src_col_px)
c_max = np.nanmax(src_col_px)
if not np.isfinite(r_min) or not np.isfinite(r_max):
return cp.full(_empty_shape, nodata, dtype=_empty_dtype)
if not np.isfinite(c_min) or not np.isfinite(c_max):
return cp.full(_empty_shape, nodata, dtype=_empty_dtype)
r_min = int(np.floor(r_min)) - 2
r_max = int(np.ceil(r_max)) + 3
c_min = int(np.floor(c_min)) - 2
c_max = int(np.ceil(c_max)) + 3
if r_min >= src_h or r_max <= 0 or c_min >= src_w or c_max <= 0:
return cp.full(_empty_shape, nodata, dtype=_empty_dtype)
r_min_clip = max(0, r_min)
r_max_clip = min(src_h, r_max)
c_min_clip = max(0, c_min)
c_max_clip = min(src_w, c_max)
# Guard: cap source window to prevent GPU OOM if projection maps a
# small output chunk to a huge source area (matches numpy path).
_MAX_WINDOW_PIXELS = 64 * 1024 * 1024 # 64 Mpix (~512 MB for float64)
win_pixels = (r_max_clip - r_min_clip) * (c_max_clip - c_min_clip)
if win_pixels > _MAX_WINDOW_PIXELS:
return cp.full(_empty_shape, nodata, dtype=_empty_dtype)
window = source_data[r_min_clip:r_max_clip, c_min_clip:c_max_clip]
if hasattr(window, 'compute'):
window = window.compute()
if not isinstance(window, cp.ndarray):
window = cp.asarray(window)
orig_dtype = window.dtype
# Adjust coordinates relative to window (stays on GPU if CuPy)
local_row = src_row_px - r_min_clip
local_col = src_col_px - c_min_clip
# Multi-band: reproject each band separately, share coordinates.
# Matches the 3-D branch in _reproject_chunk_numpy so 3-D inputs work
# on cupy and dask+cupy backends instead of crashing with a CUDA
# signature mismatch (#2027).
if window.ndim == 3:
n_bands = window.shape[2]
# The coordinate arrays are shared by every band. On the CPU
# transform fallback they arrive as numpy; convert them to the
# device once here, otherwise _resample_cupy_native re-uploads
# the same two chunk-sized arrays on every band iteration (#3268).
if not isinstance(local_row, cp.ndarray):
local_row = cp.asarray(local_row)
if not isinstance(local_col, cp.ndarray):
local_col = cp.asarray(local_col)
bands = []
for b in range(n_bands):
band_data = window[:, :, b].astype(cp.float64)
# Mask this band with its own source sentinel when the raster
# declares per-band nodata; otherwise fall back to the single
# resolved sentinel (#2647). Pre-converting to NaN here lets
# each band use a different source sentinel; the native kernel
# still fills out-of-bounds pixels with the resolved `nodata`.
src_nd = band_nodata[b] if band_nodata is not None else nodata
if not np.isnan(src_nd):
band_data = cp.where(
band_data == src_nd, cp.nan, band_data,
)
# Always resample through the native CUDA kernels so the cupy
# backend matches numpy exactly. They accept CPU coordinate
# arrays (transferring them to the GPU) and do the
# nodata->NaN conversion internally, so they serve both the
# on-device coordinate path and the pyproj fallback. Using
# cupyx.scipy.ndimage.map_coordinates here instead would
# diverge from numpy: it bleeds the cval=0.0 constant into the
# half-pixel boundary band rather than renormalizing, and its
# order=3 path is a B-spline rather than Catmull-Rom (#2620).
band_result = _resample_cupy_native(
band_data, local_row, local_col,
resampling=resampling, nodata=nodata,
)
if np.issubdtype(orig_dtype, np.integer):
info = np.iinfo(orig_dtype)
band_result = cp.clip(
cp.round(band_result), info.min, info.max,
).astype(orig_dtype)
bands.append(band_result)
return cp.stack(bands, axis=-1)
window = window.astype(cp.float64)
# Always resample through the native CUDA kernels for numpy parity.
# local_row/local_col may be CuPy (on-device transform) or numpy
# (pyproj fallback); _resample_cupy_native handles both and does the
# nodata->NaN conversion internally. The previous
# cupyx.scipy.ndimage.map_coordinates fallback diverged from numpy at
# chunk edges and for cubic resampling (#2620).
result = _resample_cupy_native(window, local_row, local_col,
resampling=resampling, nodata=nodata)
# Clamp and cast back for integer source dtypes (parity with numpy path)
if np.issubdtype(orig_dtype, np.integer):
info = np.iinfo(orig_dtype)
result = cp.clip(cp.round(result), info.min, info.max).astype(orig_dtype)
return result
# ---------------------------------------------------------------------------
# reproject()
# ---------------------------------------------------------------------------
def reproject(
raster,
target_crs,
*,
source_crs=None,
resolution=None,
bounds=None,
width=None,
height=None,
resampling='bilinear',
nodata=None,
transform_precision=16,
chunk_size=None,
name=None,
max_memory=None,
source_vertical_crs=None,
target_vertical_crs=None,
bounds_policy="auto",
src_vertical_crs=_DEPRECATED,
tgt_vertical_crs=_DEPRECATED,
):
"""Reproject a raster DataArray to a new coordinate reference system.
Supports numpy, cupy, dask+numpy, and dask+cupy backends. For dask
inputs, the computation is fully lazy: each output chunk independently
reads only the source pixels it needs.
Numpy inputs whose input or output working set exceeds ~512 MB are
routed through the same lazy dask path when dask is installed, so the
result is dask-backed in that case. Without dask, a streaming
fallback bounds memory via ``max_memory``.
Parameters
----------
raster : xr.DataArray
Input raster with y/x coordinates.
target_crs
Target CRS in any format accepted by ``pyproj.CRS()``.
source_crs : optional
Source CRS. Auto-detected from *raster* if None.
resolution : float or (float, float) or None
Output pixel size in target CRS units.
bounds : (left, bottom, right, top) or None
Explicit output extent in target CRS.
width, height : int or None
Explicit output grid dimensions.
resampling : str
One of 'nearest', 'bilinear', 'cubic'.
nodata : float or None
Nodata value. Auto-detected if None. For integer input dtypes,
an explicit value that does not fit the dtype range raises
``ValueError`` (e.g. ``nodata=-9999`` with a ``uint8`` raster).
Attrs/rioxarray-derived out-of-range values emit a
``UserWarning`` and fall back to ``dtype.min`` for signed or
``dtype.max`` for unsigned so legacy files still load (#2572).
transform_precision : int
Control-grid subdivisions for the coordinate transform (default 16).
Higher values increase accuracy at the cost of more pyproj calls.
Set to 0 for exact per-pixel transforms matching GDAL/rasterio.
chunk_size : int or (int, int) or None
Output chunk size for dask. If None, defaults to 512 for the
standard dask path and 2048 for the in-memory streaming and
dask+cupy paths (chosen to amortize kernel launch overhead).
name : str or None
Name for the output DataArray.
max_memory : int or str or None
Maximum memory budget for the reprojection working set.
Accepts bytes (int) or human-readable strings like ``'4GB'``,
``'512MB'``. Controls how many output tiles are processed
in parallel for large-dataset streaming mode. Default None
uses 1GB. Has no effect for small datasets that fit in memory.
source_vertical_crs : str or None
Source vertical datum for height values. One of:
- ``'EGM96'`` -- orthometric heights relative to EGM96 geoid (MSL)
- ``'EGM2008'`` -- orthometric heights relative to EGM2008 geoid
- ``'ellipsoidal'`` -- heights relative to the WGS84 ellipsoid
- ``None`` -- no vertical transformation (default)
target_vertical_crs : str or None
Target vertical datum. Same options as *source_vertical_crs*.
Both must be set to trigger a vertical transformation.
src_vertical_crs : str or None
Deprecated alias for *source_vertical_crs*. Passing it emits a
``DeprecationWarning``.
tgt_vertical_crs : str or None
Deprecated alias for *target_vertical_crs*. Passing it emits a
``DeprecationWarning``.
bounds_policy : {"auto", "raw", "clamp", "percentile"}, default "auto"
How to derive the output extent from the source extent when
``bounds`` is not supplied. Only relevant when projecting near a
singularity (antimeridian, pole, projection edge):
- ``"raw"``: use the true projected extent of the source corners
and edges. No clamp, no percentile, no heuristic. The output
may be very large if the input straddles a projection
singularity. Use this when you want a true projection of the
source extent.
- ``"clamp"``: trim geographic source bounds inward by 0.01 deg
from +/-180 longitude and +/-90 latitude before projecting.
Avoids infinities at singularities. No percentile fallback.
No-op on projected source CRSes (UTM, Mercator, etc.) since
the clamp only applies in degrees.
- ``"percentile"``: project a dense interior grid of the source
extent and use the 2nd/98th percentiles of the result as the
output bounds. Rejects projection outliers at the cost of
trimming valid pixels.
- ``"auto"`` (default): apply ``"clamp"`` for geographic source
CRSes and fall back to ``"percentile"`` when the projected
extent is more than 50x the source extent. Matches the
historical behaviour.
When ``"auto"``, ``"clamp"``, or ``"percentile"`` actually alters
the bounds, a ``UserWarning`` is emitted naming the policy and
reporting the per-side delta versus the raw projected bounds.
Filter with ``warnings.filterwarnings`` if the crop is intentional.
Returns
-------
xr.DataArray
The output ``attrs['crs']`` is in WKT format.
Whenever *target_vertical_crs* is set, ``attrs['vertical_crs']``
records the target vertical datum's EPSG code (5773 for EGM96,
3855 for EGM2008, 4979 for ellipsoidal WGS84) to match the
convention used by ``xrspatial.geotiff``. The friendly string
token (``'EGM96'`` etc.) is preserved under ``attrs['vertical_datum']``.
Both attrs are written even when no shift is applied (e.g. when
*source_vertical_crs* equals *target_vertical_crs*, or when only the
target is given), so the output's vertical reference is always
explicit.
The output y coordinate is always emitted in descending order
(top-down, north-up) and the output x coordinate is always
emitted in ascending order (left-to-right) regardless of the
input directions. This matches the standard raster convention
and the output of common GIS libraries. Inputs with descending
x are detected from the x coordinate values and handled the
same way as descending y: the pixel-index mapping is mirrored
so the output values stay correct.
Non-spatial coords from the input (such as a scalar ``time``
coord or a non-dimension coord that is not aligned to the
spatial dims) are carried through to the output. Coords that
are aligned to the input y or x dims are dropped because their
values do not apply to the rebuilt grid.
Examples
--------
>>> import xarray as xr
>>> import numpy as np
>>> from xrspatial.reproject import reproject
>>> raster = xr.DataArray(
... np.random.rand(64, 64),
... dims=['y', 'x'],
... coords={'y': np.linspace(50, 40, 64),
... 'x': np.linspace(-5, 5, 64)},
... attrs={'crs': 'EPSG:4326'},
... )
>>> result = reproject(raster, 'EPSG:3857')
>>> result.attrs['crs'].startswith(('PROJCRS', 'PROJCS'))
True
"""
# Back-compat shim for the old abbreviated kwarg names. These were
# renamed to source_vertical_crs / target_vertical_crs to match the
# source_crs / target_crs spelling used by the rest of the signature.
source_vertical_crs = _resolve_deprecated_vertical_kwarg(
'src_vertical_crs', src_vertical_crs,
'source_vertical_crs', source_vertical_crs)
target_vertical_crs = _resolve_deprecated_vertical_kwarg(
'tgt_vertical_crs', tgt_vertical_crs,
'target_vertical_crs', target_vertical_crs)
_validate_raster(raster, func_name='reproject', name='raster',
ndim=(2, 3))
# Reject irregular / non-monotonic source coords before any CRS
# resolution or grid math. _source_bounds() infers pixel size from
# only the first two coord samples and downstream pixel math assumes
# uniform spacing, so an unchecked irregular input would silently
# produce wrong georeferencing (#2184).
_validate_source_coords(raster, 'reproject')
_validate_grid_params(
resolution=resolution,
bounds=bounds,
width=width,
height=height,
transform_precision=transform_precision,
func_name='reproject',
)
_validate_resampling(resampling)
from ._grid import _validate_bounds_policy
_validate_bounds_policy(bounds_policy, func_name='reproject')
# Reject unknown vertical-datum tokens at the API boundary so we never
# write None into attrs['vertical_crs'] for typos / unsupported values.
for _name, _val in (('source_vertical_crs', source_vertical_crs),
('target_vertical_crs', target_vertical_crs)):
if _val is not None and _val not in _VERTICAL_DATUM_EPSG:
raise ValueError(
f"Unknown {_name}={_val!r}; expected one of "
f"{sorted(_VERTICAL_DATUM_EPSG)} or None."
)
# Normalize 3-D inputs to canonical (y, x, band) layout.
# The per-chunk workers slice the source as ``source_data[r:, c:]`` and
# assume the band axis is trailing. A rasterio/rioxarray-style
# ``(band, y, x)`` input would otherwise slice the band/y axes instead
# of the y/x axes and either crash or return wrong-shape data (#2182).
# We record the input's original dim order so the output can be
# transposed back at the end, preserving downstream expectations.
_input_dims = tuple(raster.dims)
if raster.ndim == 3:
_ydim_in, _xdim_in = _find_spatial_dims(raster)
_band_dims_in = [d for d in _input_dims
if d not in (_ydim_in, _xdim_in)]
_band_dim_in = _band_dims_in[0] if _band_dims_in else None
if _band_dim_in is not None:
_canonical = (_ydim_in, _xdim_in, _band_dim_in)
if _input_dims != _canonical:
raster = raster.transpose(*_canonical)
# Resolve CRS
src_crs = _resolve_crs(source_crs)
if src_crs is None:
src_crs = _detect_source_crs(raster)
if src_crs is None:
raise ValueError(
"Could not detect source CRS. Pass source_crs explicitly."
)
tgt_crs = _resolve_crs(target_crs)
# Detect nodata. Pass the raster dtype so integer rasters get an
# integer-compatible sentinel (dtype min for signed, dtype max for
# unsigned) instead of NaN. Without this hint, the worker's
# cast-back step would collapse NaN to 0 and `attrs['nodata']`
# would contradict the array contents (#2185).
nd = _detect_nodata(raster, nodata, dtype=raster.dtype)
# Multi-band rasters can declare a distinct source sentinel per band
# via the rasterio `nodatavals` tuple. `nd` is the single resolved
# output sentinel; `band_nd` carries the raw per-band source sentinels
# so each band is masked with its own value before resampling (#2647).
# `None` means one scalar covers every band -- the workers use `nd`.
# The raster is in canonical (y, x, band) layout here, so the band
# axis is trailing.
_n_bands = raster.shape[2] if raster.ndim == 3 else None
band_nd = _detect_band_nodata(raster, nodata, _n_bands)
# Source geometry
src_bounds = _source_bounds(raster)
_ydim, _xdim = _find_spatial_dims(raster)
src_shape = (raster.sizes[_ydim], raster.sizes[_xdim])
y_desc = _is_y_descending(raster)
x_desc = _is_x_descending(raster)
# Detect backend before computing the output grid so the grid's
# output-size guard can tell a lazy dask output (never materialized
# in full) from a materializing backend.
from ..utils import has_dask_array, is_cupy_array
data = raster.data
is_dask = False
if has_dask_array():
import dask.array as _da
is_dask = isinstance(data, _da.Array)
is_cupy = False
if is_dask:
# Check underlying type
try:
from ..utils import is_cupy_backed
is_cupy = is_cupy_backed(raster)
except (ImportError, ModuleNotFoundError):
pass
else:
is_cupy = is_cupy_array(data)
# Compute the output grid with the size guard disabled: whether the
# guard applies depends on the final execution path (a lazy dask
# output never materializes the full grid), and the path decision
# below needs the output shape. The guard is re-applied after the
# path is known -- same pattern merge() uses.
grid = _compute_output_grid(
src_bounds, src_shape, src_crs, tgt_crs,
resolution=resolution, bounds=bounds,
width=width, height=height,
bounds_policy=bounds_policy,
lazy_output=True,
)
out_bounds = grid['bounds']