Skip to content
Merged
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ python3 -m mypy tests
Using the config file `mypy.ini`, you can suppress missing stub errors for external libraries.
You can ignore a library by adding two lines to the config file. For example, suppressing matplotlib would look like this:

```
```ini
[mypy-matplotlib.*]
ignore_missing_imports = True

Expand Down
2 changes: 1 addition & 1 deletion geoengine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from .layers import Layer, LayerCollection, LayerListing, LayerCollectionListing, \
LayerId, LayerCollectionId, LayerProviderId, \
layer_collection, layer
from .ml import register_ml_model, MlModelConfig
from .ml import register_ml_model, MlModelConfig, MlModelName
from .permissions import add_permission, remove_permission, add_role, remove_role, assign_role, revoke_role, \
ADMIN_ROLE_ID, REGISTERED_USER_ROLE_ID, ANONYMOUS_USER_ROLE_ID, Permission, Resource, UserId, RoleId
from .tasks import Task, TaskId
Expand Down
6 changes: 3 additions & 3 deletions geoengine/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def to_api_enum(self) -> geoengine_openapi_client.OgrSourceErrorSpec:


class DatasetName:
'''A wrapper for a dataset id'''
'''A wrapper for a dataset name'''

__dataset_name: str

Expand All @@ -266,7 +266,7 @@ def __init__(self, dataset_name: str) -> None:

@classmethod
def from_response(cls, response: geoengine_openapi_client.CreateDatasetHandler200Response) -> DatasetName:
'''Parse a http response to an `DatasetId`'''
'''Parse a http response to an `DatasetName`'''
return DatasetName(response.dataset_name)

def __str__(self) -> str:
Expand All @@ -276,7 +276,7 @@ def __repr__(self) -> str:
return str(self)

def __eq__(self, other) -> bool:
'''Checks if two dataset ids are equal'''
'''Checks if two dataset names are equal'''
if not isinstance(other, self.__class__):
return False

Expand Down
39 changes: 37 additions & 2 deletions geoengine/ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
Util functions for machine learning
'''

from __future__ import annotations
from pathlib import Path
import tempfile
from dataclasses import dataclass
import geoengine_openapi_client.models
from onnx import TypeProto, TensorProto, ModelProto
from onnx.helper import tensor_dtype_to_string
from geoengine_openapi_client.models import MlModelMetadata, MlModel, RasterDataType
Expand All @@ -23,10 +25,42 @@ class MlModelConfig:
description: str = "My Ml Model Description"


class MlModelName:
'''A wrapper for an MlModel name'''

__ml_model_name: str

def __init__(self, ml_model_name: str) -> None:
self.__ml_model_name = ml_model_name

@classmethod
def from_response(cls, response: geoengine_openapi_client.models.MlModelNameResponse) -> MlModelName:
'''Parse a http response to an `DatasetName`'''
return MlModelName(response.ml_model_name)

def __str__(self) -> str:
return self.__ml_model_name

def __repr__(self) -> str:
return str(self)

def __eq__(self, other) -> bool:
'''Checks if two dataset names are equal'''
if not isinstance(other, self.__class__):
return False

return self.__ml_model_name == other.__ml_model_name # pylint: disable=protected-access

def to_api_dict(self) -> geoengine_openapi_client.models.MlModelNameResponse:
return geoengine_openapi_client.models.MlModelNameResponse(
ml_model_name=str(self.__ml_model_name)
)


def register_ml_model(onnx_model: ModelProto,
model_config: MlModelConfig,
upload_timeout: int = 3600,
register_timeout: int = 60):
register_timeout: int = 60) -> MlModelName:
'''Uploads an onnx file and registers it as an ml model'''

validate_model_config(
Expand Down Expand Up @@ -55,7 +89,8 @@ def register_ml_model(onnx_model: ModelProto,

model = MlModel(name=model_config.name, upload=str(upload_id), metadata=model_config.metadata,
display_name=model_config.display_name, description=model_config.description)
ml_api.add_ml_model(model, _request_timeout=register_timeout)
res_name = ml_api.add_ml_model(model, _request_timeout=register_timeout)
return MlModelName.from_response(res_name)


def validate_model_config(onnx_model: ModelProto, *,
Expand Down
22 changes: 17 additions & 5 deletions geoengine/permissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from enum import Enum

import ast
from typing import Dict, Literal, Any
from typing import Dict, Literal, Any, Union
from uuid import UUID

import geoengine_openapi_client
Expand All @@ -15,6 +15,7 @@
from geoengine.datasets import DatasetName
from geoengine.error import GeoEngineException
from geoengine.layers import LayerCollectionId, LayerId
from geoengine.ml import MlModelName


class RoleId:
Expand Down Expand Up @@ -82,7 +83,7 @@ def __repr__(self) -> str:
class Resource:
'''A wrapper for a resource id'''

def __init__(self, resource_type: Literal['dataset', 'layer', 'layerCollection'],
def __init__(self, resource_type: Literal['dataset', 'layer', 'layerCollection', 'mlModel'],
resource_id: str) -> None:
'''Create a resource id'''
self.__type = resource_type
Expand All @@ -99,9 +100,18 @@ def from_layer_collection_id(cls, layer_collection_id: LayerCollectionId) -> Res
return Resource('layerCollection', str(layer_collection_id))

@classmethod
def from_dataset_name(cls, dataset_name: DatasetName) -> Resource:
'''Create a resource id from a dataset id'''
return Resource('dataset', str(dataset_name))
def from_dataset_name(cls, dataset_name: Union[DatasetName, str]) -> Resource:
'''Create a resource id from a dataset name'''
if isinstance(dataset_name, DatasetName):
dataset_name = str(dataset_name)
return Resource('dataset', dataset_name)

@classmethod
def from_ml_model_name(cls, ml_model_name: Union[MlModelName, str]) -> Resource:
'''Create a resource from an ml model name'''
if isinstance(ml_model_name, MlModelName):
ml_model_name = str(ml_model_name)
return Resource('mlModel', ml_model_name)

def to_api_dict(self) -> geoengine_openapi_client.Resource:
'''Convert to a dict for the API'''
Expand All @@ -115,6 +125,8 @@ def to_api_dict(self) -> geoengine_openapi_client.Resource:
inner = geoengine_openapi_client.ProjectResource(type="project", id=self.__id)
elif self.__type == "dataset":
inner = geoengine_openapi_client.DatasetResource(type="dataset", id=self.__id)
elif self.__type == "mlModel":
inner = geoengine_openapi_client.MlModelResource(type="mlModel", id=self.__id)

return geoengine_openapi_client.Resource(inner)

Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package_dir =
packages = find:
python_requires = >=3.9
install_requires =
geoengine-openapi-client == 0.0.18
geoengine-openapi-client @ git+https://github.com/geo-engine/openapi-client@ml_and_dataset_name_as_resource_name#subdirectory=python
geopandas >=0.9,<0.15
matplotlib >=3.5,<3.8
numpy >=1.21,<2
Expand Down
60 changes: 23 additions & 37 deletions tests/test_ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np
from geoengine_openapi_client.models import MlModelMetadata, RasterDataType
import geoengine as ge
from . import UrllibMocker
from tests.ge_test import GeoEngineTestInstance


class WorkflowStorageTests(unittest.TestCase):
Expand All @@ -24,43 +24,19 @@ def test_uploading_onnx_model(self):

onnx_clf = to_onnx(clf, training_x[:1], options={'zipmap': False}, target_opset=9)

with UrllibMocker() as m:
session_id = "c4983c3e-9b53-47ae-bda9-382223bd5081"
request_headers = {'Authorization': f'Bearer {session_id}'}
# TODO: use `enterContext(cm)` instead of `with cm:` in Python 3.11
with GeoEngineTestInstance() as ge_instance:
ge_instance.wait_for_ready()

m.post('http://mock-instance/anonymous', json={
"id": session_id,
"project": None,
"view": None
})
ge.initialize(ge_instance.address())

upload_id = "c314ff6d-3e37-41b4-b9b2-3669f13f7369"
session = ge.get_session()
model_name = f"{session.user_id}:foo"

m.post('http://mock-instance/upload', json={
"id": upload_id
}, request_headers=request_headers)

m.post('http://mock-instance/ml/models',
expected_request_body={
"description": "A simple decision tree model",
"displayName": "Decision Tree",
"metadata": {
"fileName": "model.onnx",
"inputType": "F32",
"numInputBands": 2,
"outputType": "I64"
},
"name": "foo",
"upload": upload_id
},
request_headers=request_headers)

ge.initialize("http://mock-instance")

ge.register_ml_model(
res_name = ge.register_ml_model(
onnx_model=onnx_clf,
model_config=ge.ml.MlModelConfig(
name="foo",
name=model_name,
metadata=MlModelMetadata(
file_name="model.onnx",
input_type=RasterDataType.F32,
Expand All @@ -71,12 +47,22 @@ def test_uploading_onnx_model(self):
description="A simple decision tree model",
)
)
self.assertEqual(str(res_name), model_name)

# Now test permission setting
ge.add_permission(
ge.REGISTERED_USER_ROLE_ID, ge.Resource.from_ml_model_name(res_name), ge.Permission.READ
)
ge.remove_permission(
ge.REGISTERED_USER_ROLE_ID, ge.Resource.from_ml_model_name(res_name), ge.Permission.READ
)

# failing tests
with self.assertRaises(ge.InputException) as exception:
ge.register_ml_model(
_res_name = ge.register_ml_model(
onnx_model=onnx_clf,
model_config=ge.ml.MlModelConfig(
name="foo",
name=model_name,
metadata=MlModelMetadata(
file_name="model.onnx",
input_type=RasterDataType.F32,
Expand All @@ -93,10 +79,10 @@ def test_uploading_onnx_model(self):
)

with self.assertRaises(ge.InputException) as exception:
ge.register_ml_model(
_res_name = ge.register_ml_model(
onnx_model=onnx_clf,
model_config=ge.ml.MlModelConfig(
name="foo",
name=model_name,
metadata=MlModelMetadata(
file_name="model.onnx",
input_type=RasterDataType.F64,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class UploadTests(unittest.TestCase):
'''Test runner regarding upload functionality'''

def setUp(self) -> None:
ge.reset(False)
ge.reset(logout=False)

def test_upload(self):
# TODO: use `enterContext(cm)` instead of `with cm:` in Python 3.11
Expand Down
Loading