Skip to content

Commit 08bbbde

Browse files
authored
feat: specify ml model nodata handling (#236)
* adapt to updated ml model metadata from backend * ruff ruff * don't use alias field names
1 parent 0bd83d9 commit 08bbbde

File tree

6 files changed

+108
-40
lines changed

6 files changed

+108
-40
lines changed

.github/.backend_git_ref

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
7aadfa383e6eee63442e366890dfb1160114caed
1+
9fcd0e8d520b3e7679d29c969263345ea190ec46

examples/interactive_ml/app/Simple Random Forest Two-Class Classifier on Sentinel-2 Images.ipynb

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
},
1414
{
1515
"cell_type": "code",
16-
"execution_count": 2,
16+
"execution_count": null,
1717
"metadata": {},
1818
"outputs": [],
1919
"source": [
@@ -34,7 +34,15 @@
3434
"import ipyvuetify as vue\n",
3535
"import numpy as np\n",
3636
"import xarray as xr\n",
37-
"from geoengine_openapi_client.models import MlModelMetadata, RasterDataType\n",
37+
"from geoengine_openapi_client.models import (\n",
38+
" MlModelInputNoDataHandling,\n",
39+
" MlModelInputNoDataHandlingVariant,\n",
40+
" MlModelMetadata,\n",
41+
" MlModelOutputNoDataHandling,\n",
42+
" MlModelOutputNoDataHandlingVariant,\n",
43+
" MlTensorShape3D,\n",
44+
" RasterDataType,\n",
45+
")\n",
3846
"from matplotlib import pyplot as plt\n",
3947
"from matplotlib.patches import Circle\n",
4048
"from onnx.checker import check_model\n",
@@ -335,11 +343,18 @@
335343
" onnx_model=onnx_clf,\n",
336344
" model_config=ge.ml.MlModelConfig(\n",
337345
" name=model_name,\n",
346+
" file_name=\"model.onnx\",\n",
338347
" metadata=MlModelMetadata(\n",
339-
" file_name=\"model.onnx\",\n",
340-
" input_type=RasterDataType.F32,\n",
341-
" num_input_bands=4,\n",
342-
" output_type=RasterDataType.U8,\n",
348+
" inputType=RasterDataType.F32,\n",
349+
" outputType=RasterDataType.U8,\n",
350+
" inputShape=MlTensorShape3D(x=1, y=1, bands=4),\n",
351+
" outputShape=MlTensorShape3D(x=1, y=1, bands=1),\n",
352+
" inputNoDataHandling=MlModelInputNoDataHandling(\n",
353+
" variant=MlModelInputNoDataHandlingVariant.SKIPIFNODATA\n",
354+
" ),\n",
355+
" outputNoDataHandling=MlModelOutputNoDataHandling(\n",
356+
" variant=MlModelOutputNoDataHandlingVariant.NANISNODATA\n",
357+
" ),\n",
343358
" ),\n",
344359
" display_name=\"Decision Tree\",\n",
345360
" description=\"A simple decision tree model\",\n",
@@ -813,7 +828,7 @@
813828
"name": "python",
814829
"nbconvert_exporter": "python",
815830
"pygments_lexer": "ipython3",
816-
"version": "3.10.12"
831+
"version": "3.12.3"
817832
}
818833
},
819834
"nbformat": 4,

examples/ml_pipeline.ipynb

Lines changed: 28 additions & 8 deletions
Large diffs are not rendered by default.

geoengine/ml.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from pathlib import Path
1010

1111
import geoengine_openapi_client
12-
import geoengine_openapi_client.models
1312
from geoengine_openapi_client.models import MlModel, MlModelMetadata, MlTensorShape3D, RasterDataType
1413
from onnx import ModelProto, TensorProto, TypeProto
1514
from onnx.helper import tensor_dtype_to_string
@@ -24,6 +23,7 @@ class MlModelConfig:
2423
"""Configuration for an ml model"""
2524

2625
name: str
26+
file_name: str
2727
metadata: MlModelMetadata
2828
display_name: str = "My Ml Model"
2929
description: str = "My Ml Model Description"
@@ -47,7 +47,7 @@ def register_ml_model(
4747

4848
with geoengine_openapi_client.ApiClient(session.configuration) as api_client:
4949
with tempfile.TemporaryDirectory() as temp_dir:
50-
file_name = Path(temp_dir) / model_config.metadata.file_name
50+
file_name = Path(temp_dir) / model_config.file_name
5151

5252
with open(file_name, "wb") as file:
5353
file.write(onnx_model.SerializeToString())
@@ -61,6 +61,7 @@ def register_ml_model(
6161

6262
model = MlModel(
6363
name=model_config.name,
64+
file_name=model_config.file_name,
6465
upload=str(upload_id),
6566
metadata=model_config.metadata,
6667
display_name=model_config.display_name,

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ readme = { file = "README.md", content-type = "text/markdown" }
1616
license-files = ["LICENSE"]
1717
requires-python = ">=3.10"
1818
dependencies = [
19-
"geoengine-openapi-client == 0.0.25",
19+
"geoengine-openapi-client == 0.0.26",
2020
"geopandas >=1.0,<2.0",
2121
"matplotlib >=3.5,<3.11",
2222
"numpy >=1.21,<2.4",

tests/test_ml.py

Lines changed: 53 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,15 @@
33
import unittest
44

55
import numpy as np
6-
from geoengine_openapi_client.models import MlModelMetadata, MlTensorShape3D, RasterDataType
6+
from geoengine_openapi_client.models import (
7+
MlModelInputNoDataHandling,
8+
MlModelInputNoDataHandlingVariant,
9+
MlModelMetadata,
10+
MlModelOutputNoDataHandling,
11+
MlModelOutputNoDataHandlingVariant,
12+
MlTensorShape3D,
13+
RasterDataType,
14+
)
715
from onnx import TensorShapeProto as TSP
816
from skl2onnx import to_onnx
917
from sklearn.ensemble import RandomForestClassifier
@@ -84,12 +92,18 @@ def test_uploading_onnx_model(self):
8492
onnx_model=onnx_clf,
8593
model_config=ge.ml.MlModelConfig(
8694
name=model_name,
95+
file_name="model.onnx",
8796
metadata=MlModelMetadata(
88-
file_name="model.onnx",
89-
input_type=RasterDataType.F32,
90-
output_type=RasterDataType.I64,
91-
input_shape=MlTensorShape3D(y=1, x=1, bands=2),
92-
output_shape=MlTensorShape3D(y=1, x=1, bands=1),
97+
inputType=RasterDataType.F32,
98+
outputType=RasterDataType.I64,
99+
inputShape=MlTensorShape3D(y=1, x=1, bands=2),
100+
outputShape=MlTensorShape3D(y=1, x=1, bands=1),
101+
inputNoDataHandling=MlModelInputNoDataHandling(
102+
variant=MlModelInputNoDataHandlingVariant.SKIPIFNODATA
103+
),
104+
outputNoDataHandling=MlModelOutputNoDataHandling(
105+
variant=MlModelOutputNoDataHandlingVariant.NANISNODATA
106+
),
93107
),
94108
display_name="Decision Tree",
95109
description="A simple decision tree model",
@@ -120,12 +134,18 @@ def test_uploading_onnx_model(self):
120134
onnx_model=onnx_clf,
121135
model_config=ge.ml.MlModelConfig(
122136
name=model_name,
137+
file_name="model.onnx",
123138
metadata=MlModelMetadata(
124-
file_name="model.onnx",
125-
input_type=RasterDataType.F32,
126-
output_type=RasterDataType.I64,
127-
input_shape=MlTensorShape3D(y=1, x=1, bands=4),
128-
output_shape=MlTensorShape3D(y=1, x=1, bands=1),
139+
inputType=RasterDataType.F32,
140+
outputType=RasterDataType.I64,
141+
inputShape=MlTensorShape3D(y=1, x=1, bands=4),
142+
outputShape=MlTensorShape3D(y=1, x=1, bands=1),
143+
inputNoDataHandling=MlModelInputNoDataHandling(
144+
variant=MlModelInputNoDataHandlingVariant.SKIPIFNODATA
145+
),
146+
outputNoDataHandling=MlModelOutputNoDataHandling(
147+
variant=MlModelOutputNoDataHandlingVariant.NANISNODATA
148+
),
129149
),
130150
display_name="Decision Tree",
131151
description="A simple decision tree model",
@@ -140,12 +160,18 @@ def test_uploading_onnx_model(self):
140160
onnx_model=onnx_clf,
141161
model_config=ge.ml.MlModelConfig(
142162
name=model_name,
163+
file_name="model.onnx",
143164
metadata=MlModelMetadata(
144-
file_name="model.onnx",
145-
input_type=RasterDataType.F64,
146-
output_type=RasterDataType.I64,
147-
input_shape=MlTensorShape3D(y=1, x=1, bands=2),
148-
output_shape=MlTensorShape3D(y=1, x=1, bands=1),
165+
inputType=RasterDataType.F64,
166+
outputType=RasterDataType.I64,
167+
inputShape=MlTensorShape3D(y=1, x=1, bands=2),
168+
outputShape=MlTensorShape3D(y=1, x=1, bands=1),
169+
inputNoDataHandling=MlModelInputNoDataHandling(
170+
variant=MlModelInputNoDataHandlingVariant.SKIPIFNODATA
171+
),
172+
outputNoDataHandling=MlModelOutputNoDataHandling(
173+
variant=MlModelOutputNoDataHandlingVariant.NANISNODATA
174+
),
149175
),
150176
display_name="Decision Tree",
151177
description="A simple decision tree model",
@@ -161,12 +187,18 @@ def test_uploading_onnx_model(self):
161187
onnx_model=onnx_clf,
162188
model_config=ge.ml.MlModelConfig(
163189
name="foo",
190+
file_name="model.onnx",
164191
metadata=MlModelMetadata(
165-
file_name="model.onnx",
166-
input_type=RasterDataType.F32,
167-
output_type=RasterDataType.I32,
168-
input_shape=MlTensorShape3D(y=1, x=1, bands=2),
169-
output_shape=MlTensorShape3D(y=1, x=1, bands=1),
192+
inputType=RasterDataType.F32,
193+
outputType=RasterDataType.I32,
194+
inputShape=MlTensorShape3D(y=1, x=1, bands=2),
195+
outputShape=MlTensorShape3D(y=1, x=1, bands=1),
196+
inputNoDataHandling=MlModelInputNoDataHandling(
197+
variant=MlModelInputNoDataHandlingVariant.SKIPIFNODATA
198+
),
199+
outputNoDataHandling=MlModelOutputNoDataHandling(
200+
variant=MlModelOutputNoDataHandlingVariant.NANISNODATA
201+
),
170202
),
171203
display_name="Decision Tree",
172204
description="A simple decision tree model",

0 commit comments

Comments
 (0)