This repository has been archived by the owner on Feb 11, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 30
/
Copy pathsplit.py
76 lines (59 loc) · 1.96 KB
/
split.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from functools import partial
from pathlib import Path
import multiprocessing
import glob
import tqdm
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
def str2bool(v):
if isinstance(v, bool):
return v
if v.lower() in ("yes", "true", "t", "y", "1"):
return True
elif v.lower() in ("no", "false", "f", "n", "0"):
return False
else:
raise argparse.ArgumentTypeError("Boolean value expected.")
def save(df, path, print_fun=print):
df.to_csv(path, index=False)
print_fun(f"> df saved to {path}")
def check_save(df, path, **kwargs):
if path.exists():
usr = input(f"> ! csv already exists. overwrite {path}? [y/N] \n")
if usr.lower() != "y":
return
save(df, path, **kwargs)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("path", type=str, help="path to the dataset")
parser.add_argument(
"--split", default=0.05, type=float, help="percentage of valid and test split"
)
parser.add_argument(
"--seed", default=42, type=int, help="random seed used while splitting"
)
args = parser.parse_args()
path = Path(args.path)
p = path
df_path = p / "asr-dataset.csv"
path_train = p / "asr-dataset-train.csv"
path_valid = p / "asr-dataset-valid.csv"
path_test = p / "asr-dataset-test.csv"
# check if exists
if df_path.exists():
df = pd.read_csv(df_path)
print(f"> df loaded from {df_path}")
else:
raise Exception("asr-dataset.csv does not exist")
# first, train and non_train
train, non_train = train_test_split(
df, test_size=args.split * 2.0, random_state=args.seed
)
# then, valid and test
valid, test = train_test_split(non_train, test_size=0.5, random_state=args.seed)
# save
check_save(train, path_train)
check_save(valid, path_valid)
check_save(test, path_test)