Skip to content

Commit 9808864

Browse files
authored
Add user command event types (#6246)
adding new user command event, logic in TUI to render user command events
1 parent e743d25 commit 9808864

File tree

7 files changed

+411
-66
lines changed

7 files changed

+411
-66
lines changed

codex-rs/core/src/event_mapping.rs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ use tracing::warn;
1414
use uuid::Uuid;
1515

1616
use crate::user_instructions::UserInstructions;
17+
use crate::user_shell_command::is_user_shell_command_text;
1718

1819
fn is_session_prefix(text: &str) -> bool {
1920
let trimmed = text.trim_start();
@@ -31,7 +32,7 @@ fn parse_user_message(message: &[ContentItem]) -> Option<UserMessageItem> {
3132
for content_item in message.iter() {
3233
match content_item {
3334
ContentItem::InputText { text } => {
34-
if is_session_prefix(text) {
35+
if is_session_prefix(text) || is_user_shell_command_text(text) {
3536
return None;
3637
}
3738
content.push(UserInput::Text { text: text.clone() });
@@ -197,7 +198,14 @@ mod tests {
197198
text: "# AGENTS.md instructions for test_directory\n\n<INSTRUCTIONS>\ntest_text\n</INSTRUCTIONS>".to_string(),
198199
}],
199200
},
200-
];
201+
ResponseItem::Message {
202+
id: None,
203+
role: "user".to_string(),
204+
content: vec![ContentItem::InputText {
205+
text: "<user_shell_command>echo 42</user_shell_command>".to_string(),
206+
}],
207+
},
208+
];
201209

202210
for item in items {
203211
let turn_item = parse_turn_item(&item);

codex-rs/core/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ mod function_tool;
8181
mod state;
8282
mod tasks;
8383
mod user_notification;
84+
mod user_shell_command;
8485
pub mod util;
8586

8687
pub use apply_patch::CODEX_APPLY_PATCH_ARG1;

codex-rs/core/src/tasks/user_shell.rs

Lines changed: 128 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,35 @@
11
use std::sync::Arc;
2+
use std::time::Duration;
23

34
use async_trait::async_trait;
4-
use codex_protocol::models::ShellToolCallParams;
5+
use codex_async_utils::CancelErr;
6+
use codex_async_utils::OrCancelExt;
57
use codex_protocol::user_input::UserInput;
6-
use tokio::sync::Mutex;
78
use tokio_util::sync::CancellationToken;
89
use tracing::error;
910
use uuid::Uuid;
1011

1112
use crate::codex::TurnContext;
13+
use crate::exec::ExecToolCallOutput;
14+
use crate::exec::SandboxType;
15+
use crate::exec::StdoutStream;
16+
use crate::exec::StreamOutput;
17+
use crate::exec::execute_exec_env;
18+
use crate::exec_env::create_env;
19+
use crate::parse_command::parse_command;
1220
use crate::protocol::EventMsg;
21+
use crate::protocol::ExecCommandBeginEvent;
22+
use crate::protocol::ExecCommandEndEvent;
23+
use crate::protocol::SandboxPolicy;
1324
use crate::protocol::TaskStartedEvent;
25+
use crate::sandboxing::ExecEnv;
1426
use crate::state::TaskKind;
15-
use crate::tools::context::ToolPayload;
16-
use crate::tools::parallel::ToolCallRuntime;
17-
use crate::tools::router::ToolCall;
18-
use crate::tools::router::ToolRouter;
19-
use crate::turn_diff_tracker::TurnDiffTracker;
27+
use crate::tools::format_exec_output_str;
28+
use crate::user_shell_command::user_shell_command_record_item;
2029

2130
use super::SessionTask;
2231
use super::SessionTaskContext;
2332

24-
const USER_SHELL_TOOL_NAME: &str = "local_shell";
25-
2633
#[derive(Clone)]
2734
pub(crate) struct UserShellCommandTask {
2835
command: String,
@@ -78,34 +85,126 @@ impl SessionTask for UserShellCommandTask {
7885
}
7986
};
8087

81-
let params = ShellToolCallParams {
88+
let call_id = Uuid::new_v4().to_string();
89+
let raw_command = self.command.clone();
90+
91+
let parsed_cmd = parse_command(&shell_invocation);
92+
session
93+
.send_event(
94+
turn_context.as_ref(),
95+
EventMsg::ExecCommandBegin(ExecCommandBeginEvent {
96+
call_id: call_id.clone(),
97+
command: shell_invocation.clone(),
98+
cwd: turn_context.cwd.clone(),
99+
parsed_cmd,
100+
is_user_shell_command: true,
101+
}),
102+
)
103+
.await;
104+
105+
let exec_env = ExecEnv {
82106
command: shell_invocation,
83-
workdir: None,
107+
cwd: turn_context.cwd.clone(),
108+
env: create_env(&turn_context.shell_environment_policy),
84109
timeout_ms: None,
110+
sandbox: SandboxType::None,
85111
with_escalated_permissions: None,
86112
justification: None,
113+
arg0: None,
87114
};
88115

89-
let tool_call = ToolCall {
90-
tool_name: USER_SHELL_TOOL_NAME.to_string(),
91-
call_id: Uuid::new_v4().to_string(),
92-
payload: ToolPayload::LocalShell { params },
93-
};
116+
let stdout_stream = Some(StdoutStream {
117+
sub_id: turn_context.sub_id.clone(),
118+
call_id: call_id.clone(),
119+
tx_event: session.get_tx_event(),
120+
});
94121

95-
let router = Arc::new(ToolRouter::from_config(&turn_context.tools_config, None));
96-
let tracker = Arc::new(Mutex::new(TurnDiffTracker::new()));
97-
let runtime = ToolCallRuntime::new(
98-
Arc::clone(&router),
99-
Arc::clone(&session),
100-
Arc::clone(&turn_context),
101-
Arc::clone(&tracker),
102-
);
122+
let sandbox_policy = SandboxPolicy::DangerFullAccess;
123+
let exec_result = execute_exec_env(exec_env, &sandbox_policy, stdout_stream)
124+
.or_cancel(&cancellation_token)
125+
.await;
103126

104-
if let Err(err) = runtime
105-
.handle_tool_call(tool_call, cancellation_token)
106-
.await
107-
{
108-
error!("user shell command failed: {err:?}");
127+
match exec_result {
128+
Err(CancelErr::Cancelled) => {
129+
let aborted_message = "command aborted by user".to_string();
130+
let exec_output = ExecToolCallOutput {
131+
exit_code: -1,
132+
stdout: StreamOutput::new(String::new()),
133+
stderr: StreamOutput::new(aborted_message.clone()),
134+
aggregated_output: StreamOutput::new(aborted_message.clone()),
135+
duration: Duration::ZERO,
136+
timed_out: false,
137+
};
138+
let output_items = [user_shell_command_record_item(&raw_command, &exec_output)];
139+
session
140+
.record_conversation_items(turn_context.as_ref(), &output_items)
141+
.await;
142+
session
143+
.send_event(
144+
turn_context.as_ref(),
145+
EventMsg::ExecCommandEnd(ExecCommandEndEvent {
146+
call_id,
147+
stdout: String::new(),
148+
stderr: aborted_message.clone(),
149+
aggregated_output: aborted_message.clone(),
150+
exit_code: -1,
151+
duration: Duration::ZERO,
152+
formatted_output: aborted_message,
153+
}),
154+
)
155+
.await;
156+
}
157+
Ok(Ok(output)) => {
158+
session
159+
.send_event(
160+
turn_context.as_ref(),
161+
EventMsg::ExecCommandEnd(ExecCommandEndEvent {
162+
call_id: call_id.clone(),
163+
stdout: output.stdout.text.clone(),
164+
stderr: output.stderr.text.clone(),
165+
aggregated_output: output.aggregated_output.text.clone(),
166+
exit_code: output.exit_code,
167+
duration: output.duration,
168+
formatted_output: format_exec_output_str(&output),
169+
}),
170+
)
171+
.await;
172+
173+
let output_items = [user_shell_command_record_item(&raw_command, &output)];
174+
session
175+
.record_conversation_items(turn_context.as_ref(), &output_items)
176+
.await;
177+
}
178+
Ok(Err(err)) => {
179+
error!("user shell command failed: {err:?}");
180+
let message = format!("execution error: {err:?}");
181+
let exec_output = ExecToolCallOutput {
182+
exit_code: -1,
183+
stdout: StreamOutput::new(String::new()),
184+
stderr: StreamOutput::new(message.clone()),
185+
aggregated_output: StreamOutput::new(message.clone()),
186+
duration: Duration::ZERO,
187+
timed_out: false,
188+
};
189+
session
190+
.send_event(
191+
turn_context.as_ref(),
192+
EventMsg::ExecCommandEnd(ExecCommandEndEvent {
193+
call_id,
194+
stdout: exec_output.stdout.text.clone(),
195+
stderr: exec_output.stderr.text.clone(),
196+
aggregated_output: exec_output.aggregated_output.text.clone(),
197+
exit_code: exec_output.exit_code,
198+
duration: exec_output.duration,
199+
formatted_output: format_exec_output_str(&exec_output),
200+
}),
201+
)
202+
.await;
203+
let output_items = [user_shell_command_record_item(&raw_command, &exec_output)];
204+
session
205+
.record_conversation_items(turn_context.as_ref(), &output_items)
206+
.await;
207+
}
109208
}
110209
None
111210
}
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
use std::time::Duration;
2+
3+
use codex_protocol::models::ContentItem;
4+
use codex_protocol::models::ResponseItem;
5+
6+
use crate::exec::ExecToolCallOutput;
7+
use crate::tools::format_exec_output_str;
8+
9+
pub const USER_SHELL_COMMAND_OPEN: &str = "<user_shell_command>";
10+
pub const USER_SHELL_COMMAND_CLOSE: &str = "</user_shell_command>";
11+
12+
pub fn is_user_shell_command_text(text: &str) -> bool {
13+
let trimmed = text.trim_start();
14+
let lowered = trimmed.to_ascii_lowercase();
15+
lowered.starts_with(USER_SHELL_COMMAND_OPEN)
16+
}
17+
18+
fn format_duration_line(duration: Duration) -> String {
19+
let duration_seconds = duration.as_secs_f64();
20+
format!("Duration: {duration_seconds:.4} seconds")
21+
}
22+
23+
fn format_user_shell_command_body(command: &str, exec_output: &ExecToolCallOutput) -> String {
24+
let mut sections = Vec::new();
25+
sections.push("<command>".to_string());
26+
sections.push(command.to_string());
27+
sections.push("</command>".to_string());
28+
sections.push("<result>".to_string());
29+
sections.push(format!("Exit code: {}", exec_output.exit_code));
30+
sections.push(format_duration_line(exec_output.duration));
31+
sections.push("Output:".to_string());
32+
sections.push(format_exec_output_str(exec_output));
33+
sections.push("</result>".to_string());
34+
sections.join("\n")
35+
}
36+
37+
pub fn format_user_shell_command_record(command: &str, exec_output: &ExecToolCallOutput) -> String {
38+
let body = format_user_shell_command_body(command, exec_output);
39+
format!("{USER_SHELL_COMMAND_OPEN}\n{body}\n{USER_SHELL_COMMAND_CLOSE}")
40+
}
41+
42+
pub fn user_shell_command_record_item(
43+
command: &str,
44+
exec_output: &ExecToolCallOutput,
45+
) -> ResponseItem {
46+
ResponseItem::Message {
47+
id: None,
48+
role: "user".to_string(),
49+
content: vec![ContentItem::InputText {
50+
text: format_user_shell_command_record(command, exec_output),
51+
}],
52+
}
53+
}
54+
55+
#[cfg(test)]
56+
mod tests {
57+
use super::*;
58+
use crate::exec::StreamOutput;
59+
use pretty_assertions::assert_eq;
60+
61+
#[test]
62+
fn detects_user_shell_command_text_variants() {
63+
assert!(is_user_shell_command_text(
64+
"<user_shell_command>\necho hi\n</user_shell_command>"
65+
));
66+
assert!(!is_user_shell_command_text("echo hi"));
67+
}
68+
69+
#[test]
70+
fn formats_basic_record() {
71+
let exec_output = ExecToolCallOutput {
72+
exit_code: 0,
73+
stdout: StreamOutput::new("hi".to_string()),
74+
stderr: StreamOutput::new(String::new()),
75+
aggregated_output: StreamOutput::new("hi".to_string()),
76+
duration: Duration::from_secs(1),
77+
timed_out: false,
78+
};
79+
let item = user_shell_command_record_item("echo hi", &exec_output);
80+
let ResponseItem::Message { content, .. } = item else {
81+
panic!("expected message");
82+
};
83+
let [ContentItem::InputText { text }] = content.as_slice() else {
84+
panic!("expected input text");
85+
};
86+
assert_eq!(
87+
text,
88+
"<user_shell_command>\n<command>\necho hi\n</command>\n<result>\nExit code: 0\nDuration: 1.0000 seconds\nOutput:\nhi\n</result>\n</user_shell_command>"
89+
);
90+
}
91+
92+
#[test]
93+
fn uses_aggregated_output_over_streams() {
94+
let exec_output = ExecToolCallOutput {
95+
exit_code: 42,
96+
stdout: StreamOutput::new("stdout-only".to_string()),
97+
stderr: StreamOutput::new("stderr-only".to_string()),
98+
aggregated_output: StreamOutput::new("combined output wins".to_string()),
99+
duration: Duration::from_millis(120),
100+
timed_out: false,
101+
};
102+
let record = format_user_shell_command_record("false", &exec_output);
103+
assert_eq!(
104+
record,
105+
"<user_shell_command>\n<command>\nfalse\n</command>\n<result>\nExit code: 42\nDuration: 0.1200 seconds\nOutput:\ncombined output wins\n</result>\n</user_shell_command>"
106+
);
107+
}
108+
}

codex-rs/core/tests/common/responses.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,18 @@ impl ResponsesRequest {
6161
self.0.body_json().unwrap()
6262
}
6363

64+
/// Returns all `input_text` spans from `message` inputs for the provided role.
65+
pub fn message_input_texts(&self, role: &str) -> Vec<String> {
66+
self.inputs_of_type("message")
67+
.into_iter()
68+
.filter(|item| item.get("role").and_then(Value::as_str) == Some(role))
69+
.filter_map(|item| item.get("content").and_then(Value::as_array).cloned())
70+
.flatten()
71+
.filter(|span| span.get("type").and_then(Value::as_str) == Some("input_text"))
72+
.filter_map(|span| span.get("text").and_then(Value::as_str).map(str::to_owned))
73+
.collect()
74+
}
75+
6476
pub fn input(&self) -> Vec<Value> {
6577
self.0.body_json::<Value>().unwrap()["input"]
6678
.as_array()

0 commit comments

Comments
 (0)