Skip to content

Commit 8932b6d

Browse files
authored
Merge pull request SWE-agent#193 from princeton-nlp/allow-run-locally
Allow to run on local repos
2 parents fbcd1b1 + abd42d9 commit 8932b6d

File tree

21 files changed

+2130
-111
lines changed

21 files changed

+2130
-111
lines changed

run.py

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import logging
33
import os
44
import re
5+
import subprocess
56
import traceback
67
from typing import Any, Dict, Optional
78
import rich.console
@@ -42,22 +43,20 @@
4243
@dataclass(frozen=True)
4344
class ActionsArguments(FlattenedAccess, FrozenSerializable):
4445
"""Run real-life actions (opening PRs, etc.) if we can solve the issue."""
45-
open_pr: bool = False # Open a PR with the patch if we can solve the issue
46-
# Skip action if there are already commits claiming to fix the issue. Please only
47-
# set this to False if you are sure the commits are not fixes or if this is your
48-
# own repository!
46+
# Open a PR with the patch if we can solve the issue
47+
open_pr: bool = False
48+
# When working with local repository: Apply patch
49+
apply_patch_locally: bool = False
50+
# Option to be used with open_pr: Skip action if there are already commits claiming
51+
# to fix the issue. Please only set this to False if you are sure the commits are
52+
# not fixes or if this is your own repository!
4953
skip_if_commits_reference_issue: bool = True
50-
# For PRs: If you want to push the branch to a fork (e.g., because you lack
51-
# permissions to push to the main repo), set this to the URL of the fork.
54+
# OBSOLETE. Do not use, will raise error.
5255
push_gh_repo_url: str = ""
5356

5457
def __post_init__(self):
55-
if not self.skip_if_commits_reference_issue and self.push_gh_repo_url:
56-
raise ValueError(
57-
"Overriding `skip_if_commits_reference_issue` when you are "
58-
"pushing to a fork is not supported. You should manually "
59-
"apply the patch to the forked repository."
60-
)
58+
if self.push_gh_repo_url:
59+
raise ValueError("push_gh_repo_url is obsolete. Use repo_path instead")
6160

6261
@dataclass(frozen=True)
6362
class ScriptArguments(FlattenedAccess, FrozenSerializable):
@@ -118,6 +117,7 @@ def main(args: ScriptArguments):
118117
# Get info, patch information
119118
issue = getattr(env, "query", None)
120119
files = []
120+
assert env.record is not None # mypy
121121
if "patch" in env.record:
122122
files = "\n".join(
123123
[f"- {x.path}" for x in PatchSet(env.record["patch"]).modified_files]
@@ -147,9 +147,11 @@ def main(args: ScriptArguments):
147147
return_type="info_trajectory",
148148
)
149149
save_predictions(traj_dir, instance_id, info)
150-
save_patch(traj_dir, instance_id, info)
150+
patch_path = save_patch(traj_dir, instance_id, info)
151151
if args.actions.open_pr and should_open_pr(args, info, token=env._github_token):
152152
env.open_pr(trajectory=trajectory, push_gh_repo_url=args.actions.push_gh_repo_url)
153+
if args.actions.apply_patch_locally and patch_path is not None and env.record["repo_type"] == "local":
154+
apply_patch(Path(args.environment.repo_path), patch_file=patch_path)
153155

154156
except KeyboardInterrupt:
155157
logger.info("Exiting InterCode environment...")
@@ -281,6 +283,21 @@ def save_patch(traj_dir: Path, instance_id: str, info) -> Optional[Path]:
281283
return patch_output_file
282284

283285

286+
def apply_patch(local_dir: Path, patch_file: Path) -> None:
287+
"""Apply a patch to a local directory."""
288+
assert local_dir.is_dir()
289+
assert patch_file.exists()
290+
# The resolve() is important, because we're gonna run the cmd
291+
# somewhere else
292+
cmd = ["git", "apply", str(patch_file.resolve())]
293+
try:
294+
subprocess.run(cmd, cwd=local_dir, check=True)
295+
except subprocess.CalledProcessError as e:
296+
logger.error(f"Failed to apply patch {patch_file} to {local_dir}: {e}")
297+
return
298+
logger.info(f"Applied patch {patch_file} to {local_dir}")
299+
300+
284301
def _print_patch_message(patch_output_file: Path):
285302
console = rich.console.Console()
286303
msg = [

run_replay.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import yaml
66

77
from argparse import ArgumentParser
8-
from sweagent.environment.utils import is_from_github_url
98
from typing import Any, Dict, List
109
import run as runscript
1110

@@ -66,16 +65,15 @@ def create_task_instances_tmp_file(data: List[Dict[str, Any]]) -> str:
6665
print(json.dumps(d), file=f, end="\n", flush=True)
6766
return tmp_path
6867

69-
is_github = False
68+
is_other = False
7069
if data_path.endswith(".jsonl"):
7170
replay_task_instances_path = create_task_instances_tmp_file([json.loads(x) for x in open(data_path, "r").readlines()])
7271
elif data_path.endswith(".json"):
7372
replay_task_instances_path = create_task_instances_tmp_file(json.load(open(data_path)))
74-
elif is_from_github_url(data_path):
75-
is_github = True
76-
replay_task_instances_path = data_path
7773
else:
78-
raise ValueError("--data_path must be a .json or .jsonl")
74+
# Assume data_path is a github url or local url
75+
is_other = True
76+
replay_task_instances_path = data_path
7977

8078
# Call run.py via subprocess
8179
run_args = [
@@ -86,7 +84,7 @@ def create_task_instances_tmp_file(data: List[Dict[str, Any]]) -> str:
8684
"--replay_path", replay_action_trajs_path,
8785
*forward_args,
8886
]
89-
if is_github:
87+
if is_other:
9088
# Not sure if this only applies to github urls for data_path
9189
run_args.extend(["--skip_existing", "False"])
9290
if suffix is not None:
@@ -95,11 +93,8 @@ def create_task_instances_tmp_file(data: List[Dict[str, Any]]) -> str:
9593
runscript.main(script_args)
9694

9795
os.remove(replay_action_trajs_path)
98-
try:
96+
if not is_other:
9997
os.remove(replay_task_instances_path)
100-
except FileNotFoundError:
101-
pass
102-
10398

10499
def main(
105100
traj_path: str,

sweagent/environment/swe_env.py

Lines changed: 60 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,12 @@
1919
from simple_parsing.helpers.serialization.serializable import FrozenSerializable
2020
import yaml
2121
from sweagent.environment.utils import (
22+
copy_anything_to_container,
2223
copy_file_to_container,
2324
format_trajectory_markdown,
2425
get_container,
2526
get_gh_issue_data,
2627
get_instances,
27-
is_from_github_url,
2828
parse_gh_issue_url,
2929
parse_gh_repo_url,
3030
read_with_timeout,
@@ -53,6 +53,9 @@
5353
class EnvironmentArguments(FrozenSerializable):
5454
"""Configure data sources and setup instructions for th environment in which we solve the tasks.
5555
"""
56+
# Source of issue statement/problem statement. To run over a batch of issues: Path to a data file
57+
# (`json`, `jsonl`) or directory. To run over single issue: github issue url or path to markdown file
58+
# with problem statement.
5659
data_path: str
5760
image_name: str
5861
split: str = "dev"
@@ -62,11 +65,13 @@ class EnvironmentArguments(FrozenSerializable):
6265
timeout: int = 35
6366
verbose: bool = False
6467
no_mirror: bool = False
65-
# Custom environment setup. Currently only used when data_path is a GitHub URL.
68+
# Custom environment setup. Currently only used when data_path points to a single issue.
6669
# This needs to be either a string pointing to a yaml file (with yaml, yml file extension)
6770
# or a shell script (with sh extension).
6871
# See https://github.com/princeton-nlp/SWE-agent/pull/153 for more information
6972
environment_setup: Optional[str] = None
73+
# Only used when running on single issue. Path to local repository or github repository.
74+
repo_path: str = ""
7075

7176

7277
class SWEEnv(gym.Env):
@@ -84,7 +89,6 @@ def __init__(self, args: EnvironmentArguments):
8489
self.logger = logger
8590
self.persistent = args.container_name is not None
8691
self.returncode = None
87-
self.is_from_github_url = is_from_github_url(args.data_path)
8892
if not self.args.verbose:
8993
self.logger.disabled = True
9094

@@ -107,7 +111,9 @@ def __init__(self, args: EnvironmentArguments):
107111

108112
# Load Task Instances
109113
self.data_path = self.args.data_path
110-
self.data = get_instances(self.data_path, self.args.base_commit, self.args.split, token=self._github_token)
114+
self.data = get_instances(self.data_path, self.args.base_commit, self.args.split, token=self._github_token, repo_path=self.args.repo_path)
115+
#: Instance we're currently processing. Gets set in self.reset.
116+
self.record = None
111117
self.logger.info(f"💽 Loaded dataset from {self.data_path}")
112118

113119
# Establish connection with execution container
@@ -119,7 +125,48 @@ def __init__(self, args: EnvironmentArguments):
119125
self.idx = 0
120126
self.clean_multi_line_functions = lambda x: x
121127

122-
def reset(self, index: int = None, apply_test_patch: bool = False) -> Tuple[str, dict]:
128+
@property
129+
def _repo_name(self) -> str:
130+
"""Name of the local copy of the repository"""
131+
assert self.record is not None
132+
return self.record["repo"].replace("/", "__")
133+
134+
def _copy_repo(self) -> str:
135+
"""Clone/copy repository/codebase in container
136+
Returns:
137+
folder name of clone
138+
"""
139+
assert self.record is not None # mypy
140+
if self.record["repo_type"] == "local":
141+
copy_anything_to_container(self.container_obj, self.record["repo"].removeprefix("local://"), "/"+self._repo_name)
142+
self.communicate_with_handling(
143+
input=f"chown -R root:root {self._repo_name}",
144+
error_msg="Failed to change permissions on copied repository",
145+
)
146+
return self._repo_name
147+
assert self.record["repo_type"] == "github"
148+
token_prefix = ""
149+
if self._github_token:
150+
token_prefix = f"{self._github_token}@"
151+
# fixme: This if statement is brittle and should probably be replaced with better logic
152+
if not self.args.no_mirror and self.record["problem_statement_source"] == "swe-bench":
153+
self.logger.info(f"{self._repo_name} not found in container, cloning...")
154+
self.communicate_with_handling(
155+
input=f"git clone https://{token_prefix}github.com/swe-bench/{self._repo_name}.git",
156+
error_msg="Failed to clone repository from mirror",
157+
timeout_duration=LONG_TIMEOUT,
158+
)
159+
return self._repo_name
160+
else:
161+
logger.info(f"Trying to clone from non-mirror...")
162+
self.communicate_with_handling(
163+
input=f"git clone https://{token_prefix}github.com/{self.record['repo']}.git {self._repo_name}",
164+
error_msg="Failed to clone repository from non-mirror",
165+
timeout_duration=LONG_TIMEOUT,
166+
)
167+
return self._repo_name
168+
169+
def reset(self, index: Optional[int] = None, apply_test_patch: bool = False) -> Tuple[Optional[str], dict]:
123170
"""
124171
Function to reset container between each task instance.
125172
* Clones instance's repository
@@ -151,30 +198,13 @@ def reset(self, index: int = None, apply_test_patch: bool = False) -> Tuple[str,
151198
# Clone repository if not already cloned
152199
self.communicate(input="cd /")
153200
folders = self.communicate(input="ls").split("\n")
154-
repo_name = self.record["repo"].replace("/", "__")
155-
if repo_name not in folders:
156-
token_prefix = ""
157-
if self._github_token:
158-
token_prefix = f"{self._github_token}@"
159-
if not self.args.no_mirror and not self.is_from_github_url:
160-
self.logger.info(f"{repo_name} not found in container, cloning...")
161-
self.communicate_with_handling(
162-
input=f"git clone https://{token_prefix}github.com/swe-bench/{repo_name}.git",
163-
error_msg="Failed to clone repository from mirror",
164-
timeout_duration=LONG_TIMEOUT,
165-
)
166-
else:
167-
logger.info(f"Trying to clone from non-mirror...")
168-
self.communicate_with_handling(
169-
input=f"git clone https://{token_prefix}github.com/{self.record['repo']}.git {repo_name}",
170-
error_msg="Failed to clone repository from non-mirror",
171-
timeout_duration=LONG_TIMEOUT,
172-
)
201+
if self._repo_name not in folders:
202+
self._copy_repo()
173203

174204
# Clean repository of any modifications + Checkout base commit
175205
for cmd in [
176206
"echo -n > /root/files_to_edit.txt",
177-
f"cd {repo_name}",
207+
f"cd {self._repo_name}",
178208
"export ROOT=$(pwd -P)",
179209
"git status",
180210
"git restore .",
@@ -559,14 +589,15 @@ def install_env(self) -> None:
559589
"""
560590
Creates conda environment and installs third party dependencies to allow code execution
561591
"""
562-
if self.is_from_github_url and self.args.environment_setup is None:
592+
assert self.record is not None # mypy
593+
if (self.record["problem_statement_source"] != "swe-bench" or \
594+
self.record["repo_type"] == "local") and self.args.environment_setup is None:
563595
logger.warning((
564596
"install_environment is set to True, but the data path is a GitHub URL "
565597
"without an environment config file (environment_config key/flag). "
566598
"Skipping conda environment installation."
567599
))
568600
return
569-
repo_name = self.record["repo"].replace("/", "__")
570601
if self.args.environment_setup is not None:
571602
assert isinstance(self.args.environment_setup, (str, os.PathLike))
572603
if Path(self.args.environment_setup).suffix in [".yml", ".yaml"]:
@@ -592,7 +623,7 @@ def install_env(self) -> None:
592623
)
593624
raise ValueError(msg) from e
594625
# Create environment if does not exist yet
595-
env_name = f"{repo_name}__{self.record['version']}"
626+
env_name = f"{self._repo_name}__{self.record['version']}"
596627
env_check = self.communicate(
597628
f"conda env list | grep {env_name}", timeout_duration=LONG_TIMEOUT
598629
)
@@ -676,7 +707,7 @@ def install_env(self) -> None:
676707
pre_install_cmd,
677708
error_msg="Pre-install commands failed to execute successfully",
678709
)
679-
self.logger.info(f"Installing {repo_name} at base commit...")
710+
self.logger.info(f"Installing {self._repo_name} at base commit...")
680711
if "install" in install_configs:
681712
install_cmd = install_configs["install"]
682713
self.communicate_with_handling(

0 commit comments

Comments
 (0)