Skip to content

Commit

Permalink
add unit test and fix null cols for impute
Browse files Browse the repository at this point in the history
  • Loading branch information
jitingxu1 committed Sep 13, 2024
1 parent 6582682 commit 87a79c7
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 0 deletions.
3 changes: 3 additions & 0 deletions ibis_ml/steps/_impute.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import math
from typing import TYPE_CHECKING, Any

import ibis.expr.types as ir
Expand All @@ -14,6 +15,8 @@


def _fillna(col, val):
if val is None or col.type().is_floating() and math.isnan(val):
raise ValueError(f"Cannot fill column {col.get_name()!r} with `None` or `NaN`")
if col.type().is_floating():
return (col.isnull() | col.isnan()).ifelse(val, col) # noqa: PD003
else:
Expand Down
56 changes: 56 additions & 0 deletions tests/test_impute.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import ibis
import numpy as np
import pandas as pd
import pandas.testing as tm
import pytest

import ibis_ml as ml


@pytest.mark.parametrize(
("mode", "col_name", "expected"),
[
("mean", "floating_col", 1.0),
("median", "floating_col", 0.0),
("mode", "floating_col", 0.0),
("mean", "int_col", 1),
("median", "int_col", 0),
("mode", "int_col", 0),
("mode", "string_col", "a"),
]
)
def test_impute(mode, col_name, expected):
mode_class = getattr(ml, f"Impute{mode.capitalize()}")
step = mode_class(col_name)
train_table = ibis.memtable(
{
"floating_col": [0.0, 0.0, 3.0, None, np.nan],
"int_col": [0, 0, 3, None, None],
"string_col": ["a", "a", "c", None, None],
"null_col": [None]*5,
}
)
test_table = ibis.memtable(
{
col_name: [None],
}
)
step.fit_table(train_table, ml.core.Metadata())
result = step.transform_table(test_table)
expected = pd.DataFrame(
{
col_name: [expected],
}
)
tm.assert_frame_equal(result.execute(), expected, check_dtype=False)

# null col will raise a ValueError
test_table = ibis.memtable(
{
"null_col": [None],
}
)
with pytest.raises(ValueError):
step = mode_class("null_col")
step.fit_table(train_table, ml.core.Metadata())
step.transform_table(test_table)

0 comments on commit 87a79c7

Please sign in to comment.