Skip to content

Commit 6d7fc18

Browse files
committed
Merge branch 'main' of github.com:geo-engine/geoengine-python into add-ml-model-shape
2 parents da6ee10 + ad36bed commit 6d7fc18

File tree

9 files changed

+61
-79
lines changed

9 files changed

+61
-79
lines changed

.github/workflows/ci.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ jobs:
1818
fail-fast: false
1919
matrix:
2020
# use all supported versions from https://devguide.python.org/versions/
21-
python-version: ["3.9", "3.10", "3.11", "3.12"]
21+
python-version: ["3.10", "3.11", "3.12", "3.13"]
2222

2323
with:
2424
python-version: ${{ matrix.python-version }}
@@ -31,6 +31,6 @@ jobs:
3131
uses: geo-engine/geoengine-python/.github/workflows/test-python.yml@main
3232

3333
with:
34-
python-version: 3.9
34+
python-version: "3.10"
3535
use-uv: true
3636
coverage: true

geoengine/colorizer.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import warnings
88
from typing import Dict, List, Tuple, Union, cast
99
import numpy as np
10-
import numpy.typing as npt
1110
from matplotlib.colors import Colormap
1211
from matplotlib.cm import ScalarMappable
1312
import geoengine_openapi_client
@@ -67,8 +66,9 @@ def linear_with_mpl_cmap(
6766
raise ValueError(f"underColor must be a RGBA color specification, got {under_color} instead.")
6867

6968
# get the map, and transform it to a list of (uint8) rgba values
70-
list_of_rgba_colors: List[npt.NDArray[np.uint8]] = ScalarMappable(cmap=color_map).to_rgba(
71-
np.linspace(min_max[0], min_max[1], n_steps), bytes=True)
69+
list_of_rgba_colors = ScalarMappable(cmap=color_map).to_rgba(
70+
np.linspace(min_max[0], min_max[1], n_steps), bytes=True
71+
)
7272

7373
# if you want to remap the colors, you can do it here (e.g. cutting of the most extreme colors)
7474
values_of_breakpoints: List[float] = np.linspace(min_max[0], min_max[1], n_steps).tolist()
@@ -120,8 +120,9 @@ def logarithmic_with_mpl_cmap(
120120
raise ValueError(f"underColor must be a RGBA color specification, got {under_color} instead.")
121121

122122
# get the map, and transform it to a list of (uint8) rgba values
123-
list_of_rgba_colors: List[npt.NDArray[np.uint8]] = ScalarMappable(cmap=color_map).to_rgba(
124-
np.linspace(min_max[0], min_max[1], n_steps), bytes=True)
123+
list_of_rgba_colors = ScalarMappable(cmap=color_map).to_rgba(
124+
np.linspace(min_max[0], min_max[1], n_steps), bytes=True
125+
)
125126

126127
# if you want to remap the colors, you can do it here (e.g. cutting of the most extreme colors)
127128
values_of_breakpoints: List[float] = np.logspace(np.log10(min_max[0]), np.log10(min_max[1]), n_steps).tolist()
@@ -192,8 +193,9 @@ def palette_with_colormap(
192193
f"Number of available colors: {n_colors_of_cmap}"))
193194

194195
# we only need to generate enough different colors for all values specified in the colors parameter
195-
list_of_rgba_colors: List[npt.NDArray[np.uint8]] = ScalarMappable(cmap=color_map).to_rgba(
196-
np.linspace(0, len(values), len(values)), bytes=True)
196+
list_of_rgba_colors = ScalarMappable(cmap=color_map).to_rgba(
197+
np.linspace(0, len(values), len(values)), bytes=True
198+
)
197199

198200
# generate the dict with value: color mapping
199201
color_mapping: Dict[float, Rgba] = dict(zip(

geoengine/raster.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,11 +86,11 @@ def has_null_values(self) -> bool:
8686

8787
@property
8888
def time_start_ms(self) -> np.datetime64:
89-
return np.datetime64(self.time.start, 'ms')
89+
return self.time.start.astype('datetime64[ms]')
9090

9191
@property
92-
def time_end_ms(self) -> np.datetime64:
93-
return np.datetime64(self.time.end, 'ms')
92+
def time_end_ms(self) -> Optional[np.datetime64]:
93+
return None if self.time.end is None else self.time.end.astype('datetime64[ms]')
9494

9595
@property
9696
def pixel_size(self) -> Tuple[float, float]:
@@ -290,7 +290,7 @@ def single_band(self, index: int) -> RasterTile2D:
290290
def to_numpy_masked_array_stack(self) -> np.ma.MaskedArray:
291291
'''Return the raster stack as a 3D masked numpy array'''
292292
arrays = [self.single_band(i).to_numpy_masked_array() for i in range(0, len(self.data))]
293-
stack = np.stack(arrays, axis=0)
293+
stack = np.ma.stack(arrays, axis=0)
294294
return stack
295295

296296
def to_xarray(self, clip_with_bounds: Optional[gety.SpatialBounds] = None) -> xr.DataArray:

geoengine/workflow.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
from owslib.wcs import WebCoverageService
2727
from vega import VegaLite
2828
import websockets
29-
import websockets.client
3029
import xarray as xr
3130
import pyarrow as pa
3231

@@ -585,7 +584,7 @@ async def raster_stream(
585584
if url is None:
586585
raise InputException('Invalid websocket url')
587586

588-
async with websockets.client.connect(
587+
async with websockets.asyncio.client.connect(
589588
uri=self.__replace_http_with_ws(url),
590589
extra_headers=session.auth_header,
591590
open_timeout=open_timeout,
@@ -594,7 +593,7 @@ async def raster_stream(
594593

595594
tile_bytes: Optional[bytes] = None
596595

597-
while websocket.open:
596+
while websocket.state == websockets.protocol.State.OPEN:
598597
async def read_new_bytes() -> Optional[bytes]:
599598
# already send the next request to speed up the process
600599
try:
@@ -792,7 +791,7 @@ def process_bytes(batch_bytes: Optional[bytes]) -> Optional[gpd.GeoDataFrame]:
792791
if url is None:
793792
raise InputException('Invalid websocket url')
794793

795-
async with websockets.client.connect(
794+
async with websockets.asyncio.client.connect(
796795
uri=self.__replace_http_with_ws(url),
797796
extra_headers=session.auth_header,
798797
open_timeout=open_timeout,
@@ -801,7 +800,7 @@ def process_bytes(batch_bytes: Optional[bytes]) -> Optional[gpd.GeoDataFrame]:
801800

802801
batch_bytes: Optional[bytes] = None
803802

804-
while websocket.open:
803+
while websocket.state == websockets.protocol.State.OPEN:
805804
async def read_new_bytes() -> Optional[bytes]:
806805
# already send the next request to speed up the process
807806
try:

mypy.ini

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,3 @@ ignore_missing_imports = True
4242

4343
[mypy-sklearn.*]
4444
ignore_missing_imports = True
45-
46-
# testcontainers is typed, but it doesn't correctly declare itself as such.
47-
# Hopefully it can be fixed one day:
48-
#
49-
# https://github.com/testcontainers/testcontainers-python/issues/305
50-
[mypy-testcontainers.*]
51-
ignore_missing_imports = True

setup.cfg

Lines changed: 25 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -16,27 +16,31 @@ classifiers =
1616
[options]
1717
package_dir =
1818
packages = find:
19-
python_requires = >=3.9
19+
python_requires = >=3.10
2020
install_requires =
2121
geoengine-openapi-client @ git+https://github.com/geo-engine/openapi-client@ml-model-input-outpt-shape-2#subdirectory=python
22-
geopandas >=0.9,<0.15
23-
matplotlib >=3.5,<3.8
24-
numpy >=1.21,<2.1
25-
owslib >=0.27,<0.32
22+
geopandas >=1.0,<2.0
23+
matplotlib >=3.5,<3.11
24+
numpy >=1.21,<2.3
25+
owslib >=0.27,<0.34
2626
pillow >=10.0,<12
27-
pyarrow >=17.0,<18
28-
python-dotenv >=0.19,<1.1
27+
pyarrow >=17.0,<21
28+
python-dotenv >=0.19,<1.2
2929
rasterio >=1.3,<2
3030
requests >= 2.26,<3
31-
rioxarray >=0.9.1, <0.19
31+
rioxarray >=0.9.1, <0.20
3232
StrEnum >=0.4.6,<0.5 # TODO: use from stdlib when `python_requires = >=3.11`
33-
vega >= 3.5,<4
34-
websockets >= 10.0,<11
35-
xarray >=0.19,<2024.12
36-
urllib3 >= 2.1, < 2.4
37-
pydantic >= 2.10.6, < 2.11
38-
skl2onnx >=1.17,<2
39-
onnx == 1.17
33+
vega >= 3.5,<4.2
34+
websockets >= 14.0,<16
35+
xarray >=0.19,<2025.5
36+
urllib3 >= 2.1, < 2.5
37+
pydantic >= 2.10.6, < 2.12
38+
skl2onnx >=1.17,<2 ; python_version<"3.13"
39+
skl2onnx @ git+https://github.com/onnx/sklearn-onnx@1035fdf ; python_version>="3.13" # TODO: remove when skl2onnx 1.19 is released
40+
onnx == 1.17 ; python_version<"3.13" # TODO: remove when skl2onnx 1.19 is released
41+
onnx == 1.18 ; python_version>="3.13" # TODO: remove when skl2onnx 1.19 is released
42+
43+
[[onnx]]
4044

4145
[options.extras_require]
4246
dev =
@@ -45,22 +49,21 @@ dev =
4549
pdoc3 >=0.10,<0.11
4650
pycodestyle >=2.8,<3 # formatter
4751
pylint >=3.3,<4 # code linter
48-
setuptools >=42,<76
49-
twine >=3.4,<5 # PyPI
52+
setuptools >=42,<81
53+
twine >=3.4,<6 # PyPI
5054
types-requests >=2.26,<3 # mypy type hints
51-
types-setuptools >= 71.1 # mypy type hints
55+
types-setuptools >=71.1,<81 # mypy type hints
5256
wheel >=0.37,<0.46
5357
test =
5458
psycopg >=3.2,<4
55-
pytest >=6.3,<8
59+
pytest >=6.3,<9
5660
pytest-cov >=6.0,<7
5761
requests_mock >=1.9,<2
58-
scikit-learn >=1.5,<1.6
59-
testcontainers[postgres] >=4.9,<5
62+
scikit-learn >=1.5,<1.7
6063
examples =
6164
cartopy >=0.22,<0.25 # for WMS example
6265
ipympl >=0.9.4,<0.10 # for ML example
63-
ipyvuetify >=1.10,<1.11 # for ML app
66+
ipyvuetify >=1.10,<1.12 # for ML app
6467
ipywidgets >=8.1.5,<9 # for ML example
6568
nbconvert >=7,<8 # convert notebooks to Python
6669
scipy >=1.7,<2 # for WMS example

tests/test_colorizer.py

Lines changed: 7 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -124,30 +124,13 @@ def test_colormap_not_available(self):
124124

125125
result = str(ctx.exception)
126126

127-
expected_end = "supported values are 'Accent', 'Accent_r', "\
128-
"'Blues', 'Blues_r', 'BrBG', 'BrBG_r', 'BuGn', 'BuGn_r', 'BuPu', 'BuPu_r', 'CMRmap', "\
129-
"'CMRmap_r', 'Dark2', 'Dark2_r', 'GnBu', 'GnBu_r', 'Greens', 'Greens_r', 'Greys', 'Greys_r', "\
130-
"'OrRd', 'OrRd_r', 'Oranges', 'Oranges_r', 'PRGn', 'PRGn_r', 'Paired', 'Paired_r', 'Pastel1', "\
131-
"'Pastel1_r', 'Pastel2', 'Pastel2_r', 'PiYG', 'PiYG_r', 'PuBu', 'PuBuGn', 'PuBuGn_r', "\
132-
"'PuBu_r', 'PuOr', 'PuOr_r', 'PuRd', 'PuRd_r', 'Purples', 'Purples_r', 'RdBu', 'RdBu_r', "\
133-
"'RdGy', 'RdGy_r', 'RdPu', 'RdPu_r', 'RdYlBu', 'RdYlBu_r', 'RdYlGn', 'RdYlGn_r', 'Reds', "\
134-
"'Reds_r', 'Set1', 'Set1_r', 'Set2', 'Set2_r', 'Set3', 'Set3_r', 'Spectral', 'Spectral_r', "\
135-
"'Wistia', 'Wistia_r', 'YlGn', 'YlGnBu', 'YlGnBu_r', 'YlGn_r', 'YlOrBr', 'YlOrBr_r', "\
136-
"'YlOrRd', 'YlOrRd_r', 'afmhot', 'afmhot_r', 'autumn', 'autumn_r', 'binary', 'binary_r', "\
137-
"'bone', 'bone_r', 'brg', 'brg_r', 'bwr', 'bwr_r', 'cividis', 'cividis_r', 'cool', 'cool_r', "\
138-
"'coolwarm', 'coolwarm_r', 'copper', 'copper_r', 'cubehelix', 'cubehelix_r', 'flag', "\
139-
"'flag_r', 'gist_earth', 'gist_earth_r', 'gist_gray', 'gist_gray_r', 'gist_heat', "\
140-
"'gist_heat_r', 'gist_ncar', 'gist_ncar_r', 'gist_rainbow', 'gist_rainbow_r', 'gist_stern', "\
141-
"'gist_stern_r', 'gist_yarg', 'gist_yarg_r', 'gnuplot', 'gnuplot2', 'gnuplot2_r', "\
142-
"'gnuplot_r', 'gray', 'gray_r', 'hot', 'hot_r', 'hsv', 'hsv_r', 'inferno', 'inferno_r', "\
143-
"'jet', 'jet_r', 'magma', " "'magma_r', 'nipy_spectral', 'nipy_spectral_r', 'ocean', "\
144-
"'ocean_r', 'pink', 'pink_r', 'plasma', 'plasma_r', 'prism', 'prism_r', 'rainbow', "\
145-
"'rainbow_r', 'seismic', 'seismic_r', 'spring', 'spring_r', 'summer', 'summer_r', 'tab10', "\
146-
"'tab10_r', 'tab20', 'tab20_r', 'tab20b', 'tab20b_r', 'tab20c', 'tab20c_r', 'terrain', "\
147-
"'terrain_r', 'turbo', " "'turbo_r', 'twilight', 'twilight_r', 'twilight_shifted', "\
148-
"'twilight_shifted_r', 'viridis', 'viridis_r', 'winter', 'winter_r'"
149-
150-
assert result.endswith(expected_end)
127+
expected_start = "'some_map' is not a valid value for"
128+
expected_contains = "; supported values are 'Accent',"
129+
130+
assert result.startswith(expected_start), \
131+
f"The error should start with `{expected_start}`, but starts with `{result[:len(expected_start)]}`"
132+
assert expected_contains in result, \
133+
f"The error should contain `{expected_contains}`, but does not. Full error: {result}"
151134

152135
def test_defaults(self):
153136
"""Tests the manipulation of the default values."""

tests/test_workflow_raster_stream.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import rioxarray
1111
import pyarrow as pa
1212
import xarray as xr
13+
import websockets.protocol
1314
from geoengine.types import RasterBandDescriptor
1415
import geoengine as ge
1516
from . import UrllibMocker
@@ -34,9 +35,9 @@ async def __aexit__(self, *args):
3435
pass
3536

3637
@property
37-
def open(self) -> bool:
38+
def state(self) -> websockets.protocol.State:
3839
'''Mock open impl'''
39-
return len(self.__tiles) > 0
40+
return websockets.protocol.State.OPEN if len(self.__tiles) > 0 else websockets.protocol.State.CLOSED
4041

4142
async def recv(self):
4243
return self.__tiles.pop()
@@ -130,7 +131,7 @@ def test_streaming_workflow(self):
130131
resolution=ge.SpatialResolution(45.0, 22.5),
131132
)
132133

133-
with unittest.mock.patch("websockets.client.connect", return_value=MockWebsocket()):
134+
with unittest.mock.patch("websockets.asyncio.client.connect", return_value=MockWebsocket()):
134135
async def inner1():
135136
tiles = []
136137

@@ -141,7 +142,7 @@ async def inner1():
141142

142143
asyncio.run(inner1())
143144

144-
with unittest.mock.patch("websockets.client.connect", return_value=MockWebsocket()):
145+
with unittest.mock.patch("websockets.asyncio.client.connect", return_value=MockWebsocket()):
145146
async def inner2():
146147
array = await workflow.raster_stream_into_xarray(query_rect)
147148
assert array.shape == (2, 1, 8, 8) # time, band, y, x

tests/test_workflow_vector_stream.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import geopandas as gpd
1111
import numpy as np
1212
import pandas as pd
13+
import websockets.protocol
1314
import geoengine as ge
1415
from . import UrllibMocker
1516

@@ -51,9 +52,9 @@ async def __aexit__(self, *args):
5152
pass
5253

5354
@property
54-
def open(self) -> bool:
55+
def state(self) -> websockets.protocol.State:
5556
'''Mock open impl'''
56-
return len(self.__chunks) > 0
57+
return websockets.protocol.State.OPEN if len(self.__chunks) > 0 else websockets.protocol.State.CLOSED
5758

5859
async def recv(self):
5960
return self.__chunks.pop(0)
@@ -152,7 +153,7 @@ def test_streaming_workflow(self):
152153
resolution=ge.SpatialResolution(0.5, 0.5),
153154
)
154155

155-
with unittest.mock.patch("websockets.client.connect", return_value=MockWebsocket()):
156+
with unittest.mock.patch("websockets.asyncio.client.connect", return_value=MockWebsocket()):
156157
async def inner1():
157158
chunks = []
158159

@@ -163,7 +164,7 @@ async def inner1():
163164

164165
asyncio.run(inner1())
165166

166-
with unittest.mock.patch("websockets.client.connect", return_value=MockWebsocket()):
167+
with unittest.mock.patch("websockets.asyncio.client.connect", return_value=MockWebsocket()):
167168
async def inner2():
168169
data_frame = await workflow.vector_stream_into_geopandas(query_rect)
169170

0 commit comments

Comments
 (0)