Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 17 additions & 10 deletions geoengine/colorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,13 @@ def from_response(response: geoengine_openapi_client.Colorizer) -> Colorizer:
raise TypeError("Unknown colorizer type")


def rgba_from_list(values: list[int]) -> Rgba:
"""Convert a list of integers to an RGBA tuple."""
if len(values) != 4:
raise ValueError(f"Expected a list of 4 integers, got {len(values)} instead.")
return (values[0], values[1], values[2], values[3])


@dataclass
class LinearGradientColorizer(Colorizer):
'''A linear gradient colorizer.'''
Expand All @@ -242,10 +249,10 @@ def from_response_linear(response: geoengine_openapi_client.LinearGradient) -> L
"""Create a colorizer from a response."""
breakpoints = [ColorBreakpoint.from_response(breakpoint) for breakpoint in response.breakpoints]
return LinearGradientColorizer(
no_data_color=response.no_data_color,
no_data_color=rgba_from_list(response.no_data_color),
breakpoints=breakpoints,
over_color=response.over_color,
under_color=response.under_color,
over_color=rgba_from_list(response.over_color),
under_color=rgba_from_list(response.under_color),
)

def to_api_dict(self) -> geoengine_openapi_client.Colorizer:
Expand Down Expand Up @@ -273,9 +280,9 @@ def from_response_logarithmic(
breakpoints = [ColorBreakpoint.from_response(breakpoint) for breakpoint in response.breakpoints]
return LogarithmicGradientColorizer(
breakpoints=breakpoints,
no_data_color=response.no_data_color,
over_color=response.over_color,
under_color=response.under_color,
no_data_color=rgba_from_list(response.no_data_color),
over_color=rgba_from_list(response.over_color),
under_color=rgba_from_list(response.under_color),
)

def to_api_dict(self) -> geoengine_openapi_client.Colorizer:
Expand All @@ -300,16 +307,16 @@ def from_response_palette(response: geoengine_openapi_client.PaletteColorizer) -
"""Create a colorizer from a response."""

return PaletteColorizer(
colors={float(k): v for k, v in response.colors.items()},
no_data_color=response.no_data_color,
default_color=response.default_color,
colors={float(k): rgba_from_list(v) for k, v in response.colors.items()},
no_data_color=rgba_from_list(response.no_data_color),
default_color=rgba_from_list(response.default_color),
)

def to_api_dict(self) -> geoengine_openapi_client.Colorizer:
"""Return the colorizer as a dictionary."""
return geoengine_openapi_client.Colorizer(geoengine_openapi_client.PaletteColorizer(
type='palette',
colors=self.colors,
colors={str(k): v for k, v in self.colors.items()},
default_color=self.default_color,
no_data_color=self.no_data_color,
))
16 changes: 11 additions & 5 deletions geoengine/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,14 @@ def upload_dataframe(
ints = [key for (key, value) in columns.items() if value.data_type == 'int']
texts = [key for (key, value) in columns.items() if value.data_type == 'text']

result_descriptor = VectorResultDescriptor(
data_type=vector_type,
spatial_reference=df.crs.to_string(),
columns=columns,
).to_api_dict().actual_instance
if not isinstance(result_descriptor, geoengine_openapi_client.TypedVectorResultDescriptor):
raise TypeError('Expected TypedVectorResultDescriptor')

create = geoengine_openapi_client.CreateDataset(
data_path=geoengine_openapi_client.DataPath(geoengine_openapi_client.DataPathOneOf1(
upload=str(upload_id)
Expand Down Expand Up @@ -494,11 +502,9 @@ def upload_dataframe(
),
on_error=on_error.to_api_enum(),
),
result_descriptor=VectorResultDescriptor(
data_type=vector_type,
spatial_reference=df.crs.to_string(),
columns=columns,
).to_api_dict().actual_instance
result_descriptor=geoengine_openapi_client.VectorResultDescriptor.from_dict(
result_descriptor.to_dict()
)
)
)
)
Expand Down
2 changes: 1 addition & 1 deletion geoengine/error.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(self, response: Union[geoengine_openapi_client.ApiException, Dict[s
super().__init__()

if isinstance(response, geoengine_openapi_client.ApiException):
obj = json.loads(response.body)
obj = json.loads(response.body) if response.body else {'error': 'unknown', 'message': 'unknown'}
else:
obj = response

Expand Down
6 changes: 4 additions & 2 deletions geoengine/ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,12 @@ def check_data_type(data_type: TypeProto, expected_type: RasterDataType, prefix:
if not data_type.tensor_type:
raise InputException('Only tensor input types are supported')
elem_type = data_type.tensor_type.elem_type
if elem_type != RASTER_TYPE_TO_ONNX_TYPE[expected_type]:
expected_tensor_type = RASTER_TYPE_TO_ONNX_TYPE[expected_type]
if elem_type != expected_tensor_type:
elem_type_str = tensor_dtype_to_string(elem_type)
expected_type_str = tensor_dtype_to_string(expected_tensor_type)
raise InputException(f'Model {prefix} type `{elem_type_str}` does not match the '
f'expected type `{expected_type}`')
f'expected type `{expected_type_str}`')

model_inputs = onnx_model.graph.input
model_outputs = onnx_model.graph.output
Expand Down
25 changes: 24 additions & 1 deletion geoengine/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,29 @@ def __repr__(self) -> str:
return f'{self.name}: {self.measurement}'


def literal_raster_data_type(
data_type: geoengine_openapi_client.RasterDataType
) -> Literal['U8', 'U16', 'U32', 'U64', 'I8', 'I16', 'I32', 'I64', 'F32', 'F64']:
'''Convert a `RasterDataType` to a literal'''

data_type_map: dict[
geoengine_openapi_client.RasterDataType,
Literal['U8', 'U16', 'U32', 'U64', 'I8', 'I16', 'I32', 'I64', 'F32', 'F64']
] = {
geoengine_openapi_client.RasterDataType.U8: 'U8',
geoengine_openapi_client.RasterDataType.U16: 'U16',
geoengine_openapi_client.RasterDataType.U32: 'U32',
geoengine_openapi_client.RasterDataType.U64: 'U64',
geoengine_openapi_client.RasterDataType.I8: 'I8',
geoengine_openapi_client.RasterDataType.I16: 'I16',
geoengine_openapi_client.RasterDataType.I32: 'I32',
geoengine_openapi_client.RasterDataType.I64: 'I64',
geoengine_openapi_client.RasterDataType.F32: 'F32',
geoengine_openapi_client.RasterDataType.F64: 'F64',
}
return data_type_map[data_type]


class RasterResultDescriptor(ResultDescriptor):
'''
A raster result descriptor
Expand Down Expand Up @@ -701,7 +724,7 @@ def from_response_raster(
response: geoengine_openapi_client.TypedRasterResultDescriptor) -> RasterResultDescriptor:
'''Parse a raster result descriptor from an http response'''
spatial_ref = response.spatial_reference
data_type = response.data_type.value
data_type = literal_raster_data_type(response.data_type)
bands = [RasterBandDescriptor.from_response(band) for band in response.bands]

time_bounds = None
Expand Down
8 changes: 6 additions & 2 deletions geoengine/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -984,8 +984,10 @@ def data_usage(offset: int = 0, limit: int = 10) -> List[geoengine_openapi_clien
response = user_api.data_usage_handler(offset=offset, limit=limit)

# create dataframe from response
usage_dicts = [data_usage.dict(by_alias=True) for data_usage in response]
usage_dicts = [data_usage.model_dump(by_alias=True) for data_usage in response]
df = pd.DataFrame(usage_dicts)
if 'timestamp' in df.columns:
df['timestamp'] = pd.to_datetime(df['timestamp'], utc=True)

return df

Expand All @@ -1005,7 +1007,9 @@ def data_usage_summary(granularity: geoengine_openapi_client.UsageSummaryGranula
offset=offset, limit=limit)

# create dataframe from response
usage_dicts = [data_usage.dict(by_alias=True) for data_usage in response]
usage_dicts = [data_usage.model_dump(by_alias=True) for data_usage in response]
df = pd.DataFrame(usage_dicts)
if 'timestamp' in df.columns:
df['timestamp'] = pd.to_datetime(df['timestamp'], utc=True)

return df
9 changes: 5 additions & 4 deletions geoengine/workflow_builder/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,11 +694,12 @@ def from_operator_dict(cls, operator_dict: Dict[str, Any]) -> 'Expression':

output_band = None
if "outputBand" in operator_dict["params"] and operator_dict["params"]["outputBand"] is not None:
output_band = RasterBandDescriptor.from_response(
geoengine_openapi_client.RasterBandDescriptor.from_dict(
operator_dict["params"]["outputBand"]
)
raster_band_descriptor = geoengine_openapi_client.RasterBandDescriptor.from_dict(
operator_dict["params"]["outputBand"]
)
if raster_band_descriptor is None:
raise ValueError("Invalid output band")
output_band = RasterBandDescriptor.from_response(raster_band_descriptor)

return Expression(
expression=operator_dict["params"]["expression"],
Expand Down
4 changes: 2 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package_dir =
packages = find:
python_requires = >=3.9
install_requires =
geoengine-openapi-client == 0.0.19
geoengine-openapi-client == 0.0.21
geopandas >=0.9,<0.15
matplotlib >=3.5,<3.8
numpy >=1.21,<2.1
Expand All @@ -34,7 +34,7 @@ install_requires =
websockets >= 10.0,<11
xarray >=0.19,<2024.12
urllib3 >= 2.0, < 2.3
pydantic >= 1.10.5, < 2
pydantic >= 2.10.6, < 2.11
skl2onnx >=1.17,<2

[options.extras_require]
Expand Down
1 change: 1 addition & 0 deletions tests/ge_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ def _start(self) -> None:
'GEOENGINE__POSTGRES__PASSWORD': POSTGRES_PASSWORD,
'GEOENGINE__POSTGRES__SCHEMA': self.db_schema,
'GEOENGINE__LOGGING__LOG_SPEC': GE_LOG_SPEC,
'GEOENGINE__POSTGRES__CLEAR_DATABASE_ON_START': 'true',
'PATH': os.environ['PATH'],
},
stderr=subprocess.PIPE,
Expand Down
4 changes: 2 additions & 2 deletions tests/test_ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def test_uploading_onnx_model(self):
)
self.assertEqual(
str(exception.exception),
'Model input type `TensorProto.FLOAT` does not match the expected type `RasterDataType.F64`'
'Model input type `TensorProto.FLOAT` does not match the expected type `TensorProto.DOUBLE`'
)

with self.assertRaises(ge.InputException) as exception:
Expand All @@ -126,5 +126,5 @@ def test_uploading_onnx_model(self):
)
self.assertEqual(
str(exception.exception),
'Model output type `TensorProto.INT64` does not match the expected type `RasterDataType.I32`'
'Model output type `TensorProto.INT64` does not match the expected type `TensorProto.INT32`'
)
45 changes: 23 additions & 22 deletions tests/test_workflow_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,28 +17,29 @@ def setUp(self) -> None:
def test_storing_workflow(self):

expected_request_text = {
'name': None,
'displayName': 'Foo',
'description': 'Bar',
'query': {
'spatialBounds': {
'upperLeftCoordinate': {
'x': -180.0,
'y': 90.0
},
'lowerRightCoordinate': {
'x': 180.0,
'y': -90.0
}
},
'timeInterval': {
'start': 1396353600000,
'end': 1396353600000,
},
'spatialResolution': {
'x': 1.8,
'y': 1.8
}
"asCog": True,
"description": "Bar",
"displayName": "Foo",
"name": None,
"query": {
"spatialBounds": {
"lowerRightCoordinate": {
"x": 180,
"y": -90
},
"upperLeftCoordinate": {
"x": -180,
"y": 90
}
},
"spatialResolution": {
"x": 1.8,
"y": 1.8
},
"timeInterval": {
"end": 1396353600000,
"start": 1396353600000
}
}
}

Expand Down