diff --git a/geoengine/ml.py b/geoengine/ml.py index d9ef6c01..de355300 100644 --- a/geoengine/ml.py +++ b/geoengine/ml.py @@ -19,7 +19,6 @@ class MlModelConfig: '''Configuration for an ml model''' name: str metadata: MlModelMetadata - file_name: str = "model.onnx" display_name: str = "My Ml Model" description: str = "My Ml Model Description" @@ -41,7 +40,7 @@ def register_ml_model(onnx_model: ModelProto, with geoengine_openapi_client.ApiClient(session.configuration) as api_client: with tempfile.TemporaryDirectory() as temp_dir: - file_name = Path(temp_dir) / model_config.file_name + file_name = Path(temp_dir) / model_config.metadata.file_name with open(file_name, 'wb') as file: file.write(onnx_model.SerializeToString()) @@ -74,12 +73,6 @@ def check_data_type(data_type: TypeProto, expected_type: RasterDataType, prefix: raise InputException(f'Model {prefix} type `{elem_type_str}` does not match the ' f'expected type `{expected_type}`') - for domain in onnx_model.opset_import: - if domain.domain != '': - continue - if domain.version != 9: - raise InputException('Only ONNX models with opset version 9 are supported') - model_inputs = onnx_model.graph.input model_outputs = onnx_model.graph.output diff --git a/tests/test_ml.py b/tests/test_ml.py index 88a7865c..5fe03080 100644 --- a/tests/test_ml.py +++ b/tests/test_ml.py @@ -23,7 +23,6 @@ def test_uploading_onnx_model(self): clf.fit(training_x, training_y) onnx_clf = to_onnx(clf, training_x[:1], options={'zipmap': False}, target_opset=9) - onnx_clf2 = to_onnx(clf, training_x[:1], options={'zipmap': False}, target_opset=12) with UrllibMocker() as m: session_id = "c4983c3e-9b53-47ae-bda9-382223bd5081" @@ -132,23 +131,3 @@ def test_uploading_onnx_model(self): str(exception.exception), 'Model output type `TensorProto.INT64` does not match the expected type `RasterDataType.I32`' ) - - with self.assertRaises(ge.InputException) as exception: - ge.register_ml_model( - onnx_model=onnx_clf2, - model_config=ge.ml.MlModelConfig( - name="foo", - metadata=MlModelMetadata( - file_name="model.onnx", - input_type=RasterDataType.F32, - num_input_bands=2, - output_type=RasterDataType.I64, - ), - display_name="Decision Tree", - description="A simple decision tree model", - ) - ) - self.assertEqual( - str(exception.exception), - 'Only ONNX models with opset version 9 are supported' - )