Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions strix/agents/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,7 @@ async def _initialize_sandbox_and_state(self, task: str) -> None:
self.state.add_message("user", task)

async def _process_iteration(self, tracer: Optional["Tracer"]) -> bool:
await self._record_agent_checkpoint()
response = await self.llm.generate(self.state.get_conversation_history())

content_stripped = (response.content or "").strip()
Expand Down Expand Up @@ -513,3 +514,18 @@ def _check_agent_messages(self, state: AgentState) -> None: # noqa: PLR0912
logger = logging.getLogger(__name__)
logger.warning(f"Error checking agent messages: {e}")
return

async def _record_agent_checkpoint(self):
from strix.checkpoint.models import AgentCheckpointInfo
from strix.checkpoint.manager import record_execution_checkpoint
from strix.tools.agents_graph.agents_graph_actions import _agent_messages, _root_agent_id


checkpoint_info = AgentCheckpointInfo(
agent_state=self.state,
prompt_modules=self.llm_config.prompt_modules,
pending_agent_messages=_agent_messages,
is_root_agent=(self.state.agent_id == _root_agent_id)
)

await record_execution_checkpoint(checkpoint_info)
Empty file added strix/checkpoint/__init__.py
Empty file.
212 changes: 212 additions & 0 deletions strix/checkpoint/file_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
import json
import os
import tempfile
from pathlib import Path
from typing import Any, Optional
from datetime import datetime

from logging import getLogger

from strix.agents.state import AgentState
from strix.checkpoint.models import (
StrixExecutionCheckpoint,
CheckpointVersionInfo,
TracerCheckpointInfo,
AgentGraphCheckpointInfo,
AgentCheckpointInfo,
)
from strix.checkpoint.store import CheckpointStore

logger = getLogger(__name__)


ISO = "%Y-%m-%dT%H:%M:%S.%fZ"


class CheckpointFileStore(CheckpointStore):
"""Simple filesystem-backed checkpoint store.

Layout under `file_path` (a directory):
- version_info.json
- agents/agent-{agent_id}.json # one file per agent
- tracer/tracer-{uuid}.json # tracer checkpoints (we keep only last by file mtime)
- execution_checkpoint.json # optional consolidated StrixExecutionCheckpoint

Guarantees:
- Writes are atomic (temp file + os.replace).
- Storing a checkpoint for an agent will overwrite that agent's file (keeps last).

Notes / assumptions:
- The caller provides `agent_state_opaque` and `tracer_state_opaque` as JSON-serializable strings
(already serialized representations). This class will store them as-is under the per-agent
/ tracer files. When loading, it will attempt to parse them back into the Pydantic models
expected by StrixExecutionCheckpoint. If parsing fails, the raw strings are left as-is where
parsing is not possible.
- load() will prefer a consolidated execution_checkpoint.json if present. Otherwise it will
assemble an execution checkpoint from version_info.json + the newest tracer file + all
agent files.
"""

def __init__(self, file_path: Path):
self.file_path = Path(file_path)
self.agents_dir = self.file_path / "agents"
self.tracer_dir = self.file_path / "tracer"
self.file_path.mkdir(parents=True, exist_ok=True)
self.agents_dir.mkdir(parents=True, exist_ok=True)
self.tracer_dir.mkdir(parents=True, exist_ok=True)

# -- small helpers
def _atomic_write(self, path: Path, data: bytes) -> None:
dirpath = path.parent
dirpath.mkdir(parents=True, exist_ok=True)
fd, tmp_path = tempfile.mkstemp(dir=dirpath)
try:
with os.fdopen(fd, "wb") as f:
f.write(data)
os.replace(tmp_path, str(path))
finally:
# if something failed and tmp still exists, try to clean
try:
if os.path.exists(tmp_path):
os.remove(tmp_path)
except Exception:
pass

def _read_json(self, path: Path) -> Optional[dict[str, Any]]:
try:
with path.open("r", encoding="utf-8") as f:
return json.load(f)
except FileNotFoundError:
return None

# -- public API
def store_version_info(self, version_info: CheckpointVersionInfo) -> None:
"""Store version_info to version_info.json (atomic)."""
payload = version_info.model_dump_json(indent=4)
self._atomic_write( self.file_path / "version_info.json", payload.encode("utf-8"))

def store_checkpoint(self, agent_checkpoint: AgentCheckpointInfo, tracer_checkpoint: TracerCheckpointInfo) -> None:
"""Store a checkpoint for a single agent.

Behavior:
- Writes agent file at agents/agent-{agent_id}.json (overwrites previous file atomically).
- If tracer_state_opaque is provided, writes a new tracer file tracer/tracer-{uuid}.json
(we keep them as separate timestamped files; load() will pick the newest one).
"""
agent_path = self.agents_dir / f"state-{agent_checkpoint.agent_state.agent_id}.json"
agent_checkpoint_payload = agent_checkpoint.model_dump_json(indent=4)
self._atomic_write(agent_path, agent_checkpoint_payload.encode("utf-8"))

tracer_path = self.tracer_dir / "tracer-latest.json"
tracer_checkpoint_payload = tracer_checkpoint.model_dump_json(indent=4)
self._atomic_write(tracer_path, tracer_checkpoint_payload.encode("utf-8"))


def load(self) -> Optional[StrixExecutionCheckpoint]:
"""
Load a StrixExecutionCheckpoint with the following logic:

1. Try to load version_info.json
2. Try to load tracer-latest.json
3. Load all state-{agent_id}.json and deduce:
- root_agent_id
- latest pending_agent_messages
- agent_states (AgentState objects)
4. Build AgentGraphCheckpointInfo
5. Assemble and return StrixExecutionCheckpoint
"""

# --- 1. VERSION INFO ---
version_info_raw = self._read_json(self.file_path / "version_info.json")
version_info = None
if version_info_raw:
try:
version_info = CheckpointVersionInfo.model_validate(version_info_raw)
except Exception as e:
logger.warning(f"Invalid version_info.json: {e}")

# --- 2. TRACER INFO ---
tracer_info_raw = self._read_json(self.tracer_dir / "tracer-latest.json")
tracer_info = None
if tracer_info_raw:
try:
tracer_info = TracerCheckpointInfo.model_validate(tracer_info_raw)
except Exception as e:
logger.warning(f"Invalid tracer-latest.json: {e}")

# --- 3. LOAD ALL AGENTS ---
agent_files = sorted(self.agents_dir.glob("state-*.json"))

if not agent_files:
logger.warning("No agents found in checkpoint")
return None

decoded_agent_states: dict[str, AgentState] = {}
root_agent_id: str | None = None

latest_pending: dict[str, list[dict[str, Any]]] = {}
latest_created_at: datetime | None = None

for p in agent_files:
agent_id = p.stem.replace("state-", "")
raw = self._read_json(p)
if not raw:
continue

try:
# raw is expected to follow AgentCheckpointInfo structure
agent_cp = AgentCheckpointInfo.model_validate(raw)
except Exception as e:
logger.warning(f"Invalid agent checkpoint for {agent_id}: {e}")
continue

# 1. Capture agent_state
decoded_agent_states[agent_id] = agent_cp.agent_state

# 2. Root agent?
if agent_cp.is_root_agent and root_agent_id is None:
root_agent_id = agent_id

# 3. Pending messages: choose by latest created_at
if agent_cp.pending_agent_messages:
created_at = agent_cp.created_at
if latest_created_at is None or created_at > latest_created_at:
latest_created_at = created_at
latest_pending = agent_cp.pending_agent_messages

# No valid agent states recovered? Bail out.
if not decoded_agent_states:
logger.warning("No valid agent states found in checkpoint")
return None

# If root agent was never detected, fallback to the first one
if root_agent_id is None:
root_agent_id = next(iter(decoded_agent_states.keys()))

# --- 4. BUILD GRAPH INFO ---
graph_info = AgentGraphCheckpointInfo(
root_agent_id=root_agent_id,
pending_agent_messages=latest_pending or {},
agent_states=decoded_agent_states,
)

# --- 5. FINAL CHECKS ---
if not tracer_info:
logger.warning("Missing tracer_info; cannot build full checkpoint")
return None

if not version_info:
logger.warning("Missing version_info; cannot build full checkpoint")
return None

# static checkpoint_id for file-based restoration
checkpoint_id = version_info.checkpoint_id
run_name = version_info.run_name

return StrixExecutionCheckpoint(
checkpoint_id=checkpoint_id,
run_name=run_name,
version_info=version_info,
tracer_info=tracer_info,
graph_info=graph_info,
)
157 changes: 157 additions & 0 deletions strix/checkpoint/manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
from typing import TYPE_CHECKING, Optional
import uuid
from importlib.metadata import version
from pathlib import Path
from datetime import datetime

from strix.agents.state import AgentState
from strix.checkpoint.models import StrixExecutionCheckpoint, CheckpointVersionInfo, AgentCheckpointInfo
from strix.telemetry import get_global_tracer
from strix.checkpoint.store import CheckpointStore
from strix.checkpoint.file_store import CheckpointFileStore
from strix.checkpoint.sqlite_store import CheckpointSQLiteStore

if TYPE_CHECKING:
from strix.telemetry.tracer import Tracer


_resumed_execution : StrixExecutionCheckpoint | None = None
_active_checkpoint_store : CheckpointStore | None = None
_active_checkpoint_path: Path | None = None

def resume_checkpoint(checkpoint_file_path: str) -> StrixExecutionCheckpoint:
"""
Resume execution from a checkpoint file.
This should be called before any other resume function in this module.
"""
global _resumed_execution
_resumed_execution = CheckpointSQLiteStore(Path(checkpoint_file_path)).load()
if not _resumed_execution:
raise RuntimeError(f"Failed to load checkpoint from file: {checkpoint_file_path}")

return _resumed_execution


def resume_tracer(tracer: Optional["Tracer"] = None) -> None:
if _resumed_execution is None:
raise RecursionError("No loaded checkpoint to resume tracer from")

if tracer is None:
tracer = get_global_tracer()

if tracer:
tracer.restore_state_from_checkpoint(_resumed_execution.tracer_info)

def resume_root_agent_state_from_checkpoint() -> AgentState:
if _resumed_execution is None:
raise RuntimeError("No loaded checkpoint to resume main agent checkpoint from")

root_agent_id = _resumed_execution.graph_info.root_agent_id

if root_agent_id is None:
raise RuntimeError("No root agent id found in checkpoint")

root_agent_state = _resumed_execution.graph_info.agent_states[root_agent_id].model_copy()

root_agent_state.add_message("user", _make_resume_agent_prompt(_resumed_execution))

# Delete the agent_id reference from state allow the new instance generate its own unique id
root_agent_state_modified = AgentState.model_construct(**root_agent_state.model_dump(exclude={"agent_id"}))

return root_agent_state_modified


def _make_resume_agent_prompt(checkpoint: StrixExecutionCheckpoint) -> str:
# Defensive: guard against missing data/None keys
root_agent_id = getattr(checkpoint.graph_info, "root_agent_id", None)
agent_states = getattr(checkpoint.graph_info, "agent_states", {})
checkpoint_start_time = None

if root_agent_id and root_agent_id in agent_states:
root_agent_state = agent_states[root_agent_id]
checkpoint_start_time = getattr(root_agent_state, "start_time", None)

try:
dt_start = datetime.fromisoformat(checkpoint_start_time) if checkpoint_start_time else None
execution_age = ""
if dt_start:
delta = datetime.now(dt_start.tzinfo) - dt_start
hrs = delta.total_seconds() // 3600
mins = (delta.total_seconds() % 3600) // 60
if hrs >= 1:
execution_age = f"{int(hrs)} hour(s) and {int(mins)} minute(s) ago"
elif mins > 0:
execution_age = f"{int(mins)} minute(s) ago"
else:
execution_age = "just now"
else:
execution_age = "an unknown time ago"
except Exception:
execution_age = "an unknown time ago"


detailed_agent_info: list[str] = []
if agent_states:
for aid, ast in agent_states.items():
name = getattr(ast, "agent_name", "[unknown]")
iterations = getattr(ast, "iteration", 0)
task = getattr(ast, "task", "[no task]")
short_task = (task[:80] + "...") if task and len(task) > 80 else task
detailed_agent_info.append(
f" - id: {aid}\n"
f" name: {name}\n"
f" iterations: {iterations}\n"
f" task: {short_task}"
)
agents_summary = (
"\n- Agents detected in checkpoint:\n" +
("\n".join(detailed_agent_info) if detailed_agent_info else " [none]")
)

resume_prompt = (
f"This Strix agent has been resumed from a saved checkpoint."
f"\n\n"
f"Checkpoint details:\n"
f"- Original start time: {checkpoint_start_time or '[unknown]'}"
f" ({execution_age})"
f"{agents_summary}"
f"\n\n"
f"WARNING:\n"
f"- Only YOU (the main/root agent) are now running. Any subagents or parallel agents from the original execution (if any) are no longer running. "
f"They will NOT resume unless explicitly re-created.\n"
f"- DO NOT assume any previous network, environment, tools server, or system state are available. "
f"Connectivity, underlying tools, or even the machine itself may have changed since the original execution."
f"\n- You must proceed as if you have just started, but you can use prior messages and agent state restored here as context."
f"\n\n"
f"Continue only after thoroughly verifying or reacquiring any information you need — nothing from the prior state is guaranteed valid except the restored data."
)

return resume_prompt


def initialize_execution_recording(results_dir: Path, run_name: str) -> None:
global _active_checkpoint_store
global _active_checkpoint_path
version_info = CheckpointVersionInfo(strix_version=version("strix-agent"), run_name=run_name, checkpoint_id=str(uuid.uuid4()))

checkpoint_path = results_dir / "strix_checkpoint.db"
_active_checkpoint_store = CheckpointSQLiteStore(checkpoint_path)
_active_checkpoint_store.store_version_info(version_info)
_active_checkpoint_path = checkpoint_path

def get_active_checkpoint_path() -> Optional[Path]:
global _active_checkpoint_path
return _active_checkpoint_path


async def record_execution_checkpoint(agent_checkpoint_info: AgentCheckpointInfo) -> None:


tracer = get_global_tracer()
if tracer is None:
raise RuntimeError("No tracer found to record checkpoint")

tracer_checkpoint_info = tracer.record_state_to_checkpoint()

if _active_checkpoint_store:
_active_checkpoint_store.store_checkpoint(agent_checkpoint_info, tracer_checkpoint_info)
Loading