From 014660575eb01c13aacfc9da971fc333e14c2345 Mon Sep 17 00:00:00 2001 From: Juniper Tyree <50025784+juntyr@users.noreply.github.com> Date: Sun, 2 Feb 2025 10:19:34 +0000 Subject: [PATCH 1/2] Revert "Feat: use cuMemHostLaunch instead of cuStreamAddCallback internally" This reverts commit 6ab550cc6a15db6d5ec47768f8c08cc8df6bfd31. --- crates/cust/src/stream.rs | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/crates/cust/src/stream.rs b/crates/cust/src/stream.rs index 25a4178d..e28be5cf 100644 --- a/crates/cust/src/stream.rs +++ b/crates/cust/src/stream.rs @@ -13,7 +13,7 @@ use crate::error::{CudaResult, DropResult, ToResult}; use crate::event::Event; use crate::function::{BlockSize, Function, GridSize}; -use crate::sys::{self as cuda, CUstream}; +use crate::sys::{self as cuda, cudaError_enum, CUstream}; use std::ffi::c_void; use std::mem; use std::panic; @@ -151,6 +151,9 @@ impl Stream { /// /// Callbacks must not make any CUDA API calls. /// + /// The callback will be passed a `CudaResult<()>` indicating the + /// current state of the device with `Ok(())` denoting normal operation. + /// /// # Examples /// /// ``` @@ -164,8 +167,8 @@ impl Stream { /// /// // ... queue up some work on the stream /// - /// stream.add_callback(Box::new(|| { - /// println!("Work is done!"); + /// stream.add_callback(Box::new(|status| { + /// println!("Device status is {:?}", status); /// })); /// /// // ... queue up some more work on the stream @@ -173,13 +176,14 @@ impl Stream { /// # } pub fn add_callback(&self, callback: Box) -> CudaResult<()> where - T: FnOnce() + Send, + T: FnOnce(CudaResult<()>) + Send, { unsafe { - cuda::cuLaunchHostFunc( + cuda::cuStreamAddCallback( self.inner, Some(callback_wrapper::), Box::into_raw(callback) as *mut c_void, + 0, ) .to_result() } @@ -339,13 +343,16 @@ impl Drop for Stream { } } } -unsafe extern "C" fn callback_wrapper(callback: *mut c_void) -where - T: FnOnce() + Send, +unsafe extern "C" fn callback_wrapper( + _stream: CUstream, + status: cudaError_enum, + callback: *mut c_void, +) where + T: FnOnce(CudaResult<()>) + Send, { // Stop panics from unwinding across the FFI let _ = panic::catch_unwind(|| { let callback: Box = Box::from_raw(callback as *mut T); - callback(); + callback(status.to_result()); }); } From 97ebe94705add424c7661bd6d73e8e95a1892ed1 Mon Sep 17 00:00:00 2001 From: Juniper Tyree <50025784+juntyr@users.noreply.github.com> Date: Sun, 2 Feb 2025 10:24:05 +0000 Subject: [PATCH 2/2] Amend the changelog --- crates/cust/CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/crates/cust/CHANGELOG.md b/crates/cust/CHANGELOG.md index d734c687..ff16db76 100644 --- a/crates/cust/CHANGELOG.md +++ b/crates/cust/CHANGELOG.md @@ -6,6 +6,7 @@ Notable changes to this project will be documented in this file. - Add `memory::memcpy_dtoh` to allow copying from device to host. - Add support in `memory` for pitched malloc and 2D memcpy between device and host. + - `Stream::add_callback` now internally uses `cuStreamAddCallback` again, since there are no current plans to remove it (https://stackoverflow.com/a/58173486). As a result, the function again takes a device status as a parameter and *does* execute on context error. ## 0.3.2 - 2/16/22