Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
fail-fast: false
matrix:
# use all supported versions from https://devguide.python.org/versions/
python-version: ["3.9", "3.10", "3.11", "3.12"]
python-version: ["3.10", "3.11", "3.12", "3.13"]

with:
python-version: ${{ matrix.python-version }}
Expand All @@ -31,6 +31,6 @@ jobs:
uses: geo-engine/geoengine-python/.github/workflows/test-python.yml@main

with:
python-version: 3.9
python-version: "3.10"
use-uv: true
coverage: true
16 changes: 9 additions & 7 deletions geoengine/colorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import warnings
from typing import Dict, List, Tuple, Union, cast
import numpy as np
import numpy.typing as npt
from matplotlib.colors import Colormap
from matplotlib.cm import ScalarMappable
import geoengine_openapi_client
Expand Down Expand Up @@ -67,8 +66,9 @@ def linear_with_mpl_cmap(
raise ValueError(f"underColor must be a RGBA color specification, got {under_color} instead.")

# get the map, and transform it to a list of (uint8) rgba values
list_of_rgba_colors: List[npt.NDArray[np.uint8]] = ScalarMappable(cmap=color_map).to_rgba(
np.linspace(min_max[0], min_max[1], n_steps), bytes=True)
list_of_rgba_colors = ScalarMappable(cmap=color_map).to_rgba(
np.linspace(min_max[0], min_max[1], n_steps), bytes=True
)

# if you want to remap the colors, you can do it here (e.g. cutting of the most extreme colors)
values_of_breakpoints: List[float] = np.linspace(min_max[0], min_max[1], n_steps).tolist()
Expand Down Expand Up @@ -120,8 +120,9 @@ def logarithmic_with_mpl_cmap(
raise ValueError(f"underColor must be a RGBA color specification, got {under_color} instead.")

# get the map, and transform it to a list of (uint8) rgba values
list_of_rgba_colors: List[npt.NDArray[np.uint8]] = ScalarMappable(cmap=color_map).to_rgba(
np.linspace(min_max[0], min_max[1], n_steps), bytes=True)
list_of_rgba_colors = ScalarMappable(cmap=color_map).to_rgba(
np.linspace(min_max[0], min_max[1], n_steps), bytes=True
)

# if you want to remap the colors, you can do it here (e.g. cutting of the most extreme colors)
values_of_breakpoints: List[float] = np.logspace(np.log10(min_max[0]), np.log10(min_max[1]), n_steps).tolist()
Expand Down Expand Up @@ -192,8 +193,9 @@ def palette_with_colormap(
f"Number of available colors: {n_colors_of_cmap}"))

# we only need to generate enough different colors for all values specified in the colors parameter
list_of_rgba_colors: List[npt.NDArray[np.uint8]] = ScalarMappable(cmap=color_map).to_rgba(
np.linspace(0, len(values), len(values)), bytes=True)
list_of_rgba_colors = ScalarMappable(cmap=color_map).to_rgba(
np.linspace(0, len(values), len(values)), bytes=True
)

# generate the dict with value: color mapping
color_mapping: Dict[float, Rgba] = dict(zip(
Expand Down
8 changes: 4 additions & 4 deletions geoengine/raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,11 @@ def has_null_values(self) -> bool:

@property
def time_start_ms(self) -> np.datetime64:
return np.datetime64(self.time.start, 'ms')
return self.time.start.astype('datetime64[ms]')

@property
def time_end_ms(self) -> np.datetime64:
return np.datetime64(self.time.end, 'ms')
def time_end_ms(self) -> Optional[np.datetime64]:
return None if self.time.end is None else self.time.end.astype('datetime64[ms]')

@property
def pixel_size(self) -> Tuple[float, float]:
Expand Down Expand Up @@ -290,7 +290,7 @@ def single_band(self, index: int) -> RasterTile2D:
def to_numpy_masked_array_stack(self) -> np.ma.MaskedArray:
'''Return the raster stack as a 3D masked numpy array'''
arrays = [self.single_band(i).to_numpy_masked_array() for i in range(0, len(self.data))]
stack = np.stack(arrays, axis=0)
stack = np.ma.stack(arrays, axis=0)
return stack

def to_xarray(self, clip_with_bounds: Optional[gety.SpatialBounds] = None) -> xr.DataArray:
Expand Down
9 changes: 4 additions & 5 deletions geoengine/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from owslib.wcs import WebCoverageService
from vega import VegaLite
import websockets
import websockets.client
import xarray as xr
import pyarrow as pa

Expand Down Expand Up @@ -585,7 +584,7 @@ async def raster_stream(
if url is None:
raise InputException('Invalid websocket url')

async with websockets.client.connect(
async with websockets.asyncio.client.connect(
uri=self.__replace_http_with_ws(url),
extra_headers=session.auth_header,
open_timeout=open_timeout,
Expand All @@ -594,7 +593,7 @@ async def raster_stream(

tile_bytes: Optional[bytes] = None

while websocket.open:
while websocket.state == websockets.protocol.State.OPEN:
async def read_new_bytes() -> Optional[bytes]:
# already send the next request to speed up the process
try:
Expand Down Expand Up @@ -792,7 +791,7 @@ def process_bytes(batch_bytes: Optional[bytes]) -> Optional[gpd.GeoDataFrame]:
if url is None:
raise InputException('Invalid websocket url')

async with websockets.client.connect(
async with websockets.asyncio.client.connect(
uri=self.__replace_http_with_ws(url),
extra_headers=session.auth_header,
open_timeout=open_timeout,
Expand All @@ -801,7 +800,7 @@ def process_bytes(batch_bytes: Optional[bytes]) -> Optional[gpd.GeoDataFrame]:

batch_bytes: Optional[bytes] = None

while websocket.open:
while websocket.state == websockets.protocol.State.OPEN:
async def read_new_bytes() -> Optional[bytes]:
# already send the next request to speed up the process
try:
Expand Down
7 changes: 0 additions & 7 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,3 @@ ignore_missing_imports = True

[mypy-sklearn.*]
ignore_missing_imports = True

# testcontainers is typed, but it doesn't correctly declare itself as such.
# Hopefully it can be fixed one day:
#
# https://github.com/testcontainers/testcontainers-python/issues/305
[mypy-testcontainers.*]
ignore_missing_imports = True
47 changes: 25 additions & 22 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -16,26 +16,30 @@ classifiers =
[options]
package_dir =
packages = find:
python_requires = >=3.9
python_requires = >=3.10
install_requires =
geoengine-openapi-client == 0.0.22
geopandas >=0.9,<0.15
matplotlib >=3.5,<3.8
numpy >=1.21,<2.1
owslib >=0.27,<0.32
geoengine-openapi-client == 0.0.23
geopandas >=1.0,<2.0
matplotlib >=3.5,<3.11
numpy >=1.21,<2.3
owslib >=0.27,<0.34
pillow >=10.0,<12
pyarrow >=17.0,<18
python-dotenv >=0.19,<1.1
pyarrow >=17.0,<21
python-dotenv >=0.19,<1.2
rasterio >=1.3,<2
requests >= 2.26,<3
rioxarray >=0.9.1, <0.19
rioxarray >=0.9.1, <0.20
StrEnum >=0.4.6,<0.5 # TODO: use from stdlib when `python_requires = >=3.11`
vega >= 3.5,<4
websockets >= 10.0,<11
xarray >=0.19,<2024.12
urllib3 >= 2.1, < 2.4
pydantic >= 2.10.6, < 2.11
skl2onnx >=1.17,<2
vega >= 3.5,<4.2
websockets >= 14.0,<16
xarray >=0.19,<2025.5
urllib3 >= 2.1, < 2.5
pydantic >= 2.10.6, < 2.12
skl2onnx >=1.17,<2 ; python_version<"3.13"
skl2onnx @ git+https://github.com/onnx/sklearn-onnx ; python_version>="3.13" # TODO: remove when skl2onnx 1.19 is released
onnx @ https://test-files.pythonhosted.org/packages/a6/17/d5fda1165f0eac5055a9b63a00b13734a65f08ff6c9b0a54467c39b2dfea/onnx-1.18.0rc2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl ; python_version>="3.13" # TODO: remove when skl2onnx 1.19 is released

[[onnx]]

[options.extras_require]
dev =
Expand All @@ -44,22 +48,21 @@ dev =
pdoc3 >=0.10,<0.11
pycodestyle >=2.8,<3 # formatter
pylint >=3.3,<4 # code linter
setuptools >=42,<76
twine >=3.4,<5 # PyPI
setuptools >=42,<81
twine >=3.4,<6 # PyPI
types-requests >=2.26,<3 # mypy type hints
types-setuptools >= 71.1 # mypy type hints
types-setuptools >=71.1,<81 # mypy type hints
wheel >=0.37,<0.46
test =
psycopg >=3.2,<4
pytest >=6.3,<8
pytest >=6.3,<9
pytest-cov >=6.0,<7
requests_mock >=1.9,<2
scikit-learn >=1.5,<1.6
testcontainers[postgres] >=4.9,<5
scikit-learn >=1.5,<1.7
examples =
cartopy >=0.22,<0.25 # for WMS example
ipympl >=0.9.4,<0.10 # for ML example
ipyvuetify >=1.10,<1.11 # for ML app
ipyvuetify >=1.10,<1.12 # for ML app
ipywidgets >=8.1.5,<9 # for ML example
nbconvert >=7,<8 # convert notebooks to Python
scipy >=1.7,<2 # for WMS example
Expand Down
31 changes: 7 additions & 24 deletions tests/test_colorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,30 +124,13 @@ def test_colormap_not_available(self):

result = str(ctx.exception)

expected_end = "supported values are 'Accent', 'Accent_r', "\
"'Blues', 'Blues_r', 'BrBG', 'BrBG_r', 'BuGn', 'BuGn_r', 'BuPu', 'BuPu_r', 'CMRmap', "\
"'CMRmap_r', 'Dark2', 'Dark2_r', 'GnBu', 'GnBu_r', 'Greens', 'Greens_r', 'Greys', 'Greys_r', "\
"'OrRd', 'OrRd_r', 'Oranges', 'Oranges_r', 'PRGn', 'PRGn_r', 'Paired', 'Paired_r', 'Pastel1', "\
"'Pastel1_r', 'Pastel2', 'Pastel2_r', 'PiYG', 'PiYG_r', 'PuBu', 'PuBuGn', 'PuBuGn_r', "\
"'PuBu_r', 'PuOr', 'PuOr_r', 'PuRd', 'PuRd_r', 'Purples', 'Purples_r', 'RdBu', 'RdBu_r', "\
"'RdGy', 'RdGy_r', 'RdPu', 'RdPu_r', 'RdYlBu', 'RdYlBu_r', 'RdYlGn', 'RdYlGn_r', 'Reds', "\
"'Reds_r', 'Set1', 'Set1_r', 'Set2', 'Set2_r', 'Set3', 'Set3_r', 'Spectral', 'Spectral_r', "\
"'Wistia', 'Wistia_r', 'YlGn', 'YlGnBu', 'YlGnBu_r', 'YlGn_r', 'YlOrBr', 'YlOrBr_r', "\
"'YlOrRd', 'YlOrRd_r', 'afmhot', 'afmhot_r', 'autumn', 'autumn_r', 'binary', 'binary_r', "\
"'bone', 'bone_r', 'brg', 'brg_r', 'bwr', 'bwr_r', 'cividis', 'cividis_r', 'cool', 'cool_r', "\
"'coolwarm', 'coolwarm_r', 'copper', 'copper_r', 'cubehelix', 'cubehelix_r', 'flag', "\
"'flag_r', 'gist_earth', 'gist_earth_r', 'gist_gray', 'gist_gray_r', 'gist_heat', "\
"'gist_heat_r', 'gist_ncar', 'gist_ncar_r', 'gist_rainbow', 'gist_rainbow_r', 'gist_stern', "\
"'gist_stern_r', 'gist_yarg', 'gist_yarg_r', 'gnuplot', 'gnuplot2', 'gnuplot2_r', "\
"'gnuplot_r', 'gray', 'gray_r', 'hot', 'hot_r', 'hsv', 'hsv_r', 'inferno', 'inferno_r', "\
"'jet', 'jet_r', 'magma', " "'magma_r', 'nipy_spectral', 'nipy_spectral_r', 'ocean', "\
"'ocean_r', 'pink', 'pink_r', 'plasma', 'plasma_r', 'prism', 'prism_r', 'rainbow', "\
"'rainbow_r', 'seismic', 'seismic_r', 'spring', 'spring_r', 'summer', 'summer_r', 'tab10', "\
"'tab10_r', 'tab20', 'tab20_r', 'tab20b', 'tab20b_r', 'tab20c', 'tab20c_r', 'terrain', "\
"'terrain_r', 'turbo', " "'turbo_r', 'twilight', 'twilight_r', 'twilight_shifted', "\
"'twilight_shifted_r', 'viridis', 'viridis_r', 'winter', 'winter_r'"

assert result.endswith(expected_end)
expected_start = "'some_map' is not a valid value for"
expected_contains = "; supported values are 'Accent',"

assert result.startswith(expected_start), \
f"The error should start with `{expected_start}`, but starts with `{result[:len(expected_start)]}`"
assert expected_contains in result, \
f"The error should contain `{expected_contains}`, but does not. Full error: {result}"

def test_defaults(self):
"""Tests the manipulation of the default values."""
Expand Down
9 changes: 5 additions & 4 deletions tests/test_workflow_raster_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import rioxarray
import pyarrow as pa
import xarray as xr
import websockets.protocol
from geoengine.types import RasterBandDescriptor
import geoengine as ge
from . import UrllibMocker
Expand All @@ -34,9 +35,9 @@ async def __aexit__(self, *args):
pass

@property
def open(self) -> bool:
def state(self) -> websockets.protocol.State:
'''Mock open impl'''
return len(self.__tiles) > 0
return websockets.protocol.State.OPEN if len(self.__tiles) > 0 else websockets.protocol.State.CLOSED

async def recv(self):
return self.__tiles.pop()
Expand Down Expand Up @@ -130,7 +131,7 @@ def test_streaming_workflow(self):
resolution=ge.SpatialResolution(45.0, 22.5),
)

with unittest.mock.patch("websockets.client.connect", return_value=MockWebsocket()):
with unittest.mock.patch("websockets.asyncio.client.connect", return_value=MockWebsocket()):
async def inner1():
tiles = []

Expand All @@ -141,7 +142,7 @@ async def inner1():

asyncio.run(inner1())

with unittest.mock.patch("websockets.client.connect", return_value=MockWebsocket()):
with unittest.mock.patch("websockets.asyncio.client.connect", return_value=MockWebsocket()):
async def inner2():
array = await workflow.raster_stream_into_xarray(query_rect)
assert array.shape == (2, 1, 8, 8) # time, band, y, x
Expand Down
9 changes: 5 additions & 4 deletions tests/test_workflow_vector_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import geopandas as gpd
import numpy as np
import pandas as pd
import websockets.protocol
import geoengine as ge
from . import UrllibMocker

Expand Down Expand Up @@ -51,9 +52,9 @@ async def __aexit__(self, *args):
pass

@property
def open(self) -> bool:
def state(self) -> websockets.protocol.State:
'''Mock open impl'''
return len(self.__chunks) > 0
return websockets.protocol.State.OPEN if len(self.__chunks) > 0 else websockets.protocol.State.CLOSED

async def recv(self):
return self.__chunks.pop(0)
Expand Down Expand Up @@ -152,7 +153,7 @@ def test_streaming_workflow(self):
resolution=ge.SpatialResolution(0.5, 0.5),
)

with unittest.mock.patch("websockets.client.connect", return_value=MockWebsocket()):
with unittest.mock.patch("websockets.asyncio.client.connect", return_value=MockWebsocket()):
async def inner1():
chunks = []

Expand All @@ -163,7 +164,7 @@ async def inner1():

asyncio.run(inner1())

with unittest.mock.patch("websockets.client.connect", return_value=MockWebsocket()):
with unittest.mock.patch("websockets.asyncio.client.connect", return_value=MockWebsocket()):
async def inner2():
data_frame = await workflow.vector_stream_into_geopandas(query_rect)

Expand Down