From df8de2d51eb868c2811b9b47e0bbc6343218d507 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20Dr=C3=B6nner?= Date: Tue, 19 Nov 2024 15:23:07 +0100 Subject: [PATCH 1/4] In the end, there can be only one file_name --- geoengine/ml.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/geoengine/ml.py b/geoengine/ml.py index d9ef6c01..ac7a1fe7 100644 --- a/geoengine/ml.py +++ b/geoengine/ml.py @@ -19,7 +19,6 @@ class MlModelConfig: '''Configuration for an ml model''' name: str metadata: MlModelMetadata - file_name: str = "model.onnx" display_name: str = "My Ml Model" description: str = "My Ml Model Description" @@ -41,7 +40,7 @@ def register_ml_model(onnx_model: ModelProto, with geoengine_openapi_client.ApiClient(session.configuration) as api_client: with tempfile.TemporaryDirectory() as temp_dir: - file_name = Path(temp_dir) / model_config.file_name + file_name = Path(temp_dir) / model_config.metadata.file_name with open(file_name, 'wb') as file: file.write(onnx_model.SerializeToString()) @@ -74,11 +73,6 @@ def check_data_type(data_type: TypeProto, expected_type: RasterDataType, prefix: raise InputException(f'Model {prefix} type `{elem_type_str}` does not match the ' f'expected type `{expected_type}`') - for domain in onnx_model.opset_import: - if domain.domain != '': - continue - if domain.version != 9: - raise InputException('Only ONNX models with opset version 9 are supported') model_inputs = onnx_model.graph.input model_outputs = onnx_model.graph.output From e7fcf8fd7b5e0903789595ce3a3092dcdd3bdbf1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20Dr=C3=B6nner?= Date: Tue, 19 Nov 2024 15:25:21 +0100 Subject: [PATCH 2/4] remove a single blank line --- geoengine/ml.py | 1 - 1 file changed, 1 deletion(-) diff --git a/geoengine/ml.py b/geoengine/ml.py index ac7a1fe7..de355300 100644 --- a/geoengine/ml.py +++ b/geoengine/ml.py @@ -73,7 +73,6 @@ def check_data_type(data_type: TypeProto, expected_type: RasterDataType, prefix: raise InputException(f'Model {prefix} type `{elem_type_str}` does not match the ' f'expected type `{expected_type}`') - model_inputs = onnx_model.graph.input model_outputs = onnx_model.graph.output From 4dc3465de1386ac063ce3b2a564cb53f592cdf4f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20Dr=C3=B6nner?= Date: Tue, 19 Nov 2024 15:40:25 +0100 Subject: [PATCH 3/4] remove opset test --- tests/test_ml.py | 20 -------------------- 1 file changed, 20 deletions(-) diff --git a/tests/test_ml.py b/tests/test_ml.py index 88a7865c..77b4fe63 100644 --- a/tests/test_ml.py +++ b/tests/test_ml.py @@ -132,23 +132,3 @@ def test_uploading_onnx_model(self): str(exception.exception), 'Model output type `TensorProto.INT64` does not match the expected type `RasterDataType.I32`' ) - - with self.assertRaises(ge.InputException) as exception: - ge.register_ml_model( - onnx_model=onnx_clf2, - model_config=ge.ml.MlModelConfig( - name="foo", - metadata=MlModelMetadata( - file_name="model.onnx", - input_type=RasterDataType.F32, - num_input_bands=2, - output_type=RasterDataType.I64, - ), - display_name="Decision Tree", - description="A simple decision tree model", - ) - ) - self.assertEqual( - str(exception.exception), - 'Only ONNX models with opset version 9 are supported' - ) From b75074e6a66592ed7e6eb4aef386678be31fada9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20Dr=C3=B6nner?= Date: Tue, 19 Nov 2024 17:31:39 +0100 Subject: [PATCH 4/4] remove unused variable --- tests/test_ml.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_ml.py b/tests/test_ml.py index 77b4fe63..5fe03080 100644 --- a/tests/test_ml.py +++ b/tests/test_ml.py @@ -23,7 +23,6 @@ def test_uploading_onnx_model(self): clf.fit(training_x, training_y) onnx_clf = to_onnx(clf, training_x[:1], options={'zipmap': False}, target_opset=9) - onnx_clf2 = to_onnx(clf, training_x[:1], options={'zipmap': False}, target_opset=12) with UrllibMocker() as m: session_id = "c4983c3e-9b53-47ae-bda9-382223bd5081"