diff --git a/PLAN.md b/PLAN.md new file mode 100644 index 0000000..b1385c1 --- /dev/null +++ b/PLAN.md @@ -0,0 +1,40 @@ +# SystemReminders Capability + +Closes #83 + +## Problem + +Long-running agents suffer from *instruction fade-out* -- the phenomenon where agents progressively ignore system prompt guidelines after many turns of tool use. A single system prompt at the start of a session is insufficient for maintaining behavioral consistency across extended interactions. + +## Solution + +A `SystemReminders` capability that injects periodic `SystemPromptPart` entries into model conversations via the `before_model_request` hook. This is a focused first implementation that provides the core mechanism for periodic reminders, which more advanced features (trigger-based reminders, cooldowns, priorities) can be layered on top of. + +## Design + +### Two kinds of reminders + +- **Static** (`Reminder`): a fixed message string injected every N model requests (configurable `interval`). +- **Dynamic** (callable): a sync or async function receiving `RunContext` and returning `str | None`. Called on every model request; returns `None` to skip injection. + +### Injection mechanism + +Reminder parts are appended as `SystemPromptPart` entries to the last `ModelRequest` in the message history. This places them close to the model's attention window without creating separate messages. + +### Per-run isolation + +`for_run()` returns a fresh instance with a reset request counter, ensuring concurrent runs on the same agent don't interfere with each other. + +### Not spec-serializable + +`get_serialization_name()` returns `None` because dynamic reminders take callables which cannot be serialized. + +## Files + +- `src/pydantic_harness/system_reminders.py` -- `Reminder`, `DynamicReminder`, `AsyncDynamicReminder`, `SystemReminders` +- `src/pydantic_harness/__init__.py` -- public exports +- `tests/test_system_reminders.py` -- 27 tests covering all code paths + +## Future work + +The issue (#83) describes a richer system with trigger-based reminders (loop detection, token budget warnings, post-compaction re-injection), cooldowns, fire limits, priority ordering, and template substitution. This implementation provides the foundational interval-based and dynamic-callable mechanisms that those features can build on. diff --git a/src/pydantic_harness/__init__.py b/src/pydantic_harness/__init__.py index 9d728b6..391ed9a 100644 --- a/src/pydantic_harness/__init__.py +++ b/src/pydantic_harness/__init__.py @@ -7,4 +7,11 @@ # Each capability module is imported and re-exported here. # Capabilities are listed alphabetically. -__all__: list[str] = [] +from pydantic_harness.system_reminders import AsyncDynamicReminder, DynamicReminder, Reminder, SystemReminders + +__all__: list[str] = [ + 'AsyncDynamicReminder', + 'DynamicReminder', + 'Reminder', + 'SystemReminders', +] diff --git a/src/pydantic_harness/system_reminders.py b/src/pydantic_harness/system_reminders.py new file mode 100644 index 0000000..c26bd6f --- /dev/null +++ b/src/pydantic_harness/system_reminders.py @@ -0,0 +1,204 @@ +"""System reminders capability for periodic behavioral steering. + +Provides the [`SystemReminders`][pydantic_harness.SystemReminders] capability, +which injects periodic system messages into model conversations to counteract +instruction fade-out in long-running agent sessions. + +Example usage:: + + from pydantic_ai import Agent + from pydantic_harness import SystemReminders, Reminder + + reminders = SystemReminders( + reminders=[ + Reminder('Remember to use the provided tools.', interval=3), + Reminder('Always verify your work before responding.', interval=5), + ], + ) + agent = Agent('openai:gpt-5', capabilities=[reminders]) +""" + +from __future__ import annotations + +from collections.abc import Awaitable, Callable +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any + +from pydantic_ai.capabilities.abstract import AbstractCapability +from pydantic_ai.messages import ModelRequest, SystemPromptPart +from pydantic_ai.tools import AgentDepsT, RunContext + +if TYPE_CHECKING: + from pydantic_ai.models import ModelRequestContext + + +@dataclass +class Reminder: + """A static reminder to inject periodically during an agent run. + + Args: + content: The reminder text to inject as a system prompt part. + interval: Inject this reminder every N model requests. For example, + ``interval=3`` means the reminder fires on the 3rd, 6th, 9th, etc. + model request within a single run. + trigger: An optional predicate receiving the current + [`RunContext`][pydantic_ai.tools.RunContext]. When provided, the + reminder only fires when the trigger returns ``True`` *and* the + interval condition is met. + max_fires: Maximum number of times this reminder may fire within a + single run. ``None`` means no limit. + tag: When set, wrap the content in XML tags: ``content``. + For example, ``tag='system-reminder'`` produces + ``content``. + """ + + content: str + interval: int = 1 + trigger: Callable[[RunContext[Any]], bool] | None = None + max_fires: int | None = None + tag: str | None = None + + def __post_init__(self) -> None: # noqa: D105 + if self.interval < 1: + raise ValueError(f'interval must be >= 1, got {self.interval}') + if self.max_fires is not None and self.max_fires < 1: + raise ValueError(f'max_fires must be >= 1, got {self.max_fires}') + + def render_content(self) -> str: + """Return the content, wrapped in XML tags if ``tag`` is set.""" + if self.tag is not None: + return f'<{self.tag}>{self.content}' + return self.content + + +DynamicReminder = Callable[[RunContext[Any]], str | None] +"""A callable that returns reminder text (or None to skip) based on the current run context. + +Dynamic reminders are called on every model request, giving full control +over when and what to inject. +""" + +AsyncDynamicReminder = Callable[[RunContext[Any]], Awaitable[str | None]] +"""An async callable variant of [`DynamicReminder`][pydantic_harness.system_reminders.DynamicReminder].""" + + +@dataclass +class SystemReminders(AbstractCapability[AgentDepsT]): + r"""Capability that injects periodic system reminders into model conversations. + + System reminders counteract *instruction fade-out* -- the phenomenon where + agents progressively ignore system prompt guidelines after many turns of + tool use. Reminders are injected as [`SystemPromptPart`][pydantic_ai.messages.SystemPromptPart] + entries appended to the last [`ModelRequest`][pydantic_ai.messages.ModelRequest] + in the message history before each model call. + + Supports two kinds of reminders: + + - **Static** ([`Reminder`][pydantic_harness.Reminder]): a fixed message + injected every N model requests within a run. + - **Dynamic** (callable): a function receiving + [`RunContext`][pydantic_ai.tools.RunContext] and returning a string to inject + (or ``None`` to skip). Called on every model request. + + Per-run state (the model request counter) is isolated via + [`for_run`][pydantic_ai.capabilities.AbstractCapability.for_run], so + concurrent runs on the same agent don't interfere with each other. + + Example:: + + reminders = SystemReminders( + reminders=[ + Reminder('Stay focused on the user\'s original request.', interval=5), + ], + dynamic_reminders=[ + lambda ctx: 'Wrap up soon.' if ctx.run_step > 20 else None, + ], + ) + """ + + reminders: list[Reminder] = field(default_factory=list[Reminder]) + """Static reminders to inject at fixed intervals.""" + + dynamic_reminders: list[DynamicReminder | AsyncDynamicReminder] = field( + default_factory=list[DynamicReminder | AsyncDynamicReminder] + ) + """Dynamic reminders evaluated on every model request.""" + + _request_count: int = field(default=0, init=False, repr=False) + _fire_counts: list[int] = field(default_factory=list[int], init=False, repr=False) + + def __post_init__(self) -> None: # noqa: D105 + if not self.reminders and not self.dynamic_reminders: + raise ValueError('At least one static or dynamic reminder must be provided.') + self._fire_counts = [0] * len(self.reminders) + + async def for_run(self, ctx: RunContext[AgentDepsT]) -> SystemReminders[AgentDepsT]: + """Return a fresh instance with a reset request counter for per-run isolation.""" + return SystemReminders( + reminders=self.reminders, + dynamic_reminders=self.dynamic_reminders, + ) + + async def before_model_request( + self, + ctx: RunContext[AgentDepsT], + request_context: ModelRequestContext, + ) -> ModelRequestContext: + """Inject applicable reminders into the message history before the model call.""" + self._request_count += 1 + + parts_to_inject: list[SystemPromptPart] = [] + + # Evaluate static reminders based on interval, trigger, and max_fires. + for idx, reminder in enumerate(self.reminders): + if self._request_count % reminder.interval != 0: + continue + if reminder.trigger is not None and not reminder.trigger(ctx): + continue + if reminder.max_fires is not None and self._fire_counts[idx] >= reminder.max_fires: + continue + self._fire_counts[idx] += 1 + parts_to_inject.append(SystemPromptPart(content=reminder.render_content())) + + # Evaluate dynamic reminders. + for dynamic in self.dynamic_reminders: + result = dynamic(ctx) + if isinstance(result, Awaitable): + result = await result + if result is not None: + parts_to_inject.append(SystemPromptPart(content=result)) + + if parts_to_inject: + _inject_into_last_request(request_context.messages, parts_to_inject) + + return request_context + + @classmethod + def get_serialization_name(cls) -> str | None: # noqa: D102 + return None # Not spec-serializable (dynamic reminders take callables) + + +def _inject_into_last_request( + messages: list[Any], + parts: list[SystemPromptPart], +) -> None: + """Append system prompt parts to the last ModelRequest in the message list. + + If no ModelRequest exists yet, prepend one containing just the reminder parts. + """ + for i in range(len(messages) - 1, -1, -1): + msg = messages[i] + if isinstance(msg, ModelRequest): + # ModelRequest.parts is a Sequence; we need to produce a new list + # with the reminder parts appended. + messages[i] = ModelRequest( + parts=[*msg.parts, *parts], + timestamp=msg.timestamp, + instructions=msg.instructions, + kind=msg.kind, + run_id=msg.run_id, + metadata=msg.metadata, + ) + return + # No existing request -- create one with just the reminder parts. + messages.append(ModelRequest(parts=parts)) diff --git a/tests/test_system_reminders.py b/tests/test_system_reminders.py new file mode 100644 index 0000000..087f65a --- /dev/null +++ b/tests/test_system_reminders.py @@ -0,0 +1,679 @@ +"""Tests for the SystemReminders capability.""" + +from __future__ import annotations + +from typing import Any +from unittest.mock import MagicMock + +import pytest +from pydantic_ai.messages import ModelRequest, ModelResponse, SystemPromptPart, TextPart, UserPromptPart +from pydantic_ai.models import ModelRequestContext, ModelRequestParameters + +from pydantic_harness import Reminder, SystemReminders +from pydantic_harness.system_reminders import AsyncDynamicReminder, DynamicReminder + + +def _make_run_context(*, run_step: int = 1) -> Any: + """Create a minimal RunContext-like object for testing.""" + ctx = MagicMock() + ctx.run_step = run_step + return ctx + + +def _make_request_context( + messages: list[Any] | None = None, +) -> ModelRequestContext: + """Create a ModelRequestContext with the given messages.""" + if messages is None: + messages = [ModelRequest.user_text_prompt('hello')] + return ModelRequestContext( + model=MagicMock(), + messages=messages, + model_settings=None, + model_request_parameters=ModelRequestParameters(), + ) + + +def _dynamic_content(ctx: Any) -> str | None: + return 'dynamic content' + + +def _returns_none(ctx: Any) -> str | None: + return None + + +def _returns_dynamic(ctx: Any) -> str | None: + return 'dynamic' + + +# --- Reminder validation --- + + +class TestReminderValidation: + def test_valid_reminder(self) -> None: + r = Reminder('test', interval=3) + assert r.content == 'test' + assert r.interval == 3 + + def test_default_interval(self) -> None: + r = Reminder('test') + assert r.interval == 1 + + def test_zero_interval_raises(self) -> None: + with pytest.raises(ValueError, match='interval must be >= 1'): + Reminder('test', interval=0) + + def test_negative_interval_raises(self) -> None: + with pytest.raises(ValueError, match='interval must be >= 1'): + Reminder('test', interval=-1) + + def test_zero_max_fires_raises(self) -> None: + with pytest.raises(ValueError, match='max_fires must be >= 1'): + Reminder('test', max_fires=0) + + def test_negative_max_fires_raises(self) -> None: + with pytest.raises(ValueError, match='max_fires must be >= 1'): + Reminder('test', max_fires=-2) + + +# --- SystemReminders validation --- + + +class TestSystemRemindersValidation: + def test_requires_at_least_one_reminder(self) -> None: + with pytest.raises(ValueError, match='At least one'): + SystemReminders() + + def test_static_reminders_only(self) -> None: + sr = SystemReminders(reminders=[Reminder('test')]) + assert len(sr.reminders) == 1 + assert len(sr.dynamic_reminders) == 0 + + def test_dynamic_reminders_only(self) -> None: + sr = SystemReminders(dynamic_reminders=[_returns_dynamic]) + assert len(sr.reminders) == 0 + assert len(sr.dynamic_reminders) == 1 + + def test_both_kinds(self) -> None: + sr = SystemReminders( + reminders=[Reminder('static')], + dynamic_reminders=[_returns_dynamic], + ) + assert len(sr.reminders) == 1 + assert len(sr.dynamic_reminders) == 1 + + +# --- for_run isolation --- + + +class TestForRun: + @pytest.mark.anyio + async def test_for_run_returns_fresh_instance(self) -> None: + sr = SystemReminders(reminders=[Reminder('test')]) + ctx = _make_run_context() + fresh = await sr.for_run(ctx) + assert fresh is not sr + + @pytest.mark.anyio + async def test_for_run_preserves_config(self) -> None: + reminders = [Reminder('a', interval=2), Reminder('b', interval=5)] + dynamic: list[DynamicReminder | AsyncDynamicReminder] = [_returns_dynamic] + sr = SystemReminders(reminders=reminders, dynamic_reminders=dynamic) + ctx = _make_run_context() + fresh = await sr.for_run(ctx) + assert fresh.reminders is reminders + assert fresh.dynamic_reminders is dynamic + + @pytest.mark.anyio + async def test_for_run_resets_counter(self) -> None: + sr = SystemReminders(reminders=[Reminder('test')]) + # Simulate some requests to increment counter. + sr._request_count = 5 # pyright: ignore[reportPrivateUsage] + ctx = _make_run_context() + fresh = await sr.for_run(ctx) + assert fresh._request_count == 0 # pyright: ignore[reportPrivateUsage] + + @pytest.mark.anyio + async def test_for_run_resets_fire_counts(self) -> None: + sr = SystemReminders(reminders=[Reminder('test', max_fires=5)]) + sr._fire_counts = [3] # pyright: ignore[reportPrivateUsage] + ctx = _make_run_context() + fresh = await sr.for_run(ctx) + assert fresh._fire_counts == [0] # pyright: ignore[reportPrivateUsage] + + +# --- Static reminder injection --- + + +class TestStaticReminders: + @pytest.mark.anyio + async def test_interval_1_fires_every_request(self) -> None: + sr = SystemReminders(reminders=[Reminder('always')]) + ctx = _make_run_context() + + for _ in range(3): + req_ctx = _make_request_context() + await sr.before_model_request(ctx, req_ctx) + + last_msg = req_ctx.messages[-1] + assert isinstance(last_msg, ModelRequest) + system_parts = [p for p in last_msg.parts if isinstance(p, SystemPromptPart)] + assert len(system_parts) == 1 + assert system_parts[0].content == 'always' + + @pytest.mark.anyio + async def test_interval_3_fires_on_3rd_request(self) -> None: + sr = SystemReminders(reminders=[Reminder('every third', interval=3)]) + ctx = _make_run_context() + + # Requests 1 and 2: no injection. + for _ in range(2): + req_ctx = _make_request_context() + await sr.before_model_request(ctx, req_ctx) + last_msg = req_ctx.messages[-1] + assert isinstance(last_msg, ModelRequest) + system_parts = [p for p in last_msg.parts if isinstance(p, SystemPromptPart)] + assert len(system_parts) == 0 + + # Request 3: injection. + req_ctx = _make_request_context() + await sr.before_model_request(ctx, req_ctx) + last_msg = req_ctx.messages[-1] + assert isinstance(last_msg, ModelRequest) + system_parts = [p for p in last_msg.parts if isinstance(p, SystemPromptPart)] + assert len(system_parts) == 1 + assert system_parts[0].content == 'every third' + + @pytest.mark.anyio + async def test_multiple_reminders_different_intervals(self) -> None: + sr = SystemReminders( + reminders=[ + Reminder('every 2', interval=2), + Reminder('every 3', interval=3), + ], + ) + ctx = _make_run_context() + + # Request 1: none. + req_ctx = _make_request_context() + await sr.before_model_request(ctx, req_ctx) + assert _system_contents(req_ctx) == [] + + # Request 2: "every 2" only. + req_ctx = _make_request_context() + await sr.before_model_request(ctx, req_ctx) + assert _system_contents(req_ctx) == ['every 2'] + + # Request 3: "every 3" only. + req_ctx = _make_request_context() + await sr.before_model_request(ctx, req_ctx) + assert _system_contents(req_ctx) == ['every 3'] + + # Request 4: "every 2" only. + req_ctx = _make_request_context() + await sr.before_model_request(ctx, req_ctx) + assert _system_contents(req_ctx) == ['every 2'] + + # Request 5: none. + req_ctx = _make_request_context() + await sr.before_model_request(ctx, req_ctx) + assert _system_contents(req_ctx) == [] + + # Request 6: both. + req_ctx = _make_request_context() + await sr.before_model_request(ctx, req_ctx) + assert _system_contents(req_ctx) == ['every 2', 'every 3'] + + +# --- Dynamic reminder injection --- + + +class TestDynamicReminders: + @pytest.mark.anyio + async def test_sync_dynamic_returning_string(self) -> None: + sr = SystemReminders(dynamic_reminders=[_dynamic_content]) + ctx = _make_run_context() + + req_ctx = _make_request_context() + await sr.before_model_request(ctx, req_ctx) + assert _system_contents(req_ctx) == ['dynamic content'] + + @pytest.mark.anyio + async def test_sync_dynamic_returning_none_skips(self) -> None: + sr = SystemReminders(dynamic_reminders=[_returns_none]) + ctx = _make_run_context() + + req_ctx = _make_request_context() + await sr.before_model_request(ctx, req_ctx) + assert _system_contents(req_ctx) == [] + + @pytest.mark.anyio + async def test_async_dynamic_reminder(self) -> None: + async def async_reminder(ctx: Any) -> str | None: + return 'async content' + + sr = SystemReminders(dynamic_reminders=[async_reminder]) + ctx = _make_run_context() + + req_ctx = _make_request_context() + await sr.before_model_request(ctx, req_ctx) + assert _system_contents(req_ctx) == ['async content'] + + @pytest.mark.anyio + async def test_async_dynamic_returning_none(self) -> None: + async def async_none(ctx: Any) -> str | None: + return None + + sr = SystemReminders(dynamic_reminders=[async_none]) + ctx = _make_run_context() + + req_ctx = _make_request_context() + await sr.before_model_request(ctx, req_ctx) + assert _system_contents(req_ctx) == [] + + @pytest.mark.anyio + async def test_dynamic_receives_run_context(self) -> None: + def step_check(ctx: Any) -> str | None: + if ctx.run_step > 10: + return 'wrap up' + return None + + sr = SystemReminders(dynamic_reminders=[step_check]) + + # Low step: no reminder. + ctx_low = _make_run_context(run_step=5) + req_ctx = _make_request_context() + await sr.before_model_request(ctx_low, req_ctx) + assert _system_contents(req_ctx) == [] + + # High step: reminder fires. + ctx_high = _make_run_context(run_step=15) + req_ctx = _make_request_context() + await sr.before_model_request(ctx_high, req_ctx) + assert _system_contents(req_ctx) == ['wrap up'] + + +# --- Mixed static and dynamic --- + + +class TestMixedReminders: + @pytest.mark.anyio + async def test_static_and_dynamic_combined(self) -> None: + sr = SystemReminders( + reminders=[Reminder('static', interval=1)], + dynamic_reminders=[_returns_dynamic], + ) + ctx = _make_run_context() + + req_ctx = _make_request_context() + await sr.before_model_request(ctx, req_ctx) + assert _system_contents(req_ctx) == ['static', 'dynamic'] + + @pytest.mark.anyio + async def test_static_fires_dynamic_skips(self) -> None: + sr = SystemReminders( + reminders=[Reminder('static', interval=1)], + dynamic_reminders=[_returns_none], + ) + ctx = _make_run_context() + + req_ctx = _make_request_context() + await sr.before_model_request(ctx, req_ctx) + assert _system_contents(req_ctx) == ['static'] + + +# --- Message injection behavior --- + + +class TestMessageInjection: + @pytest.mark.anyio + async def test_appends_to_last_model_request(self) -> None: + """Reminder parts are appended to the last ModelRequest.""" + sr = SystemReminders(reminders=[Reminder('reminder')]) + ctx = _make_run_context() + + messages: list[Any] = [ + ModelRequest(parts=[UserPromptPart('first')]), + ModelRequest(parts=[UserPromptPart('second')]), + ] + req_ctx = _make_request_context(messages) + await sr.before_model_request(ctx, req_ctx) + + # First request unchanged. + first = req_ctx.messages[0] + assert isinstance(first, ModelRequest) + assert len(first.parts) == 1 + + # Second request has reminder appended. + second = req_ctx.messages[1] + assert isinstance(second, ModelRequest) + assert len(second.parts) == 2 + assert isinstance(second.parts[1], SystemPromptPart) + assert second.parts[1].content == 'reminder' + + @pytest.mark.anyio + async def test_preserves_existing_parts(self) -> None: + """Existing parts on the ModelRequest are preserved.""" + sr = SystemReminders(reminders=[Reminder('reminder')]) + ctx = _make_run_context() + + req_ctx = _make_request_context( + [ + ModelRequest(parts=[UserPromptPart('user msg'), SystemPromptPart(content='existing')]), + ] + ) + await sr.before_model_request(ctx, req_ctx) + + msg = req_ctx.messages[0] + assert isinstance(msg, ModelRequest) + assert len(msg.parts) == 3 + assert isinstance(msg.parts[0], UserPromptPart) + assert isinstance(msg.parts[1], SystemPromptPart) + assert msg.parts[1].content == 'existing' + assert isinstance(msg.parts[2], SystemPromptPart) + assert msg.parts[2].content == 'reminder' + + @pytest.mark.anyio + async def test_creates_request_when_none_exists(self) -> None: + """If no ModelRequest exists, a new one is created.""" + sr = SystemReminders(reminders=[Reminder('orphan')]) + ctx = _make_run_context() + + req_ctx = _make_request_context(messages=[]) + await sr.before_model_request(ctx, req_ctx) + + assert len(req_ctx.messages) == 1 + msg = req_ctx.messages[0] + assert isinstance(msg, ModelRequest) + assert len(msg.parts) == 1 + assert isinstance(msg.parts[0], SystemPromptPart) + assert msg.parts[0].content == 'orphan' + + @pytest.mark.anyio + async def test_skips_model_response_to_find_last_request(self) -> None: + """When the last message is a ModelResponse, skip it to find the ModelRequest.""" + sr = SystemReminders(reminders=[Reminder('reminder')]) + ctx = _make_run_context() + + messages: list[Any] = [ + ModelRequest(parts=[UserPromptPart('hello')]), + ModelResponse(parts=[TextPart(content='hi')]), + ] + req_ctx = _make_request_context(messages) + await sr.before_model_request(ctx, req_ctx) + + # The ModelRequest should have the reminder appended. + first = req_ctx.messages[0] + assert isinstance(first, ModelRequest) + assert len(first.parts) == 2 + assert isinstance(first.parts[1], SystemPromptPart) + assert first.parts[1].content == 'reminder' + + # The ModelResponse should be unchanged. + second = req_ctx.messages[1] + assert isinstance(second, ModelResponse) + + @pytest.mark.anyio + async def test_no_injection_when_nothing_fires(self) -> None: + """Messages are untouched when no reminders fire.""" + sr = SystemReminders(reminders=[Reminder('skip', interval=3)]) + ctx = _make_run_context() + + original_msg = ModelRequest(parts=[UserPromptPart('hello')]) + req_ctx = _make_request_context([original_msg]) + await sr.before_model_request(ctx, req_ctx) + + # Request 1: interval=3 doesn't fire. + msg = req_ctx.messages[0] + assert isinstance(msg, ModelRequest) + assert len(msg.parts) == 1 + + +# --- Condition-triggered reminders --- + + +class TestTrigger: + @pytest.mark.anyio + async def test_trigger_true_fires(self) -> None: + sr = SystemReminders(reminders=[Reminder('triggered', trigger=lambda ctx: True)]) + ctx = _make_run_context() + + req_ctx = _make_request_context() + await sr.before_model_request(ctx, req_ctx) + assert _system_contents(req_ctx) == ['triggered'] + + @pytest.mark.anyio + async def test_trigger_false_suppresses(self) -> None: + sr = SystemReminders(reminders=[Reminder('suppressed', trigger=lambda ctx: False)]) + ctx = _make_run_context() + + req_ctx = _make_request_context() + await sr.before_model_request(ctx, req_ctx) + assert _system_contents(req_ctx) == [] + + @pytest.mark.anyio + async def test_trigger_receives_run_context(self) -> None: + """Trigger predicate receives the RunContext and can inspect run_step.""" + + def high_step_trigger(ctx: Any) -> bool: + return ctx.run_step > 10 # type: ignore[no-any-return] + + sr = SystemReminders(reminders=[Reminder('late warning', trigger=high_step_trigger)]) + + # Low step: trigger is False, no injection. + ctx_low = _make_run_context(run_step=5) + req_ctx = _make_request_context() + await sr.before_model_request(ctx_low, req_ctx) + assert _system_contents(req_ctx) == [] + + # High step: trigger is True, fires. + ctx_high = _make_run_context(run_step=15) + req_ctx = _make_request_context() + await sr.before_model_request(ctx_high, req_ctx) + assert _system_contents(req_ctx) == ['late warning'] + + @pytest.mark.anyio + async def test_trigger_combined_with_interval(self) -> None: + """Trigger must return True AND interval must match for the reminder to fire.""" + sr = SystemReminders( + reminders=[Reminder('combo', interval=2, trigger=lambda ctx: True)], + ) + ctx = _make_run_context() + + # Request 1: interval doesn't match (1 % 2 != 0). + req_ctx = _make_request_context() + await sr.before_model_request(ctx, req_ctx) + assert _system_contents(req_ctx) == [] + + # Request 2: interval matches and trigger is True. + req_ctx = _make_request_context() + await sr.before_model_request(ctx, req_ctx) + assert _system_contents(req_ctx) == ['combo'] + + @pytest.mark.anyio + async def test_trigger_false_blocks_even_when_interval_matches(self) -> None: + sr = SystemReminders( + reminders=[Reminder('blocked', interval=1, trigger=lambda ctx: False)], + ) + ctx = _make_run_context() + + req_ctx = _make_request_context() + await sr.before_model_request(ctx, req_ctx) + assert _system_contents(req_ctx) == [] + + +# --- Max fires --- + + +class TestMaxFires: + @pytest.mark.anyio + async def test_max_fires_limits_injections(self) -> None: + sr = SystemReminders(reminders=[Reminder('limited', max_fires=2)]) + ctx = _make_run_context() + + # First two requests: fires. + for _ in range(2): + req_ctx = _make_request_context() + await sr.before_model_request(ctx, req_ctx) + assert _system_contents(req_ctx) == ['limited'] + + # Third request: max reached, no injection. + req_ctx = _make_request_context() + await sr.before_model_request(ctx, req_ctx) + assert _system_contents(req_ctx) == [] + + @pytest.mark.anyio + async def test_max_fires_none_means_unlimited(self) -> None: + sr = SystemReminders(reminders=[Reminder('unlimited', max_fires=None)]) + ctx = _make_run_context() + + for _ in range(10): + req_ctx = _make_request_context() + await sr.before_model_request(ctx, req_ctx) + assert _system_contents(req_ctx) == ['unlimited'] + + @pytest.mark.anyio + async def test_max_fires_with_interval(self) -> None: + """max_fires counts actual fires, not interval-eligible requests.""" + sr = SystemReminders(reminders=[Reminder('capped', interval=2, max_fires=1)]) + ctx = _make_run_context() + + # Request 1: interval doesn't match. + req_ctx = _make_request_context() + await sr.before_model_request(ctx, req_ctx) + assert _system_contents(req_ctx) == [] + + # Request 2: fires (first and only fire). + req_ctx = _make_request_context() + await sr.before_model_request(ctx, req_ctx) + assert _system_contents(req_ctx) == ['capped'] + + # Request 3: interval doesn't match. + req_ctx = _make_request_context() + await sr.before_model_request(ctx, req_ctx) + assert _system_contents(req_ctx) == [] + + # Request 4: interval matches but max_fires exhausted. + req_ctx = _make_request_context() + await sr.before_model_request(ctx, req_ctx) + assert _system_contents(req_ctx) == [] + + @pytest.mark.anyio + async def test_max_fires_per_reminder_independence(self) -> None: + """Each reminder tracks its own fire count independently.""" + sr = SystemReminders( + reminders=[ + Reminder('once', max_fires=1), + Reminder('twice', max_fires=2), + ], + ) + ctx = _make_run_context() + + # Request 1: both fire. + req_ctx = _make_request_context() + await sr.before_model_request(ctx, req_ctx) + assert _system_contents(req_ctx) == ['once', 'twice'] + + # Request 2: 'once' exhausted, 'twice' still has one left. + req_ctx = _make_request_context() + await sr.before_model_request(ctx, req_ctx) + assert _system_contents(req_ctx) == ['twice'] + + # Request 3: both exhausted. + req_ctx = _make_request_context() + await sr.before_model_request(ctx, req_ctx) + assert _system_contents(req_ctx) == [] + + +# --- XML tag wrapping --- + + +class TestTagWrapping: + def test_render_content_without_tag(self) -> None: + r = Reminder('plain content') + assert r.render_content() == 'plain content' + + def test_render_content_with_tag(self) -> None: + r = Reminder('reminder text', tag='system-reminder') + assert r.render_content() == 'reminder text' + + @pytest.mark.anyio + async def test_tag_wrapping_in_injection(self) -> None: + sr = SystemReminders(reminders=[Reminder('stay focused', tag='system-reminder')]) + ctx = _make_run_context() + + req_ctx = _make_request_context() + await sr.before_model_request(ctx, req_ctx) + assert _system_contents(req_ctx) == ['stay focused'] + + @pytest.mark.anyio + async def test_mixed_tagged_and_untagged(self) -> None: + sr = SystemReminders( + reminders=[ + Reminder('tagged', tag='hint'), + Reminder('untagged'), + ], + ) + ctx = _make_run_context() + + req_ctx = _make_request_context() + await sr.before_model_request(ctx, req_ctx) + assert _system_contents(req_ctx) == ['tagged', 'untagged'] + + +# --- Combined features --- + + +class TestCombinedFeatures: + @pytest.mark.anyio + async def test_trigger_and_max_fires_and_tag(self) -> None: + """All three new features work together on a single reminder.""" + fires: list[bool] = [True, True, True, False] + call_idx = 0 + + def toggling_trigger(ctx: Any) -> bool: + nonlocal call_idx + result = fires[call_idx] if call_idx < len(fires) else False + call_idx += 1 + return result + + sr = SystemReminders( + reminders=[Reminder('combo', trigger=toggling_trigger, max_fires=2, tag='note')], + ) + ctx = _make_run_context() + + # Request 1: trigger True, fires (1/2). + req_ctx = _make_request_context() + await sr.before_model_request(ctx, req_ctx) + assert _system_contents(req_ctx) == ['combo'] + + # Request 2: trigger True, fires (2/2). + req_ctx = _make_request_context() + await sr.before_model_request(ctx, req_ctx) + assert _system_contents(req_ctx) == ['combo'] + + # Request 3: trigger True, but max_fires exhausted. + req_ctx = _make_request_context() + await sr.before_model_request(ctx, req_ctx) + assert _system_contents(req_ctx) == [] + + +# --- Serialization --- + + +class TestSerialization: + def test_not_serializable(self) -> None: + assert SystemReminders.get_serialization_name() is None + + +# --- Helpers --- + + +def _system_contents(req_ctx: ModelRequestContext) -> list[str]: + """Extract system prompt contents from the last ModelRequest in a request context.""" + if not req_ctx.messages: # pragma: no cover + return [] + last = req_ctx.messages[-1] + if not isinstance(last, ModelRequest): # pragma: no cover + return [] + return [p.content for p in last.parts if isinstance(p, SystemPromptPart)]