Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 49 additions & 9 deletions agent/prototyper.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from datetime import timedelta
from typing import Optional

from helper.error_classifier import BuildErrorClassifier
import logger
from agent.base_agent import BaseAgent
from data_prep import project_targets
Expand Down Expand Up @@ -367,15 +368,54 @@ def _generate_prompt_from_build_result(
# Preference 7: New fuzz target + both `build.sh`s cannot compile. No need
# to mention the default build.sh.
# return build_result
builder = prompt_builder.PrototyperFixerTemplateBuilder(
model=self.llm,
benchmark=build_result.benchmark,
build_result=build_result,
compile_log=compile_log,
initial=prompt.get())
prompt = builder.build(example_pair=[],
project_dir=self.inspect_tool.project_dir)
return build_result, prompt
rag_enabled = False
try:
rag_enabled = bool(getattr(self, 'args', None)) and bool(getattr(self.args, 'rag_classifier', False))
except Exception:
rag_enabled = False
if rag_enabled:
# Use RAG-based classifier to build a targeted prompt.
error_classifier = BuildErrorClassifier("helper/error_patterns.yaml")
classification = error_classifier.classify(compile_log)
logger.debug("=== Compilation Log Start ===\n%s\n=== Compilation Log End ===", compile_log, trial=build_result.trial)

if classification:
logger.info("RAG match: identified build error type %s", classification["type"], trial=build_result.trial)
builder = prompt_builder.PrototyperErrorClassifierTemplateBuilder(
model=self.llm,
benchmark=build_result.benchmark,
build_result=build_result,
compile_log=compile_log,
error_classifier=error_classifier,
initial=prompt.get()
)
prompt = builder.build(project_dir=self.inspect_tool.project_dir)
return build_result, prompt

# If RAG could not classify, fall back to generic fixer template.
logger.warning("RAG match: classification failed, no error type matched", trial=build_result.trial)
builder = prompt_builder.PrototyperFixerTemplateBuilder(
model=self.llm,
benchmark=build_result.benchmark,
build_result=build_result,
compile_log=compile_log,
initial=prompt.get()
)
prompt = builder.build(example_pair=[], project_dir=self.inspect_tool.project_dir)
return build_result, prompt

else:
# RAG disabled -> always use the generic fixer template.
logger.info("RAG classifier disabled (no --rag-classifier flag); using FixerTemplateBuilder.", trial=build_result.trial)
builder = prompt_builder.PrototyperFixerTemplateBuilder(
model=self.llm,
benchmark=build_result.benchmark,
build_result=build_result,
compile_log=compile_log,
initial=prompt.get()
)
prompt = builder.build(example_pair=[], project_dir=self.inspect_tool.project_dir)
return build_result, prompt

def _container_handle_conclusion(self, cur_round: int, response: str,
build_result: BuildResult,
Expand Down
2 changes: 1 addition & 1 deletion ci/k8s/pr-exp.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ spec:
name: results-volume
env:
- name: LLM_NUM_EXP
value: '40'
value: '20'
- name: LLM_NUM_EVA
value: '10'
- name: VERTEX_AI_LOCATIONS
Expand Down
51 changes: 51 additions & 0 deletions helper/error_classifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import re
import yaml

class BuildErrorClassifier:
def __init__(self, error_db_path: str):
with open(error_db_path, 'r') as f:
self.error_db = yaml.safe_load(f)

def classify(self, compile_log: str) -> dict | None:
for error_type, data in self.error_db.items():
for pattern in data.get("patterns", []):
if re.search(pattern, compile_log, re.IGNORECASE):
return {
"type": error_type,
"good": data.get("good", []),
"bad": data.get("bad", []),
}
return None

def _find_first_error_msg(self, compile_log: str) -> str | None:
match = re.search(r"<stderr>(.*?)</stderr>", compile_log, re.DOTALL)
if match:
compile_log = match.group(1).strip()
else:
return None

lines = compile_log.splitlines()
for i, line in enumerate(lines):
if any(kw in line.lower() for kw in ('error:', 'fatal error', 'undefined reference')):
return '\n'.join(lines[i:])
return None

def trim_and_classify_err_msg(self, compile_log:str) -> dict | None:
compile_log = self._find_first_error_msg(compile_log)
if not compile_log:
return None
for error_type, data in self.error_db.items():
for pattern in data.get("patterns", []):
try:
match = re.search(pattern, compile_log, re.IGNORECASE)
except Exception:
print(f"Error with pattern: {pattern}")
continue
if match:
return {
"type": error_type,
"trimmed_msg": compile_log.strip()}
return {
"type": "unknown",
"trimmed_msg": compile_log.strip()}

Loading
Loading