Skip to content

Commit 6549106

Browse files
authored
Merge pull request #461 from yonishelach/auto_trainer
[Auto-trainer] Update function according to `apply_mlrun`
2 parents 601cbc7 + 376ff46 commit 6549106

File tree

4 files changed

+57
-34
lines changed

4 files changed

+57
-34
lines changed

auto_trainer/auto_trainer.py

100644100755
+34-19
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,6 @@ class KWArgsPrefixes:
2222
PREDICT = "PREDICT_"
2323

2424

25-
def _more_than_one(arg_list: List) -> bool:
26-
return len([1 for arg in arg_list if arg is not None]) > 1
27-
28-
2925
def _get_sub_dict_by_prefix(src: Dict, prefix_key: str) -> Dict[str, Any]:
3026
"""
3127
Collect all the keys from the given dict that starts with the given prefix and creates a new dictionary with these
@@ -164,9 +160,6 @@ def train(
164160
drop_columns=drop_columns,
165161
)
166162

167-
# Remove labels from sample set:
168-
sample_set = sample_set.drop(label_columns, axis=1, errors="ignore")
169-
170163
# Parsing kwargs:
171164
# TODO: Use in xgb or lgbm train function.
172165
train_kwargs = _get_sub_dict_by_prefix(
@@ -209,9 +202,9 @@ def train(
209202
context=context,
210203
tag=tag,
211204
sample_set=sample_set,
212-
y_column=label_columns,
205+
y_columns=label_columns,
213206
test_set=test_set,
214-
X_test=x_test,
207+
x_test=x_test,
215208
y_test=y_test,
216209
artifacts=context.artifacts,
217210
)
@@ -277,18 +270,22 @@ def predict(
277270
dataset: mlrun.DataItem,
278271
drop_columns: Union[str, List[str], int, List[int]] = None,
279272
label_columns: Optional[Union[str, List[str]]] = None,
273+
result_set: Optional[str] = None,
280274
):
281275
"""
282276
Predicting dataset by a model.
283277
284278
:param context: MLRun context.
285279
:param model: The model Store path.
286-
:param dataset: The dataset to evaluate the model on. Can be either a URI, a FeatureVector or a
280+
:param dataset: The dataset to predict the model on. Can be either a URI, a FeatureVector or a
287281
sample in a shape of a list/dict.
282+
When passing a sample, pass the dataset as a field in `params` instead of `inputs`.
288283
:param drop_columns: str/int or a list of strings/ints that represent the column names/indices to drop.
289284
When the dataset is a list/dict this parameter should be represented by integers.
290285
:param label_columns: The target label(s) of the column(s) in the dataset. for Regression or
291286
Classification tasks.
287+
:param result_set: The db key to set name of the prediction result and the filename.
288+
Default to 'prediction'.
292289
"""
293290
# Get dataset by URL or by FeatureVector:
294291
dataset, label_columns = _get_dataframe(
@@ -307,20 +304,38 @@ def predict(
307304
model_handler = AutoMLRun.load_model(model_path=model, context=context)
308305

309306
# Dropping label columns if necessary:
310-
if label_columns and all(label in dataset.columns for label in label_columns):
311-
dataset = dataset.drop(label_columns, axis=1)
307+
if not label_columns:
308+
label_columns = []
309+
elif isinstance(label_columns, str):
310+
label_columns = [label_columns]
312311

313312
# Predicting:
314313
context.logger.info(f"making prediction by '{model_handler.model_name}'")
315314
y_pred = model_handler.model.predict(dataset, **predict_kwargs)
316315

317-
if not label_columns:
318-
if len(y_pred.shape) == 1 or y_pred.shape[1] == 1:
319-
label_columns = ["predicted_labels"]
316+
# Preparing and validating label columns for the dataframe of the prediction result:
317+
num_predicted = 1 if len(y_pred.shape) == 1 else y_pred.shape[1]
318+
319+
if num_predicted > len(label_columns):
320+
if num_predicted == 1:
321+
label_columns = ["predicted labels"]
320322
else:
321-
label_columns = [f"predicted_label_{i}" for i in range(y_pred.shape[1])]
322-
elif isinstance(label_columns, str):
323-
label_columns = [label_columns]
323+
label_columns.extend(
324+
[
325+
f"predicted_label_{i + 1 + len(label_columns)}"
326+
for i in range(num_predicted - len(label_columns))
327+
]
328+
)
329+
elif num_predicted < len(label_columns):
330+
context.logger.error(
331+
f"number of predicted labels: {num_predicted} is smaller than number of label columns: {len(label_columns)}"
332+
)
333+
raise ValueError
324334

335+
artifact_name = result_set or 'prediction'
336+
labels_inside_df = set(label_columns) & set(dataset.columns.tolist())
337+
if labels_inside_df:
338+
context.logger.error(f"The labels: {labels_inside_df} are already existed in the dataframe")
339+
raise ValueError
325340
pred_df = pd.concat([dataset, pd.DataFrame(y_pred, columns=label_columns)], axis=1)
326-
context.log_dataset("prediction", pred_df)
341+
context.log_dataset(artifact_name, pred_df, db_key=result_set)

0 commit comments

Comments
 (0)