Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions codex-rs/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions codex-rs/utils/pty/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,12 @@ workspace = true

[dependencies]
anyhow = { workspace = true }
libc = { workspace = true }
portable-pty = { workspace = true }
tokio = { workspace = true, features = [
"macros",
"rt-multi-thread",
"sync",
"time",
] }
tracing = { workspace = true }
154 changes: 154 additions & 0 deletions codex-rs/utils/pty/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use tokio::sync::mpsc;
use tokio::sync::oneshot;
use tokio::sync::Mutex as TokioMutex;
use tokio::task::JoinHandle;
use tracing::trace;

#[derive(Debug)]
pub struct ExecCommandSession {
Expand All @@ -26,6 +27,7 @@ pub struct ExecCommandSession {
wait_handle: StdMutex<Option<JoinHandle<()>>>,
exit_status: Arc<AtomicBool>,
exit_code: Arc<StdMutex<Option<i32>>>,
pid: Arc<StdMutex<Option<u32>>>,
}

impl ExecCommandSession {
Expand All @@ -39,6 +41,7 @@ impl ExecCommandSession {
wait_handle: JoinHandle<()>,
exit_status: Arc<AtomicBool>,
exit_code: Arc<StdMutex<Option<i32>>>,
pid: Arc<StdMutex<Option<u32>>>,
) -> (Self, broadcast::Receiver<Vec<u8>>) {
let initial_output_rx = output_tx.subscribe();
(
Expand All @@ -51,6 +54,7 @@ impl ExecCommandSession {
wait_handle: StdMutex::new(Some(wait_handle)),
exit_status,
exit_code,
pid,
},
initial_output_rx,
)
Expand All @@ -73,8 +77,46 @@ impl ExecCommandSession {
}
}

/// Kills the process group for the given PID using SIGKILL.
///
/// Uses `killpg()` to terminate the process and all descendants.
/// Returns `Ok` if killed or process not found (e.g. ESRCH).
#[cfg(unix)]
fn kill_child_process_group(pid: u32) -> std::io::Result<()> {
use std::io::ErrorKind;

let pid = pid as libc::pid_t;
let pgid = unsafe { libc::getpgid(pid) };
if pgid == -1 {
let err = std::io::Error::last_os_error();
if err.kind() != ErrorKind::NotFound {
return Err(err);
}
return Ok(());
}

let result = unsafe { libc::killpg(pgid, libc::SIGKILL) };
if result == -1 {
let err = std::io::Error::last_os_error();
if err.kind() != ErrorKind::NotFound {
return Err(err);
}
}

Ok(())
}

impl Drop for ExecCommandSession {
fn drop(&mut self) {
#[cfg(unix)]
if let Ok(mut pid_guard) = self.pid.lock() {
if let Some(pid) = pid_guard.take() {
if let Err(e) = kill_child_process_group(pid) {
trace!("Failed to kill process group for pid {}: {}", pid, e);
}
Comment on lines 76 to +116
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Guard against killing reused PIDs after child exit

The session destructor kills a process group whenever pid is still set (kill_child_process_group(pid)), but the PID is only cleared in the wait thread after child.wait() returns. Because wait() reaps the process and frees its PID before the mutex is updated (lines ~233‑247), there is a brief window where the child has fully exited and its PID may already be reused by an unrelated process while pid_guard still contains the stale value. If the session is dropped during that window, getpgid/killpg will target the new process group and may SIGKILL an unrelated process tree. Consider checking exit_status before killing or clearing pid before calling wait() so that drop never acts on a PID that might already be recycled.

Useful? React with 👍 / 👎.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Taking a look.

}
}

if let Ok(mut killer_opt) = self.killer.lock() {
if let Some(mut killer) = killer_opt.take() {
let _ = killer.kill();
Expand Down Expand Up @@ -136,6 +178,14 @@ pub async fn spawn_pty_process(
}

let mut child = pair.slave.spawn_command(command_builder)?;

// portable_pty calls setsid(), which creates a process group where
// pgid == pid. This allows us to kill all descendants via killpg().
//
// TODO: We cannot set PR_SET_PDEATHSIG here because portable_pty doesn't
// expose a way to extend its pre_exec callback.
let child_pid = child.process_id();

let killer = child.clone_killer();

let (writer_tx, mut writer_rx) = mpsc::channel::<Vec<u8>>(128);
Expand Down Expand Up @@ -180,6 +230,8 @@ pub async fn spawn_pty_process(
let wait_exit_status = Arc::clone(&exit_status);
let exit_code = Arc::new(StdMutex::new(None));
let wait_exit_code = Arc::clone(&exit_code);
let pid = Arc::new(StdMutex::new(child_pid));
let wait_pid = Arc::clone(&pid);
let wait_handle: JoinHandle<()> = tokio::task::spawn_blocking(move || {
let code = match child.wait() {
Ok(status) => status.exit_code() as i32,
Expand All @@ -189,6 +241,10 @@ pub async fn spawn_pty_process(
if let Ok(mut guard) = wait_exit_code.lock() {
*guard = Some(code);
}
// Clear PID to prevent killing wrong process on drop
if let Ok(mut guard) = wait_pid.lock() {
*guard = None;
}
let _ = exit_tx.send(code);
});

Expand All @@ -201,6 +257,7 @@ pub async fn spawn_pty_process(
wait_handle,
exit_status,
exit_code,
pid,
);

Ok(SpawnedPty {
Expand All @@ -209,3 +266,100 @@ pub async fn spawn_pty_process(
exit_rx,
})
}

#[cfg(test)]
mod tests {
use super::*;

#[cfg(unix)]
#[tokio::test]
async fn test_pty_kills_grandchildren_on_drop() -> Result<()> {
let bg_pid: i32;

{
let spawned = spawn_pty_process(
"/bin/bash",
&["-c".to_string(), "sleep 60 & echo $!; sleep 60".to_string()],
&std::env::current_dir()?,
&std::env::vars().collect(),
&None,
)
.await?;

let mut output = Vec::new();
let mut rx = spawned.output_rx;

for _ in 0..10 {
tokio::time::sleep(Duration::from_millis(100)).await;
while let Ok(chunk) = rx.try_recv() {
output.extend_from_slice(&chunk);
}
if !output.is_empty() {
break;
}
}

let stdout = String::from_utf8_lossy(&output);
let pid_line = stdout.lines().next().unwrap_or("").trim();
bg_pid = pid_line.parse().map_err(|error| {
std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("Failed to parse pid from stdout '{pid_line}': {error}"),
)
})?;

// SpawnedPty drops here, which triggers process group kill.
}

// Verify background child was killed.
let mut killed = false;
for _ in 0..20 {
// Use kill(pid, 0) to check if the process is alive
if unsafe { libc::kill(bg_pid, 0) } == -1 {
if let Some(libc::ESRCH) = std::io::Error::last_os_error().raw_os_error() {
killed = true;
break;
}
}
tokio::time::sleep(Duration::from_millis(100)).await;
}

assert!(
killed,
"grandchild process with pid {bg_pid} is still alive"
);
Ok(())
}

#[cfg(unix)]
#[tokio::test]
async fn test_pty_clears_pid_after_exit() -> Result<()> {
let spawned = spawn_pty_process(
"/bin/bash",
&["-c".to_string(), "exit 0".to_string()],
&std::env::current_dir()?,
&std::env::vars().collect(),
&None,
)
.await?;

let _ = spawned.exit_rx.await;

let mut pid_cleared = false;
for _ in 0..10 {
tokio::time::sleep(Duration::from_millis(50)).await;
pid_cleared = spawned
.session
.pid
.lock()
.map(|guard| guard.is_none())
.unwrap_or(false);
if pid_cleared {
break;
}
}

assert!(pid_cleared, "PID should be cleared after process exits");
Ok(())
}
}
Loading