-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathsample_guided_rollouts.py
More file actions
76 lines (61 loc) · 3.25 KB
/
sample_guided_rollouts.py
File metadata and controls
76 lines (61 loc) · 3.25 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import argparse, pickle, os
from tqdm import tqdm
from datasets import load_dataset
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer
def apply_template(problem, solution, tokenizer):
messages = [
{"role": "user", "content": problem},
{"role": "assistant", "content": solution}
]
return tokenizer.apply_chat_template(messages, add_generation_prompt=False, continue_final_message=True, tokenize=False)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--input_dataset_name", type=str)
parser.add_argument("--subset", type=str, default=None)
parser.add_argument("--dataset_split", type=str)
parser.add_argument("--dataset_start", type=int)
parser.add_argument("--dataset_end", type=int)
parser.add_argument("--output_path", type=str)
parser.add_argument("-K", type=int)
parser.add_argument("--model", type=str)
parser.add_argument("--temperature", type=float, default=0.6)
parser.add_argument("--top_p", type=float, default=0.95)
parser.add_argument("--top_k", type=int, default=20)
parser.add_argument("--max_tokens", type=int, default=4096)
parser.add_argument("--batch_size", type=int, default=128)
parser.add_argument("--tensor_parallel_size", type=int, default=1)
args = parser.parse_args()
print(args.dataset_start, args.dataset_end)
print(args.model)
print(args.output_path)
print("Top K:", args.top_k)
os.makedirs(args.output_path, exist_ok=True)
dataset = load_dataset(args.input_dataset_name, split=args.dataset_split, data_dir=args.subset)
print(f"Filtered dataset to {len(dataset)} examples.")
args.dataset_end = min(args.dataset_end, len(dataset))
if args.dataset_start >= len(dataset):
raise ValueError(f"dataset_start {args.dataset_start} must be less than the length of the dataset {len(dataset)}")
elif args.dataset_start >= args.dataset_end:
raise ValueError(f"dataset_start {args.dataset_start} must be less than dataset_end {args.dataset_end}")
llm = LLM(model=args.model, tensor_parallel_size=args.tensor_parallel_size, enable_prefix_caching=True, max_model_len=32768)
tokenizer = AutoTokenizer.from_pretrained(args.model)
solutions = {}
for i in tqdm(range(args.dataset_start, args.dataset_end, args.batch_size)):
batch_problems = dataset[i : min(i + args.batch_size, args.dataset_end)]['problem']
batch_student_solutions = dataset[i : min(i + args.batch_size, args.dataset_end)]['intervention_guided_attempt']
convs = [apply_template(problem, solution, tokenizer) for problem, solution in zip(batch_problems, batch_student_solutions)]
completions = llm.generate(
prompts=convs,
sampling_params=SamplingParams(
n=args.K,
temperature=args.temperature,
max_tokens=args.max_tokens,
top_p=args.top_p,
top_k=args.top_k,
),
)
for j, completion in enumerate(completions):
solutions[i + j] = completion
with open(os.path.join(args.output_path, f'guided_pass_at_{args.K}_{args.dataset_start}_{args.dataset_end}_{args.dataset_split}.pkl'), 'wb') as f:
pickle.dump(solutions, f)