Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions examples/ml_pipeline.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
"import geoengine as ge\n",
"from geoengine.ml import MlModelConfig\n",
"\n",
"from geoengine_openapi_client.models import MlModelMetadata, RasterDataType\n",
"from geoengine_openapi_client.models import MlModelMetadata, RasterDataType, TensorShape3D\n",
"\n",
"from sklearn.tree import DecisionTreeClassifier\n",
"import numpy as np\n",
Expand Down Expand Up @@ -88,8 +88,9 @@
"metadata = MlModelMetadata(\n",
" file_name=\"model.onnx\",\n",
" input_type=RasterDataType.F32,\n",
" num_input_bands=2,\n",
" output_type=RasterDataType.I64,\n",
" input_shape=TensorShape3D(y=1, x=1, bands=2),\n",
" output_shape=TensorShape3D(y=1, x=1, bands=1)\n",
")\n",
"\n",
"model_config = MlModelConfig(\n",
Expand Down
71 changes: 61 additions & 10 deletions geoengine/ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import geoengine_openapi_client.models
from onnx import TypeProto, TensorProto, ModelProto
from onnx.helper import tensor_dtype_to_string
from geoengine_openapi_client.models import MlModelMetadata, MlModel, RasterDataType
from geoengine_openapi_client.models import MlModelMetadata, MlModel, RasterDataType, TensorShape3D
import geoengine_openapi_client
from geoengine.auth import get_session
from geoengine.datasets import UploadId
Expand Down Expand Up @@ -67,7 +67,8 @@ def register_ml_model(onnx_model: ModelProto,
onnx_model,
input_type=model_config.metadata.input_type,
output_type=model_config.metadata.output_type,
num_input_bands=model_config.metadata.num_input_bands,
input_shape=model_config.metadata.input_shape,
out_shape=model_config.metadata.output_shape
)

session = get_session()
Expand All @@ -93,10 +94,12 @@ def register_ml_model(onnx_model: ModelProto,
return MlModelName.from_response(res_name)


# pylint: disable=too-many-branches,too-many-statements
def validate_model_config(onnx_model: ModelProto, *,
input_type: RasterDataType,
output_type: RasterDataType,
num_input_bands: int):
input_shape: TensorShape3D,
out_shape: TensorShape3D):
'''Validates the model config. Raises an exception if the model config is invalid'''

def check_data_type(data_type: TypeProto, expected_type: RasterDataType, prefix: 'str'):
Expand All @@ -115,18 +118,66 @@ def check_data_type(data_type: TypeProto, expected_type: RasterDataType, prefix:
raise InputException('Models with multiple inputs are not supported')
check_data_type(model_inputs[0].type, input_type, 'input')

dims = model_inputs[0].type.tensor_type.shape.dim
if len(dims) != 2:
raise InputException('Only 2D input tensors are supported')
if not dims[1].dim_value:
raise InputException('Dimension 1 of the input tensor must have a length')
if dims[1].dim_value != num_input_bands:
raise InputException(f'Model input has {dims[1].dim_value} bands, but {num_input_bands} bands are expected')
dim = model_inputs[0].type.tensor_type.shape.dim

if len(dim) == 2:
if not dim[1].dim_value:
raise InputException('Dimension 1 of a 1D input tensor must have a length')
if dim[1].dim_value != input_shape.bands:
raise InputException(f'Model input has {dim[1].dim_value} bands, but {input_shape.bands} are expected')
elif len(dim) == 4:
if not dim[1].dim_value:
raise InputException('Dimension 1 of the a 3D input tensor must have a length')
if not dim[2].dim_value:
raise InputException('Dimension 2 of the a 3D input tensor must have a length')
if not dim[3].dim_value:
raise InputException('Dimension 3 of the a 3D input tensor must have a length')
if dim[1].dim_value != input_shape.y:
raise InputException(f'Model input has {dim[1].dim_value} y size, but {input_shape.y} are expected')
if dim[2].dim_value != input_shape.x:
raise InputException(f'Model input has {dim[2].dim_value} x size, but {input_shape.x} are expected')
if dim[3].dim_value != input_shape.bands:
raise InputException(f'Model input has {dim[3].dim_value} bands, but {input_shape.bands} are expected')
else:
raise InputException('Only 1D and 3D input tensors are supported')

if len(model_outputs) < 1:
raise InputException('Models with no outputs are not supported')
check_data_type(model_outputs[0].type, output_type, 'output')

dim = model_outputs[0].type.tensor_type.shape.dim
if len(dim) == 1:
pass # this is a happens if there is only a single out? so shape would be [-1]
elif len(dim) == 2:
if not dim[1].dim_value:
raise InputException('Dimension 1 of a 1D input tensor must have a length')
if dim[1].dim_value != 1:
raise InputException(f'Model output has {dim[1].dim_value} bands, but {out_shape.bands} are expected')
elif len(dim) == 3:
if not dim[1].dim_value:
raise InputException('Dimension 1 of a 3D input tensor must have a length')
if not dim[2].dim_value:
raise InputException('Dimension 2 of a 3D input tensor must have a length')
if dim[1].dim_value != out_shape.y:
raise InputException(f'Model output has {dim[1].dim_value} y size, but {out_shape.y} are expected')
if dim[2].dim_value != out_shape.x:
raise InputException(f'Model output has {dim[2].dim_value} x size, but {out_shape.x} are expected')
elif len(dim) == 4:
if not dim[1].dim_value:
raise InputException('Dimension 1 of the a 3D input tensor must have a length')
if not dim[2].dim_value:
raise InputException('Dimension 2 of the a 3D input tensor must have a length')
if not dim[3].dim_value:
raise InputException('Dimension 3 of the a 3D input tensor must have a length')
if dim[1].dim_value != out_shape.y:
raise InputException(f'Model output has {dim[1].dim_value} y size, but {out_shape.y} are expected')
if dim[2].dim_value != out_shape.x:
raise InputException(f'Model output has {dim[2].dim_value} x size, but {out_shape.x} are expected')
if dim[3].dim_value != out_shape.bands:
raise InputException(f'Model output has {dim[3].dim_value} bands, but {out_shape.bands} are expected')
else:
raise InputException('Only 1D and 3D output tensors are supported')


RASTER_TYPE_TO_ONNX_TYPE = {
RasterDataType.F32: TensorProto.FLOAT,
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package_dir =
packages = find:
python_requires = >=3.9
install_requires =
geoengine-openapi-client == 0.0.19
geoengine-openapi-client @ git+https://github.com/geo-engine/openapi-client@ml-model-input-output-shape#subdirectory=python
geopandas >=0.9,<0.15
matplotlib >=3.5,<3.8
numpy >=1.21,<2.1
Expand Down
16 changes: 10 additions & 6 deletions tests/test_ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from sklearn.ensemble import RandomForestClassifier
from skl2onnx import to_onnx
import numpy as np
from geoengine_openapi_client.models import MlModelMetadata, RasterDataType
from geoengine_openapi_client.models import MlModelMetadata, RasterDataType, TensorShape3D
import geoengine as ge
from tests.ge_test import GeoEngineTestInstance

Expand Down Expand Up @@ -40,8 +40,9 @@ def test_uploading_onnx_model(self):
metadata=MlModelMetadata(
file_name="model.onnx",
input_type=RasterDataType.F32,
num_input_bands=2,
output_type=RasterDataType.I64,
input_shape=TensorShape3D(y=1, x=1, bands=2),
output_shape=TensorShape3D(y=1, x=1, bands=1)
),
display_name="Decision Tree",
description="A simple decision tree model",
Expand Down Expand Up @@ -77,16 +78,17 @@ def test_uploading_onnx_model(self):
metadata=MlModelMetadata(
file_name="model.onnx",
input_type=RasterDataType.F32,
num_input_bands=4,
output_type=RasterDataType.I64,
input_shape=TensorShape3D(y=1, x=1, bands=4),
output_shape=TensorShape3D(y=1, x=1, bands=1)
),
display_name="Decision Tree",
description="A simple decision tree model",
)
)
self.assertEqual(
str(exception.exception),
'Model input has 2 bands, but 4 bands are expected'
'Model input has 2 bands, but 4 are expected'
)

with self.assertRaises(ge.InputException) as exception:
Expand All @@ -97,8 +99,9 @@ def test_uploading_onnx_model(self):
metadata=MlModelMetadata(
file_name="model.onnx",
input_type=RasterDataType.F64,
num_input_bands=2,
output_type=RasterDataType.I64,
input_shape=TensorShape3D(y=1, x=1, bands=2),
output_shape=TensorShape3D(y=1, x=1, bands=1)
),
display_name="Decision Tree",
description="A simple decision tree model",
Expand All @@ -117,8 +120,9 @@ def test_uploading_onnx_model(self):
metadata=MlModelMetadata(
file_name="model.onnx",
input_type=RasterDataType.F32,
num_input_bands=2,
output_type=RasterDataType.I32,
input_shape=TensorShape3D(y=1, x=1, bands=2),
output_shape=TensorShape3D(y=1, x=1, bands=1)
),
display_name="Decision Tree",
description="A simple decision tree model",
Expand Down
Loading