Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 1 addition & 8 deletions geoengine/ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ class MlModelConfig:
'''Configuration for an ml model'''
name: str
metadata: MlModelMetadata
file_name: str = "model.onnx"
display_name: str = "My Ml Model"
description: str = "My Ml Model Description"

Expand All @@ -41,7 +40,7 @@ def register_ml_model(onnx_model: ModelProto,

with geoengine_openapi_client.ApiClient(session.configuration) as api_client:
with tempfile.TemporaryDirectory() as temp_dir:
file_name = Path(temp_dir) / model_config.file_name
file_name = Path(temp_dir) / model_config.metadata.file_name

with open(file_name, 'wb') as file:
file.write(onnx_model.SerializeToString())
Expand Down Expand Up @@ -74,12 +73,6 @@ def check_data_type(data_type: TypeProto, expected_type: RasterDataType, prefix:
raise InputException(f'Model {prefix} type `{elem_type_str}` does not match the '
f'expected type `{expected_type}`')

for domain in onnx_model.opset_import:
if domain.domain != '':
continue
if domain.version != 9:
raise InputException('Only ONNX models with opset version 9 are supported')

model_inputs = onnx_model.graph.input
model_outputs = onnx_model.graph.output

Expand Down
21 changes: 0 additions & 21 deletions tests/test_ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ def test_uploading_onnx_model(self):
clf.fit(training_x, training_y)

onnx_clf = to_onnx(clf, training_x[:1], options={'zipmap': False}, target_opset=9)
onnx_clf2 = to_onnx(clf, training_x[:1], options={'zipmap': False}, target_opset=12)

with UrllibMocker() as m:
session_id = "c4983c3e-9b53-47ae-bda9-382223bd5081"
Expand Down Expand Up @@ -132,23 +131,3 @@ def test_uploading_onnx_model(self):
str(exception.exception),
'Model output type `TensorProto.INT64` does not match the expected type `RasterDataType.I32`'
)

with self.assertRaises(ge.InputException) as exception:
ge.register_ml_model(
onnx_model=onnx_clf2,
model_config=ge.ml.MlModelConfig(
name="foo",
metadata=MlModelMetadata(
file_name="model.onnx",
input_type=RasterDataType.F32,
num_input_bands=2,
output_type=RasterDataType.I64,
),
display_name="Decision Tree",
description="A simple decision tree model",
)
)
self.assertEqual(
str(exception.exception),
'Only ONNX models with opset version 9 are supported'
)
Loading