@@ -22,10 +22,6 @@ class KWArgsPrefixes:
22
22
PREDICT = "PREDICT_"
23
23
24
24
25
- def _more_than_one (arg_list : List ) -> bool :
26
- return len ([1 for arg in arg_list if arg is not None ]) > 1
27
-
28
-
29
25
def _get_sub_dict_by_prefix (src : Dict , prefix_key : str ) -> Dict [str , Any ]:
30
26
"""
31
27
Collect all the keys from the given dict that starts with the given prefix and creates a new dictionary with these
@@ -164,9 +160,6 @@ def train(
164
160
drop_columns = drop_columns ,
165
161
)
166
162
167
- # Remove labels from sample set:
168
- sample_set = sample_set .drop (label_columns , axis = 1 , errors = "ignore" )
169
-
170
163
# Parsing kwargs:
171
164
# TODO: Use in xgb or lgbm train function.
172
165
train_kwargs = _get_sub_dict_by_prefix (
@@ -209,9 +202,9 @@ def train(
209
202
context = context ,
210
203
tag = tag ,
211
204
sample_set = sample_set ,
212
- y_column = label_columns ,
205
+ y_columns = label_columns ,
213
206
test_set = test_set ,
214
- X_test = x_test ,
207
+ x_test = x_test ,
215
208
y_test = y_test ,
216
209
artifacts = context .artifacts ,
217
210
)
@@ -277,18 +270,22 @@ def predict(
277
270
dataset : mlrun .DataItem ,
278
271
drop_columns : Union [str , List [str ], int , List [int ]] = None ,
279
272
label_columns : Optional [Union [str , List [str ]]] = None ,
273
+ result_set : Optional [str ] = None ,
280
274
):
281
275
"""
282
276
Predicting dataset by a model.
283
277
284
278
:param context: MLRun context.
285
279
:param model: The model Store path.
286
- :param dataset: The dataset to evaluate the model on. Can be either a URI, a FeatureVector or a
280
+ :param dataset: The dataset to predict the model on. Can be either a URI, a FeatureVector or a
287
281
sample in a shape of a list/dict.
282
+ When passing a sample, pass the dataset as a field in `params` instead of `inputs`.
288
283
:param drop_columns: str/int or a list of strings/ints that represent the column names/indices to drop.
289
284
When the dataset is a list/dict this parameter should be represented by integers.
290
285
:param label_columns: The target label(s) of the column(s) in the dataset. for Regression or
291
286
Classification tasks.
287
+ :param result_set: The db key to set name of the prediction result and the filename.
288
+ Default to 'prediction'.
292
289
"""
293
290
# Get dataset by URL or by FeatureVector:
294
291
dataset , label_columns = _get_dataframe (
@@ -307,20 +304,38 @@ def predict(
307
304
model_handler = AutoMLRun .load_model (model_path = model , context = context )
308
305
309
306
# Dropping label columns if necessary:
310
- if label_columns and all (label in dataset .columns for label in label_columns ):
311
- dataset = dataset .drop (label_columns , axis = 1 )
307
+ if not label_columns :
308
+ label_columns = []
309
+ elif isinstance (label_columns , str ):
310
+ label_columns = [label_columns ]
312
311
313
312
# Predicting:
314
313
context .logger .info (f"making prediction by '{ model_handler .model_name } '" )
315
314
y_pred = model_handler .model .predict (dataset , ** predict_kwargs )
316
315
317
- if not label_columns :
318
- if len (y_pred .shape ) == 1 or y_pred .shape [1 ] == 1 :
319
- label_columns = ["predicted_labels" ]
316
+ # Preparing and validating label columns for the dataframe of the prediction result:
317
+ num_predicted = 1 if len (y_pred .shape ) == 1 else y_pred .shape [1 ]
318
+
319
+ if num_predicted > len (label_columns ):
320
+ if num_predicted == 1 :
321
+ label_columns = ["predicted labels" ]
320
322
else :
321
- label_columns = [f"predicted_label_{ i } " for i in range (y_pred .shape [1 ])]
322
- elif isinstance (label_columns , str ):
323
- label_columns = [label_columns ]
323
+ label_columns .extend (
324
+ [
325
+ f"predicted_label_{ i + 1 + len (label_columns )} "
326
+ for i in range (num_predicted - len (label_columns ))
327
+ ]
328
+ )
329
+ elif num_predicted < len (label_columns ):
330
+ context .logger .error (
331
+ f"number of predicted labels: { num_predicted } is smaller than number of label columns: { len (label_columns )} "
332
+ )
333
+ raise ValueError
324
334
335
+ artifact_name = result_set or 'prediction'
336
+ labels_inside_df = set (label_columns ) & set (dataset .columns .tolist ())
337
+ if labels_inside_df :
338
+ context .logger .error (f"The labels: { labels_inside_df } are already existed in the dataframe" )
339
+ raise ValueError
325
340
pred_df = pd .concat ([dataset , pd .DataFrame (y_pred , columns = label_columns )], axis = 1 )
326
- context .log_dataset ("prediction" , pred_df )
341
+ context .log_dataset (artifact_name , pred_df , db_key = result_set )
0 commit comments