From 6527806b7cd1e47007d53e7d4fbc2bcac50569ee Mon Sep 17 00:00:00 2001 From: Jorge Prendes Date: Thu, 6 Mar 2025 11:37:59 +0000 Subject: [PATCH] use async rust-extensions Signed-off-by: Jorge Prendes --- Cargo.lock | 77 ++++++++++++++- crates/containerd-shim-wasm/Cargo.toml | 3 +- .../containerd-shim-wasm/src/sandbox/cli.rs | 97 ++++++++----------- .../src/sandbox/shim/cli.rs | 18 ++-- .../src/sandbox/shim/events.rs | 37 ++++--- .../src/sandbox/shim/local.rs | 27 +++--- .../src/sandbox/shim/otel.rs | 2 +- .../src/sys/unix/container/instance.rs | 2 +- .../src/sys/unix/pid_fd.rs | 33 ++++--- 9 files changed, 178 insertions(+), 118 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 27eab9513..7793571ac 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -182,9 +182,9 @@ dependencies = [ [[package]] name = "async-trait" -version = "0.1.83" +version = "0.1.87" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "721cae7de5c34fbb2acd27e21e6d2cf7b886dce0c27388d46c4e6c47ea4318dd" +checksum = "d556ec1359574147ec0c4fc5eb525f3f23263a592b1a9c07e0a75b427de55c97" dependencies = [ "proc-macro2", "quote", @@ -850,9 +850,11 @@ version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6a6ddc50d113188cb707839b8670faabdbab39c052846e2430ea8d47d893b18d" dependencies = [ + "async-trait", "cgroups-rs", "command-fds", "containerd-shim-protos", + "futures", "go-flag", "lazy_static", "libc", @@ -867,8 +869,10 @@ dependencies = [ "serde_json", "sha2", "signal-hook", + "signal-hook-tokio", "thiserror 2.0.12", "time", + "tokio", "which 7.0.1", "windows-sys 0.52.0", ] @@ -886,6 +890,7 @@ version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fb8db604974f81d1e350d30f274872f43b45e79203ebb8b1ff714e7b18d24e81" dependencies = [ + "async-trait", "protobuf 3.2.0", "ttrpc", "ttrpc-codegen", @@ -907,6 +912,7 @@ name = "containerd-shim-wasm" version = "0.10.0" dependencies = [ "anyhow", + "async-trait", "caps", "chrono", "containerd-client", @@ -3205,6 +3211,15 @@ dependencies = [ "libc", ] +[[package]] +name = "memoffset" +version = "0.6.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5aa361d4faea93603064a027415f07bd8e1d5c88c9fbf68bf56a285428fd79ce" +dependencies = [ + "autocfg", +] + [[package]] name = "memoffset" version = "0.7.1" @@ -3320,6 +3335,18 @@ dependencies = [ "cc", ] +[[package]] +name = "nix" +version = "0.24.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fa52e972a9a719cecb6864fb88568781eb706bac2cd1d4f04a648542dbf78069" +dependencies = [ + "bitflags 1.3.2", + "cfg-if 1.0.0", + "libc", + "memoffset 0.6.5", +] + [[package]] name = "nix" version = "0.25.1" @@ -5246,6 +5273,18 @@ dependencies = [ "libc", ] +[[package]] +name = "signal-hook-tokio" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "213241f76fb1e37e27de3b6aa1b068a2c333233b59cca6634f634b80a27ecf1e" +dependencies = [ + "futures-core", + "libc", + "signal-hook", + "tokio", +] + [[package]] name = "simd-adler32" version = "0.3.7" @@ -5693,6 +5732,7 @@ dependencies = [ "bytes", "libc", "mio", + "parking_lot", "pin-project-lite", "signal-hook-registry", "socket2", @@ -5808,6 +5848,19 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-vsock" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52a15c15b1bc91f90902347eff163b5b682643aff0c8e972912cca79bd9208dd" +dependencies = [ + "bytes", + "futures", + "libc", + "tokio", + "vsock 0.3.0", +] + [[package]] name = "tokio-vsock" version = "0.6.0" @@ -5818,7 +5871,7 @@ dependencies = [ "futures", "libc", "tokio", - "vsock", + "vsock 0.5.1", ] [[package]] @@ -6104,7 +6157,7 @@ dependencies = [ "tokio", "tokio-stream", "tokio-util", - "tokio-vsock", + "tokio-vsock 0.6.0", "trapeze-codegen", "trapeze-macros", "windows-sys 0.59.0", @@ -6145,8 +6198,10 @@ version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2c580c498a547b4c083ec758be543e11a0772e03013aef4cdb1fbe77c8b62cae" dependencies = [ + "async-trait", "byteorder", "crossbeam", + "futures", "home", "libc", "log", @@ -6154,6 +6209,8 @@ dependencies = [ "protobuf 3.2.0", "protobuf-codegen 3.2.0", "thiserror 1.0.69", + "tokio", + "tokio-vsock 0.4.0", "windows-sys 0.48.0", ] @@ -6190,7 +6247,7 @@ version = "1.6.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "97fee6b57c6a41524a810daee9286c02d7752c4253064d0b05472833a438f675" dependencies = [ - "cfg-if 0.1.10", + "cfg-if 1.0.0", "static_assertions", ] @@ -6400,6 +6457,16 @@ dependencies = [ "virtual-mio", ] +[[package]] +name = "vsock" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c8e1df0bf1e1b28095c24564d1b90acae64ca69b097ed73896e342fa6649c57" +dependencies = [ + "libc", + "nix 0.24.3", +] + [[package]] name = "vsock" version = "0.5.1" diff --git a/crates/containerd-shim-wasm/Cargo.toml b/crates/containerd-shim-wasm/Cargo.toml index c480e7835..41601f121 100644 --- a/crates/containerd-shim-wasm/Cargo.toml +++ b/crates/containerd-shim-wasm/Cargo.toml @@ -14,7 +14,7 @@ doctest = false [dependencies] anyhow = { workspace = true } chrono = { workspace = true } -containerd-shim = { workspace = true } +containerd-shim = { workspace = true, features = ["async"] } containerd-shim-wasm-test-modules = { workspace = true, optional = true } oci-tar-builder = { workspace = true, optional = true } env_logger = { workspace = true, optional = true } @@ -38,6 +38,7 @@ prost = "0.13" toml = "0.8" trait-variant = "0.1" tokio-async-drop = "0.1" +async-trait = "0.1.87" # tracing # note: it's important to keep the version of tracing in sync with tracing-subscriber diff --git a/crates/containerd-shim-wasm/src/sandbox/cli.rs b/crates/containerd-shim-wasm/src/sandbox/cli.rs index fa116957f..a98100faa 100644 --- a/crates/containerd-shim-wasm/src/sandbox/cli.rs +++ b/crates/containerd-shim-wasm/src/sandbox/cli.rs @@ -168,6 +168,34 @@ fn init_zygote_and_logger(debug: bool, config: &Config) { ); } +fn otel_traces_guard() -> Box { + #[cfg(feature = "opentelemetry")] + if otel_traces_enabled() { + let otlp_config = OtlpConfig::build_from_env().expect("Failed to build OtelConfig."); + let guard = otlp_config + .init() + .expect("Failed to initialize OpenTelemetry."); + return Box::new(guard); + } + Box::new(()) +} + +fn init_traces_context() { + #[cfg(feature = "opentelemetry")] + // read TRACECONTEXT env var that's set by the parent process + if let Ok(ctx) = std::env::var("TRACECONTEXT") { + OtlpConfig::set_trace_context(&ctx).unwrap(); + } else { + let ctx = OtlpConfig::get_trace_context().unwrap(); + // SAFETY: although it's in a multithreaded context, + // it's safe to assume that all the other threads are not + // messing with env vars at this point. + unsafe { + std::env::set_var("TRACECONTEXT", ctx); + } + } +} + /// Main entry point for the shim. /// /// If the `opentelemetry` feature is enabled, this function will start the shim with OpenTelemetry tracing. @@ -199,69 +227,24 @@ pub fn shim_main<'a, I>( std::process::exit(0); } + let config = config.unwrap_or_default(); + // Initialize the zygote and logger for the container process #[cfg(unix)] - { - let default_config = Config::default(); - let config = config.as_ref().unwrap_or(&default_config); - init_zygote_and_logger(flags.debug, config); - } + init_zygote_and_logger(flags.debug, &config); - #[cfg(feature = "opentelemetry")] - if otel_traces_enabled() { - // opentelemetry uses tokio, so we need to initialize a runtime - async { - let otlp_config = OtlpConfig::build_from_env().expect("Failed to build OtelConfig."); - let _guard = otlp_config - .init() - .expect("Failed to initialize OpenTelemetry."); - tokio::task::block_in_place(move || { - shim_main_inner::(name, version, revision, shim_version, config); - }); - } - .block_on(); - } else { - shim_main_inner::(name, version, revision, shim_version, config); - } + async { + let _guard = otel_traces_guard(); + init_traces_context(); - #[cfg(not(feature = "opentelemetry"))] - { - shim_main_inner::(name, version, revision, shim_version, config); + let shim_version = shim_version.into().unwrap_or("v1"); + let lower_name = name.to_lowercase(); + let shim_id = format!("io.containerd.{lower_name}.{shim_version}"); + + run::>(&shim_id, Some(config)).await; } + .block_on(); #[cfg(target_os = "linux")] log_mem(); } - -#[cfg_attr(feature = "tracing", tracing::instrument(level = "Info"))] -fn shim_main_inner<'a, I>( - name: &str, - version: &str, - revision: impl Into> + std::fmt::Debug, - shim_version: impl Into> + std::fmt::Debug, - config: Option, -) where - I: 'static + Instance + Sync + Send, -{ - #[cfg(feature = "opentelemetry")] - { - // read TRACECONTEXT env var that's set by the parent process - if let Ok(ctx) = std::env::var("TRACECONTEXT") { - OtlpConfig::set_trace_context(&ctx).unwrap(); - } else { - let ctx = OtlpConfig::get_trace_context().unwrap(); - // SAFETY: although it's in a multithreaded context, - // it's safe to assume that all the other threads are not - // messing with env vars at this point. - unsafe { - std::env::set_var("TRACECONTEXT", ctx); - } - } - } - - let shim_version = shim_version.into().unwrap_or("v1"); - let lower_name = name.to_lowercase(); - let shim_id = format!("io.containerd.{lower_name}.{shim_version}"); - - run::>(&shim_id, config); -} diff --git a/crates/containerd-shim-wasm/src/sandbox/shim/cli.rs b/crates/containerd-shim-wasm/src/sandbox/shim/cli.rs index 9bb0c1a5f..d0d8211e5 100644 --- a/crates/containerd-shim-wasm/src/sandbox/shim/cli.rs +++ b/crates/containerd-shim-wasm/src/sandbox/shim/cli.rs @@ -2,10 +2,11 @@ use std::env::current_dir; use std::fmt::Debug; use std::marker::PhantomData; +use async_trait::async_trait; use chrono::Utc; use containerd_shim::error::Error as ShimError; use containerd_shim::publisher::RemotePublisher; -use containerd_shim::util::write_address; +use containerd_shim::util::write_str_to_file; use containerd_shim::{self as shim, api}; use oci_spec::runtime::Spec; use shim::Flags; @@ -38,6 +39,7 @@ where } } +#[async_trait] impl shim::Shim for Cli where I: Instance + Sync + Send, @@ -45,7 +47,7 @@ where type T = Local; #[cfg_attr(feature = "tracing", tracing::instrument(level = "Info"))] - fn new(_runtime_id: &str, args: &Flags, _config: &mut shim::Config) -> Self { + async fn new(_runtime_id: &str, args: &Flags, _config: &mut shim::Config) -> Self { Cli { namespace: args.namespace.to_string(), containerd_address: args.address.clone(), @@ -56,7 +58,7 @@ where } #[cfg_attr(feature = "tracing", tracing::instrument(level = "Info"))] - fn start_shim(&mut self, opts: containerd_shim::StartOpts) -> shim::Result { + async fn start_shim(&mut self, opts: containerd_shim::StartOpts) -> shim::Result { let dir = current_dir().map_err(|err| ShimError::Other(err.to_string()))?; let spec = Spec::load(dir.join("config.json")).map_err(|err| { shim::Error::InvalidArgument(format!("error loading runtime spec: {}", err)) @@ -69,15 +71,15 @@ where .and_then(|a| a.get("io.kubernetes.cri.sandbox-id")) .unwrap_or(&id); - let (_child, address) = shim::spawn(opts, grouping, vec![])?; + let address = shim::spawn(opts, grouping, vec![]).await?; - write_address(&address)?; + write_str_to_file("address", &address).await?; Ok(address) } #[cfg_attr(feature = "tracing", tracing::instrument(level = "Info"))] - fn wait(&mut self) { + async fn wait(&mut self) { self.exit.wait().block_on(); } @@ -85,14 +87,14 @@ where feature = "tracing", tracing::instrument(skip(publisher), level = "Info") )] - fn create_task_service(&self, publisher: RemotePublisher) -> Self::T { + async fn create_task_service(&self, publisher: RemotePublisher) -> Self::T { let events = RemoteEventSender::new(&self.namespace, publisher); let exit = self.exit.clone(); Local::::new(events, exit, &self.namespace, &self.containerd_address) } #[cfg_attr(feature = "tracing", tracing::instrument(level = "Info"))] - fn delete_shim(&mut self) -> shim::Result { + async fn delete_shim(&mut self) -> shim::Result { Ok(api::DeleteResponse { exit_status: 137, exited_at: Some(Utc::now().to_timestamp()).into(), diff --git a/crates/containerd-shim-wasm/src/sandbox/shim/events.rs b/crates/containerd-shim-wasm/src/sandbox/shim/events.rs index c727fe051..9d563dbba 100644 --- a/crates/containerd-shim-wasm/src/sandbox/shim/events.rs +++ b/crates/containerd-shim-wasm/src/sandbox/shim/events.rs @@ -1,10 +1,11 @@ -use std::sync::Arc; - use chrono::{DateTime, TimeZone}; use containerd_shim::event::Event; use containerd_shim::publisher::RemotePublisher; use log::warn; +use protobuf::MessageDyn; use protobuf::well_known_types::timestamp::Timestamp; +use tokio::sync::mpsc::error::SendError; +use tokio::sync::mpsc::{UnboundedSender, unbounded_channel}; pub trait EventSender: Clone + Send + Sync + 'static { fn send(&self, event: impl Event); @@ -12,23 +13,24 @@ pub trait EventSender: Clone + Send + Sync + 'static { #[derive(Clone)] pub struct RemoteEventSender { - inner: Arc, -} - -struct Inner { - namespace: String, - publisher: RemotePublisher, + tx: UnboundedSender<(String, Box)>, } impl RemoteEventSender { pub fn new(namespace: impl AsRef, publisher: RemotePublisher) -> RemoteEventSender { let namespace = namespace.as_ref().to_string(); - RemoteEventSender { - inner: Arc::new(Inner { - namespace, - publisher, - }), - } + let (tx, mut rx) = unbounded_channel::<(String, Box)>(); + tokio::spawn(async move { + while let Some((topic, event)) = rx.recv().await { + if let Err(err) = publisher + .publish(Default::default(), &topic, &namespace, event) + .await + { + warn!("failed to publish event, topic: {topic}: {err}") + } + } + }); + RemoteEventSender { tx } } } @@ -36,11 +38,8 @@ impl EventSender for RemoteEventSender { fn send(&self, event: impl Event) { let topic = event.topic(); let event = Box::new(event); - let publisher = &self.inner.publisher; - if let Err(err) = - publisher.publish(Default::default(), &topic, &self.inner.namespace, event) - { - warn!("failed to publish event, topic: {}: {}", &topic, err) + if let Err(SendError((topic, _))) = self.tx.send((topic, event)) { + warn!("failed to publish event, topic: {topic}: channel closed") } } } diff --git a/crates/containerd-shim-wasm/src/sandbox/shim/local.rs b/crates/containerd-shim-wasm/src/sandbox/shim/local.rs index 56849cb7c..3847ff924 100644 --- a/crates/containerd-shim-wasm/src/sandbox/shim/local.rs +++ b/crates/containerd-shim-wasm/src/sandbox/shim/local.rs @@ -5,6 +5,7 @@ use std::path::Path; use std::sync::Arc; use anyhow::ensure; +use async_trait::async_trait; use containerd_shim::api::{ ConnectRequest, ConnectResponse, CreateTaskRequest, CreateTaskResponse, DeleteRequest, Empty, KillRequest, ShutdownRequest, StartRequest, StartResponse, StateRequest, StateResponse, @@ -12,10 +13,9 @@ use containerd_shim::api::{ }; use containerd_shim::error::Error as ShimError; use containerd_shim::protos::events::task::{TaskCreate, TaskDelete, TaskExit, TaskIO, TaskStart}; -use containerd_shim::protos::shim::shim_ttrpc::Task; use containerd_shim::protos::types::task::Status; use containerd_shim::util::IntoOption; -use containerd_shim::{DeleteResponse, TtrpcContext, TtrpcResult}; +use containerd_shim::{DeleteResponse, Task, TtrpcContext, TtrpcResult}; use futures::FutureExt as _; use log::debug; use oci_spec::runtime::Spec; @@ -378,9 +378,10 @@ impl Local { } } +#[async_trait] impl Task for Local { #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Info"))] - fn create( + async fn create( &self, _ctx: &TtrpcContext, req: CreateTaskRequest, @@ -394,7 +395,7 @@ impl Task for Local { } #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Info"))] - fn start(&self, _ctx: &TtrpcContext, req: StartRequest) -> TtrpcResult { + async fn start(&self, _ctx: &TtrpcContext, req: StartRequest) -> TtrpcResult { debug!("start: {:?}", req); #[cfg(feature = "opentelemetry")] @@ -404,7 +405,7 @@ impl Task for Local { } #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Info"))] - fn kill(&self, _ctx: &TtrpcContext, req: KillRequest) -> TtrpcResult { + async fn kill(&self, _ctx: &TtrpcContext, req: KillRequest) -> TtrpcResult { debug!("kill: {:?}", req); #[cfg(feature = "opentelemetry")] @@ -414,7 +415,7 @@ impl Task for Local { } #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Info"))] - fn delete(&self, _ctx: &TtrpcContext, req: DeleteRequest) -> TtrpcResult { + async fn delete(&self, _ctx: &TtrpcContext, req: DeleteRequest) -> TtrpcResult { debug!("delete: {:?}", req); #[cfg(feature = "opentelemetry")] @@ -424,7 +425,7 @@ impl Task for Local { } #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Info"))] - fn wait(&self, _ctx: &TtrpcContext, req: WaitRequest) -> TtrpcResult { + async fn wait(&self, _ctx: &TtrpcContext, req: WaitRequest) -> TtrpcResult { debug!("wait: {:?}", req); #[cfg(feature = "opentelemetry")] @@ -464,7 +465,11 @@ impl Task for Local { } #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Info"))] - fn connect(&self, _ctx: &TtrpcContext, req: ConnectRequest) -> TtrpcResult { + async fn connect( + &self, + _ctx: &TtrpcContext, + req: ConnectRequest, + ) -> TtrpcResult { debug!("connect: {:?}", req); #[cfg(feature = "opentelemetry")] @@ -481,7 +486,7 @@ impl Task for Local { } #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Info"))] - fn state(&self, _ctx: &TtrpcContext, req: StateRequest) -> TtrpcResult { + async fn state(&self, _ctx: &TtrpcContext, req: StateRequest) -> TtrpcResult { debug!("state: {:?}", req); #[cfg(feature = "opentelemetry")] @@ -491,7 +496,7 @@ impl Task for Local { } #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Info"))] - fn shutdown(&self, _ctx: &TtrpcContext, _: ShutdownRequest) -> TtrpcResult { + async fn shutdown(&self, _ctx: &TtrpcContext, _: ShutdownRequest) -> TtrpcResult { debug!("shutdown"); #[cfg(feature = "opentelemetry")] @@ -504,7 +509,7 @@ impl Task for Local { } #[cfg_attr(feature = "tracing", tracing::instrument(skip(self), level = "Info"))] - fn stats(&self, _ctx: &TtrpcContext, req: StatsRequest) -> TtrpcResult { + async fn stats(&self, _ctx: &TtrpcContext, req: StatsRequest) -> TtrpcResult { debug!("stats: {:?}", req); #[cfg(feature = "opentelemetry")] diff --git a/crates/containerd-shim-wasm/src/sandbox/shim/otel.rs b/crates/containerd-shim-wasm/src/sandbox/shim/otel.rs index 4d7300108..df6ab38cb 100644 --- a/crates/containerd-shim-wasm/src/sandbox/shim/otel.rs +++ b/crates/containerd-shim-wasm/src/sandbox/shim/otel.rs @@ -87,7 +87,7 @@ impl Config { /// Initializes the tracer, sets up the telemetry and subscriber layers, and sets the global subscriber. /// /// Note: this function should be called only once and be called by the binary entry point. - pub fn init(&self) -> anyhow::Result { + pub fn init(&self) -> anyhow::Result { let tracer = self.init_tracer()?; let telemetry = tracing_opentelemetry::layer().with_tracer(tracer); set_text_map_propagator(TraceContextPropagator::new()); diff --git a/crates/containerd-shim-wasm/src/sys/unix/container/instance.rs b/crates/containerd-shim-wasm/src/sys/unix/container/instance.rs index e404aa809..f25e3ee14 100644 --- a/crates/containerd-shim-wasm/src/sys/unix/container/instance.rs +++ b/crates/containerd-shim-wasm/src/sys/unix/container/instance.rs @@ -93,7 +93,7 @@ impl SandboxInstance for Instance { // Use a pidfd FD so that we can wait for the process to exit asynchronously. // This should be created BEFORE calling container.start() to ensure we never // miss the SIGCHLD event. - let pidfd = PidFd::new(pid)?; + let pidfd = PidFd::new(pid).await?; self.container.start()?; diff --git a/crates/containerd-shim-wasm/src/sys/unix/pid_fd.rs b/crates/containerd-shim-wasm/src/sys/unix/pid_fd.rs index e63ac1910..050f90ed4 100644 --- a/crates/containerd-shim-wasm/src/sys/unix/pid_fd.rs +++ b/crates/containerd-shim-wasm/src/sys/unix/pid_fd.rs @@ -8,6 +8,8 @@ use nix::sys::wait::{Id, WaitPidFlag, WaitStatus, waitid}; use nix::unistd::Pid; use tokio::io::unix::AsyncFd; +use crate::sandbox::async_utils::AmbientRuntime; + pub(super) struct PidFd { fd: OwnedFd, pid: pid_t, @@ -15,10 +17,10 @@ pub(super) struct PidFd { } impl PidFd { - pub(super) fn new(pid: impl Into) -> anyhow::Result { + pub(super) async fn new(pid: impl Into) -> anyhow::Result { use libc::{PIDFD_NONBLOCK, SYS_pidfd_open, syscall}; let pid = pid.into(); - let subs = monitor_subscribe(Topic::Pid)?; + let subs = monitor_subscribe(Topic::Pid).await?; let pidfd = unsafe { syscall(SYS_pidfd_open, pid, PIDFD_NONBLOCK) }; if pidfd == -1 { return Err(std::io::Error::last_os_error().into()); @@ -58,18 +60,19 @@ impl PidFd { } } -pub async fn try_wait_pid(pid: i32, s: Subscription) -> Result { - tokio::task::spawn_blocking(move || { - while let Ok(ExitEvent { subject, exit_code }) = s.rx.recv_timeout(Duration::from_secs(2)) { - let Subject::Pid(p) = subject else { - continue; - }; - if pid == p { - return Ok(exit_code); - } +pub async fn try_wait_pid(pid: i32, mut s: Subscription) -> Result { + while let Some(ExitEvent { subject, exit_code }) = + s.rx.recv() + .with_timeout(Duration::from_secs(2)) + .await + .flatten() + { + let Subject::Pid(p) = subject else { + continue; + }; + if pid == p { + return Ok(exit_code); } - Err(Errno::ECHILD) - }) - .await - .map_err(|_| Errno::ECHILD)? + } + Err(Errno::ECHILD) }