File tree Expand file tree Collapse file tree 4 files changed +17
-13
lines changed Expand file tree Collapse file tree 4 files changed +17
-13
lines changed Original file line number Diff line number Diff line change @@ -257,18 +257,19 @@ def test(
257257 if reference :
258258 branch = "reference"
259259 else :
260- if "humaneval" not in commit0_config ["dataset_name" ].split ("/" )[- 1 ].lower ():
260+ dataset_name = commit0_config ["dataset_name" ].lower ()
261+ if "humaneval" in dataset_name or "mbpp" in dataset_name or "bigcodebench" in dataset_name or "codecontests" in dataset_name :
262+ branch = repo_or_repo_path
263+ else :
261264 if branch is None and not reference :
262265 git_path = os .path .join (
263266 commit0_config ["base_dir" ], repo_or_repo_path .split ("/" )[- 1 ]
264267 )
265268 branch = get_active_branch (git_path )
266- else :
267- branch = test_ids
268269
269270 if stdin :
270271 # Read test names from stdin
271- test_ids = sys .stdin .read (). strip ()
272+ test_ids = sys .stdin .read ()
272273 elif test_ids is None :
273274 typer .echo ("Error: test_ids must be provided or use --stdin option" , err = True )
274275 raise typer .Exit (code = 1 )
Original file line number Diff line number Diff line change @@ -25,14 +25,15 @@ def main(
2525 dataset_name , split = dataset_split
2626 ) # type: ignore
2727 specs = []
28- if "swe" in dataset_name .lower ():
28+ dataset_name = dataset_name .lower ()
29+ if "swe" in dataset_name :
2930 dataset_type = "swebench"
30- elif "humaneval" in dataset_name . lower () :
31+ elif "humaneval" in dataset_name or "mbpp" in dataset_name or "bigcodebench" in dataset_name or "codecontests" in dataset_name :
3132 dataset_type = "simple"
3233 else :
3334 dataset_type = "commit0"
3435 for example in dataset :
35- if "swe" in dataset_name . lower () or dataset_type == "simple" :
36+ if "swe" in dataset_name or dataset_type == "simple" :
3637 if split != "all" and split not in example ["instance_id" ]:
3738 continue
3839 else :
Original file line number Diff line number Diff line change @@ -51,17 +51,18 @@ def main(
5151 dataset : Iterator [Union [RepoInstance , SimpleInstance ]] = load_dataset (
5252 dataset_name , split = dataset_split
5353 ) # type: ignore
54+ dataset_name = dataset_name .lower ()
5455 spec = None
5556 example = None
5657 repo_name = None
5758 dataset_type = None
5859 for example in dataset :
5960 if repo_or_repo_dir .endswith ("/" ):
6061 repo_or_repo_dir = repo_or_repo_dir [:- 1 ]
61- if "swe" in dataset_name . lower () :
62+ if "swe" in dataset_name :
6263 repo_name = example ["instance_id" ]
6364 dataset_type = "swebench"
64- elif "humaneval" in dataset_name . lower () :
65+ elif "humaneval" in dataset_name or "mbpp" in dataset_name or "bigcodebench" in dataset_name or "codecontests" in dataset_name :
6566 repo_name = example ["instance_id" ]
6667 dataset_type = "simple"
6768 else :
@@ -130,7 +131,7 @@ def main(
130131 )
131132
132133 # make patch file
133- if "swe" in dataset_name . lower () :
134+ if "swe" in dataset_name :
134135 if branch == "reference" :
135136 patch = (
136137 example ["test" ]["patch" ] + "\n \n " + example ["test" ]["test_patch" ]
@@ -164,7 +165,7 @@ def main(
164165 + example ["test" ]
165166 )
166167 else :
167- solution = open ( test_ids ). read ()
168+ solution = test_ids
168169 prompt = example ["prompt" ] if "prompt" in example .keys () else ""
169170 matches = extract_code_blocks (solution )
170171 if len (matches ) > 0 :
Original file line number Diff line number Diff line change @@ -23,12 +23,13 @@ def main(
2323 base_dir : str ,
2424) -> None :
2525 dataset : Iterator [RepoInstance ] = load_dataset (dataset_name , split = dataset_split ) # type: ignore
26- if "humaneval" in dataset_name .lower ():
26+ dataset_name = dataset_name .lower ()
27+ if "humaneval" in dataset_name or "mbpp" in dataset_name or "bigcodebench" in dataset_name or "codecontests" in dataset_name :
2728 return
2829 for example in dataset :
2930 repo_name = example ["repo" ].split ("/" )[- 1 ]
3031 clone_url = f"https://github.com/{ example ['repo' ]} .git"
31- if "swe" in dataset_name . lower () :
32+ if "swe" in dataset_name :
3233 if repo_split != "all" and repo_split not in example ["instance_id" ]:
3334 continue
3435 clone_dir = os .path .abspath (os .path .join (base_dir , example ["instance_id" ]))
You can’t perform that action at this time.
0 commit comments