Skip to content

Commit 46258e8

Browse files
gustavocidornelaswhoseoyster
authored andcommitted
Completes OPEN-3915 Validate predict_proba output for full models
1 parent 5b0c06e commit 46258e8

File tree

2 files changed

+137
-78
lines changed

2 files changed

+137
-78
lines changed

openlayer/utils.py

+59
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77
import sys
88
import traceback
99
import warnings
10+
from typing import Any, Dict
1011

12+
import pandas as pd
1113
import yaml
1214

1315

@@ -136,3 +138,60 @@ def list_resources_in_bundle(bundle_path: str) -> list:
136138
if resource in VALID_RESOURCES:
137139
resources.append(resource)
138140
return resources
141+
142+
143+
def load_dataset_from_bundle(bundle_path: str, label: str) -> pd.DataFrame:
144+
"""Loads a dataset from a commit bundle.
145+
146+
Parameters
147+
----------
148+
label : str
149+
The type of the dataset. Can be either "training" or "validation".
150+
151+
Returns
152+
-------
153+
pd.DataFrame
154+
The dataset.
155+
"""
156+
dataset_file_path = f"{bundle_path}/{label}/dataset.csv"
157+
158+
dataset_df = pd.read_csv(dataset_file_path)
159+
160+
return dataset_df
161+
162+
163+
def load_dataset_config_from_bundle(bundle_path: str, label: str) -> Dict[str, Any]:
164+
"""Loads a dataset config from a commit bundle.
165+
166+
Parameters
167+
----------
168+
label : str
169+
The type of the dataset. Can be either "training" or "validation".
170+
171+
Returns
172+
-------
173+
Dict[str, Any]
174+
The dataset config.
175+
"""
176+
dataset_config_file_path = f"{bundle_path}/{label}/dataset_config.yaml"
177+
178+
with open(dataset_config_file_path, "r", encoding="UTF-8") as stream:
179+
dataset_config = yaml.safe_load(stream)
180+
181+
return dataset_config
182+
183+
184+
def load_model_config_from_bundle(bundle_path: str) -> Dict[str, Any]:
185+
"""Loads a model config from a commit bundle.
186+
187+
Returns
188+
-------
189+
Dict[str, Any]
190+
The model config.
191+
"""
192+
model_config_file_path = f"{bundle_path}/model/model_config.yaml"
193+
194+
with open(model_config_file_path, "r", encoding="UTF-8") as stream:
195+
model_config = yaml.safe_load(stream)
196+
197+
return model_config

openlayer/validators.py

+78-78
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from typing import Any, Dict, List, Optional
1919

2020
import marshmallow as ma
21+
import numpy as np
2122
import pandas as pd
2223
import pkg_resources
2324
import yaml
@@ -204,31 +205,25 @@ def _validate_bundle_state(self):
204205
training_predictions_column_name = None
205206
validation_predictions_column_name = None
206207
if "training" in self._bundle_resources:
207-
with open(
208-
f"{self.bundle_path}/training/dataset_config.yaml",
209-
"r",
210-
encoding="UTF-8",
211-
) as stream:
212-
training_dataset_config = yaml.safe_load(stream)
213-
208+
training_dataset_config = utils.load_dataset_config_from_bundle(
209+
bundle_path=self.bundle_path, label="training"
210+
)
214211
training_predictions_column_name = training_dataset_config.get(
215212
"predictionsColumnName"
216213
)
217214

218215
if "validation" in self._bundle_resources:
219-
with open(
220-
f"{self.bundle_path}/validation/dataset_config.yaml",
221-
"r",
222-
encoding="UTF-8",
223-
) as stream:
224-
validation_dataset_config = yaml.safe_load(stream)
225-
216+
validation_dataset_config = utils.load_dataset_config_from_bundle(
217+
bundle_path=self.bundle_path, label="validation"
218+
)
226219
validation_predictions_column_name = validation_dataset_config.get(
227220
"predictionsColumnName"
228221
)
229222

230223
if "model" in self._bundle_resources:
231-
model_config = self._load_model_config_from_bundle()
224+
model_config = utils.load_model_config_from_bundle(
225+
bundle_path=self.bundle_path
226+
)
232227
model_type = model_config.get("modelType")
233228
if (
234229
training_predictions_column_name is None
@@ -306,17 +301,21 @@ def _validate_bundle_resources(self):
306301

307302
if "model" in self._bundle_resources and not self._skip_model_validation:
308303
model_config_file_path = f"{self.bundle_path}/model/model_config.yaml"
309-
model_config = self._load_model_config_from_bundle()
304+
model_config = utils.load_model_config_from_bundle(
305+
bundle_path=self.bundle_path
306+
)
310307

311308
if model_config["modelType"] == "shell":
312309
model_validator = ModelValidator(
313310
model_config_file_path=model_config_file_path
314311
)
315312
elif model_config["modelType"] == "full":
316313
# Use data from the validation as test data
317-
validation_dataset_df = self._load_dataset_from_bundle("validation")
318-
validation_dataset_config = self._load_dataset_config_from_bundle(
319-
"validation"
314+
validation_dataset_df = utils.load_dataset_from_bundle(
315+
bundle_path=self.bundle_path, label="validation"
316+
)
317+
validation_dataset_config = utils.load_dataset_config_from_bundle(
318+
bundle_path=self.bundle_path, label="validation"
320319
)
321320

322321
sample_data = None
@@ -350,60 +349,6 @@ def _validate_bundle_resources(self):
350349
# Add the bundle resources failed validations to the list of all failed validations
351350
self.failed_validations.extend(bundle_resources_failed_validations)
352351

353-
def _load_dataset_from_bundle(self, label: str) -> pd.DataFrame:
354-
"""Loads a dataset from a commit bundle.
355-
356-
Parameters
357-
----------
358-
label : str
359-
The type of the dataset. Can be either "training" or "validation".
360-
361-
Returns
362-
-------
363-
pd.DataFrame
364-
The dataset.
365-
"""
366-
dataset_file_path = f"{self.bundle_path}/{label}/dataset.csv"
367-
368-
dataset_df = pd.read_csv(dataset_file_path)
369-
370-
return dataset_df
371-
372-
def _load_dataset_config_from_bundle(self, label: str) -> Dict[str, Any]:
373-
"""Loads a dataset config from a commit bundle.
374-
375-
Parameters
376-
----------
377-
label : str
378-
The type of the dataset. Can be either "training" or "validation".
379-
380-
Returns
381-
-------
382-
Dict[str, Any]
383-
The dataset config.
384-
"""
385-
dataset_config_file_path = f"{self.bundle_path}/{label}/dataset_config.yaml"
386-
387-
with open(dataset_config_file_path, "r", encoding="UTF-8") as stream:
388-
dataset_config = yaml.safe_load(stream)
389-
390-
return dataset_config
391-
392-
def _load_model_config_from_bundle(self) -> Dict[str, Any]:
393-
"""Loads a model config from a commit bundle.
394-
395-
Returns
396-
-------
397-
Dict[str, Any]
398-
The model config.
399-
"""
400-
model_config_file_path = f"{self.bundle_path}/model/model_config.yaml"
401-
402-
with open(model_config_file_path, "r", encoding="UTF-8") as stream:
403-
model_config = yaml.safe_load(stream)
404-
405-
return model_config
406-
407352
def _validate_resource_consistency(self):
408353
"""Validates that the resources in the bundle are consistent with each other.
409354
@@ -419,10 +364,14 @@ def _validate_resource_consistency(self):
419364
# Loading the relevant configs
420365
model_config = {}
421366
if "model" in self._bundle_resources:
422-
model_config = self._load_model_config_from_bundle()
423-
training_dataset_config = self._load_dataset_config_from_bundle("training")
424-
validation_dataset_config = self._load_dataset_config_from_bundle(
425-
"validation"
367+
model_config = utils.load_model_config_from_bundle(
368+
bundle_path=self.bundle_path
369+
)
370+
training_dataset_config = utils.load_dataset_config_from_bundle(
371+
bundle_path=self.bundle_path, label="training"
372+
)
373+
validation_dataset_config = utils.load_dataset_config_from_bundle(
374+
bundle_path=self.bundle_path, label="validation"
426375
)
427376
model_feature_names = model_config.get("featureNames")
428377
model_class_names = model_config.get("classNames")
@@ -1113,6 +1062,8 @@ def __init__(
11131062
self.sample_data = sample_data
11141063
self._use_runner = use_runner
11151064
self.failed_validations = []
1065+
self.model_config = None
1066+
self.model_output = None
11161067

11171068
def validate(self) -> List[str]:
11181069
"""Runs all model validations.
@@ -1300,6 +1251,8 @@ def _validate_model_config(self):
13001251
if model_config_failed_validations:
13011252
logger.error("`model_config.yaml` failed validations:")
13021253
_list_failed_validation_messages(model_config_failed_validations)
1254+
else:
1255+
self.model_config = model_config
13031256

13041257
# Add the `model_config.yaml` failed validations to the list of all failed validations
13051258
self.failed_validations.extend(model_config_failed_validations)
@@ -1359,7 +1312,9 @@ def _validate_prediction_interface(self):
13591312
# Test `predict_proba` function
13601313
try:
13611314
with utils.HidePrints():
1362-
ml_model.predict_proba(self.sample_data)
1315+
self.model_output = ml_model.predict_proba(
1316+
self.sample_data
1317+
)
13631318
except Exception as exc:
13641319
exception_stack = utils.get_exception_stacktrace(exc)
13651320
prediction_interface_failed_validations.append(
@@ -1368,6 +1323,9 @@ def _validate_prediction_interface(self):
13681323
f"\t {exception_stack}"
13691324
)
13701325

1326+
if self.model_output is not None:
1327+
self._validate_model_output()
1328+
13711329
# Print results of the validation
13721330
if prediction_interface_failed_validations:
13731331
logger.error("`prediction_interface.py` failed validations:")
@@ -1401,6 +1359,48 @@ def _validate_model_runner(self):
14011359
# Add the model runner failed validations to the list of all failed validations
14021360
self.failed_validations.extend(model_runner_failed_validations)
14031361

1362+
def _validate_model_output(self):
1363+
"""Validates the model output.
1364+
1365+
Checks if the model output is an-array like object with shape (n_samples, n_classes)
1366+
Also checks if the model output is a probability distribution.
1367+
"""
1368+
model_output_failed_validations = []
1369+
1370+
# Check if the model output is an array-like object
1371+
if not isinstance(self.model_output, np.ndarray):
1372+
model_output_failed_validations.append(
1373+
"The output of the `predict_proba` method in the `prediction_interface.py` "
1374+
"file is not an array-like object. It should be a numpy array of shape "
1375+
"(n_samples, n_classes)."
1376+
)
1377+
elif self.model_config is not None:
1378+
# Check if the model output has the correct shape
1379+
num_rows = len(self.sample_data)
1380+
num_classes = len(self.model_config.get("classes"))
1381+
if self.model_output.shape != (num_rows, num_classes):
1382+
model_output_failed_validations.append(
1383+
"The output of the `predict_proba` method in the `prediction_interface.py` "
1384+
" has the wrong shape. It should be a numpy array of shape "
1385+
f"({num_rows}, {num_classes}). The current output has shape "
1386+
f"{self.model_output.shape}"
1387+
)
1388+
# Check if the model output is a probability distribution
1389+
elif not np.allclose(self.model_output.sum(axis=1), 1, atol=0.05):
1390+
model_output_failed_validations.append(
1391+
"The output of the `predict_proba` method in the `prediction_interface.py` "
1392+
"file is not a probability distribution. The sum of the probabilities for "
1393+
"each sample should be equal to 1."
1394+
)
1395+
1396+
# Print results of the validation
1397+
if model_output_failed_validations:
1398+
logger.error("Model output failed validations:")
1399+
_list_failed_validation_messages(model_output_failed_validations)
1400+
1401+
# Add the model output failed validations to the list of all failed validations
1402+
self.failed_validations.extend(model_output_failed_validations)
1403+
14041404

14051405
class ProjectValidator:
14061406
"""Validates the project.

0 commit comments

Comments
 (0)