Skip to content

Commit

Permalink
fix(utils): fix possible name collision in train_test_split (ibis-p…
Browse files Browse the repository at this point in the history
…roject#142)

Co-authored-by: Deepyaman Datta <[email protected]>
  • Loading branch information
jitingxu1 and deepyaman authored Sep 13, 2024
1 parent a183f44 commit 6582682
Showing 1 changed file with 17 additions and 9 deletions.
26 changes: 17 additions & 9 deletions ibis_ml/utils/_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,20 +76,28 @@ def train_test_split(
random.seed(random_seed)

# Generate a random 256-bit key
random_key = str(random.getrandbits(256))
random_str = str(random.getrandbits(256))

if isinstance(unique_key, str):
unique_key = [unique_key]

# Append random string to the name to avoid collision
combined_key = f"combined_key_{random_str}"
train_flag = f"train_{random_str}"

table = table.mutate(
combined_key=ibis.literal(",").join(
table[col].cast("str") for col in unique_key
)
**{
combined_key: ibis.literal(",").join(
table[col].cast("str") for col in unique_key
)
}
).mutate(
train=(_.combined_key + random_key).hash().abs() % num_buckets
< int((1 - test_size) * num_buckets)
**{
train_flag: (_[combined_key] + random_str).hash().abs() % num_buckets
< int((1 - test_size) * num_buckets)
}
)

return table[table.train].drop(["combined_key", "train"]), table[~table.train].drop(
["combined_key", "train"]
)
return table[table[train_flag]].drop([combined_key, train_flag]), table[
~table[train_flag]
].drop([combined_key, train_flag])

0 comments on commit 6582682

Please sign in to comment.