11import git
22import os
3+ import re
34import sys
45import traceback
56from datasets import load_dataset
1112 Files ,
1213 RUN_PYTEST_LOG_DIR ,
1314 RepoInstance ,
15+ SimpleInstance ,
1416)
1517from commit0 .harness .spec import make_spec
1618from commit0 .harness .utils import (
@@ -46,7 +48,7 @@ def main(
4648 Tests are run either locally through docker
4749 or remotely through Modal.
4850 """
49- dataset : Iterator [RepoInstance ] = load_dataset (dataset_name , split = dataset_split ) # type: ignore
51+ dataset : Iterator [Union [ RepoInstance , SimpleInstance ] ] = load_dataset (dataset_name , split = dataset_split ) # type: ignore
5052 spec = None
5153 example = None
5254 repo_name = None
@@ -56,10 +58,13 @@ def main(
5658 if "swe" in dataset_name .lower ():
5759 repo_name = example ["instance_id" ]
5860 dataset_type = "swebench"
61+ elif "humaneval" in dataset_name .lower ():
62+ repo_name = example ["instance_id" ]
63+ dataset_type = "simple"
5964 else :
6065 repo_name = example ["repo" ].split ("/" )[- 1 ]
6166 dataset_type = "commit0"
62- if repo_name in os .path .basename (repo_or_repo_dir ):
67+ if repo_name in os .path .basename (repo_or_repo_dir ) or repo_or_repo_dir . endswith ( repo_name ) :
6368 spec = make_spec (example , dataset_type )
6469 break
6570 assert spec is not None , "No spec available"
@@ -73,46 +78,61 @@ def main(
7378 log_file = log_dir / "run_pytest.log"
7479 logger = setup_logger (repo_name , log_file , verbose = verbose )
7580
76- try :
77- local_repo = git .Repo (repo_or_repo_dir )
78- logger .info (f"Loaded a git repo from { repo_or_repo_dir } " )
79- except (git .exc .NoSuchPathError , git .exc .InvalidGitRepositoryError ): # type: ignore
80- repo_dir = os .path .join (base_dir , repo_name )
81- logger .error (f"{ repo_or_repo_dir } is not a git dir, trying { repo_dir } again" )
81+ if dataset_type != "simple" : # if dataset_type is not simple, load git repo
8282 try :
83- local_repo = git .Repo (repo_dir )
84- logger .info (f"Retried succeeded. Loaded a git repo from { repo_dir } " )
85- except git .exc .NoSuchPathError : # type: ignore
86- raise Exception (
87- f"{ repo_dir } and { repo_or_repo_dir } are not git directories.\n Usage: commit0 test {{repo_dir}} {{branch}} {{test_ids}}"
88- )
89- except Exception as e :
90- raise e
91- commit_id = ""
92- if branch == "reference" :
93- commit_id = example ["reference_commit" ]
94- else :
95- # Check if it's a local branch
96- if branch in local_repo .branches :
97- commit_id = local_repo .commit (branch ).hexsha
83+ local_repo = git .Repo (repo_or_repo_dir )
84+ logger .info (f"Loaded a git repo from { repo_or_repo_dir } " )
85+ except (git .exc .NoSuchPathError , git .exc .InvalidGitRepositoryError ): # type: ignore
86+ repo_dir = os .path .join (base_dir , repo_name )
87+ logger .error (f"{ repo_or_repo_dir } is not a git dir, trying { repo_dir } again" )
88+ try :
89+ local_repo = git .Repo (repo_dir )
90+ logger .info (f"Retried succeeded. Loaded a git repo from { repo_dir } " )
91+ except git .exc .NoSuchPathError : # type: ignore
92+ raise Exception (
93+ f"{ repo_dir } and { repo_or_repo_dir } are not git directories.\n Usage: commit0 test {{repo_dir}} {{branch}} {{test_ids}}"
94+ )
95+ except Exception as e :
96+ raise e
97+ commit_id = ""
98+ if branch == "reference" :
99+ commit_id = example ["reference_commit" ]
98100 else :
99- found_remote_branch = False
100- for remote in local_repo .remotes :
101- remote .fetch () # Fetch latest updates from each remote
101+ # Check if it's a local branch
102+ if branch in local_repo .branches :
103+ commit_id = local_repo .commit (branch ).hexsha
104+ else :
105+ found_remote_branch = False
106+ for remote in local_repo .remotes :
107+ remote .fetch () # Fetch latest updates from each remote
102108
103- # Check if the branch exists in this remote
104- for ref in remote .refs :
105- if (
106- ref .remote_head == branch
107- ): # Compare branch name without remote prefix
108- commit_id = local_repo .commit (ref .name ).hexsha
109- found_remote_branch = True
110- break # Branch found, no need to keep checking this remote
111- if found_remote_branch :
112- break # Stop checking other remotes if branch is found
113- if not found_remote_branch :
114- raise Exception (f"Branch { branch } does not exist locally or remotely." )
115- if "swe" in dataset_name .lower ():
109+ # Check if the branch exists in this remote
110+ for ref in remote .refs :
111+ if (
112+ ref .remote_head == branch
113+ ): # Compare branch name without remote prefix
114+ commit_id = local_repo .commit (ref .name ).hexsha
115+ found_remote_branch = True
116+ break # Branch found, no need to keep checking this remote
117+ if found_remote_branch :
118+ break # Stop checking other remotes if branch is found
119+ if not found_remote_branch :
120+ raise Exception (f"Branch { branch } does not exist locally or remotely." )
121+ if dataset_type == "simple" :
122+ if branch == "reference" :
123+ patch = example ["prompt" ] + "\n \n " + example ["canonical_solution" ] + "\n \n " + example ["test" ]
124+ else :
125+ solution = open (test_ids ).read ()
126+ pattern = r"```python\n(.*?)```"
127+ matches = re .finditer (pattern , solution , re .DOTALL )
128+ matches = [match .group (1 ).strip () for match in matches ]
129+ if len (matches ) > 0 :
130+ solution = "\n \n " .join (matches )
131+ else :
132+ solution = example ["prompt" ] + "\n \n " + solution
133+ patch = solution + "\n \n " + example ["test" ]
134+ patch = patch + "\n \n " + f"check({ example ['entry_point' ]} )"
135+ elif "swe" in dataset_name .lower ():
116136 if branch == "reference" :
117137 patch = example ["test" ]["patch" ] + "\n \n " + example ["test" ]["test_patch" ]
118138 else :
@@ -127,12 +147,15 @@ def main(
127147 patch_file = Path (log_dir / "patch.diff" )
128148 patch_file .write_text (patch , encoding = "utf-8" , errors = "ignore" )
129149
130- # make eval file
131- if coverage :
132- coverage_text = f" --cov={ example ['src_dir' ]} --cov-branch --cov-report json"
150+ if dataset_type != "simple" :
151+ # make eval file
152+ if coverage :
153+ coverage_text = f" --cov={ example ['src_dir' ]} --cov-branch --cov-report json"
154+ else :
155+ coverage_text = ""
156+ eval_script = spec .eval_script .format (test_ids = test_ids , coverage = coverage_text )
133157 else :
134- coverage_text = ""
135- eval_script = spec .eval_script .format (test_ids = test_ids , coverage = coverage_text )
158+ eval_script = spec .eval_script
136159 eval_file = Path (log_dir / "eval.sh" )
137160 eval_file .write_text (eval_script )
138161
0 commit comments