From 5e78804d78b4f0344a51ec6d7b22aa2bfa2d9ce7 Mon Sep 17 00:00:00 2001 From: jitingxu1 Date: Tue, 17 Sep 2024 15:40:43 -0700 Subject: [PATCH] change error to warning for fill col with None or nan --- ibis_ml/steps/_impute.py | 10 ++++++++-- tests/test_impute.py | 15 +++++++++++---- 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/ibis_ml/steps/_impute.py b/ibis_ml/steps/_impute.py index 7d0bfbf..033b85d 100644 --- a/ibis_ml/steps/_impute.py +++ b/ibis_ml/steps/_impute.py @@ -10,13 +10,19 @@ if TYPE_CHECKING: from collections.abc import Iterable +import warnings _DOCS_PAGE_NAME = "imputation" def _fillna(col, val): - if val is None or col.type().is_floating() and math.isnan(val): - raise ValueError(f"Cannot fill column {col.get_name()!r} with `None` or `NaN`") + if val is None or (col.type().is_numeric() and math.isnan(val)): + warnings.warn( + "Imputation requires at least one non-missing value in " + f"column {col.get_name()!r}", + UserWarning, + stacklevel=2, + ) if col.type().is_floating(): return (col.isnull() | col.isnan()).ifelse(val, col) # noqa: PD003 else: diff --git a/tests/test_impute.py b/tests/test_impute.py index 96b96fe..33dec8d 100644 --- a/tests/test_impute.py +++ b/tests/test_impute.py @@ -1,3 +1,5 @@ +import math + import ibis import numpy as np import pandas as pd @@ -50,10 +52,15 @@ def test_fillna(train_table): expected = pd.DataFrame({"floating_col": [0]}) tm.assert_frame_equal(result.execute(), expected, check_dtype=False) - # test _fillna with None - step = ml.FillNA("floating_col", None) + +@pytest.mark.parametrize("val", [None, math.nan]) +def test_fillna_with_none(train_table, val): + step = ml.FillNA("floating_col", val) step.fit_table(train_table, ml.core.Metadata()) - with pytest.raises( - ValueError, match="Cannot fill column 'floating_col' with `None` or `NaN`" + test_table = ibis.memtable({"floating_col": [1.0, None]}) + with pytest.warns( + UserWarning, + match="Imputation requires at least one non-missing value in " + "column 'floating_col'", ): step.transform_table(test_table)