Skip to content

Commit 72b7b48

Browse files
committed
Do not introduce dataclass
1 parent aabc3d8 commit 72b7b48

File tree

7 files changed

+62
-34
lines changed

7 files changed

+62
-34
lines changed

sweagent/environment/utils.py

Lines changed: 48 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import tempfile
1313
import time
1414
import traceback
15-
import dataclasses
1615

1716
from datasets import load_dataset, load_from_disk
1817
from ghapi.all import GhApi
@@ -390,30 +389,6 @@ def get_problem_statement_from_github_issue(owner: str, repo: str, issue_number:
390389
return f"{title}\n{body}\n"
391390

392391

393-
@dataclasses.dataclass
394-
class Instance:
395-
repo: str
396-
base_commit: str
397-
version: str
398-
problem_statement: str
399-
instance_id: str
400-
# todo: This field is only needed while swe_env is using some questionable logic
401-
# to determine whether to clone from a mirror or not. This should be removed in the future.
402-
# Values: 'swe-bench' (loaded from json/jsonl for swe-bench style inference),
403-
# 'online' (loaded from github issue or similar) or 'local' (loaded from local file)
404-
problem_statement_source: str = "swe-bench"
405-
repo_type: str = "github"
406-
407-
def _validate(self):
408-
if self.repo_type not in {"github", "local"}:
409-
raise ValueError(f"Invalid repo type: {self.repo_type=}")
410-
if self.repo_type == "github" and self.repo.count("/") != 1:
411-
raise ValueError(f"Invalid repo format for {self.repo_type=}: {self.repo=}")
412-
413-
def __post_init__(self):
414-
self._validate()
415-
416-
417392
class InstanceBuilder:
418393
def __init__(self, token: Optional[str] = None):
419394
"""This helper class is used to build the data for an instance object,
@@ -481,7 +456,41 @@ def set_repo_info(self, repo: str, base_commit: Optional[str] = None):
481456
else:
482457
raise ValueError(f"Could not determine repo path from {repo=}.")
483458

484-
def build(self) -> Instance: return Instance(**self.args)
459+
def set_from_dict(self, instance_dict: Dict[str, Any]):
460+
self.args |= instance_dict
461+
462+
def set_missing_fields(self):
463+
# todo: This field is only needed while swe_env is using some questionable logic
464+
# to determine whether to clone from a mirror or not. This should be removed in the future.
465+
# Values: 'swe-bench' (loaded from json/jsonl for swe-bench style inference),
466+
# 'online' (loaded from github issue or similar) or 'local' (loaded from local file)
467+
if "problem_statement_source" not in self.args:
468+
self.args["problem_statement_source"] = "swe-bench"
469+
if "repo_type" not in self.args:
470+
self.args["repo_type"] = "github"
471+
472+
def validate(self):
473+
required_fields = [
474+
"problem_statement",
475+
"instance_id",
476+
"repo",
477+
"repo_type",
478+
"base_commit",
479+
"version",
480+
"problem_statement_source",
481+
]
482+
if not all(x in self.args for x in required_fields):
483+
missing = set(required_fields) - set(self.args.keys())
484+
raise ValueError(f"Missing required fields: {missing=}")
485+
if self.args["repo_type"] not in {"github", "local"}:
486+
raise ValueError(f"Invalid repo type: {self.args['repo_type']=}")
487+
if self.args["repo_type"] == "github" and self.args["repo"].count("/") != 1:
488+
raise ValueError(f"Invalid repo format for {self.args['repo_type']=}: {self.args['repo']=}")
489+
490+
def build(self) -> Dict[str, Any]:
491+
self.set_missing_fields()
492+
self.validate()
493+
return self.args
485494

486495

487496
def get_instances(
@@ -501,17 +510,22 @@ def get_instances(
501510
Returns:
502511
List of instances as dictionaries
503512
"""
504-
def set_missing_keys(instances):
505-
return [dataclasses.asdict(Instance(**inst)) for inst in instances]
513+
def instance_from_dict(instances):
514+
ib = InstanceBuilder(token=token)
515+
ib.set_from_dict(instances)
516+
return ib.build()
517+
518+
def postproc_instance_list(instances):
519+
return [instance_from_dict(x) for x in instances]
506520

507521

508522
# If file_path is a directory, attempt load from disk
509523
if os.path.isdir(file_path):
510524
try:
511525
dataset_or_dict = load_from_disk(file_path)
512526
if isinstance(dataset_or_dict, dict):
513-
return set_missing_keys(dataset_or_dict[split])
514-
return set_missing_keys(dataset_or_dict)
527+
return postproc_instance_list(dataset_or_dict[split])
528+
return postproc_instance_list(dataset_or_dict)
515529
except FileNotFoundError:
516530
# Raised by load_from_disk if the directory is not a dataset directory
517531
pass
@@ -527,24 +541,24 @@ def set_missing_keys(instances):
527541
else:
528542
raise ValueError(f"Could not determine repo path from {file_path=}, {repo_path=}")
529543

530-
return [dataclasses.asdict(ib.build())]
544+
return [ib.build()]
531545

532546
if base_commit is not None:
533547
raise ValueError("base_commit must be None if data_path is not a github issue url")
534548

535549
# If file_path is a file, load the file
536550
if file_path.endswith(".json"):
537-
return set_missing_keys(json.load(open(file_path)))
551+
return postproc_instance_list(json.load(open(file_path)))
538552
if file_path.endswith(".jsonl"):
539-
return set_missing_keys([json.loads(x) for x in open(file_path, 'r').readlines()])
553+
return postproc_instance_list([json.loads(x) for x in open(file_path, 'r').readlines()])
540554

541555
if repo_path:
542556
msg = "repo_path must be empty if data_path is not a github url or local repo url"
543557
raise ValueError(msg)
544558

545559
# Attempt load from HF datasets as a last resort
546560
try:
547-
return set_missing_keys(load_dataset(file_path, split=split))
561+
return postproc_instance_list(load_dataset(file_path, split=split))
548562
except:
549563
raise ValueError(
550564
f"Could not load instances from {file_path}. "

tests/test_data/data_sources/debug_20240322.json

Lines changed: 1 addition & 0 deletions
Large diffs are not rendered by default.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
[{"instance_id": "swe-bench__humaneval-30", "problem_statement": "I have a function that needs implementing, can you help?", "created_at": "2023110716", "version": "1.0", "test_patch": "diff --git a/test.py b/test.py\nnew file mode 100644\nindex 0000000..52ecda2\n--- /dev/null\n+++ b/test.py\n@@ -0,0 +1,13 @@\n+from main import get_positive\n+\n+\n+METADATA = {}\n+\n+\n+def check(candidate):\n+ assert candidate([-1, -2, 4, 5, 6]) == [4, 5, 6]\n+ assert candidate([5, 3, -5, 2, 3, 3, 9, 0, 123, 1, -10]) == [5, 3, 2, 3, 3, 9, 123, 1]\n+ assert candidate([-1, -2]) == []\n+ assert candidate([]) == []\n+\n+check(get_positive)\n", "base_commit": "0880311", "base_commit_with_tests": "b2e380b", "environment_setup_commit": null, "hints_text": null, "repo": "swe-bench/humaneval", "FAIL_TO_PASS": "", "PASS_TO_PASS": ""}, {"instance_id": "swe-bench__humaneval-85", "problem_statement": "I have a function that needs implementing, can you help?", "created_at": "2023110716", "version": "1.0", "test_patch": "diff --git a/test.py b/test.py\nnew file mode 100644\nindex 0000000..13d6e1f\n--- /dev/null\n+++ b/test.py\n@@ -0,0 +1,12 @@\n+from main import add\n+def check(candidate):\n+\n+ # Check some simple cases\n+ assert candidate([4, 88]) == 88\n+ assert candidate([4, 5, 6, 7, 2, 122]) == 122\n+ assert candidate([4, 0, 6, 7]) == 0\n+ assert candidate([4, 4, 6, 8]) == 12\n+\n+ # Check some edge cases that are easy to work out by hand.\n+ \n+check(add)\n", "base_commit": "2de55bc", "base_commit_with_tests": "c8c997b", "environment_setup_commit": null, "hints_text": null, "repo": "swe-bench/humaneval", "FAIL_TO_PASS": "", "PASS_TO_PASS": ""}, {"instance_id": "swe-bench__humaneval-22", "problem_statement": "I have a function that needs implementing, can you help?", "created_at": "2023110716", "version": "1.0", "test_patch": "diff --git a/test.py b/test.py\nnew file mode 100644\nindex 0000000..d881459\n--- /dev/null\n+++ b/test.py\n@@ -0,0 +1,14 @@\n+from main import filter_integers\n+\n+\n+METADATA = {\n+ 'author': 'jt',\n+ 'dataset': 'test'\n+}\n+\n+\n+def check(candidate):\n+ assert candidate([]) == []\n+ assert candidate([4, {}, [], 23.2, 9, 'adasd']) == [4, 9]\n+ assert candidate([3, 'c', 3, 3, 'a', 'b']) == [3, 3, 3]\n+check(filter_integers)\n", "base_commit": "f0dbe5e", "base_commit_with_tests": "55cc474", "environment_setup_commit": null, "hints_text": null, "repo": "swe-bench/humaneval", "FAIL_TO_PASS": "", "PASS_TO_PASS": ""}, {"instance_id": "swe-bench__humaneval-104", "problem_statement": "I have a function that needs implementing, can you help?", "created_at": "2023110716", "version": "1.0", "test_patch": "diff --git a/test.py b/test.py\nnew file mode 100644\nindex 0000000..617da5a\n--- /dev/null\n+++ b/test.py\n@@ -0,0 +1,13 @@\n+from main import unique_digits\n+def check(candidate):\n+\n+ # Check some simple cases\n+ assert candidate([15, 33, 1422, 1]) == [1, 15, 33]\n+ assert candidate([152, 323, 1422, 10]) == []\n+ assert candidate([12345, 2033, 111, 151]) == [111, 151]\n+ assert candidate([135, 103, 31]) == [31, 135]\n+\n+ # Check some edge cases that are easy to work out by hand.\n+ assert True\n+\n+check(unique_digits)\n", "base_commit": "b52ee85", "base_commit_with_tests": "4a92a50", "environment_setup_commit": null, "hints_text": null, "repo": "swe-bench/humaneval", "FAIL_TO_PASS": "", "PASS_TO_PASS": ""}, {"instance_id": "swe-bench__humaneval-0", "problem_statement": "I have a function that needs implementing, can you help?", "created_at": "2023110716", "version": "1.0", "test_patch": "diff --git a/test.py b/test.py\nnew file mode 100644\nindex 0000000..2d57340\n--- /dev/null\n+++ b/test.py\n@@ -0,0 +1,19 @@\n+from main import has_close_elements\n+\n+\n+METADATA = {\n+ 'author': 'jt',\n+ 'dataset': 'test'\n+}\n+\n+\n+def check(candidate):\n+ assert candidate([1.0, 2.0, 3.9, 4.0, 5.0, 2.2], 0.3) == True\n+ assert candidate([1.0, 2.0, 3.9, 4.0, 5.0, 2.2], 0.05) == False\n+ assert candidate([1.0, 2.0, 5.9, 4.0, 5.0], 0.95) == True\n+ assert candidate([1.0, 2.0, 5.9, 4.0, 5.0], 0.8) == False\n+ assert candidate([1.0, 2.0, 3.0, 4.0, 5.0, 2.0], 0.1) == True\n+ assert candidate([1.1, 2.2, 3.1, 4.1, 5.1], 1.0) == True\n+ assert candidate([1.1, 2.2, 3.1, 4.1, 5.1], 0.5) == False\n+\n+check(has_close_elements)\n", "base_commit": "afba737", "base_commit_with_tests": "c7e41b2", "environment_setup_commit": null, "hints_text": null, "repo": "swe-bench/humaneval", "FAIL_TO_PASS": "", "PASS_TO_PASS": ""}]

tests/test_data/data_sources/swe-bench-dev-easy.json

Lines changed: 1 addition & 0 deletions
Large diffs are not rendered by default.

tests/test_data/data_sources/swe-bench-lite-test.json

Lines changed: 1 addition & 0 deletions
Large diffs are not rendered by default.

tests/test_env.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,9 @@ def test_execute_environment(tmp_path, test_env_args):
6666
test_env_args = dataclasses.replace(test_env_args, environment_setup=env_config_path)
6767
env = SWEEnv(test_env_args)
6868
env.reset()
69+
70+
71+
@pytest.mark.slow
6972
def test_open_pr(test_env_args):
7073
env = SWEEnv(test_env_args)
7174
env.reset()

tests/test_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,3 +140,10 @@ def test_get_instance_gh_issue_gh_repo(tmp_path):
140140
assert "SyntaxError" in instance["problem_statement"]
141141
assert len(instance["base_commit"]) > 10
142142
assert instance["version"]
143+
144+
145+
def test_load_instances(test_data_path, caplog):
146+
test_data_sources = test_data_path / "data_sources"
147+
examples = list(test_data_sources.iterdir())
148+
for example in examples:
149+
get_instances(file_path=str(example))

0 commit comments

Comments
 (0)