Skip to content
88 changes: 85 additions & 3 deletions src/sasctl/_services/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,13 @@ class ModelManagement(Service):
# TODO: set ds2MultiType
@classmethod
def publish_model(
cls, model, destination, name=None, force=False, reload_model_table=False
cls,
model,
destination,
model_version="latest",
name=None,
force=False,
reload_model_table=False,
):
"""

Expand All @@ -38,6 +44,8 @@ def publish_model(
The name or id of the model, or a dictionary representation of the model.
destination : str
Name of destination to publish the model to.
model_version : str or dict, optional
Provide the version id, name, or dict to publish. Defaults to 'latest'.
name : str, optional
Provide a custom name for the published model. Defaults to None.
force : bool, optional
Expand Down Expand Up @@ -68,6 +76,23 @@ def publish_model(

# TODO: Verify allowed formats by destination type.
# As of 19w04 MAS throws HTTP 500 if name is in invalid format.
if model_version != "latest":
if isinstance(model_version, dict) and "modelVersionName" in model_version:
model_version_name = model_version["modelVersionName"]
elif (
isinstance(model_version, dict)
and "modelVersionName" not in model_version
):
raise ValueError("Model version is not recognized.")
elif isinstance(model_version, str) and cls.is_uuid(model_version):
model_version_name = mr.get_model_or_version(model, model_version)[
"modelVersionName"
]
else:
model_version_name = model_version
else:
model_version_name = ""

model_name = name or "{}_{}".format(
model_obj["name"].replace(" ", ""), model_obj["id"]
).replace("-", "")
Expand All @@ -79,6 +104,7 @@ def publish_model(
{
"modelName": mp._publish_name(model_name),
"sourceUri": model_uri.get("uri"),
"modelVersionID": model_version_name,
"publishLevel": "model",
}
],
Expand All @@ -104,6 +130,7 @@ def create_performance_definition(
table_prefix,
project=None,
models=None,
modelVersions=None,
library_name="Public",
name=None,
description=None,
Expand Down Expand Up @@ -136,6 +163,8 @@ def create_performance_definition(
The name or id of the model(s), or a dictionary representation of the model(s). For
multiple models, input a list of model names, or a list of dictionaries. If no models are specified, all
models in the project specified will be used. Defaults to None.
modelVersions: str, list, optional
The name of the model version(s). Defaults to None, so all models are latest.
library_name : str
The library containing the input data, default is 'Public'.
name : str, optional
Expand Down Expand Up @@ -239,10 +268,13 @@ def create_performance_definition(
"property set." % project.name
)

# Creating the new array of modelIds with version names appended
updated_models = cls.check_model_versions(models, modelVersions)

request = {
"projectId": project.id,
"name": name or project.name + " Performance",
"modelIds": [model.id for model in models],
"modelIds": updated_models,
"championMonitored": monitor_champion,
"challengerMonitored": monitor_challenger,
"maxBins": max_bins,
Expand Down Expand Up @@ -279,7 +311,6 @@ def create_performance_definition(
for v in project.get("variables", [])
if v.get("role") == "output"
]

return cls.post(
"/performanceTasks",
json=request,
Expand All @@ -288,6 +319,57 @@ def create_performance_definition(
},
)

@classmethod
def check_model_versions(cls, models, modelVersions):
"""
Checking if the model version(s) are valid and append to model id accordingly.

Parameters
----------
models: list of str
List of models.
modelVersions : list of str
List of model versions associated with models.

Returns
-------
String list
"""
if not modelVersions:
return [model.id for model in models]

updated_models = []
if not isinstance(modelVersions, list):
modelVersions = [modelVersions]

if len(models) < len(modelVersions):
raise ValueError(
"There are too many versions for the amount of models specified."
)

modelVersions = modelVersions + [""] * (len(models) - len(modelVersions))
for model, modelVersionName in zip(models, modelVersions):

if (
isinstance(modelVersionName, dict)
and "modelVersionName" in modelVersionName
):

modelVersionName = modelVersionName["modelVersionName"]
elif (
isinstance(modelVersionName, dict)
and "modelVersionName" not in modelVersionName
):

raise ValueError("Model version is not recognized.")

if modelVersionName != "":
updated_models.append(model.id + ":" + modelVersionName)
else:
updated_models.append(model.id)

return updated_models

@classmethod
def execute_performance_definition(cls, definition):
"""Launches a job to run a performance definition.
Expand Down
3 changes: 2 additions & 1 deletion src/sasctl/_services/model_publish.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from .model_repository import ModelRepository
from .service import Service
from ..utils.decorators import deprecated


class ModelPublish(Service):
Expand Down Expand Up @@ -90,7 +91,7 @@ def delete_destination(cls, item):

return cls.delete("/destinations/{name}".format(name=item))

@classmethod
@deprecated("Use publish_model in model_management.py instead.", "1.11.5")
def publish_model(cls, model, destination, name=None, code=None, notes=None):
"""Publish a model to an existing publishing destination.

Expand Down
51 changes: 44 additions & 7 deletions src/sasctl/_services/score_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def create_score_definition(
description: str = "",
server_name: str = "cas-shared-default",
library_name: str = "Public",
model_version: str = "latest",
model_version: Union[str, dict] = "latest",
):
"""Creates the score definition service.

Expand All @@ -69,7 +69,7 @@ def create_score_definition(
library_name: str, optional
The library within the CAS server the table exists in. Defaults to "Public".
model_version: str, optional
The user-chosen version of the model with the specified model_id. Defaults to "latest".
The user-chosen version of the model. Deafaults to "latest".

Returns
-------
Expand Down Expand Up @@ -116,7 +116,7 @@ def create_score_definition(
table = cls._cas_management.get_table(table_name, library_name, server_name)
if not table and not table_file:
raise HTTPError(
f"This table may not exist in CAS. Please include the `table_file` argument in the function call if it doesn't exist."
"This table may not exist in CAS. Include the `table_file` argument."
)
elif not table and table_file:
cls._cas_management.upload_file(
Expand All @@ -125,16 +125,19 @@ def create_score_definition(
table = cls._cas_management.get_table(table_name, library_name, server_name)
if not table:
raise HTTPError(
f"The file failed to upload properly or another error occurred."
"The file failed to upload properly or another error occurred."
)
# Checks if the inputted table exists, and if not, uploads a file to create a new table

object_uri, model_version = cls.check_model_version(model_id, model_version)
# Checks if the model version is valid and how to find the name

save_score_def = {
"name": model_name, # used to be score_def_name
"description": description,
"objectDescriptor": {
"uri": f"/modelManagement/models/{model_id}",
"name": f"{model_name}({model_version})",
"uri": object_uri,
"name": f"{model_name} ({model_version})",
"type": f"{object_descriptor_type}",
},
"inputData": {
Expand All @@ -149,7 +152,7 @@ def create_score_definition(
"projectUri": f"/modelRepository/projects/{model_project_id}",
"projectVersionUri": f"/modelRepository/projects/{model_project_id}/projectVersions/{model_project_version_id}",
"publishDestination": "",
"versionedModel": f"{model_name}({model_version})",
"versionedModel": f"{model_name} ({model_version})",
},
"mappings": inputMapping,
}
Expand All @@ -161,3 +164,37 @@ def create_score_definition(
"/definitions", data=json.dumps(save_score_def), headers=headers_score_def
)
# The response information of the score definition can be seen as a JSON as well as a RestOBJ

@classmethod
def check_model_version(cls, model_id: str, model_version: Union[str, dict]):
"""Checks if the model version is valid.

Parameters
----------
model_version : str or dict
The model version to check.

Returns
-------
String tuple
"""
if model_version != "latest":

if isinstance(model_version, dict) and "modelVersionName" in model_version:
model_version = model_version["modelVersionName"]
elif (
isinstance(model_version, dict)
and "modelVersionName" not in model_version
):
raise ValueError("Model version cannot be found.")
elif isinstance(model_version, str) and cls.is_uuid(model_version):
model_version = cls._model_repository.get_model_or_version(
model_id, model_version
)["modelVersionName"]

object_uri = f"/modelManagement/models/{model_id}/versions/@{model_version}"

else:
object_uri = f"/modelManagement/models/{model_id}"

return object_uri, model_version
Loading
Loading