Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions modules/agents/codex/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,9 @@ async def _start_or_resume_thread(
"threadId": persisted,
"developerInstructions": self._build_thread_developer_instructions(request),
}
model_provider = await self._resolve_resume_model_provider_override(transport, request, persisted)
if model_provider:
resume_params["modelProvider"] = model_provider
resp = await transport.send_request(
"thread/resume",
resume_params,
Expand All @@ -585,6 +588,72 @@ async def _start_or_resume_thread(

return await self._start_thread(transport, request)

async def _resolve_resume_model_provider_override(
self,
transport: CodexTransport,
request: AgentRequest,
thread_id: str,
) -> Optional[str]:
"""Return a provider override only when a persisted thread is stale.

Codex preserves a thread's latest model / reasoning effort on resume
unless the client sends a model/provider override. Vibe Remote only
needs to override the provider after the user changes Codex auth mode
or project provider settings, so inspect the stored thread first and
leave normal resumes on Codex's persisted fallback path.
"""
current_provider = await self._read_effective_model_provider(transport, request)
if not current_provider:
return None

try:
resp = await transport.send_request(
"thread/read",
{
"threadId": thread_id,
"includeTurns": False,
},
)
except Exception as exc:
logger.warning("Failed to read Codex thread %s provider before resume: %s", thread_id, exc)
return None

thread_obj = resp.get("thread") if isinstance(resp, dict) else None
if not isinstance(thread_obj, dict) and isinstance(resp, dict) and resp.get("id") == thread_id:
thread_obj = resp
stored_provider = thread_obj.get("modelProvider") if isinstance(thread_obj, dict) else None
if not isinstance(stored_provider, str) or not stored_provider.strip():
return None

if stored_provider.strip() == current_provider:
return None
return current_provider
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Keep persisted provider when resuming cross-provider threads

Only overriding modelProvider based on current config when it differs from the stored thread provider can break valid cross-provider resumes: this path now forces thread/resume onto the current provider for any mismatch, even when the persisted thread was created under a different provider and must continue there. In those cases, resume can fail or misroute history because the provider affinity encoded in the stored thread is discarded at resume time.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in db97d6c. The provider override is now limited to transitions between Vibe Remote-managed Codex provider IDs (openai and openai-managed). Other mismatches are treated as intentional cross-provider sessions and resume without modelProvider, preserving Codex persisted provider affinity. Added a regression test for an unmanaged anthropic thread while the current config points at openai-managed.


async def _read_effective_model_provider(
self,
transport: CodexTransport,
request: AgentRequest,
) -> Optional[str]:
"""Ask Codex app-server for the provider it resolves for this request."""
params: Dict[str, Any] = {"includeLayers": False}
working_path = getattr(request, "working_path", None)
if working_path:
params["cwd"] = working_path

try:
resp = await transport.send_request("config/read", params)
except Exception as exc:
logger.warning("Failed to read effective Codex model provider before resume: %s", exc)
return None

config_obj = resp.get("config") if isinstance(resp, dict) else None
if not isinstance(config_obj, dict):
return None
model_provider = config_obj.get("model_provider")
if isinstance(model_provider, str) and model_provider.strip():
return model_provider.strip()
return None

def _build_thread_developer_instructions(self, request: AgentRequest) -> Optional[str]:
"""Build Codex thread-level developer instructions for start/resume.

Expand Down
179 changes: 176 additions & 3 deletions tests/test_codex_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -885,7 +885,15 @@ async def test_resume_thread_refreshes_developer_instructions_without_appending(
subagent_model=None,
subagent_reasoning_effort=None,
)
transport = SimpleNamespace(send_request=AsyncMock(return_value={"thread": {"id": "thread-existing"}}))
transport = SimpleNamespace(
send_request=AsyncMock(
side_effect=[
{"config": {"model_provider": "openai"}},
{"thread": {"id": "thread-existing", "modelProvider": "openai"}},
{"thread": {"id": "thread-existing"}},
]
)
)

with patch.object(
_MODULE,
Expand All @@ -899,8 +907,8 @@ async def test_resume_thread_refreshes_developer_instructions_without_appending(
thread_id = await agent._start_or_resume_thread(transport, request)

self.assertEqual(thread_id, "thread-existing")
transport.send_request.assert_awaited_once()
method, params = transport.send_request.await_args.args
self.assertEqual(transport.send_request.await_count, 3)
method, params = transport.send_request.await_args_list[2].args
self.assertEqual(method, "thread/resume")
self.assertEqual(params["threadId"], "thread-existing")
developer_instructions = params["developerInstructions"]
Expand All @@ -913,6 +921,171 @@ async def test_resume_thread_refreshes_developer_instructions_without_appending(
)
self.assertNotIn("Channel-level session key:", developer_instructions)

async def test_resume_thread_does_not_force_model_provider_when_thread_matches_config(self):
agent = object.__new__(CodexAgent)
agent.controller = SimpleNamespace(config=SimpleNamespace(platform="slack", reply_enhancements=True))
agent.codex_config = SimpleNamespace(default_model=None)
agent._session_mgr = SimpleNamespace(set_thread_id=Mock())
agent.sessions = SimpleNamespace(
get_agent_session_id=Mock(return_value="thread-existing"),
)
request = SimpleNamespace(
working_path="/tmp/work",
context=SimpleNamespace(
platform="slack",
platform_specific={"is_dm": False},
user_id="U1",
channel_id="C1",
thread_id="171717.123",
),
base_session_id="session-1",
session_key="slack::channel::C1::thread::171717.123",
subagent_name=None,
subagent_model=None,
subagent_reasoning_effort=None,
)
transport = SimpleNamespace(
send_request=AsyncMock(
side_effect=[
{"config": {"model_provider": "openai-managed"}},
{"thread": {"id": "thread-existing", "modelProvider": "openai-managed"}},
{"thread": {"id": "thread-existing"}},
]
)
)

thread_id = await agent._start_or_resume_thread(transport, request)

self.assertEqual(thread_id, "thread-existing")
self.assertEqual(transport.send_request.await_args_list[0].args[0], "config/read")
self.assertEqual(transport.send_request.await_args_list[0].args[1]["cwd"], "/tmp/work")
self.assertEqual(transport.send_request.await_args_list[1].args[0], "thread/read")
method, params = transport.send_request.await_args_list[2].args
self.assertEqual(method, "thread/resume")
self.assertNotIn("modelProvider", params)

async def test_resume_thread_overrides_stale_session_model_provider(self):
agent = object.__new__(CodexAgent)
agent.controller = SimpleNamespace(config=SimpleNamespace(platform="slack", reply_enhancements=True))
agent.codex_config = SimpleNamespace(default_model=None)
agent._session_mgr = SimpleNamespace(set_thread_id=Mock())
agent.sessions = SimpleNamespace(
get_agent_session_id=Mock(return_value="thread-existing"),
)
request = SimpleNamespace(
working_path="/tmp/work",
context=SimpleNamespace(
platform="slack",
platform_specific={"is_dm": False},
user_id="U1",
channel_id="C1",
thread_id="171717.123",
),
base_session_id="session-1",
session_key="slack::channel::C1::thread::171717.123",
subagent_name=None,
subagent_model=None,
subagent_reasoning_effort=None,
)
transport = SimpleNamespace(
send_request=AsyncMock(
side_effect=[
{"config": {"model_provider": "openai-managed"}},
{"thread": {"id": "thread-existing", "modelProvider": "openai"}},
{"thread": {"id": "thread-existing"}},
]
)
)

thread_id = await agent._start_or_resume_thread(transport, request)

self.assertEqual(thread_id, "thread-existing")
self.assertEqual(transport.send_request.await_args_list[0].args[0], "config/read")
self.assertEqual(transport.send_request.await_args_list[0].args[1]["cwd"], "/tmp/work")
self.assertEqual(transport.send_request.await_args_list[1].args[0], "thread/read")
method, params = transport.send_request.await_args_list[2].args
self.assertEqual(method, "thread/resume")
self.assertEqual(params["modelProvider"], "openai-managed")

async def test_resume_thread_omits_model_provider_when_provider_read_fails(self):
agent = object.__new__(CodexAgent)
agent.controller = SimpleNamespace(config=SimpleNamespace(platform="slack", reply_enhancements=True))
agent.codex_config = SimpleNamespace(default_model=None)
agent._session_mgr = SimpleNamespace(set_thread_id=Mock())
agent.sessions = SimpleNamespace(
get_agent_session_id=Mock(return_value="thread-existing"),
)
request = SimpleNamespace(
working_path="/tmp/work",
context=SimpleNamespace(
platform="slack",
platform_specific={"is_dm": False},
user_id="U1",
channel_id="C1",
thread_id="171717.123",
),
base_session_id="session-1",
session_key="slack::channel::C1::thread::171717.123",
subagent_name=None,
subagent_model=None,
subagent_reasoning_effort=None,
)
transport = SimpleNamespace(
send_request=AsyncMock(
side_effect=[
{"config": {"model_provider": "openai-managed"}},
RuntimeError("thread/read unavailable"),
{"thread": {"id": "thread-existing"}},
]
)
)

thread_id = await agent._start_or_resume_thread(transport, request)

self.assertEqual(thread_id, "thread-existing")
method, params = transport.send_request.await_args_list[2].args
self.assertEqual(method, "thread/resume")
self.assertNotIn("modelProvider", params)

async def test_resume_thread_omits_model_provider_when_config_read_fails(self):
agent = object.__new__(CodexAgent)
agent.controller = SimpleNamespace(config=SimpleNamespace(platform="slack", reply_enhancements=True))
agent.codex_config = SimpleNamespace(default_model=None)
agent._session_mgr = SimpleNamespace(set_thread_id=Mock())
agent.sessions = SimpleNamespace(
get_agent_session_id=Mock(return_value="thread-existing"),
)
request = SimpleNamespace(
working_path="/tmp/work",
context=SimpleNamespace(
platform="slack",
platform_specific={"is_dm": False},
user_id="U1",
channel_id="C1",
thread_id="171717.123",
),
base_session_id="session-1",
session_key="slack::channel::C1::thread::171717.123",
subagent_name=None,
subagent_model=None,
subagent_reasoning_effort=None,
)
transport = SimpleNamespace(
send_request=AsyncMock(
side_effect=[
RuntimeError("config/read unavailable"),
{"thread": {"id": "thread-existing"}},
]
)
)

thread_id = await agent._start_or_resume_thread(transport, request)

self.assertEqual(thread_id, "thread-existing")
method, params = transport.send_request.await_args_list[1].args
self.assertEqual(method, "thread/resume")
self.assertNotIn("modelProvider", params)

async def test_resume_thread_keeps_system_prompt_injection_when_quick_replies_are_disabled(self):
agent = object.__new__(CodexAgent)
agent.controller = SimpleNamespace(config=SimpleNamespace(platform="slack", reply_enhancements=False))
Expand Down
Loading