Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added safe wrappers called copy_from_async_sync and copy_to_async_syc in crates/cust/src/memory/device/device_slice.rs #140

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions crates/cust/src/memory/device/device_slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,79 @@ impl<T: DeviceCopy> DeviceSlice<T> {
}
}

impl<T: DeviceCopy> DeviceSlice<T> {
/// Thw Asynchronously should now copy data from the `source` into this slice and then synchronizes the stream.
///
/// This will be a safe wrapper around the unsafe asynchronous copy function. It will first launch the
/// copy on the provided stream and then synchronizes the stream so that the copy is complete
/// before returning it. This eliminates the need for the caller to deal with unsafe code and
/// explicit synchronization.
///
/// # Panics
///
/// Panics will occur if the length of `source.as_ref()` does not match this slice’s length.
///
/// # So Example will be
///
/// ```rust
/// # let _context = cust::quick_init().unwrap();
/// use cust::{memory::*, stream::Stream, stream::StreamFlags};
/// let stream = Stream::new(StreamFlags::NON_BLOCKING, None).unwrap();
/// let start = [0u64, 1, 2, 3, 4, 5];
/// let mut buf = DeviceBuffer::from_slice(&[0u64; 6]).unwrap();
/// // Now this will safely perform the async copy and then wait for it to complete:
/// buf.as_slice().copy_from_async_sync(&start, &stream).unwrap();
/// let mut host = [0u64; 6];
/// buf.copy_to(&mut host).unwrap();
/// assert_eq!(start, host);
/// ```
pub fn copy_from_async_sync<I>(&mut self, source: &I, stream: &Stream) -> crate::error::CudaResult<()>
where
I: AsRef<[T]> + AsMut<[T]> + ?Sized,
{
//I Call the unsafe async copy and then synchronize the stream.
unsafe {
<Self as crate::memory::AsyncCopyDestination<I>>::async_copy_from(self, source, stream)?
};
stream.synchronize()?;
Ok(())
}

/// This Asynchronously copies data from this slice into `dest` and then synchronizes the stream.
///
/// This is a safe wrapper around the unsafe asynchronous copy function i propose. this launches the copy on
/// the provided stream and then synchronizes, ensuring that the copy is complete before returning without causing error or panics.
///
/// # Panics
///
/// Thus Panics if the length of `dest.as_mut()` does not match this slice’s length.
///
/// # Example
///
/// ```rust
/// # let _context = cust::quick_init().unwrap();
/// use cust::{memory::*, stream::Stream, stream::StreamFlags};
/// let stream = Stream::new(StreamFlags::NON_BLOCKING, None).unwrap();
/// let start = [0u64, 1, 2, 3, 4, 5];
/// let buf = DeviceBuffer::from_slice(&start).unwrap();
/// let mut host = [0u64; 6];
/// buf.as_slice().copy_to_async_sync(&mut host, &stream).unwrap();
/// assert_eq!(start, host);
/// ```
pub fn copy_to_async_sync<I>(&self, dest: &mut I, stream: &Stream) -> crate::error::CudaResult<()>
where
I: AsRef<[T]> + AsMut<[T]> + ?Sized,
{
unsafe {
<Self as crate::memory::AsyncCopyDestination<I>>::async_copy_to(self, dest, stream)?
};
stream.synchronize()?;
Ok(())
}
}



impl<T: DeviceCopy> crate::private::Sealed for DeviceSlice<T> {}
impl<T: DeviceCopy, I: AsRef<[T]> + AsMut<[T]> + ?Sized> CopyDestination<I> for DeviceSlice<T> {
fn copy_from(&mut self, val: &I) -> CudaResult<()> {
Expand Down
Loading