Skip to content

Commit 25ccb7b

Browse files
Merge pull request #203 from geo-engine/fix_ml_filename_issue
In the end, there can be only one file_name
2 parents 7d5cba8 + b75074e commit 25ccb7b

File tree

2 files changed

+1
-29
lines changed

2 files changed

+1
-29
lines changed

geoengine/ml.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ class MlModelConfig:
1919
'''Configuration for an ml model'''
2020
name: str
2121
metadata: MlModelMetadata
22-
file_name: str = "model.onnx"
2322
display_name: str = "My Ml Model"
2423
description: str = "My Ml Model Description"
2524

@@ -41,7 +40,7 @@ def register_ml_model(onnx_model: ModelProto,
4140

4241
with geoengine_openapi_client.ApiClient(session.configuration) as api_client:
4342
with tempfile.TemporaryDirectory() as temp_dir:
44-
file_name = Path(temp_dir) / model_config.file_name
43+
file_name = Path(temp_dir) / model_config.metadata.file_name
4544

4645
with open(file_name, 'wb') as file:
4746
file.write(onnx_model.SerializeToString())
@@ -74,12 +73,6 @@ def check_data_type(data_type: TypeProto, expected_type: RasterDataType, prefix:
7473
raise InputException(f'Model {prefix} type `{elem_type_str}` does not match the '
7574
f'expected type `{expected_type}`')
7675

77-
for domain in onnx_model.opset_import:
78-
if domain.domain != '':
79-
continue
80-
if domain.version != 9:
81-
raise InputException('Only ONNX models with opset version 9 are supported')
82-
8376
model_inputs = onnx_model.graph.input
8477
model_outputs = onnx_model.graph.output
8578

tests/test_ml.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ def test_uploading_onnx_model(self):
2323
clf.fit(training_x, training_y)
2424

2525
onnx_clf = to_onnx(clf, training_x[:1], options={'zipmap': False}, target_opset=9)
26-
onnx_clf2 = to_onnx(clf, training_x[:1], options={'zipmap': False}, target_opset=12)
2726

2827
with UrllibMocker() as m:
2928
session_id = "c4983c3e-9b53-47ae-bda9-382223bd5081"
@@ -132,23 +131,3 @@ def test_uploading_onnx_model(self):
132131
str(exception.exception),
133132
'Model output type `TensorProto.INT64` does not match the expected type `RasterDataType.I32`'
134133
)
135-
136-
with self.assertRaises(ge.InputException) as exception:
137-
ge.register_ml_model(
138-
onnx_model=onnx_clf2,
139-
model_config=ge.ml.MlModelConfig(
140-
name="foo",
141-
metadata=MlModelMetadata(
142-
file_name="model.onnx",
143-
input_type=RasterDataType.F32,
144-
num_input_bands=2,
145-
output_type=RasterDataType.I64,
146-
),
147-
display_name="Decision Tree",
148-
description="A simple decision tree model",
149-
)
150-
)
151-
self.assertEqual(
152-
str(exception.exception),
153-
'Only ONNX models with opset version 9 are supported'
154-
)

0 commit comments

Comments
 (0)