Skip to content

Commit 686d4e4

Browse files
committed
fix cli related baseline issue
1 parent 078b1b9 commit 686d4e4

File tree

4 files changed

+16
-34
lines changed

4 files changed

+16
-34
lines changed

baselines/class_types.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,6 @@
11
from dataclasses import dataclass
22

33

4-
@dataclass
5-
class Commit0Config:
6-
base_dir: str
7-
dataset_name: str
8-
dataset_split: str
9-
repo_split: str
10-
num_workers: int
11-
12-
134
@dataclass
145
class AgentConfig:
156
agent_name: str

baselines/configs/agent.yaml

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,11 @@ defaults:
33
- base
44
- _self_
55

6-
commit0_config:
7-
repo_split: minitorch
8-
96
agent_config:
107
use_user_prompt: false
118
use_repo_info: false
129
use_unit_tests_info: false
1310
use_spec_info: false
1411
use_lint_info: true
1512
pre_commit_config_path: .pre-commit-config.yaml
16-
run_tests: false
13+
run_tests: true

baselines/configs/base.yaml

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,6 @@
11
defaults:
22
- _self_
33

4-
5-
6-
commit0_config:
7-
base_dir: repos
8-
dataset_name: "wentingzhao/commit0_docstring"
9-
dataset_split: "test"
10-
repo_split: "simpy"
11-
num_workers: 10
12-
134
agent_config:
145
agent_name: "aider"
156
model_name: "claude-3-5-sonnet-20240620"

baselines/run_agent.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,12 @@
1515
from typing import Optional, Type
1616
from types import TracebackType
1717
from hydra.core.config_store import ConfigStore
18-
from baselines.class_types import AgentConfig, Commit0Config
18+
from baselines.class_types import AgentConfig
1919
from commit0.harness.constants import SPLIT
2020
from commit0.harness.get_pytest_ids import main as get_tests
2121
from commit0.harness.constants import RUN_AIDER_LOG_DIR, RepoInstance
2222
from tqdm import tqdm
23+
from commit0.cli import read_commit0_dot_file
2324

2425

2526
class DirContext:
@@ -40,7 +41,7 @@ def __exit__(
4041

4142

4243
def run_agent_for_repo(
43-
commit0_config: Commit0Config,
44+
repo_base_dir: str,
4445
agent_config: AgentConfig,
4546
example: RepoInstance,
4647
) -> None:
@@ -55,7 +56,7 @@ def run_agent_for_repo(
5556
test_files_str = get_tests(repo_name, verbose=0)
5657
test_files = sorted(list(set([i.split(":")[0] for i in test_files_str])))
5758

58-
repo_path = os.path.join(commit0_config.base_dir, repo_name)
59+
repo_path = os.path.join(repo_base_dir, repo_name)
5960
repo_path = os.path.abspath(repo_path)
6061
try:
6162
local_repo = Repo(repo_path)
@@ -82,13 +83,15 @@ def run_agent_for_repo(
8283
local_repo.git.reset("--hard", example["base_commit"])
8384
target_edit_files = get_target_edit_files(repo_path)
8485
with DirContext(repo_path):
85-
if commit0_config is None or agent_config is None:
86+
if agent_config is None:
8687
raise ValueError("Invalid input")
8788

8889
if agent_config.run_tests:
8990
# when unit test feedback is available, iterate over test files
9091
for test_file in test_files:
91-
test_cmd = f"python -m commit0 test {repo_path} {run_id} {test_file}"
92+
test_cmd = (
93+
f"python -m commit0 test {repo_path} {test_file} --branch {run_id}"
94+
)
9295
test_file_name = test_file.replace(".py", "").replace("/", "__")
9396
log_dir = RUN_AIDER_LOG_DIR / "with_tests" / test_file_name
9497
lint_cmd = get_lint_cmd(local_repo, agent_config.use_lint_info)
@@ -119,26 +122,26 @@ def main() -> None:
119122
Will run in parallel for each repo.
120123
"""
121124
cs = ConfigStore.instance()
122-
cs.store(name="user", node=Commit0Config)
123125
cs.store(name="user", node=AgentConfig)
124126
hydra.initialize(version_base=None, config_path="configs")
125127
config = hydra.compose(config_name="agent")
126-
commit0_config = Commit0Config(**config.commit0_config)
127128
agent_config = AgentConfig(**config.agent_config)
128129

130+
commit0_config = read_commit0_dot_file(".commit0.yaml")
131+
129132
dataset = load_dataset(
130-
commit0_config.dataset_name, split=commit0_config.dataset_split
133+
commit0_config["dataset_name"], split=commit0_config["dataset_split"]
131134
)
132135
filtered_dataset = [
133136
example
134137
for example in dataset
135-
if commit0_config.repo_split == "all"
138+
if commit0_config["repo_split"] == "all"
136139
or (
137140
isinstance(example, dict)
138141
and "repo" in example
139142
and isinstance(example["repo"], str)
140143
and example["repo"].split("/")[-1]
141-
in SPLIT.get(commit0_config.repo_split, [])
144+
in SPLIT.get(commit0_config["repo_split"], [])
142145
)
143146
]
144147
assert len(filtered_dataset) > 0, "No examples available"
@@ -149,14 +152,14 @@ def main() -> None:
149152
with tqdm(
150153
total=len(filtered_dataset), smoothing=0, desc="Running Aider for repos"
151154
) as pbar:
152-
with multiprocessing.Pool(processes=commit0_config.num_workers) as pool:
155+
with multiprocessing.Pool(processes=10) as pool:
153156
results = []
154157

155158
# Use apply_async to submit jobs and add progress bar updates
156159
for example in filtered_dataset:
157160
result = pool.apply_async(
158161
run_agent_for_repo,
159-
args=(commit0_config, agent_config, example),
162+
args=(commit0_config["base_dir"], agent_config, example),
160163
callback=lambda _: pbar.update(
161164
1
162165
), # Update progress bar on task completion

0 commit comments

Comments
 (0)