Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions modules/agents/codex/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@
from modules.agents.codex.session import CodexSessionManager
from modules.agents.codex.transport import CodexTransport
from modules.agents.codex.turn_state import CodexTurnRegistry
from vibe.codex_config import LEGACY_MANAGED_PROVIDER_IDS, MANAGED_PROVIDER_ID

logger = logging.getLogger(__name__)

_CODEX_MANAGED_PROVIDER_IDS = frozenset((MANAGED_PROVIDER_ID, *LEGACY_MANAGED_PROVIDER_IDS))


class CodexAgent(BaseAgent):
"""Codex CLI integration via persistent ``codex app-server`` subprocess.
Expand Down Expand Up @@ -563,6 +566,9 @@ async def _start_or_resume_thread(
"threadId": persisted,
"developerInstructions": self._build_thread_developer_instructions(request),
}
model_provider = await self._resolve_resume_model_provider_override(transport, request, persisted)
if model_provider:
resume_params["modelProvider"] = model_provider
resp = await transport.send_request(
"thread/resume",
resume_params,
Expand All @@ -585,6 +591,80 @@ async def _start_or_resume_thread(

return await self._start_thread(transport, request)

async def _resolve_resume_model_provider_override(
self,
transport: CodexTransport,
request: AgentRequest,
thread_id: str,
) -> Optional[str]:
"""Return a provider override only when a persisted thread is stale.

Codex preserves a thread's latest model / reasoning effort on resume
unless the client sends a model/provider override. Vibe Remote only
needs to override the provider after the user changes Codex auth mode
between Vibe Remote-managed OAuth/API-key providers, so inspect the
stored thread first and leave normal resumes on Codex's persisted
fallback path.
"""
current_provider = await self._read_effective_model_provider(transport, request)
if not current_provider:
return None

try:
resp = await transport.send_request(
"thread/read",
{
"threadId": thread_id,
"includeTurns": False,
},
)
except Exception as exc:
logger.warning("Failed to read Codex thread %s provider before resume: %s", thread_id, exc)
return None

thread_obj = resp.get("thread") if isinstance(resp, dict) else None
if not isinstance(thread_obj, dict) and isinstance(resp, dict) and resp.get("id") == thread_id:
thread_obj = resp
stored_provider = thread_obj.get("modelProvider") if isinstance(thread_obj, dict) else None
if not isinstance(stored_provider, str) or not stored_provider.strip():
return None

stored_provider = stored_provider.strip()
if stored_provider == current_provider:
return None
if not self._is_managed_provider_transition(stored_provider, current_provider):
return None
return current_provider

@staticmethod
def _is_managed_provider_transition(stored_provider: str, current_provider: str) -> bool:
return {stored_provider, current_provider}.issubset(_CODEX_MANAGED_PROVIDER_IDS)

async def _read_effective_model_provider(
self,
transport: CodexTransport,
request: AgentRequest,
) -> Optional[str]:
"""Ask Codex app-server for the provider it resolves for this request."""
params: Dict[str, Any] = {"includeLayers": False}
working_path = getattr(request, "working_path", None)
if working_path:
params["cwd"] = working_path

try:
resp = await transport.send_request("config/read", params)
except Exception as exc:
logger.warning("Failed to read effective Codex model provider before resume: %s", exc)
return None

config_obj = resp.get("config") if isinstance(resp, dict) else None
if not isinstance(config_obj, dict):
return None
model_provider = config_obj.get("model_provider")
if isinstance(model_provider, str) and model_provider.strip():
return model_provider.strip()
return None

def _build_thread_developer_instructions(self, request: AgentRequest) -> Optional[str]:
"""Build Codex thread-level developer instructions for start/resume.

Expand Down
219 changes: 216 additions & 3 deletions tests/test_codex_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -885,7 +885,15 @@ async def test_resume_thread_refreshes_developer_instructions_without_appending(
subagent_model=None,
subagent_reasoning_effort=None,
)
transport = SimpleNamespace(send_request=AsyncMock(return_value={"thread": {"id": "thread-existing"}}))
transport = SimpleNamespace(
send_request=AsyncMock(
side_effect=[
{"config": {"model_provider": "openai"}},
{"thread": {"id": "thread-existing", "modelProvider": "openai"}},
{"thread": {"id": "thread-existing"}},
]
)
)

with patch.object(
_MODULE,
Expand All @@ -899,8 +907,8 @@ async def test_resume_thread_refreshes_developer_instructions_without_appending(
thread_id = await agent._start_or_resume_thread(transport, request)

self.assertEqual(thread_id, "thread-existing")
transport.send_request.assert_awaited_once()
method, params = transport.send_request.await_args.args
self.assertEqual(transport.send_request.await_count, 3)
method, params = transport.send_request.await_args_list[2].args
self.assertEqual(method, "thread/resume")
self.assertEqual(params["threadId"], "thread-existing")
developer_instructions = params["developerInstructions"]
Expand All @@ -913,6 +921,211 @@ async def test_resume_thread_refreshes_developer_instructions_without_appending(
)
self.assertNotIn("Channel-level session key:", developer_instructions)

async def test_resume_thread_does_not_force_model_provider_when_thread_matches_config(self):
agent = object.__new__(CodexAgent)
agent.controller = SimpleNamespace(config=SimpleNamespace(platform="slack", reply_enhancements=True))
agent.codex_config = SimpleNamespace(default_model=None)
agent._session_mgr = SimpleNamespace(set_thread_id=Mock())
agent.sessions = SimpleNamespace(
get_agent_session_id=Mock(return_value="thread-existing"),
)
request = SimpleNamespace(
working_path="/tmp/work",
context=SimpleNamespace(
platform="slack",
platform_specific={"is_dm": False},
user_id="U1",
channel_id="C1",
thread_id="171717.123",
),
base_session_id="session-1",
session_key="slack::channel::C1::thread::171717.123",
subagent_name=None,
subagent_model=None,
subagent_reasoning_effort=None,
)
transport = SimpleNamespace(
send_request=AsyncMock(
side_effect=[
{"config": {"model_provider": "openai-managed"}},
{"thread": {"id": "thread-existing", "modelProvider": "openai-managed"}},
{"thread": {"id": "thread-existing"}},
]
)
)

thread_id = await agent._start_or_resume_thread(transport, request)

self.assertEqual(thread_id, "thread-existing")
self.assertEqual(transport.send_request.await_args_list[0].args[0], "config/read")
self.assertEqual(transport.send_request.await_args_list[0].args[1]["cwd"], "/tmp/work")
self.assertEqual(transport.send_request.await_args_list[1].args[0], "thread/read")
method, params = transport.send_request.await_args_list[2].args
self.assertEqual(method, "thread/resume")
self.assertNotIn("modelProvider", params)

async def test_resume_thread_overrides_stale_session_model_provider(self):
agent = object.__new__(CodexAgent)
agent.controller = SimpleNamespace(config=SimpleNamespace(platform="slack", reply_enhancements=True))
agent.codex_config = SimpleNamespace(default_model=None)
agent._session_mgr = SimpleNamespace(set_thread_id=Mock())
agent.sessions = SimpleNamespace(
get_agent_session_id=Mock(return_value="thread-existing"),
)
request = SimpleNamespace(
working_path="/tmp/work",
context=SimpleNamespace(
platform="slack",
platform_specific={"is_dm": False},
user_id="U1",
channel_id="C1",
thread_id="171717.123",
),
base_session_id="session-1",
session_key="slack::channel::C1::thread::171717.123",
subagent_name=None,
subagent_model=None,
subagent_reasoning_effort=None,
)
transport = SimpleNamespace(
send_request=AsyncMock(
side_effect=[
{"config": {"model_provider": "openai-managed"}},
{"thread": {"id": "thread-existing", "modelProvider": "openai"}},
{"thread": {"id": "thread-existing"}},
]
)
)

thread_id = await agent._start_or_resume_thread(transport, request)

self.assertEqual(thread_id, "thread-existing")
self.assertEqual(transport.send_request.await_args_list[0].args[0], "config/read")
self.assertEqual(transport.send_request.await_args_list[0].args[1]["cwd"], "/tmp/work")
self.assertEqual(transport.send_request.await_args_list[1].args[0], "thread/read")
method, params = transport.send_request.await_args_list[2].args
self.assertEqual(method, "thread/resume")
self.assertEqual(params["modelProvider"], "openai-managed")

async def test_resume_thread_preserves_unmanaged_cross_provider_session(self):
agent = object.__new__(CodexAgent)
agent.controller = SimpleNamespace(config=SimpleNamespace(platform="slack", reply_enhancements=True))
agent.codex_config = SimpleNamespace(default_model=None)
agent._session_mgr = SimpleNamespace(set_thread_id=Mock())
agent.sessions = SimpleNamespace(
get_agent_session_id=Mock(return_value="thread-existing"),
)
request = SimpleNamespace(
working_path="/tmp/work",
context=SimpleNamespace(
platform="slack",
platform_specific={"is_dm": False},
user_id="U1",
channel_id="C1",
thread_id="171717.123",
),
base_session_id="session-1",
session_key="slack::channel::C1::thread::171717.123",
subagent_name=None,
subagent_model=None,
subagent_reasoning_effort=None,
)
transport = SimpleNamespace(
send_request=AsyncMock(
side_effect=[
{"config": {"model_provider": "openai-managed"}},
{"thread": {"id": "thread-existing", "modelProvider": "anthropic"}},
{"thread": {"id": "thread-existing"}},
]
)
)

thread_id = await agent._start_or_resume_thread(transport, request)

self.assertEqual(thread_id, "thread-existing")
method, params = transport.send_request.await_args_list[2].args
self.assertEqual(method, "thread/resume")
self.assertNotIn("modelProvider", params)

async def test_resume_thread_omits_model_provider_when_provider_read_fails(self):
agent = object.__new__(CodexAgent)
agent.controller = SimpleNamespace(config=SimpleNamespace(platform="slack", reply_enhancements=True))
agent.codex_config = SimpleNamespace(default_model=None)
agent._session_mgr = SimpleNamespace(set_thread_id=Mock())
agent.sessions = SimpleNamespace(
get_agent_session_id=Mock(return_value="thread-existing"),
)
request = SimpleNamespace(
working_path="/tmp/work",
context=SimpleNamespace(
platform="slack",
platform_specific={"is_dm": False},
user_id="U1",
channel_id="C1",
thread_id="171717.123",
),
base_session_id="session-1",
session_key="slack::channel::C1::thread::171717.123",
subagent_name=None,
subagent_model=None,
subagent_reasoning_effort=None,
)
transport = SimpleNamespace(
send_request=AsyncMock(
side_effect=[
{"config": {"model_provider": "openai-managed"}},
RuntimeError("thread/read unavailable"),
{"thread": {"id": "thread-existing"}},
]
)
)

thread_id = await agent._start_or_resume_thread(transport, request)

self.assertEqual(thread_id, "thread-existing")
method, params = transport.send_request.await_args_list[2].args
self.assertEqual(method, "thread/resume")
self.assertNotIn("modelProvider", params)

async def test_resume_thread_omits_model_provider_when_config_read_fails(self):
agent = object.__new__(CodexAgent)
agent.controller = SimpleNamespace(config=SimpleNamespace(platform="slack", reply_enhancements=True))
agent.codex_config = SimpleNamespace(default_model=None)
agent._session_mgr = SimpleNamespace(set_thread_id=Mock())
agent.sessions = SimpleNamespace(
get_agent_session_id=Mock(return_value="thread-existing"),
)
request = SimpleNamespace(
working_path="/tmp/work",
context=SimpleNamespace(
platform="slack",
platform_specific={"is_dm": False},
user_id="U1",
channel_id="C1",
thread_id="171717.123",
),
base_session_id="session-1",
session_key="slack::channel::C1::thread::171717.123",
subagent_name=None,
subagent_model=None,
subagent_reasoning_effort=None,
)
transport = SimpleNamespace(
send_request=AsyncMock(
side_effect=[
RuntimeError("config/read unavailable"),
{"thread": {"id": "thread-existing"}},
]
)
)

thread_id = await agent._start_or_resume_thread(transport, request)

self.assertEqual(thread_id, "thread-existing")
method, params = transport.send_request.await_args_list[1].args
self.assertEqual(method, "thread/resume")
self.assertNotIn("modelProvider", params)

async def test_resume_thread_keeps_system_prompt_injection_when_quick_replies_are_disabled(self):
agent = object.__new__(CodexAgent)
agent.controller = SimpleNamespace(config=SimpleNamespace(platform="slack", reply_enhancements=False))
Expand Down
Loading