Skip to content

Commit

Permalink
train_test_split
Browse files Browse the repository at this point in the history
  • Loading branch information
jitingxu1 committed Jun 27, 2024
1 parent 3c2cc32 commit acdcf48
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 0 deletions.
1 change: 1 addition & 0 deletions ibis_ml/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
)
from ibis_ml.steps import *
from ibis_ml.utils._pprint import _pprint_recipe, _pprint_step, _safe_repr
from ibis_ml.utils._train_test_split import train_test_split

# Add support for `Recipe`s and `Step`s to the built-in `PrettyPrinter`.
pprint.PrettyPrinter._dispatch[Recipe.__repr__] = _pprint_recipe # noqa: SLF001
Expand Down
37 changes: 37 additions & 0 deletions ibis_ml/utils/_train_test_split.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import random

import ibis
import ibis.expr.types as ir
from ibis import _


def train_test_split(
table: ir.Table,
unique_key: str | list[str],
test_size: float = 0.25,
random_state=42,
) -> tuple[ir.Table, ir.Table]:

if not (0 < test_size < 1):
raise ValueError("test size should be a float between 0 and 1.")

# Set the random seed for reproducibility
random.seed(random_state)
# Generate a random 256-bit key
random_key = str(random.getrandbits(256))
# set the number of buckets
num_buckets = 100000

if isinstance(unique_key, str):
unique_key = [unique_key]

table = table.mutate(
combined_key=ibis.literal("").join(table[col].cast("str") for col in unique_key)
).mutate(
train=(_.combined_key + random_key).hash().abs() % num_buckets
< int((1 - test_size) * num_buckets)
)

return table[table.train].drop(["combined_key"]), table[~table.train].drop(
["combined_key"]
)
32 changes: 32 additions & 0 deletions tests/test_train_test_split.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import ibis
import pandas.testing as tm

import ibis_ml as ml


def test_train_test_split():
N = 100
test_size = 0.25
table = ibis.memtable({"key1": range(N)})

train_table, test_table = ml.train_test_split(
table, unique_key="key1", test_size=test_size, random_state=42
)

# Check counts and overlaps in train and test dataset
assert train_table.count().execute() + test_table.count().execute() == N
assert train_table.intersect(test_table).count().execute() == 0

# Check reproducibility
reproduced_train_table, reproduced_test_table = ml.train_test_split(
table, unique_key="key1", test_size=test_size, random_state=42
)
tm.assert_frame_equal(train_table.execute(), reproduced_train_table.execute())
tm.assert_frame_equal(test_table.execute(), reproduced_test_table.execute())

# make sure it could generate different data with different random state
different_train_table, different_test_table = ml.train_test_split(
table, unique_key="key1", test_size=test_size, random_state=0
)
assert not train_table.execute().equals(different_train_table.execute())
assert not test_table.execute().equals(different_test_table.execute())

0 comments on commit acdcf48

Please sign in to comment.