Skip to content

Commit 903b7c2

Browse files
committed
feat: wire TestGenerator into mine command
- Added OpenRouterClient and TestGenerator imports to mine.py - Pass test_generator to SwePipelineConfig - Fixed await on _create_sandbox in pipeline.py - Fixed DockerSandbox to enter DockerClient context - Fixed test mocking for async _create_sandbox
1 parent 22cbe33 commit 903b7c2

4 files changed

Lines changed: 31 additions & 7 deletions

File tree

src/swe_forge/cli/mine.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
)
2626

2727
from swe_forge.export.jsonl import export_jsonl
28+
from swe_forge.llm.openrouter import OpenRouterClient
2829
from swe_forge.swe.github_api import GitHubClient
2930
from swe_forge.swe.gharchive import GhArchiveClient
3031
from swe_forge.swe.models import SweTask
@@ -34,6 +35,7 @@
3435
SwePipelineConfig,
3536
SwePipelineEventType,
3637
)
38+
from swe_forge.swe.test_generator import TestGenerator
3739
from swe_forge.swe.concurrency import set_docker_containers_limit
3840

3941
logger = logging.getLogger(__name__)
@@ -278,7 +280,7 @@ def mine(
278280

279281
# Run the pipeline
280282
try:
281-
result = asyncio.run(_run_pipeline(github_token, config, repo, verbose))
283+
result = asyncio.run(_run_pipeline(github_token, config, repo, verbose, model))
282284

283285
if result.tasks:
284286
# Export results
@@ -417,6 +419,7 @@ async def _run_pipeline(
417419
config: SwePipelineConfig,
418420
repo_filter: Optional[str],
419421
verbose: bool,
422+
model: str = "moonshotai/kimi-k2.5",
420423
):
421424
"""Run the SWE pipeline with progress tracking."""
422425
from dataclasses import dataclass, field
@@ -426,6 +429,15 @@ class PipelineResult:
426429
tasks: list = field(default_factory=list)
427430
benchmark_metrics: object = None
428431

432+
# Create LLM client and TestGenerator for test generation
433+
openrouter_key = os.environ.get("OPENROUTER_API_KEY", "")
434+
llm_client = None
435+
test_generator = None
436+
if openrouter_key:
437+
llm_client = OpenRouterClient(api_key=openrouter_key, default_model=model)
438+
test_generator = TestGenerator(llm=llm_client, model=model)
439+
config.test_generator = test_generator
440+
429441
async with GitHubClient(token=token) as gh_client:
430442
gh_archive_client = GhArchiveClient(token=token) if not repo_filter else None
431443

@@ -577,7 +589,6 @@ async def _run_complete_mining(
577589
):
578590
"""Run the complete mining pipeline."""
579591
from swe_forge.pipeline import CompleteMiningPipeline
580-
from swe_forge.llm.openrouter import OpenRouterClient
581592

582593
llm_client = None
583594
if openrouter_key:

src/swe_forge/execution/sandbox.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ def __init__(
119119
self._state = SandboxState()
120120
self._manager: ContainerManager | None = None
121121
self._semaphore_acquired: bool = False
122+
self._client_own_connection: bool = False
122123

123124
@classmethod
124125
def from_spec(
@@ -159,6 +160,13 @@ async def __aenter__(self) -> "DockerSandbox":
159160
await sem.acquire()
160161
self._semaphore_acquired = True
161162

163+
# Enter DockerClient's async context if not already entered
164+
if self._client._docker is None:
165+
await self._client.__aenter__()
166+
self._client_own_connection = True
167+
else:
168+
self._client_own_connection = False
169+
162170
spec = ContainerSpec(
163171
name=self._container_name,
164172
image=self._config.image,
@@ -179,6 +187,8 @@ async def __aenter__(self) -> "DockerSandbox":
179187
self._manager = None
180188
sem.release()
181189
self._semaphore_acquired = False
190+
if self._client_own_connection:
191+
await self._client.__aexit__(None, None, None)
182192
raise
183193

184194
logger.info(
@@ -202,6 +212,9 @@ async def __aexit__(
202212
self._manager = None
203213
self._state.container_id = None
204214

215+
if self._client_own_connection:
216+
await self._client.__aexit__(exc_type, exc_val, exc_tb)
217+
205218
if self._semaphore_acquired:
206219
get_docker_semaphore().release()
207220
self._semaphore_acquired = False

src/swe_forge/swe/pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -458,7 +458,7 @@ async def _run_test_generation(
458458
sandbox: DockerSandbox | None = None
459459

460460
try:
461-
sandbox = self._create_sandbox(repo_url, enriched.base_commit)
461+
sandbox = await self._create_sandbox(repo_url, enriched.base_commit)
462462

463463
async with sandbox:
464464
await sandbox.setup_workspace(repo_url, enriched.base_commit)

tests/test_swe/test_pipeline_integration.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -130,14 +130,14 @@ def sample_enriched_pr() -> EnrichedPullRequest:
130130
def _bind_mock_sandbox(pipeline: SwePipeline, mock_sandbox: MockSandbox):
131131
"""Bind mock sandbox to pipeline's _create_sandbox.
132132
133-
The source code calls _create_sandbox without await, so we bind a
134-
sync function that returns the mock directly (not a coroutine).
133+
The source code calls _create_sandbox with await, so we bind an
134+
async function that returns the mock.
135135
"""
136136

137-
def create_sandbox_sync(self, repo_url: str, base_commit: str):
137+
async def create_sandbox_async(self, repo_url: str, base_commit: str):
138138
return mock_sandbox
139139

140-
pipeline._create_sandbox = create_sandbox_sync.__get__(pipeline, type(pipeline))
140+
pipeline._create_sandbox = create_sandbox_async.__get__(pipeline, type(pipeline))
141141

142142

143143
class TestGeneratorCalledWhenConfigured:

0 commit comments

Comments
 (0)