Skip to content

Commit aa7b451

Browse files
committed
fix for single process
1 parent 2d96e47 commit aa7b451

File tree

12 files changed

+140
-26
lines changed

12 files changed

+140
-26
lines changed

eval/chat_benchmarks/HumanEval/eval_instruct.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,8 @@ def generate_responses(self, model: LM) -> Dict[str, Any]:
112112
self.logger.info("Generating responses for Human Eval...")
113113
outputs = self.compute(model, all_instances)
114114

115-
if model.accelerator.process_index != 0:
115+
is_main_process = lm.accelerator.process_index == 0 if hasattr(lm, 'accelerator') else lm.world_size <= 1
116+
if not is_main_process:
116117
continue
117118

118119
generated_examples = []

eval/chat_benchmarks/IFEval/eval_instruct.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,8 @@ def generate_responses(self, model: LM) -> Dict[str, Any]:
115115
self.logger.info("Generating responses...")
116116
outputs = self.compute(model, all_instances)
117117

118-
if model.accelerator.process_index != 0:
118+
is_main_process = model.accelerator.process_index == 0 if hasattr(model, 'accelerator') else model.world_size <= 1
119+
if not is_main_process:
119120
return None
120121

121122
generated_examples = []

eval/chat_benchmarks/MBPP/eval_instruct.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,8 @@ def generate_responses(self, model: LM) -> Dict[str, Any]:
161161
outputs = self.compute(model, all_instances)
162162

163163
# Return None early for non-primary ranks
164-
if model.accelerator.process_index != 0:
164+
is_main_process = model.accelerator.process_index == 0 if hasattr(model, 'accelerator') else model.world_size <= 1
165+
if not is_main_process:
165166
return None
166167

167168
generated_examples = []

eval/chat_benchmarks/MTBench/eval_instruct.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,8 @@ def get_model_answers(self, model: LM, model_id: str, questions: List[Dict[str,
151151
all_convs[q_idx].append({"role": "assistant", "content": output})
152152
all_choices[q_idx]["turns"].append(output)
153153

154-
if model.accelerator.process_index != 0:
154+
is_main_process = model.accelerator.process_index == 0 if hasattr(model, 'accelerator') else model.world_size <= 1
155+
if not is_main_process:
155156
continue
156157

157158
# Save completed conversations

eval/chat_benchmarks/MixEval/eval_instruct.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -132,18 +132,20 @@ def generate_responses(self, model: LM) -> Dict[str, Any]:
132132
out_dict = {}
133133

134134
self.logger.info("Generating responses for MixEval...")
135+
is_main_process = model.accelerator.process_index == 0 if hasattr(model, 'accelerator') else model.world_size <= 1
136+
135137
for split in splits:
136138
self.args.split = split
137139
all_results = self._eval_split(model, split)
138-
if model.accelerator.process_index == 0:
140+
if is_main_process:
139141
response_file = self._get_response_file()
140142
with open(response_file, "w") as f:
141143
for result in all_results:
142144
f.write(json.dumps(result) + "\n")
143145
out_dict[split] = all_results
144146

145147
# Only return results on rank 0
146-
if model.world_size > 1 and model.accelerator.process_index != 0:
148+
if not is_main_process:
147149
return None
148150
return out_dict
149151

@@ -192,7 +194,8 @@ def _eval_split(self, model: LM, split: str) -> List[Dict[str, Any]]:
192194
for idx in list(range(len(eval_dataset.raw_inputs))):
193195
eval_dataset.raw_inputs[idx]["response"] = all_responses[idx]
194196

195-
if model.accelerator.process_index == 0:
197+
is_main_process = model.accelerator.process_index == 0 if hasattr(model, 'accelerator') else model.world_size <= 1
198+
if is_main_process:
196199
with open(response_file, "w") as f:
197200
for item in eval_dataset.raw_inputs:
198201
json_line = json.dumps(item)
@@ -243,7 +246,8 @@ def run_benchmark(self, model: LM) -> Dict[str, Any]:
243246
generation_results = self.generate_responses(model)
244247

245248
# Only evaluate on rank 0
246-
if model.world_size > 1 and model.accelerator.process_index != 0:
249+
is_main_process = model.accelerator.process_index == 0 if hasattr(model, 'accelerator') else model.world_size <= 1
250+
if not is_main_process:
247251
return None
248252

249253
evaluation_results = self.evaluate_responses(generation_results)

eval/chat_benchmarks/RepoBench/eval_instruct.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@ def generate_responses(self, model: LM) -> Dict[str, Any]:
5959
if self.legacy_mode:
6060
return self._generate_responses_legacy(model)
6161

62-
if model.accelerator.process_index == 0:
62+
is_main_process = model.accelerator.process_index == 0 if hasattr(model, 'accelerator') else model.world_size <= 1
63+
if is_main_process:
6364
temp_dir_obj = tempfile.TemporaryDirectory()
6465
temp_dir = temp_dir_obj.name
6566

@@ -77,10 +78,13 @@ def generate_responses(self, model: LM) -> Dict[str, Any]:
7778
all_instances = []
7879
# Split dataset across ranks for parallel construction
7980
# Get subset of dataset for this rank using the same slicing strategy as the compute function
80-
chunk_size = len(dataset) // model.world_size
81-
start = model.accelerator.process_index * chunk_size
82-
end = start + chunk_size if model.accelerator.process_index < model.world_size - 1 else len(dataset)
83-
rank_dataset = dataset.select(range(start, end))
81+
if hasattr(model, 'accelerator'):
82+
chunk_size = len(dataset) // model.world_size
83+
start = model.accelerator.process_index * chunk_size
84+
end = start + chunk_size if model.accelerator.process_index < model.world_size - 1 else len(dataset)
85+
rank_dataset = dataset.select(range(start, end))
86+
else:
87+
rank_dataset = list(islice(dataset, model.rank, len(dataset), model.world_size))
8488

8589
# Process examples for this rank's shard
8690
for idx, example in enumerate(rank_dataset):
@@ -103,7 +107,8 @@ def generate_responses(self, model: LM) -> Dict[str, Any]:
103107
outputs = self.compute(model, all_instances, do_slice=False)
104108

105109
# Only rank 0 should save the results
106-
if model.accelerator.process_indexlerator.process_index != 0:
110+
is_main_process = model.accelerator.process_index == 0 if hasattr(model, 'accelerator') else model.world_size <= 1
111+
if not is_main_process:
107112
continue
108113

109114
generated_examples = []
@@ -121,7 +126,7 @@ def generate_responses(self, model: LM) -> Dict[str, Any]:
121126
for ex in generated_examples:
122127
fw.write(json.dumps(ex) + "\n")
123128

124-
if model.accelerator.process_index == 0:
129+
if is_main_process:
125130
return {"temp_dir_obj": temp_dir_obj}
126131

127132
def _generate_responses_legacy(self, model: LM) -> Dict[str, Any]:

eval/chat_benchmarks/WildBench/eval_instruct.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,8 @@ def generate_responses(self, model: LM) -> Dict[str, Any]:
196196
outputs = self.compute(model, all_instances)
197197

198198
# Return None early for non-primary ranks
199-
if model.accelerator.process_index != 0:
199+
is_main_process = model.accelerator.process_index == 0 if hasattr(model, 'accelerator') else model.world_size <= 1
200+
if not is_main_process:
200201
return None
201202

202203
outputs = [[output] for output in outputs]

eval/chat_benchmarks/alpaca_eval/eval_instruct.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,8 @@ def generate_responses(self, model: LM) -> Dict[str, Any]:
117117
self.logger.info("Generating responses for Alpaca Eval...")
118118
outputs = self.compute(model, all_instances)
119119

120-
if model.accelerator.process_index != 0:
120+
is_main_process = model.accelerator.process_index == 0 if hasattr(model, 'accelerator') else model.world_size <= 1
121+
if not is_main_process:
121122
return None
122123

123124
model_outputs = []

eval/chat_benchmarks/alpaca_eval/src/alpaca_eval/leaderboards/data_AlpacaEval_2/weighted_alpaca_eval_gpt4_turbo_leaderboard.csv

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
,win_rate,standard_error,n_wins,n_wins_base,n_draws,n_total,discrete_win_rate,mode,avg_length,length_controlled_winrate,lc_standard_error
2-
Shopee-SlimMoA-v1,75.6142865980535,1.27062740591947,621,184,0,805,77.1428571428572,community,1994,77.4515432873834,0.430175221492396
3-
blendaxai-gm-l6-vo31,69.11033492869565,1.3280735654354865,562,242,1,805,69.87577639751554,community,1809,76.91981221023656,0.5725365663132986
2+
Shopee-SlimMoA-v1,75.61428659805350,1.2706274059194700,621,184,0,805,77.14285714285720,community,1994,77.4515432873834,0.43017522149239600
3+
blendaxai-gm-l6-vo31,69.11033492869565,1.3280735654354863,562,242,1,805,69.87577639751554,community,1809,76.91981221023656,0.5725365663132986
44
gemma-2-9b-it-WPO-HB,77.82503168985093,1.2355857177790277,640,163,2,805,79.62732919254658,community,2285,76.72506842726064,0.4242603928637889
55
blendaxai-gm-l3-v35,73.41035740244067,1.254951147343878,607,196,2,805,75.527950310559,community,2186,73.37270365010379,0.6163911450738288
66
gemma-2-9b-it-SimPO,65.86422561532919,1.423459922555078,540,264,1,805,67.14285714285714,community,1833,72.3508446939842,0.5167873784867067
7-
model_hf_model_args_pretrained=mlfoundations-dev__gemma-simpo-reproduction,67.35102937013792,1.4210070002869848,557,247,1,805,69.25465838509317,community,1950,71.18995900084634,0.5756949353655318
87
openpipe-moa-gpt-4-turbo-v1,63.15493451236265,1.422980098799326,515,283,7,805,64.40993788819875,community,1856,68.37866250336802,0.7309418614587613
98
gemma-2-9b-it-DPO,65.35922380122982,1.402802336467638,536,268,1,805,66.64596273291924,community,2016,67.6620382198043,0.6605613085864308
109
Together-MoA,59.8688062333292,1.434305604543079,490,314,1,805,60.93167701863354,community,1825,65.37996976852163,0.7392392836781445
@@ -23,7 +22,7 @@ gpt4_1106_preview_verbose,64.30360147101865,1.3348590089025316,525,268,12,805,65
2322
gpt-4o-mini-2024-07-18,44.65413862507926,1.4572395578449813,350,451,4,805,43.72670807453416,minimal,1861,50.727144855901976,0.8284734951761676
2423
Storm-7B,50.26886905528583,1.4728176780737183,397,408,0,805,49.31677018633541,community,2045,50.45110959343775,
2524
gpt4_1106_preview,50.0,0.0,0,0,805,805,50.0,minimal,2049,50.0,
26-
REBEL-Llama-3-8B-Instruct-Armo,48.43655307668638,1.480341435123528,394,410,1,805,49.00621118012423,community,1965,49.31429353685712,0.7061879308002301
25+
REBEL-Llama-3-8B-Instruct-Armo,48.43655307668638,1.480341435123528,394,410,1,805,49.006211180124225,community,1965,49.314293536857114,0.7061879308002301
2726
Infinity-Instruct-7M-Gen-Llama3_1-70B,37.46327383827497,1.4734130373862548,299,501,5,805,37.453416149068325,community,1654,46.10043331712677,0.822439983375277
2827
Llama-3-Instruct-8B-SimPO-ExPO,40.63285400856655,1.4439449942168028,325,479,1,805,40.43478260869565,community,1765,45.78021783946177,
2928
Llama-3-Instruct-8B-SimPO,40.52977498461182,1.422574464675002,319,485,1,805,39.68944099378882,community,1825,44.65131348921881,0.8800655791760451
@@ -209,5 +208,4 @@ oasst-sft-pythia-12b,1.790114083180124,0.3985580883049341,13,790,2,805,1.7391304
209208
guanaco-13b,3.469596859739131,0.5518606725700214,22,780,3,805,2.919254658385093,verified,1774,3.003787329611614,
210209
guanaco-7b,2.880002266173913,0.5202924149314048,21,783,1,805,2.670807453416149,verified,1364,2.871116813131697,
211210
Qwen1.5-1.8B-Chat,3.70555681579365,0.5811750995496215,27,774,3,804,3.544776119402985,verified,2673,2.588498849185137,
212-
baichuan-13b-chat,1.9921455615279504,0.4176985079331233,14,790,1,805,1.8012422360248446,community,1727,2.062170253598568,
213-
model_hf_model_args_pretrained=mlfoundations-dev__gemma-oh-preferences,0.005260368511326853,0.0018774672393365112,0,805,0,805,0.0,community,196,0.010252829751292214,0.0007495965900756891
211+
baichuan-13b-chat,1.9921455615279504,0.4176985079331233,14,790,1,805,1.8012422360248446,community,1727,2.062170253598568,

eval/chat_benchmarks/alpaca_eval/src/alpaca_eval/metrics/weights/weighted_alpaca_eval_gpt4_turbo/length_controlled_v1/baseline_gpt4_1106_preview.csv

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,4 +187,4 @@ Shopee-SlimMoA-v1,-0.6930943742294789,0.5778443790027642,1.4506276222723822
187187
blendaxai-gm-l6-vo31,-1.4827230167114802,0.8256378421072179,1.5942312525409852
188188
REBEL-Llama-3-8B-Instruct-Armo,-1.0427168605260002,0.6464073051877255,0.0395191056877229
189189
model_hf_model_args_pretrained=mlfoundations-dev__gemma-simpo-reproduction,-1.1818376919023723,0.6835318362039150,1.1479555832649320
190-
model_hf_model_args_pretrained=mlfoundations-dev__gemma-oh-preferences,-1.8345282763259563,0.7434213717748921,-9.8937244442602008
190+
model_hf_model_args_pretrained=mlfoundations-dev__gemma-oh-preferences,-1.8345282763259563,0.7434213717748921,-9.8937244442602008

0 commit comments

Comments
 (0)