18
18
from typing import Any , Dict , List , Optional
19
19
20
20
import marshmallow as ma
21
+ import numpy as np
21
22
import pandas as pd
22
23
import pkg_resources
23
24
import yaml
@@ -204,31 +205,25 @@ def _validate_bundle_state(self):
204
205
training_predictions_column_name = None
205
206
validation_predictions_column_name = None
206
207
if "training" in self ._bundle_resources :
207
- with open (
208
- f"{ self .bundle_path } /training/dataset_config.yaml" ,
209
- "r" ,
210
- encoding = "UTF-8" ,
211
- ) as stream :
212
- training_dataset_config = yaml .safe_load (stream )
213
-
208
+ training_dataset_config = utils .load_dataset_config_from_bundle (
209
+ bundle_path = self .bundle_path , label = "training"
210
+ )
214
211
training_predictions_column_name = training_dataset_config .get (
215
212
"predictionsColumnName"
216
213
)
217
214
218
215
if "validation" in self ._bundle_resources :
219
- with open (
220
- f"{ self .bundle_path } /validation/dataset_config.yaml" ,
221
- "r" ,
222
- encoding = "UTF-8" ,
223
- ) as stream :
224
- validation_dataset_config = yaml .safe_load (stream )
225
-
216
+ validation_dataset_config = utils .load_dataset_config_from_bundle (
217
+ bundle_path = self .bundle_path , label = "validation"
218
+ )
226
219
validation_predictions_column_name = validation_dataset_config .get (
227
220
"predictionsColumnName"
228
221
)
229
222
230
223
if "model" in self ._bundle_resources :
231
- model_config = self ._load_model_config_from_bundle ()
224
+ model_config = utils .load_model_config_from_bundle (
225
+ bundle_path = self .bundle_path
226
+ )
232
227
model_type = model_config .get ("modelType" )
233
228
if (
234
229
training_predictions_column_name is None
@@ -306,17 +301,21 @@ def _validate_bundle_resources(self):
306
301
307
302
if "model" in self ._bundle_resources and not self ._skip_model_validation :
308
303
model_config_file_path = f"{ self .bundle_path } /model/model_config.yaml"
309
- model_config = self ._load_model_config_from_bundle ()
304
+ model_config = utils .load_model_config_from_bundle (
305
+ bundle_path = self .bundle_path
306
+ )
310
307
311
308
if model_config ["modelType" ] == "shell" :
312
309
model_validator = ModelValidator (
313
310
model_config_file_path = model_config_file_path
314
311
)
315
312
elif model_config ["modelType" ] == "full" :
316
313
# Use data from the validation as test data
317
- validation_dataset_df = self ._load_dataset_from_bundle ("validation" )
318
- validation_dataset_config = self ._load_dataset_config_from_bundle (
319
- "validation"
314
+ validation_dataset_df = utils .load_dataset_from_bundle (
315
+ bundle_path = self .bundle_path , label = "validation"
316
+ )
317
+ validation_dataset_config = utils .load_dataset_config_from_bundle (
318
+ bundle_path = self .bundle_path , label = "validation"
320
319
)
321
320
322
321
sample_data = None
@@ -350,60 +349,6 @@ def _validate_bundle_resources(self):
350
349
# Add the bundle resources failed validations to the list of all failed validations
351
350
self .failed_validations .extend (bundle_resources_failed_validations )
352
351
353
- def _load_dataset_from_bundle (self , label : str ) -> pd .DataFrame :
354
- """Loads a dataset from a commit bundle.
355
-
356
- Parameters
357
- ----------
358
- label : str
359
- The type of the dataset. Can be either "training" or "validation".
360
-
361
- Returns
362
- -------
363
- pd.DataFrame
364
- The dataset.
365
- """
366
- dataset_file_path = f"{ self .bundle_path } /{ label } /dataset.csv"
367
-
368
- dataset_df = pd .read_csv (dataset_file_path )
369
-
370
- return dataset_df
371
-
372
- def _load_dataset_config_from_bundle (self , label : str ) -> Dict [str , Any ]:
373
- """Loads a dataset config from a commit bundle.
374
-
375
- Parameters
376
- ----------
377
- label : str
378
- The type of the dataset. Can be either "training" or "validation".
379
-
380
- Returns
381
- -------
382
- Dict[str, Any]
383
- The dataset config.
384
- """
385
- dataset_config_file_path = f"{ self .bundle_path } /{ label } /dataset_config.yaml"
386
-
387
- with open (dataset_config_file_path , "r" , encoding = "UTF-8" ) as stream :
388
- dataset_config = yaml .safe_load (stream )
389
-
390
- return dataset_config
391
-
392
- def _load_model_config_from_bundle (self ) -> Dict [str , Any ]:
393
- """Loads a model config from a commit bundle.
394
-
395
- Returns
396
- -------
397
- Dict[str, Any]
398
- The model config.
399
- """
400
- model_config_file_path = f"{ self .bundle_path } /model/model_config.yaml"
401
-
402
- with open (model_config_file_path , "r" , encoding = "UTF-8" ) as stream :
403
- model_config = yaml .safe_load (stream )
404
-
405
- return model_config
406
-
407
352
def _validate_resource_consistency (self ):
408
353
"""Validates that the resources in the bundle are consistent with each other.
409
354
@@ -419,10 +364,14 @@ def _validate_resource_consistency(self):
419
364
# Loading the relevant configs
420
365
model_config = {}
421
366
if "model" in self ._bundle_resources :
422
- model_config = self ._load_model_config_from_bundle ()
423
- training_dataset_config = self ._load_dataset_config_from_bundle ("training" )
424
- validation_dataset_config = self ._load_dataset_config_from_bundle (
425
- "validation"
367
+ model_config = utils .load_model_config_from_bundle (
368
+ bundle_path = self .bundle_path
369
+ )
370
+ training_dataset_config = utils .load_dataset_config_from_bundle (
371
+ bundle_path = self .bundle_path , label = "training"
372
+ )
373
+ validation_dataset_config = utils .load_dataset_config_from_bundle (
374
+ bundle_path = self .bundle_path , label = "validation"
426
375
)
427
376
model_feature_names = model_config .get ("featureNames" )
428
377
model_class_names = model_config .get ("classNames" )
@@ -1113,6 +1062,8 @@ def __init__(
1113
1062
self .sample_data = sample_data
1114
1063
self ._use_runner = use_runner
1115
1064
self .failed_validations = []
1065
+ self .model_config = None
1066
+ self .model_output = None
1116
1067
1117
1068
def validate (self ) -> List [str ]:
1118
1069
"""Runs all model validations.
@@ -1300,6 +1251,8 @@ def _validate_model_config(self):
1300
1251
if model_config_failed_validations :
1301
1252
logger .error ("`model_config.yaml` failed validations:" )
1302
1253
_list_failed_validation_messages (model_config_failed_validations )
1254
+ else :
1255
+ self .model_config = model_config
1303
1256
1304
1257
# Add the `model_config.yaml` failed validations to the list of all failed validations
1305
1258
self .failed_validations .extend (model_config_failed_validations )
@@ -1359,7 +1312,9 @@ def _validate_prediction_interface(self):
1359
1312
# Test `predict_proba` function
1360
1313
try :
1361
1314
with utils .HidePrints ():
1362
- ml_model .predict_proba (self .sample_data )
1315
+ self .model_output = ml_model .predict_proba (
1316
+ self .sample_data
1317
+ )
1363
1318
except Exception as exc :
1364
1319
exception_stack = utils .get_exception_stacktrace (exc )
1365
1320
prediction_interface_failed_validations .append (
@@ -1368,6 +1323,9 @@ def _validate_prediction_interface(self):
1368
1323
f"\t { exception_stack } "
1369
1324
)
1370
1325
1326
+ if self .model_output is not None :
1327
+ self ._validate_model_output ()
1328
+
1371
1329
# Print results of the validation
1372
1330
if prediction_interface_failed_validations :
1373
1331
logger .error ("`prediction_interface.py` failed validations:" )
@@ -1401,6 +1359,48 @@ def _validate_model_runner(self):
1401
1359
# Add the model runner failed validations to the list of all failed validations
1402
1360
self .failed_validations .extend (model_runner_failed_validations )
1403
1361
1362
+ def _validate_model_output (self ):
1363
+ """Validates the model output.
1364
+
1365
+ Checks if the model output is an-array like object with shape (n_samples, n_classes)
1366
+ Also checks if the model output is a probability distribution.
1367
+ """
1368
+ model_output_failed_validations = []
1369
+
1370
+ # Check if the model output is an array-like object
1371
+ if not isinstance (self .model_output , np .ndarray ):
1372
+ model_output_failed_validations .append (
1373
+ "The output of the `predict_proba` method in the `prediction_interface.py` "
1374
+ "file is not an array-like object. It should be a numpy array of shape "
1375
+ "(n_samples, n_classes)."
1376
+ )
1377
+ elif self .model_config is not None :
1378
+ # Check if the model output has the correct shape
1379
+ num_rows = len (self .sample_data )
1380
+ num_classes = len (self .model_config .get ("classes" ))
1381
+ if self .model_output .shape != (num_rows , num_classes ):
1382
+ model_output_failed_validations .append (
1383
+ "The output of the `predict_proba` method in the `prediction_interface.py` "
1384
+ " has the wrong shape. It should be a numpy array of shape "
1385
+ f"({ num_rows } , { num_classes } ). The current output has shape "
1386
+ f"{ self .model_output .shape } "
1387
+ )
1388
+ # Check if the model output is a probability distribution
1389
+ elif not np .allclose (self .model_output .sum (axis = 1 ), 1 , atol = 0.05 ):
1390
+ model_output_failed_validations .append (
1391
+ "The output of the `predict_proba` method in the `prediction_interface.py` "
1392
+ "file is not a probability distribution. The sum of the probabilities for "
1393
+ "each sample should be equal to 1."
1394
+ )
1395
+
1396
+ # Print results of the validation
1397
+ if model_output_failed_validations :
1398
+ logger .error ("Model output failed validations:" )
1399
+ _list_failed_validation_messages (model_output_failed_validations )
1400
+
1401
+ # Add the model output failed validations to the list of all failed validations
1402
+ self .failed_validations .extend (model_output_failed_validations )
1403
+
1404
1404
1405
1405
class ProjectValidator :
1406
1406
"""Validates the project.
0 commit comments