Skip to content

Commit 57c8782

Browse files
committed
🔨 [refactor] Integrate agent classes into Kaggle solver for enhanced task management and debugging
1 parent 24b21ff commit 57c8782

File tree

1 file changed

+69
-46
lines changed

1 file changed

+69
-46
lines changed

exp/kaggle_solver.py

Lines changed: 69 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from langgraph.func import task, entrypoint
3535

3636
from mle.cli import console
37+
from mle.function import execute_command
3738
from mle.model import load_model
3839
from mle.agents import (
3940
WorkflowCache,
@@ -57,6 +58,11 @@ class KaggleState:
5758
console: Console | None = None
5859
cache: WorkflowCache | None = None
5960
model: Any | None = None
61+
# Agents
62+
advisor: AdviseAgent | None = None
63+
planner: PlanAgent | None = None
64+
coder: CodeAgent | None = None
65+
debugger: DebugAgent | None = None
6066

6167
# run‑time ---------------------------------------------------------------
6268
resume_step: int | None = None
@@ -70,10 +76,15 @@ class KaggleState:
7076
coding_plan: Dict[str, Any] | None = None
7177
is_auto_mode: bool | None = None
7278

73-
# coding loop ------------------------------------------------------------
79+
# coding ------------------------------------------------------------
7480
current_task: Dict[str, Any] | None = None
7581
code_report: Dict[str, Any] | None = None
76-
final_code_reports: List[Dict[str, Any]] = field(default_factory=list)
82+
debug_attempt: int = 0
83+
debug_max_attempt: int = 5
84+
85+
# output files -----------------------------------------------------------
86+
submission: str = "submission.csv"
87+
sample_submission: str = None
7788

7889

7990
# -----------------------------------------------------------------------------
@@ -105,6 +116,12 @@ def init_node(inputs: dict) -> KaggleState:
105116
f"Competition with ID '{state.competition_id}' not found in MLE Bench"
106117
)
107118

119+
# Load agents
120+
state.advisor = AdviseAgent(model=state.model, working_dir=state.work_dir, console=state.console)
121+
state.planner = PlanAgent(model=state.model, working_dir=state.work_dir, console=state.console)
122+
state.coder = CodeAgent(model=state.model, working_dir=state.work_dir, console=state.console)
123+
state.debugger = DebugAgent(model=state.model, console=state.console)
124+
108125
return state
109126

110127

@@ -137,14 +154,15 @@ def overview_summary_node(state: KaggleState) -> KaggleState:
137154

138155
@task
139156
def advisor_report_node(state: KaggleState) -> KaggleState:
140-
cache, con = state.cache, state.console
141-
with cache(step=3, name="MLE advisor agent provides a high‑level report") as ca:
157+
with (state.cache(step=3, name="MLE advisor agent provides a high‑level report") as ca):
142158
state.advisor_report = ca.resume("advisor_report")
143159
if state.advisor_report is None:
144-
advisor = AdviseAgent(model=state.model, working_dir=state.work_dir, console=state.console)
145-
state.advisor_report = advisor.interact(
146-
f"[green]Competition Requirement:[/green] {state.ml_requirement}\n"
147-
f"Dataset path: {state.dataset_path}"
160+
with console.status("MLE Agent is processing the kaggle competition overview..."):
161+
requirements = state.ml_requirement + f"\nDataset path: {state.dataset_path}" \
162+
+ f"\nSUBMISSION FILE PATH: {state.submission}\n"
163+
164+
state.advisor_report = state.advisor.suggest(
165+
requirements, return_raw=True
148166
)
149167
ca.store("advisor_report", state.advisor_report)
150168
return state
@@ -155,54 +173,51 @@ def plan_generation_node(state: KaggleState) -> KaggleState:
155173
with state.cache(step=4, name="MLE plan agent generates a dev plan") as ca:
156174
state.coding_plan = ca.resume("coding_plan")
157175
if state.coding_plan is None:
158-
planner = PlanAgent(model=state.model, working_dir=state.work_dir, console=state.console)
159-
state.coding_plan = planner.interact(state.advisor_report)
176+
state.coding_plan = state.planner.interact(state.advisor_report)
160177
ca.store("coding_plan", state.coding_plan)
161178
return state
162179

163180

164-
def _ensure_auto_mode(state: KaggleState):
165-
if state.is_auto_mode is None:
166-
state.is_auto_mode = questionary.confirm(
167-
"MLE developer is about to start to code.\nChoose to debug or not? (No = code‑only mode)"
168-
).ask()
181+
@task
182+
def coder_read_requirement(state: KaggleState) -> KaggleState:
183+
state.coder.read_requirement(state.advisor_report)
184+
return state
169185

170186

171187
@task
172188
def code_task_node(state: KaggleState) -> KaggleState:
173-
_ensure_auto_mode(state)
174-
coder = CodeAgent(state.model, state.work_dir, state.console)
175-
coder.read_requirement(state.advisor_report)
176-
177-
tasks: Iterable[Dict[str, Any]] = state.coding_plan.get("tasks", [])
178-
if not tasks:
179-
return state
180-
state.current_task = tasks.pop(0)
181-
state.code_report = coder.interact(state.current_task)
189+
state.code_report = state.coder.code(state.current_task)
182190
return state
183191

184192

185-
def debug_decision(state: KaggleState) -> str:
186-
_ensure_auto_mode(state)
187-
needs_debug = (
188-
state.is_auto_mode and state.code_report and
189-
str(state.code_report.get("debug", "")).lower() == "true"
190-
)
191-
return "debug" if needs_debug else "done"
192-
193-
194193
@task
195-
def debug_loop_node(state: KaggleState) -> KaggleState:
196-
debugger = DebugAgent(state.model, state.console)
197-
coder = CodeAgent(state.model, state.work_dir, state.console)
194+
def debug(state: KaggleState) -> KaggleState:
198195
# TODO: save the code to a file, create a venvironment, and run it
199196
# collect the run logs and errors
200-
while True:
201-
with state.console.status("Debugging code …"):
202-
debug_report = debugger.analyze(state.code_report)
203-
if debug_report.get("status") == "success":
204-
break
205-
state.code_report = coder.debug(state.current_task, debug_report)
197+
with console.status("MLE Debug Agent is executing and debugging the code..."):
198+
running_cmd = state.code_report.get('command')
199+
logs = execute_command(running_cmd)
200+
debug_report = state.debugger.analyze_with_log(running_cmd, logs)
201+
state.code_report = state.coder.debug(state.current_task, debug_report)
202+
return state
203+
204+
205+
@task
206+
def check_submission_file(state: KaggleState) -> KaggleState:
207+
if not os.path.exists(state.submission):
208+
console.log(f"The submission file ({state.submission}) is not found. Please check the code.")
209+
state.code_report = state.coder.debug(
210+
state.current_task,
211+
{
212+
"status": "error",
213+
"changes": [
214+
f"make sure the submission file is generated in {state.submission}",
215+
f"make sure the submission file is in the correct format. You can refer to the example "
216+
f"submission file: {state.sample_submission}"
217+
],
218+
"suggestion": f"Please update the code related to generating the submission file."
219+
}
220+
)
206221
return state
207222

208223

@@ -218,15 +233,23 @@ def kaggle_solver(inputs: dict) -> KaggleState:
218233
overview_summary_node,
219234
advisor_report_node,
220235
plan_generation_node,
236+
coder_read_requirement,
221237
):
222238
state = step_fn(state).result()
223239

224240
# coding plan loop
225-
while state.coding_plan and state.coding_plan.get("tasks"):
241+
while state.coding_plan and (tasks := state.coding_plan.get("tasks")):
242+
state.current_task = tasks.pop(0)
226243
state = code_task_node(state).result()
227-
route = debug_decision(state)
228-
if route == "debug":
229-
state = debug_loop_node(state).result()
244+
while True:
245+
if state.debug_attempt > state.debug_max_attempt:
246+
console.log(
247+
f"Debug the code failed with max {state.debug_max_attempt} attempts. Please check the code manually."
248+
)
249+
break
250+
251+
state = debug(state).result()
252+
state.debug_attempt += 1
230253

231254
# finished
232255
if state.console:

0 commit comments

Comments
 (0)