diff --git a/strix/agents/base_agent.py b/strix/agents/base_agent.py index 5d0afb51..ca2fa326 100644 --- a/strix/agents/base_agent.py +++ b/strix/agents/base_agent.py @@ -348,6 +348,7 @@ async def _initialize_sandbox_and_state(self, task: str) -> None: self.state.add_message("user", task) async def _process_iteration(self, tracer: Optional["Tracer"]) -> bool: + await self._record_agent_checkpoint() response = await self.llm.generate(self.state.get_conversation_history()) content_stripped = (response.content or "").strip() @@ -513,3 +514,18 @@ def _check_agent_messages(self, state: AgentState) -> None: # noqa: PLR0912 logger = logging.getLogger(__name__) logger.warning(f"Error checking agent messages: {e}") return + + async def _record_agent_checkpoint(self): + from strix.checkpoint.models import AgentCheckpointInfo + from strix.checkpoint.manager import record_execution_checkpoint + from strix.tools.agents_graph.agents_graph_actions import _agent_messages, _root_agent_id + + + checkpoint_info = AgentCheckpointInfo( + agent_state=self.state, + prompt_modules=self.llm_config.prompt_modules, + pending_agent_messages=_agent_messages, + is_root_agent=(self.state.agent_id == _root_agent_id) + ) + + await record_execution_checkpoint(checkpoint_info) diff --git a/strix/checkpoint/__init__.py b/strix/checkpoint/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/strix/checkpoint/file_store.py b/strix/checkpoint/file_store.py new file mode 100644 index 00000000..2ff6e43d --- /dev/null +++ b/strix/checkpoint/file_store.py @@ -0,0 +1,212 @@ +import json +import os +import tempfile +from pathlib import Path +from typing import Any, Optional +from datetime import datetime + +from logging import getLogger + +from strix.agents.state import AgentState +from strix.checkpoint.models import ( + StrixExecutionCheckpoint, + CheckpointVersionInfo, + TracerCheckpointInfo, + AgentGraphCheckpointInfo, + AgentCheckpointInfo, +) +from strix.checkpoint.store import CheckpointStore + +logger = getLogger(__name__) + + +ISO = "%Y-%m-%dT%H:%M:%S.%fZ" + + +class CheckpointFileStore(CheckpointStore): + """Simple filesystem-backed checkpoint store. + + Layout under `file_path` (a directory): + - version_info.json + - agents/agent-{agent_id}.json # one file per agent + - tracer/tracer-{uuid}.json # tracer checkpoints (we keep only last by file mtime) + - execution_checkpoint.json # optional consolidated StrixExecutionCheckpoint + + Guarantees: + - Writes are atomic (temp file + os.replace). + - Storing a checkpoint for an agent will overwrite that agent's file (keeps last). + + Notes / assumptions: + - The caller provides `agent_state_opaque` and `tracer_state_opaque` as JSON-serializable strings + (already serialized representations). This class will store them as-is under the per-agent + / tracer files. When loading, it will attempt to parse them back into the Pydantic models + expected by StrixExecutionCheckpoint. If parsing fails, the raw strings are left as-is where + parsing is not possible. + - load() will prefer a consolidated execution_checkpoint.json if present. Otherwise it will + assemble an execution checkpoint from version_info.json + the newest tracer file + all + agent files. + """ + + def __init__(self, file_path: Path): + self.file_path = Path(file_path) + self.agents_dir = self.file_path / "agents" + self.tracer_dir = self.file_path / "tracer" + self.file_path.mkdir(parents=True, exist_ok=True) + self.agents_dir.mkdir(parents=True, exist_ok=True) + self.tracer_dir.mkdir(parents=True, exist_ok=True) + + # -- small helpers + def _atomic_write(self, path: Path, data: bytes) -> None: + dirpath = path.parent + dirpath.mkdir(parents=True, exist_ok=True) + fd, tmp_path = tempfile.mkstemp(dir=dirpath) + try: + with os.fdopen(fd, "wb") as f: + f.write(data) + os.replace(tmp_path, str(path)) + finally: + # if something failed and tmp still exists, try to clean + try: + if os.path.exists(tmp_path): + os.remove(tmp_path) + except Exception: + pass + + def _read_json(self, path: Path) -> Optional[dict[str, Any]]: + try: + with path.open("r", encoding="utf-8") as f: + return json.load(f) + except FileNotFoundError: + return None + + # -- public API + def store_version_info(self, version_info: CheckpointVersionInfo) -> None: + """Store version_info to version_info.json (atomic).""" + payload = version_info.model_dump_json(indent=4) + self._atomic_write( self.file_path / "version_info.json", payload.encode("utf-8")) + + def store_checkpoint(self, agent_checkpoint: AgentCheckpointInfo, tracer_checkpoint: TracerCheckpointInfo) -> None: + """Store a checkpoint for a single agent. + + Behavior: + - Writes agent file at agents/agent-{agent_id}.json (overwrites previous file atomically). + - If tracer_state_opaque is provided, writes a new tracer file tracer/tracer-{uuid}.json + (we keep them as separate timestamped files; load() will pick the newest one). + """ + agent_path = self.agents_dir / f"state-{agent_checkpoint.agent_state.agent_id}.json" + agent_checkpoint_payload = agent_checkpoint.model_dump_json(indent=4) + self._atomic_write(agent_path, agent_checkpoint_payload.encode("utf-8")) + + tracer_path = self.tracer_dir / "tracer-latest.json" + tracer_checkpoint_payload = tracer_checkpoint.model_dump_json(indent=4) + self._atomic_write(tracer_path, tracer_checkpoint_payload.encode("utf-8")) + + + def load(self) -> Optional[StrixExecutionCheckpoint]: + """ + Load a StrixExecutionCheckpoint with the following logic: + + 1. Try to load version_info.json + 2. Try to load tracer-latest.json + 3. Load all state-{agent_id}.json and deduce: + - root_agent_id + - latest pending_agent_messages + - agent_states (AgentState objects) + 4. Build AgentGraphCheckpointInfo + 5. Assemble and return StrixExecutionCheckpoint + """ + + # --- 1. VERSION INFO --- + version_info_raw = self._read_json(self.file_path / "version_info.json") + version_info = None + if version_info_raw: + try: + version_info = CheckpointVersionInfo.model_validate(version_info_raw) + except Exception as e: + logger.warning(f"Invalid version_info.json: {e}") + + # --- 2. TRACER INFO --- + tracer_info_raw = self._read_json(self.tracer_dir / "tracer-latest.json") + tracer_info = None + if tracer_info_raw: + try: + tracer_info = TracerCheckpointInfo.model_validate(tracer_info_raw) + except Exception as e: + logger.warning(f"Invalid tracer-latest.json: {e}") + + # --- 3. LOAD ALL AGENTS --- + agent_files = sorted(self.agents_dir.glob("state-*.json")) + + if not agent_files: + logger.warning("No agents found in checkpoint") + return None + + decoded_agent_states: dict[str, AgentState] = {} + root_agent_id: str | None = None + + latest_pending: dict[str, list[dict[str, Any]]] = {} + latest_created_at: datetime | None = None + + for p in agent_files: + agent_id = p.stem.replace("state-", "") + raw = self._read_json(p) + if not raw: + continue + + try: + # raw is expected to follow AgentCheckpointInfo structure + agent_cp = AgentCheckpointInfo.model_validate(raw) + except Exception as e: + logger.warning(f"Invalid agent checkpoint for {agent_id}: {e}") + continue + + # 1. Capture agent_state + decoded_agent_states[agent_id] = agent_cp.agent_state + + # 2. Root agent? + if agent_cp.is_root_agent and root_agent_id is None: + root_agent_id = agent_id + + # 3. Pending messages: choose by latest created_at + if agent_cp.pending_agent_messages: + created_at = agent_cp.created_at + if latest_created_at is None or created_at > latest_created_at: + latest_created_at = created_at + latest_pending = agent_cp.pending_agent_messages + + # No valid agent states recovered? Bail out. + if not decoded_agent_states: + logger.warning("No valid agent states found in checkpoint") + return None + + # If root agent was never detected, fallback to the first one + if root_agent_id is None: + root_agent_id = next(iter(decoded_agent_states.keys())) + + # --- 4. BUILD GRAPH INFO --- + graph_info = AgentGraphCheckpointInfo( + root_agent_id=root_agent_id, + pending_agent_messages=latest_pending or {}, + agent_states=decoded_agent_states, + ) + + # --- 5. FINAL CHECKS --- + if not tracer_info: + logger.warning("Missing tracer_info; cannot build full checkpoint") + return None + + if not version_info: + logger.warning("Missing version_info; cannot build full checkpoint") + return None + + # static checkpoint_id for file-based restoration + checkpoint_id = version_info.checkpoint_id + run_name = version_info.run_name + + return StrixExecutionCheckpoint( + checkpoint_id=checkpoint_id, + run_name=run_name, + version_info=version_info, + tracer_info=tracer_info, + graph_info=graph_info, + ) \ No newline at end of file diff --git a/strix/checkpoint/manager.py b/strix/checkpoint/manager.py new file mode 100644 index 00000000..663810ab --- /dev/null +++ b/strix/checkpoint/manager.py @@ -0,0 +1,157 @@ +from typing import TYPE_CHECKING, Optional +import uuid +from importlib.metadata import version +from pathlib import Path +from datetime import datetime + +from strix.agents.state import AgentState +from strix.checkpoint.models import StrixExecutionCheckpoint, CheckpointVersionInfo, AgentCheckpointInfo +from strix.telemetry import get_global_tracer +from strix.checkpoint.store import CheckpointStore +from strix.checkpoint.file_store import CheckpointFileStore +from strix.checkpoint.sqlite_store import CheckpointSQLiteStore + +if TYPE_CHECKING: + from strix.telemetry.tracer import Tracer + + +_resumed_execution : StrixExecutionCheckpoint | None = None +_active_checkpoint_store : CheckpointStore | None = None +_active_checkpoint_path: Path | None = None + +def resume_checkpoint(checkpoint_file_path: str) -> StrixExecutionCheckpoint: + """ + Resume execution from a checkpoint file. + This should be called before any other resume function in this module. + """ + global _resumed_execution + _resumed_execution = CheckpointSQLiteStore(Path(checkpoint_file_path)).load() + if not _resumed_execution: + raise RuntimeError(f"Failed to load checkpoint from file: {checkpoint_file_path}") + + return _resumed_execution + + +def resume_tracer(tracer: Optional["Tracer"] = None) -> None: + if _resumed_execution is None: + raise RecursionError("No loaded checkpoint to resume tracer from") + + if tracer is None: + tracer = get_global_tracer() + + if tracer: + tracer.restore_state_from_checkpoint(_resumed_execution.tracer_info) + +def resume_root_agent_state_from_checkpoint() -> AgentState: + if _resumed_execution is None: + raise RuntimeError("No loaded checkpoint to resume main agent checkpoint from") + + root_agent_id = _resumed_execution.graph_info.root_agent_id + + if root_agent_id is None: + raise RuntimeError("No root agent id found in checkpoint") + + root_agent_state = _resumed_execution.graph_info.agent_states[root_agent_id].model_copy() + + root_agent_state.add_message("user", _make_resume_agent_prompt(_resumed_execution)) + + # Delete the agent_id reference from state allow the new instance generate its own unique id + root_agent_state_modified = AgentState.model_construct(**root_agent_state.model_dump(exclude={"agent_id"})) + + return root_agent_state_modified + + +def _make_resume_agent_prompt(checkpoint: StrixExecutionCheckpoint) -> str: + # Defensive: guard against missing data/None keys + root_agent_id = getattr(checkpoint.graph_info, "root_agent_id", None) + agent_states = getattr(checkpoint.graph_info, "agent_states", {}) + checkpoint_start_time = None + + if root_agent_id and root_agent_id in agent_states: + root_agent_state = agent_states[root_agent_id] + checkpoint_start_time = getattr(root_agent_state, "start_time", None) + + try: + dt_start = datetime.fromisoformat(checkpoint_start_time) if checkpoint_start_time else None + execution_age = "" + if dt_start: + delta = datetime.now(dt_start.tzinfo) - dt_start + hrs = delta.total_seconds() // 3600 + mins = (delta.total_seconds() % 3600) // 60 + if hrs >= 1: + execution_age = f"{int(hrs)} hour(s) and {int(mins)} minute(s) ago" + elif mins > 0: + execution_age = f"{int(mins)} minute(s) ago" + else: + execution_age = "just now" + else: + execution_age = "an unknown time ago" + except Exception: + execution_age = "an unknown time ago" + + + detailed_agent_info: list[str] = [] + if agent_states: + for aid, ast in agent_states.items(): + name = getattr(ast, "agent_name", "[unknown]") + iterations = getattr(ast, "iteration", 0) + task = getattr(ast, "task", "[no task]") + short_task = (task[:80] + "...") if task and len(task) > 80 else task + detailed_agent_info.append( + f" - id: {aid}\n" + f" name: {name}\n" + f" iterations: {iterations}\n" + f" task: {short_task}" + ) + agents_summary = ( + "\n- Agents detected in checkpoint:\n" + + ("\n".join(detailed_agent_info) if detailed_agent_info else " [none]") + ) + + resume_prompt = ( + f"This Strix agent has been resumed from a saved checkpoint." + f"\n\n" + f"Checkpoint details:\n" + f"- Original start time: {checkpoint_start_time or '[unknown]'}" + f" ({execution_age})" + f"{agents_summary}" + f"\n\n" + f"WARNING:\n" + f"- Only YOU (the main/root agent) are now running. Any subagents or parallel agents from the original execution (if any) are no longer running. " + f"They will NOT resume unless explicitly re-created.\n" + f"- DO NOT assume any previous network, environment, tools server, or system state are available. " + f"Connectivity, underlying tools, or even the machine itself may have changed since the original execution." + f"\n- You must proceed as if you have just started, but you can use prior messages and agent state restored here as context." + f"\n\n" + f"Continue only after thoroughly verifying or reacquiring any information you need — nothing from the prior state is guaranteed valid except the restored data." + ) + + return resume_prompt + + +def initialize_execution_recording(results_dir: Path, run_name: str) -> None: + global _active_checkpoint_store + global _active_checkpoint_path + version_info = CheckpointVersionInfo(strix_version=version("strix-agent"), run_name=run_name, checkpoint_id=str(uuid.uuid4())) + + checkpoint_path = results_dir / "strix_checkpoint.db" + _active_checkpoint_store = CheckpointSQLiteStore(checkpoint_path) + _active_checkpoint_store.store_version_info(version_info) + _active_checkpoint_path = checkpoint_path + +def get_active_checkpoint_path() -> Optional[Path]: + global _active_checkpoint_path + return _active_checkpoint_path + + +async def record_execution_checkpoint(agent_checkpoint_info: AgentCheckpointInfo) -> None: + + + tracer = get_global_tracer() + if tracer is None: + raise RuntimeError("No tracer found to record checkpoint") + + tracer_checkpoint_info = tracer.record_state_to_checkpoint() + + if _active_checkpoint_store: + _active_checkpoint_store.store_checkpoint(agent_checkpoint_info, tracer_checkpoint_info) diff --git a/strix/checkpoint/models.py b/strix/checkpoint/models.py new file mode 100644 index 00000000..8c3358b3 --- /dev/null +++ b/strix/checkpoint/models.py @@ -0,0 +1,38 @@ +from datetime import UTC, datetime +from strix.agents.state import AgentState +from typing import Any +from pydantic import BaseModel, Field + + +class CheckpointVersionInfo(BaseModel): + strix_version: str + run_name: str + checkpoint_id: str + created_at: datetime = Field(default_factory=lambda: datetime.now(UTC)) + + +class TracerCheckpointInfo(BaseModel): + agents: dict[str, dict[str, Any]] + tool_executions: dict[int, dict[str, Any]] + chat_messages: list[dict[str, Any]] + +class AgentCheckpointInfo(BaseModel): + agent_state: AgentState + prompt_modules: list[str] | None = None + pending_agent_messages: dict[str, list[dict[str, Any]]] = {} + is_root_agent: bool = False + created_at: datetime = Field(default_factory=lambda: datetime.now(UTC)) + + +class AgentGraphCheckpointInfo(BaseModel): + root_agent_id: str | None = None + pending_agent_messages: dict[str, list[dict[str, Any]]] = {} + agent_states: dict[str, AgentState] = {} + + +class StrixExecutionCheckpoint(BaseModel): + checkpoint_id: str + run_name: str | None = None + version_info: CheckpointVersionInfo | None = None + tracer_info: TracerCheckpointInfo + graph_info: AgentGraphCheckpointInfo \ No newline at end of file diff --git a/strix/checkpoint/sqlite_store.py b/strix/checkpoint/sqlite_store.py new file mode 100644 index 00000000..ad189aed --- /dev/null +++ b/strix/checkpoint/sqlite_store.py @@ -0,0 +1,361 @@ +import os +import sqlite3 +import threading +from pathlib import Path +from typing import Any, Optional, TypeVar +from datetime import UTC, datetime +from logging import getLogger +from pydantic import BaseModel + +from strix.agents.state import AgentState +from strix.checkpoint.models import ( + StrixExecutionCheckpoint, + CheckpointVersionInfo, + TracerCheckpointInfo, + AgentGraphCheckpointInfo, + AgentCheckpointInfo, +) +from strix.checkpoint.store import CheckpointStore + +logger = getLogger(__name__) + +T = TypeVar("T", bound=BaseModel) + + +class CheckpointSQLiteStore(CheckpointStore): + """SQLite-backed checkpoint store with configurable retention modes. + + Environment variable STRIX_CHECKPOINT_LEVEL controls behavior: + - "none": No checkpointing performed + - "latest": Only maintain the latest checkpoint (overwrites) + - "all": Keep all checkpoints with auto-incrementing version counter + + Schema: + - version_info: stores CheckpointVersionInfo + - tracer_info: stores TracerCheckpointInfo + - agent_checkpoints: stores AgentCheckpointInfo (one row per agent) + + All tables include a 'version' column for tracking checkpoint iterations. + """ + + def __init__(self, file_path: Path): + self.file_path = Path(file_path) + self.file_path.parent.mkdir(parents=True, exist_ok=True) + + self.checkpoint_level = os.getenv("STRIX_CHECKPOINT_LEVEL", "latest").lower() + if self.checkpoint_level not in ("none", "latest", "all"): + logger.warning( + f"Invalid STRIX_CHECKPOINT_LEVEL={self.checkpoint_level}, defaulting to 'latest'" + ) + self.checkpoint_level = "latest" + + self.conn = sqlite3.connect(str(self.file_path), check_same_thread=False) + self.conn.row_factory = sqlite3.Row + self._lock = threading.Lock() # Thread-safe access to database + self._optimize_for_writes() + self._init_schema() + + def _optimize_for_writes(self) -> None: + """Configure SQLite for optimal write performance while maintaining crash safety.""" + with self._lock: + # DELETE mode: single file database (no WAL/SHM files) + # Journal file is temporary and automatically deleted after transactions + self.conn.execute("PRAGMA journal_mode=DELETE") + # NORMAL synchronous: good balance of speed and safety + # Ensures data is written to disk before returning, preventing corruption + self.conn.execute("PRAGMA synchronous=NORMAL") + # Increase cache size for better performance (64MB) + self.conn.execute("PRAGMA cache_size=-65536") + # Disable foreign keys for faster inserts (not used anyway) + self.conn.execute("PRAGMA foreign_keys=OFF") + # Optimize for write-heavy workloads + self.conn.execute("PRAGMA temp_store=MEMORY") + # Standard page size + self.conn.execute("PRAGMA page_size=4096") + self.conn.commit() + + def _init_schema(self) -> None: + """Initialize database schema with three tables.""" + with self._lock: + with self.conn: + # Version info table + # Add 'latest_id' column for "latest" mode to enable INSERT OR REPLACE + self.conn.execute(""" + CREATE TABLE IF NOT EXISTS version_info ( + version INTEGER PRIMARY KEY AUTOINCREMENT, + latest_id INTEGER DEFAULT 1, + checkpoint_id TEXT NOT NULL, + run_name TEXT NOT NULL, + strix_version TEXT NOT NULL, + created_at TEXT NOT NULL, + data TEXT NOT NULL + ) + """) + + # Add UNIQUE constraint for atomic replace in "latest" mode + if self.checkpoint_level == "latest": + self.conn.execute(""" + CREATE UNIQUE INDEX IF NOT EXISTS idx_version_latest_id + ON version_info(latest_id) + """) + + # Tracer info table + # Add 'latest_id' column for "latest" mode to enable INSERT OR REPLACE + self.conn.execute(""" + CREATE TABLE IF NOT EXISTS tracer_info ( + version INTEGER PRIMARY KEY AUTOINCREMENT, + latest_id INTEGER DEFAULT 1, + created_at TEXT NOT NULL, + data TEXT NOT NULL + ) + """) + + # Add UNIQUE constraint for atomic replace in "latest" mode + if self.checkpoint_level == "latest": + self.conn.execute(""" + CREATE UNIQUE INDEX IF NOT EXISTS idx_tracer_latest_id + ON tracer_info(latest_id) + """) + + # Agent checkpoints table + self.conn.execute(""" + CREATE TABLE IF NOT EXISTS agent_checkpoints ( + version INTEGER PRIMARY KEY AUTOINCREMENT, + agent_id TEXT NOT NULL, + is_root_agent INTEGER NOT NULL, + created_at TEXT NOT NULL, + data TEXT NOT NULL + ) + """) + + # Index for faster lookups + self.conn.execute(""" + CREATE INDEX IF NOT EXISTS idx_agent_id + ON agent_checkpoints(agent_id, version DESC) + """) + + # Add UNIQUE constraint for atomic replace operations in "latest" mode + # This enables INSERT OR REPLACE to work atomically without DELETE + INSERT + if self.checkpoint_level == "latest": + self.conn.execute(""" + CREATE UNIQUE INDEX IF NOT EXISTS idx_agent_id_unique + ON agent_checkpoints(agent_id) + """) + + def store_version_info(self, version_info: CheckpointVersionInfo) -> None: + """Store version info according to checkpoint level.""" + if self.checkpoint_level == "none": + return + + data = version_info.model_dump_json() + created_at = version_info.created_at.isoformat() + + with self._lock: + with self.conn: + if self.checkpoint_level == "latest": + # Use INSERT OR REPLACE for atomic operation + self.conn.execute( + """INSERT OR REPLACE INTO version_info + (latest_id, checkpoint_id, run_name, strix_version, created_at, data) + VALUES (1, ?, ?, ?, ?, ?)""", + ( + version_info.checkpoint_id, + version_info.run_name, + version_info.strix_version, + created_at, + data, + ), + ) + else: + self.conn.execute( + """INSERT INTO version_info + (checkpoint_id, run_name, strix_version, created_at, data) + VALUES (?, ?, ?, ?, ?)""", + ( + version_info.checkpoint_id, + version_info.run_name, + version_info.strix_version, + created_at, + data, + ), + ) + self.conn.commit() + + def store_checkpoint( + self, + agent_checkpoint: AgentCheckpointInfo, + tracer_checkpoint: TracerCheckpointInfo + ) -> None: + """Store agent and tracer checkpoints according to checkpoint level. + + Thread-safe: This method can be called from multiple threads safely. + """ + if self.checkpoint_level == "none": + return + + # Serialize outside lock to reduce lock contention + # No indent for faster serialization and smaller storage + tracer_data = tracer_checkpoint.model_dump_json() + tracer_created_at = datetime.now(UTC).isoformat() + agent_data = agent_checkpoint.model_dump_json() + agent_created_at = agent_checkpoint.created_at.isoformat() + agent_id = agent_checkpoint.agent_state.agent_id + is_root = 1 if agent_checkpoint.is_root_agent else 0 + + with self._lock: + with self.conn: + # Store tracer checkpoint - use INSERT OR REPLACE in "latest" mode + if self.checkpoint_level == "latest": + self.conn.execute( + """INSERT OR REPLACE INTO tracer_info (latest_id, created_at, data) + VALUES (1, ?, ?)""", + (tracer_created_at, tracer_data), + ) + else: + self.conn.execute( + """INSERT INTO tracer_info (created_at, data) VALUES (?, ?)""", + (tracer_created_at, tracer_data), + ) + + # Store agent checkpoint - use INSERT OR REPLACE in "latest" mode + if self.checkpoint_level == "latest": + self.conn.execute( + """INSERT OR REPLACE INTO agent_checkpoints + (agent_id, is_root_agent, created_at, data) + VALUES (?, ?, ?, ?)""", + (agent_id, is_root, agent_created_at, agent_data), + ) + else: + self.conn.execute( + """INSERT INTO agent_checkpoints + (agent_id, is_root_agent, created_at, data) + VALUES (?, ?, ?, ?)""", + (agent_id, is_root, agent_created_at, agent_data), + ) + + self.conn.commit() + + def _load_latest_json(self, table: str, model_class: type[T]) -> Optional[T]: + """Load the latest JSON data from a table and parse it with the given model.""" + # Validate table name to prevent SQL injection + valid_tables = {"version_info", "tracer_info", "agent_checkpoints"} + if table not in valid_tables: + raise ValueError(f"Invalid table name: {table}") + + with self._lock: + cursor = self.conn.execute( + f"SELECT data FROM {table} ORDER BY version DESC LIMIT 1" + ) + row = cursor.fetchone() + + if not row: + logger.warning(f"No {table} found in checkpoint") + return None + + try: + return model_class.model_validate_json(row["data"]) + except Exception as e: + logger.warning(f"Invalid {table}: {e}") + return None + + def load(self) -> Optional[StrixExecutionCheckpoint]: + """Load the latest execution checkpoint from SQLite database. + + Returns: + StrixExecutionCheckpoint if all required data exists, None otherwise. + """ + if self.checkpoint_level == "none": + logger.info("Checkpoint level is 'none', no checkpoint to load") + return None + + # --- 1. LOAD VERSION INFO (latest) --- + version_info = self._load_latest_json("version_info", CheckpointVersionInfo) + if not version_info: + return None + + # --- 2. LOAD TRACER INFO (latest) --- + tracer_info = self._load_latest_json("tracer_info", TracerCheckpointInfo) + if not tracer_info: + return None + + # --- 3. LOAD ALL AGENTS (latest version per agent_id) --- + with self._lock: + if self.checkpoint_level == "latest": + # Simple: get all rows (there's only one per agent) + cursor = self.conn.execute( + "SELECT data, agent_id FROM agent_checkpoints" + ) + else: # "all" - get latest version for each agent_id + cursor = self.conn.execute(""" + SELECT data, agent_id + FROM agent_checkpoints + WHERE version IN ( + SELECT MAX(version) + FROM agent_checkpoints + GROUP BY agent_id + ) + """) + + rows = cursor.fetchall() + if not rows: + logger.warning("No agent checkpoints found") + return None + + decoded_agent_states: dict[str, AgentState] = {} + root_agent_id: str | None = None + latest_pending: dict[str, list[dict[str, Any]]] = {} + latest_created_at: datetime | None = None + + for row in rows: + try: + agent_cp = AgentCheckpointInfo.model_validate_json(row["data"]) + except Exception as e: + logger.warning(f"Invalid agent checkpoint for {row['agent_id']}: {e}") + continue + + agent_id = agent_cp.agent_state.agent_id + decoded_agent_states[agent_id] = agent_cp.agent_state + + if agent_cp.is_root_agent and root_agent_id is None: + root_agent_id = agent_id + + if agent_cp.pending_agent_messages: + created_at = agent_cp.created_at + if latest_created_at is None or created_at > latest_created_at: + latest_created_at = created_at + latest_pending = agent_cp.pending_agent_messages + + if not decoded_agent_states: + logger.warning("No valid agent states found") + return None + + if root_agent_id is None: + root_agent_id = next(iter(decoded_agent_states.keys())) + + # --- 4. BUILD GRAPH INFO --- + graph_info = AgentGraphCheckpointInfo( + root_agent_id=root_agent_id, + pending_agent_messages=latest_pending or {}, + agent_states=decoded_agent_states, + ) + + # --- 5. ASSEMBLE CHECKPOINT --- + return StrixExecutionCheckpoint( + checkpoint_id=version_info.checkpoint_id, + run_name=version_info.run_name, + version_info=version_info, + tracer_info=tracer_info, + graph_info=graph_info, + ) + + def close(self) -> None: + """Close the database connection.""" + with self._lock: + if self.conn: + self.conn.close() + + def __del__(self): + """Ensure connection is closed on deletion.""" + if hasattr(self, "conn"): + self.close() + \ No newline at end of file diff --git a/strix/checkpoint/store.py b/strix/checkpoint/store.py new file mode 100644 index 00000000..a636fc47 --- /dev/null +++ b/strix/checkpoint/store.py @@ -0,0 +1,59 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Optional + +from strix.checkpoint.models import ( + StrixExecutionCheckpoint, + AgentCheckpointInfo, + TracerCheckpointInfo, + CheckpointVersionInfo +) + + +class CheckpointStore(ABC): + """Abstract base class defining the interface for checkpoint storage implementations. + + This interface is implemented by concrete checkpoint stores like CheckpointFile + (filesystem-backed) and potentially other implementations (e.g., SQLite, database-backed). + + The interface provides three main operations: + 1. Store version information about the checkpoint + 2. Store checkpoint data for agents and tracer state + 3. Load a complete execution checkpoint + """ + + @abstractmethod + def store_version_info(self, version_info: CheckpointVersionInfo) -> None: + """Store version information to the checkpoint store. + + Args: + version_info: The version information to store, including strix version, + run name, and creation timestamp. + """ + pass + + @abstractmethod + def store_checkpoint(self, agent_checkpoint: AgentCheckpointInfo, tracer_checkpoint: TracerCheckpointInfo) -> None: + """Store a checkpoint for a single agent. This might be called multiple times + for a single execution, with the latest state being the one that is stored. + + Args: + agent_checkpoint: The agent checkpoint to store. + """ + pass + + @abstractmethod + def load(self) -> Optional[StrixExecutionCheckpoint]: + """Load a complete execution checkpoint from the store. + + Returns: + A StrixExecutionCheckpoint if sufficient data exists to construct one, + None otherwise. + + The implementation should attempt to reconstruct a complete checkpoint from + stored version info, tracer state, and agent states. + """ + pass + + diff --git a/strix/interface/cli.py b/strix/interface/cli.py index c9bc78ff..7b262fd3 100644 --- a/strix/interface/cli.py +++ b/strix/interface/cli.py @@ -10,6 +10,7 @@ from strix.agents.StrixAgent import StrixAgent from strix.llm.config import LLMConfig from strix.telemetry.tracer import Tracer, set_global_tracer +from strix.checkpoint.manager import resume_tracer, resume_root_agent_state_from_checkpoint from .utils import get_severity_color @@ -71,7 +72,7 @@ async def run_cli(args: Any) -> None: # noqa: PLR0915 } llm_config = LLMConfig() - agent_config = { + agent_config: dict[str, Any] = { "llm_config": llm_config, "max_iterations": 300, "non_interactive": True, @@ -81,6 +82,10 @@ async def run_cli(args: Any) -> None: # noqa: PLR0915 agent_config["local_sources"] = args.local_sources tracer = Tracer(args.run_name) + if args.resume: + resume_tracer(tracer) + agent_config["state"] = resume_root_agent_state_from_checkpoint() + tracer.set_scan_config(scan_config) def display_vulnerability(report_id: str, title: str, content: str, severity: str) -> None: diff --git a/strix/interface/main.py b/strix/interface/main.py index ef8e6f86..d2dc8576 100644 --- a/strix/interface/main.py +++ b/strix/interface/main.py @@ -34,6 +34,7 @@ ) from strix.runtime.docker_runtime import STRIX_IMAGE from strix.telemetry.tracer import get_global_tracer +from strix.checkpoint.manager import get_active_checkpoint_path, resume_checkpoint, initialize_execution_recording logging.getLogger().setLevel(logging.ERROR) @@ -307,6 +308,12 @@ def parse_arguments() -> argparse.Namespace: "Default is interactive mode with TUI." ), ) + parser.add_argument( + "--resume", + type=str, + metavar="CHECKPOINT_FILE", + help="Resume a previous scan by providing the path to strix checkpoint file", + ) args = parser.parse_args() @@ -383,6 +390,13 @@ def display_completion_message(args: argparse.Namespace, results_path: Path) -> results_text.append(str(results_path), style="bold yellow") panel_parts.extend(["\n\n", results_text]) + checkpoint_path = get_active_checkpoint_path() + if checkpoint_path: + checkpoint_text = Text() + checkpoint_text.append("📀 Execution Checkpoint: ", style="bold yellow") + checkpoint_text.append(str(checkpoint_path), style="bold white") + panel_parts.extend(["\n", checkpoint_text]) + panel_content = Text.assemble(*panel_parts) border_style = "green" if scan_completed else "yellow" @@ -470,12 +484,18 @@ def main() -> None: args.local_sources = collect_local_sources(args.targets_info) + if args.resume: + resume_checkpoint(args.resume) + + results_path = Path("agent_runs") / args.run_name + initialize_execution_recording(results_dir=results_path, run_name=args.run_name) + if args.non_interactive: asyncio.run(run_cli(args)) else: asyncio.run(run_tui(args)) - results_path = Path("agent_runs") / args.run_name + display_completion_message(args, results_path) if args.non_interactive: diff --git a/strix/interface/tui.py b/strix/interface/tui.py index ff0a255c..12077f18 100644 --- a/strix/interface/tui.py +++ b/strix/interface/tui.py @@ -33,6 +33,7 @@ from strix.agents.StrixAgent import StrixAgent from strix.llm.config import LLMConfig from strix.telemetry.tracer import Tracer, set_global_tracer +from strix.checkpoint.manager import resume_tracer, resume_root_agent_state_from_checkpoint def escape_markup(text: str) -> str: @@ -280,6 +281,9 @@ def __init__(self, args: argparse.Namespace): self.agent_config = self._build_agent_config(args) self.tracer = Tracer(self.scan_config["run_name"]) + if args.resume: + resume_tracer(self.tracer) + self.tracer.set_scan_config(self.scan_config) set_global_tracer(self.tracer) @@ -320,13 +324,16 @@ def _build_scan_config(self, args: argparse.Namespace) -> dict[str, Any]: def _build_agent_config(self, args: argparse.Namespace) -> dict[str, Any]: llm_config = LLMConfig() - config = { + config: dict[str, Any] = { "llm_config": llm_config, "max_iterations": 300, } if getattr(args, "local_sources", None): config["local_sources"] = args.local_sources + + if args.resume: + config["state"] = resume_root_agent_state_from_checkpoint() return config diff --git a/strix/telemetry/tracer.py b/strix/telemetry/tracer.py index 15a4b423..32e66cd9 100644 --- a/strix/telemetry/tracer.py +++ b/strix/telemetry/tracer.py @@ -3,6 +3,7 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, Optional from uuid import uuid4 +from strix.checkpoint.models import TracerCheckpointInfo if TYPE_CHECKING: @@ -321,3 +322,19 @@ def get_total_llm_stats(self) -> dict[str, Any]: def cleanup(self) -> None: self.save_run_data() + + def record_state_to_checkpoint(self) -> TracerCheckpointInfo: + return TracerCheckpointInfo( + agents=self.agents, + tool_executions=self.tool_executions, + chat_messages=self.chat_messages, + ) + + def restore_state_from_checkpoint(self, tracer_checkpoint_info: TracerCheckpointInfo) -> None: + for agent_id, agent_data in tracer_checkpoint_info.agents.items(): + agent_data["name"] = f"{agent_data['name']} [Restored]" + self.agents[agent_id] = agent_data + + self.tool_executions = tracer_checkpoint_info.tool_executions + self.chat_messages = tracer_checkpoint_info.chat_messages + diff --git a/strix/tools/agents_graph/agents_graph_actions.py b/strix/tools/agents_graph/agents_graph_actions.py index 2e384c01..a8ca9cc2 100644 --- a/strix/tools/agents_graph/agents_graph_actions.py +++ b/strix/tools/agents_graph/agents_graph_actions.py @@ -3,6 +3,7 @@ from typing import Any, Literal from strix.tools.registry import register_tool +from strix.checkpoint.models import StrixExecutionCheckpoint _agent_graph: dict[str, Any] = { @@ -619,3 +620,8 @@ def wait_for_message( "Waiting timeout reached", ], } + + +def resume_execution_from_checkpoint(checkpoint: StrixExecutionCheckpoint) -> dict[str, Any]: + # HERE we should initialize and start all strix subagents + pass