@@ -270,7 +270,7 @@ def predict(
270
270
dataset : mlrun .DataItem ,
271
271
drop_columns : Union [str , List [str ], int , List [int ]] = None ,
272
272
label_columns : Optional [Union [str , List [str ]]] = None ,
273
- dataset_name : Optional [str ] = None ,
273
+ result_set : Optional [str ] = None ,
274
274
):
275
275
"""
276
276
Predicting dataset by a model.
@@ -279,11 +279,13 @@ def predict(
279
279
:param model: The model Store path.
280
280
:param dataset: The dataset to predict the model on. Can be either a URI, a FeatureVector or a
281
281
sample in a shape of a list/dict.
282
+ When passing a sample, pass the dataset as a field in `params` instead of `inputs`.
282
283
:param drop_columns: str/int or a list of strings/ints that represent the column names/indices to drop.
283
284
When the dataset is a list/dict this parameter should be represented by integers.
284
285
:param label_columns: The target label(s) of the column(s) in the dataset. for Regression or
285
286
Classification tasks.
286
- :param dataset_name: The file name of the prediction result. Default to 'prediction'.
287
+ :param result_set: The db key to set name of the prediction result and the filename.
288
+ Default to 'prediction'.
287
289
"""
288
290
# Get dataset by URL or by FeatureVector:
289
291
dataset , label_columns = _get_dataframe (
@@ -330,10 +332,10 @@ def predict(
330
332
)
331
333
raise ValueError
332
334
333
- artifact_name = 'prediction'
335
+ artifact_name = result_set or 'prediction'
334
336
labels_inside_df = set (label_columns ) & set (dataset .columns .tolist ())
335
337
if labels_inside_df :
336
338
context .logger .error (f"The labels: { labels_inside_df } are already existed in the dataframe" )
337
339
raise ValueError
338
340
pred_df = pd .concat ([dataset , pd .DataFrame (y_pred , columns = label_columns )], axis = 1 )
339
- context .log_dataset (artifact_name , pred_df , db_key = dataset_name or artifact_name )
341
+ context .log_dataset (artifact_name , pred_df , db_key = result_set )
0 commit comments