@@ -51,17 +51,23 @@ def main(
5151 dataset : Iterator [Union [RepoInstance , SimpleInstance ]] = load_dataset (
5252 dataset_name , split = dataset_split
5353 ) # type: ignore
54+ dataset_name = dataset_name .lower ()
5455 spec = None
5556 example = None
5657 repo_name = None
5758 dataset_type = None
5859 for example in dataset :
5960 if repo_or_repo_dir .endswith ("/" ):
6061 repo_or_repo_dir = repo_or_repo_dir [:- 1 ]
61- if "swe" in dataset_name . lower () :
62+ if "swe" in dataset_name :
6263 repo_name = example ["instance_id" ]
6364 dataset_type = "swebench"
64- elif "humaneval" in dataset_name .lower ():
65+ elif (
66+ "humaneval" in dataset_name
67+ or "mbpp" in dataset_name
68+ or "bigcodebench" in dataset_name
69+ or "codecontests" in dataset_name
70+ ):
6571 repo_name = example ["instance_id" ]
6672 dataset_type = "simple"
6773 else :
@@ -130,7 +136,7 @@ def main(
130136 )
131137
132138 # make patch file
133- if "swe" in dataset_name . lower () :
139+ if "swe" in dataset_name :
134140 if branch == "reference" :
135141 patch = (
136142 example ["test" ]["patch" ] + "\n \n " + example ["test" ]["test_patch" ]
@@ -164,7 +170,7 @@ def main(
164170 + example ["test" ]
165171 )
166172 else :
167- solution = open ( test_ids ). read ()
173+ solution = test_ids
168174 prompt = example ["prompt" ] if "prompt" in example .keys () else ""
169175 matches = extract_code_blocks (solution )
170176 if len (matches ) > 0 :
0 commit comments