@@ -22,10 +22,6 @@ class KWArgsPrefixes:
2222 PREDICT = "PREDICT_"
2323
2424
25- def _more_than_one (arg_list : List ) -> bool :
26- return len ([1 for arg in arg_list if arg is not None ]) > 1
27-
28-
2925def _get_sub_dict_by_prefix (src : Dict , prefix_key : str ) -> Dict [str , Any ]:
3026 """
3127 Collect all the keys from the given dict that starts with the given prefix and creates a new dictionary with these
@@ -164,9 +160,6 @@ def train(
164160 drop_columns = drop_columns ,
165161 )
166162
167- # Remove labels from sample set:
168- sample_set = sample_set .drop (label_columns , axis = 1 , errors = "ignore" )
169-
170163 # Parsing kwargs:
171164 # TODO: Use in xgb or lgbm train function.
172165 train_kwargs = _get_sub_dict_by_prefix (
@@ -209,9 +202,9 @@ def train(
209202 context = context ,
210203 tag = tag ,
211204 sample_set = sample_set ,
212- y_column = label_columns ,
205+ y_columns = label_columns ,
213206 test_set = test_set ,
214- X_test = x_test ,
207+ x_test = x_test ,
215208 y_test = y_test ,
216209 artifacts = context .artifacts ,
217210 )
@@ -277,18 +270,22 @@ def predict(
277270 dataset : mlrun .DataItem ,
278271 drop_columns : Union [str , List [str ], int , List [int ]] = None ,
279272 label_columns : Optional [Union [str , List [str ]]] = None ,
273+ result_set : Optional [str ] = None ,
280274):
281275 """
282276 Predicting dataset by a model.
283277
284278 :param context: MLRun context.
285279 :param model: The model Store path.
286- :param dataset: The dataset to evaluate the model on. Can be either a URI, a FeatureVector or a
280+ :param dataset: The dataset to predict the model on. Can be either a URI, a FeatureVector or a
287281 sample in a shape of a list/dict.
282+ When passing a sample, pass the dataset as a field in `params` instead of `inputs`.
288283 :param drop_columns: str/int or a list of strings/ints that represent the column names/indices to drop.
289284 When the dataset is a list/dict this parameter should be represented by integers.
290285 :param label_columns: The target label(s) of the column(s) in the dataset. for Regression or
291286 Classification tasks.
287+ :param result_set: The db key to set name of the prediction result and the filename.
288+ Default to 'prediction'.
292289 """
293290 # Get dataset by URL or by FeatureVector:
294291 dataset , label_columns = _get_dataframe (
@@ -307,20 +304,38 @@ def predict(
307304 model_handler = AutoMLRun .load_model (model_path = model , context = context )
308305
309306 # Dropping label columns if necessary:
310- if label_columns and all (label in dataset .columns for label in label_columns ):
311- dataset = dataset .drop (label_columns , axis = 1 )
307+ if not label_columns :
308+ label_columns = []
309+ elif isinstance (label_columns , str ):
310+ label_columns = [label_columns ]
312311
313312 # Predicting:
314313 context .logger .info (f"making prediction by '{ model_handler .model_name } '" )
315314 y_pred = model_handler .model .predict (dataset , ** predict_kwargs )
316315
317- if not label_columns :
318- if len (y_pred .shape ) == 1 or y_pred .shape [1 ] == 1 :
319- label_columns = ["predicted_labels" ]
316+ # Preparing and validating label columns for the dataframe of the prediction result:
317+ num_predicted = 1 if len (y_pred .shape ) == 1 else y_pred .shape [1 ]
318+
319+ if num_predicted > len (label_columns ):
320+ if num_predicted == 1 :
321+ label_columns = ["predicted labels" ]
320322 else :
321- label_columns = [f"predicted_label_{ i } " for i in range (y_pred .shape [1 ])]
322- elif isinstance (label_columns , str ):
323- label_columns = [label_columns ]
323+ label_columns .extend (
324+ [
325+ f"predicted_label_{ i + 1 + len (label_columns )} "
326+ for i in range (num_predicted - len (label_columns ))
327+ ]
328+ )
329+ elif num_predicted < len (label_columns ):
330+ context .logger .error (
331+ f"number of predicted labels: { num_predicted } is smaller than number of label columns: { len (label_columns )} "
332+ )
333+ raise ValueError
324334
335+ artifact_name = result_set or 'prediction'
336+ labels_inside_df = set (label_columns ) & set (dataset .columns .tolist ())
337+ if labels_inside_df :
338+ context .logger .error (f"The labels: { labels_inside_df } are already existed in the dataframe" )
339+ raise ValueError
325340 pred_df = pd .concat ([dataset , pd .DataFrame (y_pred , columns = label_columns )], axis = 1 )
326- context .log_dataset ("prediction" , pred_df )
341+ context .log_dataset (artifact_name , pred_df , db_key = result_set )
0 commit comments