Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow Stream::add_callback to run on context error again #138

Merged
merged 2 commits into from
Feb 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions crates/cust/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Notable changes to this project will be documented in this file.

- Add `memory::memcpy_dtoh` to allow copying from device to host.
- Add support in `memory` for pitched malloc and 2D memcpy between device and host.
- `Stream::add_callback` now internally uses `cuStreamAddCallback` again, since there are no current plans to remove it (https://stackoverflow.com/a/58173486). As a result, the function again takes a device status as a parameter and *does* execute on context error.

## 0.3.2 - 2/16/22

Expand Down
25 changes: 16 additions & 9 deletions crates/cust/src/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
use crate::error::{CudaResult, DropResult, ToResult};
use crate::event::Event;
use crate::function::{BlockSize, Function, GridSize};
use crate::sys::{self as cuda, CUstream};
use crate::sys::{self as cuda, cudaError_enum, CUstream};
use std::ffi::c_void;
use std::mem;
use std::panic;
Expand Down Expand Up @@ -151,6 +151,9 @@ impl Stream {
///
/// Callbacks must not make any CUDA API calls.
///
/// The callback will be passed a `CudaResult<()>` indicating the
/// current state of the device with `Ok(())` denoting normal operation.
///
/// # Examples
///
/// ```
Expand All @@ -164,22 +167,23 @@ impl Stream {
///
/// // ... queue up some work on the stream
///
/// stream.add_callback(Box::new(|| {
/// println!("Work is done!");
/// stream.add_callback(Box::new(|status| {
/// println!("Device status is {:?}", status);
/// }));
///
/// // ... queue up some more work on the stream
/// # Ok(())
/// # }
pub fn add_callback<T>(&self, callback: Box<T>) -> CudaResult<()>
where
T: FnOnce() + Send,
T: FnOnce(CudaResult<()>) + Send,
{
unsafe {
cuda::cuLaunchHostFunc(
cuda::cuStreamAddCallback(
self.inner,
Some(callback_wrapper::<T>),
Box::into_raw(callback) as *mut c_void,
0,
)
.to_result()
}
Expand Down Expand Up @@ -339,13 +343,16 @@ impl Drop for Stream {
}
}
}
unsafe extern "C" fn callback_wrapper<T>(callback: *mut c_void)
where
T: FnOnce() + Send,
unsafe extern "C" fn callback_wrapper<T>(
_stream: CUstream,
status: cudaError_enum,
callback: *mut c_void,
) where
T: FnOnce(CudaResult<()>) + Send,
{
// Stop panics from unwinding across the FFI
let _ = panic::catch_unwind(|| {
let callback: Box<T> = Box::from_raw(callback as *mut T);
callback();
callback(status.to_result());
});
}
Loading