Skip to content

Commit a45f4f4

Browse files
authored
fix: prepare_data — close double-split and missing sample_weight_all on non-shuffle splits (#1554)
1 parent c8f4763 commit a45f4f4

2 files changed

Lines changed: 24 additions & 1 deletion

File tree

flaml/automl/task/generic_task.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -880,6 +880,7 @@ def prepare_data(
880880
y_train_all = np.concatenate([y_train_all, y_train_all[:n][rare_index]])
881881
count += rare_count
882882
logger.info(f"class {label} augmented from {rare_count} to {count}")
883+
state.sample_weight_all = sample_weight_full
883884
SHUFFLE_SPLIT_TYPES = ["uniform", "stratified"]
884885
if is_spark_dataframe:
885886
# no need to shuffle pyspark dataframe
@@ -952,7 +953,7 @@ def prepare_data(
952953
) = self._split_pyspark(state, X_train_all, y_train_all, split_ratio)
953954
else:
954955
X_train, X_val, y_train, y_val = self._split_pyspark(state, X_train_all, y_train_all, split_ratio)
955-
if split_type == "group":
956+
elif split_type == "group":
956957
gss = GroupShuffleSplit(n_splits=1, test_size=split_ratio, random_state=RANDOM_SEED)
957958
for train_idx, val_idx in gss.split(X_train_all, y_train_all, state.groups_all):
958959
if data_is_df:

test/automl/test_split.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,28 @@ def test_time():
5252
_test(split_type="time")
5353

5454

55+
def test_time_split_with_sample_weight():
56+
"""Regression for #887 and #1553: split_type='time' + sample_weight + holdout."""
57+
from sklearn.datasets import make_classification
58+
59+
X, y = make_classification(n_samples=200, n_features=10, n_informative=5, random_state=42)
60+
sample_weight = np.ones(len(y))
61+
automl = AutoML()
62+
automl.fit(
63+
X_train=X,
64+
y_train=y,
65+
sample_weight=sample_weight,
66+
split_type="time",
67+
retrain_full=True,
68+
eval_method="holdout",
69+
task="classification",
70+
time_budget=-1,
71+
max_iter=2,
72+
estimator_list=["lgbm"],
73+
)
74+
assert automl.model is not None
75+
76+
5577
def test_groups_for_classification_task():
5678
from sklearn.externals._arff import ArffException
5779

0 commit comments

Comments
 (0)