Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
fail-fast: false
matrix:
# use all supported versions from https://devguide.python.org/versions/
python-version: ["3.9", "3.10", "3.11", "3.12"]
python-version: ["3.10", "3.11", "3.12", "3.13"]

with:
python-version: ${{ matrix.python-version }}
Expand All @@ -31,6 +31,6 @@ jobs:
uses: geo-engine/geoengine-python/.github/workflows/test-python.yml@main

with:
python-version: 3.9
python-version: "3.10"
use-uv: true
coverage: true
16 changes: 9 additions & 7 deletions geoengine/colorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import warnings
from typing import Dict, List, Tuple, Union, cast
import numpy as np
import numpy.typing as npt
from matplotlib.colors import Colormap
from matplotlib.cm import ScalarMappable
import geoengine_openapi_client
Expand Down Expand Up @@ -67,8 +66,9 @@ def linear_with_mpl_cmap(
raise ValueError(f"underColor must be a RGBA color specification, got {under_color} instead.")

# get the map, and transform it to a list of (uint8) rgba values
list_of_rgba_colors: List[npt.NDArray[np.uint8]] = ScalarMappable(cmap=color_map).to_rgba(
np.linspace(min_max[0], min_max[1], n_steps), bytes=True)
list_of_rgba_colors = ScalarMappable(cmap=color_map).to_rgba(
np.linspace(min_max[0], min_max[1], n_steps), bytes=True
)

# if you want to remap the colors, you can do it here (e.g. cutting of the most extreme colors)
values_of_breakpoints: List[float] = np.linspace(min_max[0], min_max[1], n_steps).tolist()
Expand Down Expand Up @@ -120,8 +120,9 @@ def logarithmic_with_mpl_cmap(
raise ValueError(f"underColor must be a RGBA color specification, got {under_color} instead.")

# get the map, and transform it to a list of (uint8) rgba values
list_of_rgba_colors: List[npt.NDArray[np.uint8]] = ScalarMappable(cmap=color_map).to_rgba(
np.linspace(min_max[0], min_max[1], n_steps), bytes=True)
list_of_rgba_colors = ScalarMappable(cmap=color_map).to_rgba(
np.linspace(min_max[0], min_max[1], n_steps), bytes=True
)

# if you want to remap the colors, you can do it here (e.g. cutting of the most extreme colors)
values_of_breakpoints: List[float] = np.logspace(np.log10(min_max[0]), np.log10(min_max[1]), n_steps).tolist()
Expand Down Expand Up @@ -192,8 +193,9 @@ def palette_with_colormap(
f"Number of available colors: {n_colors_of_cmap}"))

# we only need to generate enough different colors for all values specified in the colors parameter
list_of_rgba_colors: List[npt.NDArray[np.uint8]] = ScalarMappable(cmap=color_map).to_rgba(
np.linspace(0, len(values), len(values)), bytes=True)
list_of_rgba_colors = ScalarMappable(cmap=color_map).to_rgba(
np.linspace(0, len(values), len(values)), bytes=True
)

# generate the dict with value: color mapping
color_mapping: Dict[float, Rgba] = dict(zip(
Expand Down
8 changes: 4 additions & 4 deletions geoengine/raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,11 @@ def has_null_values(self) -> bool:

@property
def time_start_ms(self) -> np.datetime64:
return np.datetime64(self.time.start, 'ms')
return self.time.start.astype('datetime64[ms]')

@property
def time_end_ms(self) -> np.datetime64:
return np.datetime64(self.time.end, 'ms')
def time_end_ms(self) -> Optional[np.datetime64]:
return None if self.time.end is None else self.time.end.astype('datetime64[ms]')

@property
def pixel_size(self) -> Tuple[float, float]:
Expand Down Expand Up @@ -290,7 +290,7 @@ def single_band(self, index: int) -> RasterTile2D:
def to_numpy_masked_array_stack(self) -> np.ma.MaskedArray:
'''Return the raster stack as a 3D masked numpy array'''
arrays = [self.single_band(i).to_numpy_masked_array() for i in range(0, len(self.data))]
stack = np.stack(arrays, axis=0)
stack = np.ma.stack(arrays, axis=0)
return stack

def to_xarray(self, clip_with_bounds: Optional[gety.SpatialBounds] = None) -> xr.DataArray:
Expand Down
9 changes: 4 additions & 5 deletions geoengine/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from owslib.wcs import WebCoverageService
from vega import VegaLite
import websockets
import websockets.client
import xarray as xr
import pyarrow as pa

Expand Down Expand Up @@ -585,7 +584,7 @@ async def raster_stream(
if url is None:
raise InputException('Invalid websocket url')

async with websockets.client.connect(
async with websockets.asyncio.client.connect(
uri=self.__replace_http_with_ws(url),
extra_headers=session.auth_header,
open_timeout=open_timeout,
Expand All @@ -594,7 +593,7 @@ async def raster_stream(

tile_bytes: Optional[bytes] = None

while websocket.open:
while websocket.state == websockets.protocol.State.OPEN:
async def read_new_bytes() -> Optional[bytes]:
# already send the next request to speed up the process
try:
Expand Down Expand Up @@ -792,7 +791,7 @@ def process_bytes(batch_bytes: Optional[bytes]) -> Optional[gpd.GeoDataFrame]:
if url is None:
raise InputException('Invalid websocket url')

async with websockets.client.connect(
async with websockets.asyncio.client.connect(
uri=self.__replace_http_with_ws(url),
extra_headers=session.auth_header,
open_timeout=open_timeout,
Expand All @@ -801,7 +800,7 @@ def process_bytes(batch_bytes: Optional[bytes]) -> Optional[gpd.GeoDataFrame]:

batch_bytes: Optional[bytes] = None

while websocket.open:
while websocket.state == websockets.protocol.State.OPEN:
async def read_new_bytes() -> Optional[bytes]:
# already send the next request to speed up the process
try:
Expand Down
7 changes: 0 additions & 7 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,3 @@ ignore_missing_imports = True

[mypy-sklearn.*]
ignore_missing_imports = True

# testcontainers is typed, but it doesn't correctly declare itself as such.
# Hopefully it can be fixed one day:
#
# https://github.com/testcontainers/testcontainers-python/issues/305
[mypy-testcontainers.*]
ignore_missing_imports = True
48 changes: 26 additions & 22 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -16,26 +16,31 @@ classifiers =
[options]
package_dir =
packages = find:
python_requires = >=3.9
python_requires = >=3.10
install_requires =
geoengine-openapi-client == 0.0.22
geopandas >=0.9,<0.15
matplotlib >=3.5,<3.8
numpy >=1.21,<2.1
owslib >=0.27,<0.32
geoengine-openapi-client == 0.0.23
geopandas >=1.0,<2.0
matplotlib >=3.5,<3.11
numpy >=1.21,<2.3
owslib >=0.27,<0.34
pillow >=10.0,<12
pyarrow >=17.0,<18
python-dotenv >=0.19,<1.1
pyarrow >=17.0,<21
python-dotenv >=0.19,<1.2
rasterio >=1.3,<2
requests >= 2.26,<3
rioxarray >=0.9.1, <0.19
rioxarray >=0.9.1, <0.20
StrEnum >=0.4.6,<0.5 # TODO: use from stdlib when `python_requires = >=3.11`
vega >= 3.5,<4
websockets >= 10.0,<11
xarray >=0.19,<2024.12
urllib3 >= 2.1, < 2.4
pydantic >= 2.10.6, < 2.11
skl2onnx >=1.17,<2
vega >= 3.5,<4.2
websockets >= 14.0,<16
xarray >=0.19,<2025.5
urllib3 >= 2.1, < 2.5
pydantic >= 2.10.6, < 2.12
skl2onnx >=1.17,<2 ; python_version<"3.13"
skl2onnx @ git+https://github.com/onnx/sklearn-onnx@1035fdf ; python_version>="3.13" # TODO: remove when skl2onnx 1.19 is released
onnx == 1.17 ; python_version<"3.13" # TODO: remove when skl2onnx 1.19 is released
onnx == 1.18 ; python_version>="3.13" # TODO: remove when skl2onnx 1.19 is released

[[onnx]]

[options.extras_require]
dev =
Expand All @@ -44,22 +49,21 @@ dev =
pdoc3 >=0.10,<0.11
pycodestyle >=2.8,<3 # formatter
pylint >=3.3,<4 # code linter
setuptools >=42,<76
twine >=3.4,<5 # PyPI
setuptools >=42,<81
twine >=3.4,<6 # PyPI
types-requests >=2.26,<3 # mypy type hints
types-setuptools >= 71.1 # mypy type hints
types-setuptools >=71.1,<81 # mypy type hints
wheel >=0.37,<0.46
test =
psycopg >=3.2,<4
pytest >=6.3,<8
pytest >=6.3,<9
pytest-cov >=6.0,<7
requests_mock >=1.9,<2
scikit-learn >=1.5,<1.6
testcontainers[postgres] >=4.9,<5
scikit-learn >=1.5,<1.7
examples =
cartopy >=0.22,<0.25 # for WMS example
ipympl >=0.9.4,<0.10 # for ML example
ipyvuetify >=1.10,<1.11 # for ML app
ipyvuetify >=1.10,<1.12 # for ML app
ipywidgets >=8.1.5,<9 # for ML example
nbconvert >=7,<8 # convert notebooks to Python
scipy >=1.7,<2 # for WMS example
Expand Down
31 changes: 7 additions & 24 deletions tests/test_colorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,30 +124,13 @@ def test_colormap_not_available(self):

result = str(ctx.exception)

expected_end = "supported values are 'Accent', 'Accent_r', "\
"'Blues', 'Blues_r', 'BrBG', 'BrBG_r', 'BuGn', 'BuGn_r', 'BuPu', 'BuPu_r', 'CMRmap', "\
"'CMRmap_r', 'Dark2', 'Dark2_r', 'GnBu', 'GnBu_r', 'Greens', 'Greens_r', 'Greys', 'Greys_r', "\
"'OrRd', 'OrRd_r', 'Oranges', 'Oranges_r', 'PRGn', 'PRGn_r', 'Paired', 'Paired_r', 'Pastel1', "\
"'Pastel1_r', 'Pastel2', 'Pastel2_r', 'PiYG', 'PiYG_r', 'PuBu', 'PuBuGn', 'PuBuGn_r', "\
"'PuBu_r', 'PuOr', 'PuOr_r', 'PuRd', 'PuRd_r', 'Purples', 'Purples_r', 'RdBu', 'RdBu_r', "\
"'RdGy', 'RdGy_r', 'RdPu', 'RdPu_r', 'RdYlBu', 'RdYlBu_r', 'RdYlGn', 'RdYlGn_r', 'Reds', "\
"'Reds_r', 'Set1', 'Set1_r', 'Set2', 'Set2_r', 'Set3', 'Set3_r', 'Spectral', 'Spectral_r', "\
"'Wistia', 'Wistia_r', 'YlGn', 'YlGnBu', 'YlGnBu_r', 'YlGn_r', 'YlOrBr', 'YlOrBr_r', "\
"'YlOrRd', 'YlOrRd_r', 'afmhot', 'afmhot_r', 'autumn', 'autumn_r', 'binary', 'binary_r', "\
"'bone', 'bone_r', 'brg', 'brg_r', 'bwr', 'bwr_r', 'cividis', 'cividis_r', 'cool', 'cool_r', "\
"'coolwarm', 'coolwarm_r', 'copper', 'copper_r', 'cubehelix', 'cubehelix_r', 'flag', "\
"'flag_r', 'gist_earth', 'gist_earth_r', 'gist_gray', 'gist_gray_r', 'gist_heat', "\
"'gist_heat_r', 'gist_ncar', 'gist_ncar_r', 'gist_rainbow', 'gist_rainbow_r', 'gist_stern', "\
"'gist_stern_r', 'gist_yarg', 'gist_yarg_r', 'gnuplot', 'gnuplot2', 'gnuplot2_r', "\
"'gnuplot_r', 'gray', 'gray_r', 'hot', 'hot_r', 'hsv', 'hsv_r', 'inferno', 'inferno_r', "\
"'jet', 'jet_r', 'magma', " "'magma_r', 'nipy_spectral', 'nipy_spectral_r', 'ocean', "\
"'ocean_r', 'pink', 'pink_r', 'plasma', 'plasma_r', 'prism', 'prism_r', 'rainbow', "\
"'rainbow_r', 'seismic', 'seismic_r', 'spring', 'spring_r', 'summer', 'summer_r', 'tab10', "\
"'tab10_r', 'tab20', 'tab20_r', 'tab20b', 'tab20b_r', 'tab20c', 'tab20c_r', 'terrain', "\
"'terrain_r', 'turbo', " "'turbo_r', 'twilight', 'twilight_r', 'twilight_shifted', "\
"'twilight_shifted_r', 'viridis', 'viridis_r', 'winter', 'winter_r'"

assert result.endswith(expected_end)
expected_start = "'some_map' is not a valid value for"
expected_contains = "; supported values are 'Accent',"

assert result.startswith(expected_start), \
f"The error should start with `{expected_start}`, but starts with `{result[:len(expected_start)]}`"
assert expected_contains in result, \
f"The error should contain `{expected_contains}`, but does not. Full error: {result}"

def test_defaults(self):
"""Tests the manipulation of the default values."""
Expand Down
9 changes: 5 additions & 4 deletions tests/test_workflow_raster_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import rioxarray
import pyarrow as pa
import xarray as xr
import websockets.protocol
from geoengine.types import RasterBandDescriptor
import geoengine as ge
from . import UrllibMocker
Expand All @@ -34,9 +35,9 @@ async def __aexit__(self, *args):
pass

@property
def open(self) -> bool:
def state(self) -> websockets.protocol.State:
'''Mock open impl'''
return len(self.__tiles) > 0
return websockets.protocol.State.OPEN if len(self.__tiles) > 0 else websockets.protocol.State.CLOSED

async def recv(self):
return self.__tiles.pop()
Expand Down Expand Up @@ -130,7 +131,7 @@ def test_streaming_workflow(self):
resolution=ge.SpatialResolution(45.0, 22.5),
)

with unittest.mock.patch("websockets.client.connect", return_value=MockWebsocket()):
with unittest.mock.patch("websockets.asyncio.client.connect", return_value=MockWebsocket()):
async def inner1():
tiles = []

Expand All @@ -141,7 +142,7 @@ async def inner1():

asyncio.run(inner1())

with unittest.mock.patch("websockets.client.connect", return_value=MockWebsocket()):
with unittest.mock.patch("websockets.asyncio.client.connect", return_value=MockWebsocket()):
async def inner2():
array = await workflow.raster_stream_into_xarray(query_rect)
assert array.shape == (2, 1, 8, 8) # time, band, y, x
Expand Down
9 changes: 5 additions & 4 deletions tests/test_workflow_vector_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import geopandas as gpd
import numpy as np
import pandas as pd
import websockets.protocol
import geoengine as ge
from . import UrllibMocker

Expand Down Expand Up @@ -51,9 +52,9 @@ async def __aexit__(self, *args):
pass

@property
def open(self) -> bool:
def state(self) -> websockets.protocol.State:
'''Mock open impl'''
return len(self.__chunks) > 0
return websockets.protocol.State.OPEN if len(self.__chunks) > 0 else websockets.protocol.State.CLOSED

async def recv(self):
return self.__chunks.pop(0)
Expand Down Expand Up @@ -152,7 +153,7 @@ def test_streaming_workflow(self):
resolution=ge.SpatialResolution(0.5, 0.5),
)

with unittest.mock.patch("websockets.client.connect", return_value=MockWebsocket()):
with unittest.mock.patch("websockets.asyncio.client.connect", return_value=MockWebsocket()):
async def inner1():
chunks = []

Expand All @@ -163,7 +164,7 @@ async def inner1():

asyncio.run(inner1())

with unittest.mock.patch("websockets.client.connect", return_value=MockWebsocket()):
with unittest.mock.patch("websockets.asyncio.client.connect", return_value=MockWebsocket()):
async def inner2():
data_frame = await workflow.vector_stream_into_geopandas(query_rect)

Expand Down