11"""Main STaR Loop"""
22
3+ from copy import deepcopy
34from datasets import Dataset , DatasetDict , load_dataset
45from inference import generate_predictions
56from train import train
@@ -21,13 +22,13 @@ def main():
2122 ds [split ] = ds [split ].add_column (name = "text" , column = texts )
2223
2324 model_name = args .model_name_or_path
24- ds [ "train" ] = ds [ "train" ]. select ( range ( 10 ) )
25+ output_dir = deepcopy ( args . output_dir )
2526 for i in range (args .iteration ):
2627 # sample
2728 all_samples = generate_predictions (
2829 model_name , ds ["train" ], args .temperature , args .n
2930 )
30- ds ["train" ].add_column (name = "sample" , column = all_samples ).to_json (f"{ args . output_dir } /data/samples-iter{ i } .json" )
31+ ds ["train" ].add_column (name = "sample" , column = all_samples ).to_json (f"{ output_dir } /data/samples-iter{ i } .json" )
3132 assert len (ds ["train" ]) == len (all_samples )
3233
3334 # verify and construct the training set
@@ -43,10 +44,10 @@ def main():
4344 passed_examples .append (example )
4445 break
4546 raw_datasets = DatasetDict ({"train" : Dataset .from_list (passed_examples ), "validation" : ds ["validation" ]})
46- raw_datasets ["train" ].to_json (f"{ args . output_dir } /data/verified-samples-iter{ i } .json" )
47+ raw_datasets ["train" ].to_json (f"{ output_dir } /data/verified-samples-iter{ i } .json" )
4748
4849 # train
49- args .output_dir = f"{ args . output_dir } /models-iter{ i } "
50+ args .output_dir = f"{ output_dir } /models-iter{ i } "
5051 train (raw_datasets , model_name , args )
5152 model_name = args .output_dir
5253
0 commit comments