Skip to content

Commit 126c431

Browse files
authored
test: nullable cuDF e2e tests (#8128)
Enables nullable cuDF e2e coverage for primitives, decimals, dates, and string views. The string-view case exposed that cuDF reads Arrow validity masks as 32-bit words. To keep those reads in bounds, CUDA buffer allocations now include zeroed tail padding while preserving the visible `BufferHandle` length. --------- Signed-off-by: Alexander Droste <alexander.droste@protonmail.com>
1 parent e065c33 commit 126c431

5 files changed

Lines changed: 149 additions & 33 deletions

File tree

vortex-cuda/benches/filter_cuda.rs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ use cudarc::driver::CudaView;
2121
use cudarc::driver::DevicePtr;
2222
use cudarc::driver::DevicePtrMut;
2323
use cudarc::driver::DeviceRepr;
24+
use cudarc::driver::ValidAsZeroBits;
2425
use cudarc::driver::sys::CUevent_flags;
2526
use futures::executor::block_on;
2627
use vortex::error::VortexExpect;
@@ -135,7 +136,15 @@ async fn run_filter_timed<T: CubFilterable + DeviceRepr>(
135136
/// Benchmark filter for a specific type.
136137
fn benchmark_filter_type<T>(c: &mut Criterion, type_name: &str)
137138
where
138-
T: CubFilterable + DeviceRepr + From<u8> + Debug + Clone + Send + Sync + 'static,
139+
T: CubFilterable
140+
+ DeviceRepr
141+
+ ValidAsZeroBits
142+
+ From<u8>
143+
+ Debug
144+
+ Clone
145+
+ Send
146+
+ Sync
147+
+ 'static,
139148
{
140149
let mut group = c.benchmark_group("cuda");
141150

vortex-cuda/src/arrow/canonical.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,10 +217,12 @@ async fn export_arrow_validity_buffer(
217217
) -> VortexResult<(Option<BufferHandle>, i64)> {
218218
let mask = validity.execute_mask(len, ctx.execution_ctx())?;
219219
let null_count = i64::try_from(mask.false_count())?;
220+
let validity_bits = len + arrow_offset;
221+
let validity_bytes = validity_bits.div_ceil(8);
220222

221223
let validity_buffer = match mask {
222224
Mask::AllTrue(_) => return Ok((None, 0)),
223-
Mask::AllFalse(len) => ByteBuffer::zeroed((len + arrow_offset).div_ceil(8)),
225+
Mask::AllFalse(_) => ByteBuffer::zeroed(validity_bytes),
224226
values @ Mask::Values(_) => values.into_bit_buffer().into_inner().2,
225227
};
226228
let validity = ctx

vortex-cuda/src/executor.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ use cudarc::driver::CudaSlice;
1212
use cudarc::driver::DeviceRepr;
1313
use cudarc::driver::LaunchArgs;
1414
use cudarc::driver::LaunchConfig;
15+
use cudarc::driver::ValidAsZeroBits;
1516
use futures::future::BoxFuture;
1617
use tracing::debug;
1718
use tracing::trace;
@@ -256,7 +257,7 @@ impl CudaExecutionCtx {
256257
data: D,
257258
) -> VortexResult<BoxFuture<'static, VortexResult<BufferHandle>>>
258259
where
259-
T: DeviceRepr + Debug + Send + Sync + 'static,
260+
T: DeviceRepr + ValidAsZeroBits + Debug + Send + Sync + 'static,
260261
D: AsRef<[T]> + Send + 'static,
261262
{
262263
self.stream.copy_to_device(data)

vortex-cuda/src/stream.rs

Lines changed: 108 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,22 @@
44
//! CUDA stream utility functions.
55
66
use std::fmt::Debug;
7+
use std::mem::size_of;
8+
use std::mem::size_of_val;
79
use std::ops::Deref;
810
use std::sync::Arc;
911

1012
use cudarc::driver::CudaSlice;
1113
use cudarc::driver::CudaStream;
1214
use cudarc::driver::DeviceRepr;
15+
use cudarc::driver::ValidAsZeroBits;
1316
use cudarc::driver::result::stream;
1417
use futures::future::BoxFuture;
1518
use kanal::Sender;
1619
use tracing::warn;
1720
use vortex::array::buffer::BufferHandle;
1821
use vortex::error::VortexResult;
22+
use vortex::error::vortex_ensure;
1923
use vortex::error::vortex_err;
2024

2125
use crate::CudaDeviceBuffer;
@@ -62,22 +66,32 @@ impl VortexCudaStream {
6266
/// synchronously before returning. For **pinned** host memory the transfer
6367
/// is truly async and the source must stay alive until the copy completes
6468
/// (guaranteed by the returned future capturing it).
69+
///
70+
/// The returned [`BufferHandle`] keeps the source byte length, while its
71+
/// CUDA allocation may include zeroed tail padding. This is needed for
72+
/// Arrow validity buffers passed to cuDF, which reads masks as 32-bit words.
6573
pub(crate) fn copy_to_device<T, D>(
6674
&self,
6775
data: D,
6876
) -> VortexResult<BoxFuture<'static, VortexResult<BufferHandle>>>
6977
where
70-
T: DeviceRepr + Debug + Send + Sync + 'static,
78+
T: DeviceRepr + ValidAsZeroBits + Debug + Send + Sync + 'static,
7179
D: AsRef<[T]> + Send + 'static,
7280
{
7381
let host_slice: &[T] = data.as_ref();
82+
let byte_count = size_of_val(host_slice);
83+
let allocation_len = padded_device_allocation_len::<T>(byte_count)?;
7484
// `device_alloc` binds the CUDA context to the current thread.
75-
let mut cuda_slice: CudaSlice<T> = self.device_alloc(host_slice.len())?;
85+
let mut cuda_slice: CudaSlice<T> = self.device_alloc::<T>(allocation_len)?;
7686

77-
self.memcpy_htod(host_slice, &mut cuda_slice)
87+
let mut values = cuda_slice.slice_mut(..host_slice.len());
88+
self.memcpy_htod(host_slice, &mut values)
7889
.map_err(|e| vortex_err!("Failed to schedule H2D copy: {}", e))?;
7990

91+
zero_padding(self, &mut cuda_slice, host_slice.len())?;
92+
8093
let cuda_buf = CudaDeviceBuffer::new(cuda_slice);
94+
let buffer = BufferHandle::new_device(Arc::new(cuda_buf)).slice(0..byte_count);
8195
let stream = Arc::clone(&self.0);
8296

8397
Ok(Box::pin(async move {
@@ -86,7 +100,7 @@ impl VortexCudaStream {
86100
// Keep source memory alive until copy completes.
87101
let _keep_alive = data;
88102

89-
Ok(BufferHandle::new_device(Arc::new(cuda_buf)))
103+
Ok(buffer)
90104
}))
91105
}
92106

@@ -99,20 +113,62 @@ impl VortexCudaStream {
99113
/// For **pageable** host memory (the common case), `memcpy_htod` stages
100114
/// the source into a driver-managed pinned buffer before returning, so
101115
/// the source data is safe to drop after this call.
116+
///
117+
/// Like [`copy_to_device`](Self::copy_to_device), this preserves the source
118+
/// byte length on the returned handle while keeping any tail padding in the
119+
/// backing CUDA allocation.
102120
pub(crate) fn copy_to_device_sync<T>(&self, data: &[T]) -> VortexResult<BufferHandle>
103121
where
104-
T: DeviceRepr + Debug + Send + Sync + 'static,
122+
T: DeviceRepr + ValidAsZeroBits + Debug + Send + Sync + 'static,
105123
{
106-
let mut cuda_slice: CudaSlice<T> = self.device_alloc(data.len())?;
124+
let byte_count = size_of_val(data);
125+
let allocation_len = padded_device_allocation_len::<T>(byte_count)?;
126+
let mut cuda_slice: CudaSlice<T> = self.device_alloc(allocation_len)?;
107127

108-
self.memcpy_htod(data, &mut cuda_slice)
128+
let mut values = cuda_slice.slice_mut(..data.len());
129+
self.memcpy_htod(data, &mut values)
109130
.map_err(|e| vortex_err!("Failed to schedule H2D copy: {}", e))?;
110131

132+
zero_padding(self, &mut cuda_slice, data.len())?;
133+
111134
let cuda_buf = CudaDeviceBuffer::new(cuda_slice);
112-
Ok(BufferHandle::new_device(Arc::new(cuda_buf)))
135+
Ok(BufferHandle::new_device(Arc::new(cuda_buf)).slice(0..byte_count))
113136
}
114137
}
115138

139+
/// Returns the typed CUDA allocation length for `byte_count`.
140+
///
141+
/// The backing allocation is padded for cuDF's 32-bit validity mask reads.
142+
/// The returned length is in `T` elements.
143+
fn padded_device_allocation_len<T>(byte_count: usize) -> VortexResult<usize> {
144+
let element_size = size_of::<T>();
145+
vortex_ensure!(
146+
element_size != 0,
147+
"cannot copy zero-sized values to CUDA device"
148+
);
149+
let min_allocation_bytes = byte_count.next_multiple_of(size_of::<u32>());
150+
Ok(min_allocation_bytes.div_ceil(element_size))
151+
}
152+
153+
/// Zeroes the allocation tail after the copied values.
154+
///
155+
/// Returned handles are sliced to the copied byte count; the trailing padding
156+
/// exists so a final 32-bit mask read stays within the backing allocation.
157+
fn zero_padding<T: DeviceRepr + ValidAsZeroBits>(
158+
stream: &VortexCudaStream,
159+
cuda_slice: &mut CudaSlice<T>,
160+
copied_len: usize,
161+
) -> VortexResult<()> {
162+
if copied_len >= cuda_slice.len() {
163+
return Ok(());
164+
}
165+
166+
let mut padding = cuda_slice.slice_mut(copied_len..);
167+
stream
168+
.memset_zeros(&mut padding)
169+
.map_err(|e| vortex_err!("Failed to zero device buffer padding: {}", e))
170+
}
171+
116172
/// Registers a callback and asynchronously waits for its completion.
117173
///
118174
/// This function can be used to asynchronously wait for events previously
@@ -191,3 +247,47 @@ fn register_stream_callback(stream: &CudaStream) -> VortexResult<kanal::AsyncRec
191247

192248
Ok(rx.to_async())
193249
}
250+
251+
#[cfg(test)]
252+
mod tests {
253+
use vortex::error::VortexResult;
254+
use vortex::session::VortexSession;
255+
256+
use super::padded_device_allocation_len;
257+
use crate::CudaSession;
258+
259+
#[test]
260+
fn test_padded_device_allocation_len() -> VortexResult<()> {
261+
assert_eq!(padded_device_allocation_len::<u8>(0)?, 0);
262+
assert_eq!(padded_device_allocation_len::<u8>(1)?, 4);
263+
assert_eq!(padded_device_allocation_len::<u8>(4)?, 4);
264+
assert_eq!(padded_device_allocation_len::<u8>(5)?, 8);
265+
assert_eq!(padded_device_allocation_len::<u32>(1)?, 1);
266+
assert_eq!(padded_device_allocation_len::<u32>(5)?, 2);
267+
Ok(())
268+
}
269+
270+
#[crate::test]
271+
async fn test_copy_to_device_preserves_visible_len_with_padding() -> VortexResult<()> {
272+
let ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?;
273+
let handle = ctx.stream().copy_to_device(vec![0xab_u8])?.await?;
274+
275+
assert_eq!(handle.len(), 1);
276+
let host = handle.try_to_host()?.await?;
277+
assert_eq!(host.as_slice(), &[0xab]);
278+
279+
Ok(())
280+
}
281+
282+
#[crate::test]
283+
async fn test_copy_to_device_sync_preserves_visible_len_with_padding() -> VortexResult<()> {
284+
let ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?;
285+
let handle = ctx.stream().copy_to_device_sync(&[1_u8, 2, 3, 4, 5])?;
286+
287+
assert_eq!(handle.len(), 5);
288+
let host = handle.try_to_host()?.await?;
289+
assert_eq!(host.as_slice(), &[1, 2, 3, 4, 5]);
290+
291+
Ok(())
292+
}
293+
}

vortex-test/e2e-cuda/src/lib.rs

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -68,17 +68,21 @@ pub unsafe extern "C" fn export_array(
6868
) -> i32 {
6969
let mut ctx = CudaSession::create_execution_ctx(&SESSION).unwrap();
7070

71-
let primitive = PrimitiveArray::from_iter(0u32..5);
72-
let decimal = DecimalArray::from_iter(0i128..5, DecimalDType::new(38, 2));
73-
let strings = VarBinViewArray::from_iter_str([
74-
"one",
75-
"two",
76-
"this string is long three",
77-
"four",
78-
"this string is long five",
71+
let primitive = PrimitiveArray::from_option_iter([Some(0u32), None, Some(2), Some(3), None]);
72+
let decimal = DecimalArray::from_option_iter(
73+
[Some(0i128), Some(1), None, Some(3), Some(4)],
74+
DecimalDType::new(38, 2),
75+
);
76+
let strings = VarBinViewArray::from_iter_nullable_str([
77+
Some("one"),
78+
None,
79+
Some("this string is long three"),
80+
Some("four"),
81+
None,
7982
]);
8083
let dates = TemporalArray::new_date(
81-
PrimitiveArray::from_iter([100i32, 200, 300, 400, 500]).into_array(),
84+
PrimitiveArray::from_option_iter([Some(100i32), None, Some(300), Some(400), None])
85+
.into_array(),
8286
TimeUnit::Days,
8387
);
8488

@@ -124,24 +128,24 @@ pub unsafe extern "C" fn validate_array(
124128
let array = make_array(array_data);
125129
let struct_array = array.as_struct();
126130

127-
let primitive = UInt32Array::from_iter(0..5);
128-
let decimal = Decimal128Array::from_iter_values(0..5)
131+
let primitive = UInt32Array::from_iter([Some(0), None, Some(2), Some(3), None]);
132+
let decimal = Decimal128Array::from_iter([Some(0i128), Some(1), None, Some(3), Some(4)])
129133
.with_precision_and_scale(38, 2)
130134
.expect("with_precision_and_scale");
131-
let string = StringArray::from_iter_values([
132-
"one",
133-
"two",
134-
"this string is long three",
135-
"four",
136-
"this string is long five",
135+
let string = StringArray::from_iter([
136+
Some("one"),
137+
None,
138+
Some("this string is long three"),
139+
Some("four"),
140+
None,
137141
]);
138-
let date = Date32Array::from(vec![100i32, 200, 300, 400, 500]);
142+
let date = Date32Array::from(vec![Some(100i32), None, Some(300), Some(400), None]);
139143

140144
let expected_fields = Fields::from_iter([
141-
Field::new("prims", primitive.data_type().clone(), false),
142-
Field::new("decimals", decimal.data_type().clone(), false),
143-
Field::new("strings", string.data_type().clone(), false),
144-
Field::new("dates", date.data_type().clone(), false),
145+
Field::new("prims", primitive.data_type().clone(), true),
146+
Field::new("decimals", decimal.data_type().clone(), true),
147+
Field::new("strings", string.data_type().clone(), true),
148+
Field::new("dates", date.data_type().clone(), true),
145149
]);
146150

147151
assert_eq!(

0 commit comments

Comments
 (0)