Skip to content

Commit 376ff46

Browse files
Yonatan ShelachYonatan Shelach
authored andcommitted
added option to change the artifact name and the filename
1 parent f01e5c0 commit 376ff46

File tree

3 files changed

+19
-13
lines changed

3 files changed

+19
-13
lines changed

auto_trainer/auto_trainer.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ def predict(
270270
dataset: mlrun.DataItem,
271271
drop_columns: Union[str, List[str], int, List[int]] = None,
272272
label_columns: Optional[Union[str, List[str]]] = None,
273-
dataset_name: Optional[str] = None,
273+
result_set: Optional[str] = None,
274274
):
275275
"""
276276
Predicting dataset by a model.
@@ -279,11 +279,13 @@ def predict(
279279
:param model: The model Store path.
280280
:param dataset: The dataset to predict the model on. Can be either a URI, a FeatureVector or a
281281
sample in a shape of a list/dict.
282+
When passing a sample, pass the dataset as a field in `params` instead of `inputs`.
282283
:param drop_columns: str/int or a list of strings/ints that represent the column names/indices to drop.
283284
When the dataset is a list/dict this parameter should be represented by integers.
284285
:param label_columns: The target label(s) of the column(s) in the dataset. for Regression or
285286
Classification tasks.
286-
:param dataset_name: The file name of the prediction result. Default to 'prediction'.
287+
:param result_set: The db key to set name of the prediction result and the filename.
288+
Default to 'prediction'.
287289
"""
288290
# Get dataset by URL or by FeatureVector:
289291
dataset, label_columns = _get_dataframe(
@@ -330,10 +332,10 @@ def predict(
330332
)
331333
raise ValueError
332334

333-
artifact_name = 'prediction'
335+
artifact_name = result_set or 'prediction'
334336
labels_inside_df = set(label_columns) & set(dataset.columns.tolist())
335337
if labels_inside_df:
336338
context.logger.error(f"The labels: {labels_inside_df} are already existed in the dataframe")
337339
raise ValueError
338340
pred_df = pd.concat([dataset, pd.DataFrame(y_pred, columns=label_columns)], axis=1)
339-
context.log_dataset(artifact_name, pred_df, db_key=dataset_name or artifact_name)
341+
context.log_dataset(artifact_name, pred_df, db_key=result_set)

0 commit comments

Comments
 (0)