Skip to content

Commit 5db35d2

Browse files
authored
Allow Stream::add_callback to run on context error again (#138)
* Revert "Feat: use cuMemHostLaunch instead of cuStreamAddCallback internally" This reverts commit 6ab550c. * Amend the changelog
1 parent 833485c commit 5db35d2

File tree

2 files changed

+17
-9
lines changed

2 files changed

+17
-9
lines changed

crates/cust/CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ Notable changes to this project will be documented in this file.
66

77
- Add `memory::memcpy_dtoh` to allow copying from device to host.
88
- Add support in `memory` for pitched malloc and 2D memcpy between device and host.
9+
- `Stream::add_callback` now internally uses `cuStreamAddCallback` again, since there are no current plans to remove it (https://stackoverflow.com/a/58173486). As a result, the function again takes a device status as a parameter and *does* execute on context error.
910

1011
## 0.3.2 - 2/16/22
1112

crates/cust/src/stream.rs

+16-9
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
use crate::error::{CudaResult, DropResult, ToResult};
1414
use crate::event::Event;
1515
use crate::function::{BlockSize, Function, GridSize};
16-
use crate::sys::{self as cuda, CUstream};
16+
use crate::sys::{self as cuda, cudaError_enum, CUstream};
1717
use std::ffi::c_void;
1818
use std::mem;
1919
use std::panic;
@@ -151,6 +151,9 @@ impl Stream {
151151
///
152152
/// Callbacks must not make any CUDA API calls.
153153
///
154+
/// The callback will be passed a `CudaResult<()>` indicating the
155+
/// current state of the device with `Ok(())` denoting normal operation.
156+
///
154157
/// # Examples
155158
///
156159
/// ```
@@ -164,22 +167,23 @@ impl Stream {
164167
///
165168
/// // ... queue up some work on the stream
166169
///
167-
/// stream.add_callback(Box::new(|| {
168-
/// println!("Work is done!");
170+
/// stream.add_callback(Box::new(|status| {
171+
/// println!("Device status is {:?}", status);
169172
/// }));
170173
///
171174
/// // ... queue up some more work on the stream
172175
/// # Ok(())
173176
/// # }
174177
pub fn add_callback<T>(&self, callback: Box<T>) -> CudaResult<()>
175178
where
176-
T: FnOnce() + Send,
179+
T: FnOnce(CudaResult<()>) + Send,
177180
{
178181
unsafe {
179-
cuda::cuLaunchHostFunc(
182+
cuda::cuStreamAddCallback(
180183
self.inner,
181184
Some(callback_wrapper::<T>),
182185
Box::into_raw(callback) as *mut c_void,
186+
0,
183187
)
184188
.to_result()
185189
}
@@ -339,13 +343,16 @@ impl Drop for Stream {
339343
}
340344
}
341345
}
342-
unsafe extern "C" fn callback_wrapper<T>(callback: *mut c_void)
343-
where
344-
T: FnOnce() + Send,
346+
unsafe extern "C" fn callback_wrapper<T>(
347+
_stream: CUstream,
348+
status: cudaError_enum,
349+
callback: *mut c_void,
350+
) where
351+
T: FnOnce(CudaResult<()>) + Send,
345352
{
346353
// Stop panics from unwinding across the FFI
347354
let _ = panic::catch_unwind(|| {
348355
let callback: Box<T> = Box::from_raw(callback as *mut T);
349-
callback();
356+
callback(status.to_result());
350357
});
351358
}

0 commit comments

Comments
 (0)