diff --git a/Cargo.lock b/Cargo.lock index 7c98b754..ddfbe373 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -853,11 +853,12 @@ checksum = "9f7c3e4beb33f85d45ae3e3a1792185706c8e16d043238c593331cc7cd313b50" [[package]] name = "openab" -version = "0.6.6" +version = "0.7.2" dependencies = [ "anyhow", "base64", "image", + "libc", "rand 0.8.5", "regex", "reqwest", diff --git a/Cargo.toml b/Cargo.toml index 3b3b1514..3dd0ba7d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,6 +14,7 @@ serenity = { version = "0.12", default-features = false, features = ["client", " uuid = { version = "1", features = ["v4"] } regex = "1" anyhow = "1" +libc = "0.2" rand = "0.8" reqwest = { version = "0.12", default-features = false, features = ["rustls-tls", "multipart", "json"] } base64 = "0.22" diff --git a/src/acp/connection.rs b/src/acp/connection.rs index 83efd50d..0cc9a0b3 100644 --- a/src/acp/connection.rs +++ b/src/acp/connection.rs @@ -1,4 +1,6 @@ use crate::acp::protocol::{JsonRpcMessage, JsonRpcRequest, JsonRpcResponse}; +#[cfg(unix)] +use libc; use anyhow::{anyhow, Result}; use serde_json::{json, Value}; use std::collections::HashMap; @@ -131,8 +133,10 @@ impl AcpConnection { .stdin(std::process::Stdio::piped()) .stdout(std::process::Stdio::piped()) .stderr(std::process::Stdio::null()) - .current_dir(working_dir) - .kill_on_drop(true); + .current_dir(working_dir); + + #[cfg(unix)] + cmd.process_group(0); for (k, v) in env { cmd.env(k, expand_env(v)); } @@ -384,10 +388,25 @@ impl AcpConnection { } } +#[cfg(unix)] +impl Drop for AcpConnection { + fn drop(&mut self) { + if let Some(pid) = self._proc.id() { + // Send SIGTERM to the entire process group (-PGID) to clean up orphaned grandchildren + unsafe { + let pgid = pid as i32; + libc::kill(-pgid, libc::SIGTERM); + } + } + } +} + #[cfg(test)] mod tests { - use super::{build_permission_response, pick_best_option}; + use super::{build_permission_response, pick_best_option, AcpConnection}; use serde_json::json; + use std::collections::HashMap; + use tokio::time::Duration; #[test] fn picks_allow_always_over_other_options() { @@ -479,4 +498,52 @@ mod tests { json!({"outcome": {"outcome": "selected", "optionId": "allow_always"}}) ); } + + #[tokio::test] + async fn test_process_group_cleanup() -> anyhow::Result<()> { + #[cfg(unix)] + { + // A script that spawns a background process and stays alive + // We use 'sleep 100' as a grandchild that should be killed + let script = "sh -c 'sleep 100' & sleep 100"; + + let conn = + AcpConnection::spawn("sh", &["-c".to_string(), script.to_string()], ".", &HashMap::new()).await?; + + tokio::time::sleep(Duration::from_millis(500)).await; + + let pid = conn._proc.id().expect("should have pid"); + + // Find grandchild pid + let output = std::process::Command::new("pgrep") + .arg("-P") + .arg(pid.to_string()) + .output()?; + let grandchild_pid_str = String::from_utf8_lossy(&output.stdout).trim().to_string(); + assert!( + !grandchild_pid_str.is_empty(), + "Grandchild process should exist" + ); + // If multiple, take the first one + let grandchild_pid_str = grandchild_pid_str.lines().next().unwrap(); + let grandchild_pid: i32 = grandchild_pid_str.parse().expect("should be a pid"); + + // Drop the connection, which should kill the group + drop(conn); + + tokio::time::sleep(Duration::from_millis(500)).await; + + // Check if grandchild is gone. kill -0 pid checks if process exists. + let status = std::process::Command::new("kill") + .arg("-0") + .arg(grandchild_pid.to_string()) + .status(); + + assert!( + status.is_err() || !status.unwrap().success(), + "Grandchild process should be killed" + ); + } + Ok(()) + } }