Skip to content

Commit 4f0ba1a

Browse files
committed
Add support for masked loads & stores
1 parent 5794c83 commit 4f0ba1a

File tree

3 files changed

+285
-0
lines changed

3 files changed

+285
-0
lines changed

crates/core_simd/src/masks.rs

+6
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ mod sealed {
3434
fn eq(self, other: Self) -> bool;
3535

3636
fn to_usize(self) -> usize;
37+
fn max_unsigned() -> u64;
3738

3839
type Unsigned: SimdElement;
3940

@@ -78,6 +79,11 @@ macro_rules! impl_element {
7879
self as usize
7980
}
8081

82+
#[inline]
83+
fn max_unsigned() -> u64 {
84+
<$unsigned>::MAX as u64
85+
}
86+
8187
type Unsigned = $unsigned;
8288

8389
const TRUE: Self = -1;

crates/core_simd/src/vector.rs

+244
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use crate::simd::{
22
cmp::SimdPartialOrd,
3+
num::SimdUint,
34
ptr::{SimdConstPtr, SimdMutPtr},
45
LaneCount, Mask, MaskElement, SupportedLaneCount, Swizzle,
56
};
@@ -261,6 +262,7 @@ where
261262
/// # Panics
262263
///
263264
/// Panics if the slice's length is less than the vector's `Simd::N`.
265+
/// Use `load_or_default` for an alternative that does not panic.
264266
///
265267
/// # Example
266268
///
@@ -314,6 +316,143 @@ where
314316
unsafe { self.store(slice.as_mut_ptr().cast()) }
315317
}
316318

319+
/// Reads contiguous elements from `slice`. Elements are read so long as they're in-bounds for
320+
/// the `slice`. Otherwise, the default value for the element type is returned.
321+
///
322+
/// # Examples
323+
/// ```
324+
/// # #![feature(portable_simd)]
325+
/// # #[cfg(feature = "as_crate")] use core_simd::simd;
326+
/// # #[cfg(not(feature = "as_crate"))] use core::simd;
327+
/// # use simd::{Simd, Mask};
328+
/// let vec: Vec<i32> = vec![10, 11];
329+
///
330+
/// let result = Simd::<i32, 4>::load_or_default(&vec);
331+
/// assert_eq!(result, Simd::from_array([10, 11, 0, 0]));
332+
/// ```
333+
#[must_use]
334+
#[inline]
335+
pub fn load_or_default(slice: &[T]) -> Self
336+
where
337+
T: Default,
338+
{
339+
Self::load_or(slice, Default::default())
340+
}
341+
342+
/// Reads contiguous elements from `slice`. Elements are read so long as they're in-bounds for
343+
/// the `slice`. Otherwise, the corresponding value from `or` is passed through.
344+
///
345+
/// # Examples
346+
/// ```
347+
/// # #![feature(portable_simd)]
348+
/// # #[cfg(feature = "as_crate")] use core_simd::simd;
349+
/// # #[cfg(not(feature = "as_crate"))] use core::simd;
350+
/// # use simd::{Simd, Mask};
351+
/// let vec: Vec<i32> = vec![10, 11];
352+
/// let or = Simd::from_array([-5, -4, -3, -2]);
353+
///
354+
/// let result = Simd::load_or(&vec, or);
355+
/// assert_eq!(result, Simd::from_array([10, 11, -3, -2]));
356+
/// ```
357+
#[must_use]
358+
#[inline]
359+
pub fn load_or(slice: &[T], or: Self) -> Self {
360+
Self::load_select(slice, Mask::splat(true), or)
361+
}
362+
363+
/// Reads contiguous elements from `slice`. Each element is read from memory if its
364+
/// corresponding element in `enable` is `true`.
365+
///
366+
/// When the element is disabled or out of bounds for the slice, that memory location
367+
/// is not accessed and the corresponding value from `or` is passed through.
368+
///
369+
/// # Examples
370+
/// ```
371+
/// # #![feature(portable_simd)]
372+
/// # #[cfg(feature = "as_crate")] use core_simd::simd;
373+
/// # #[cfg(not(feature = "as_crate"))] use core::simd;
374+
/// # use simd::{Simd, Mask};
375+
/// let vec: Vec<i32> = vec![10, 11, 12, 13, 14, 15, 16, 17, 18];
376+
/// let enable = Mask::from_array([true, true, false, true]);
377+
/// let or = Simd::from_array([-5, -4, -3, -2]);
378+
///
379+
/// let result = Simd::load_select(&vec, enable, or);
380+
/// assert_eq!(result, Simd::from_array([10, 11, -3, 13]));
381+
/// ```
382+
#[must_use]
383+
#[inline]
384+
pub fn load_select_or_default(slice: &[T], enable: Mask<<T as SimdElement>::Mask, N>) -> Self
385+
where
386+
T: Default,
387+
{
388+
Self::load_select(slice, enable, Default::default())
389+
}
390+
391+
/// Reads contiguous elements from `slice`. Each element is read from memory if its
392+
/// corresponding element in `enable` is `true`.
393+
///
394+
/// When the element is disabled or out of bounds for the slice, that memory location
395+
/// is not accessed and the corresponding value from `or` is passed through.
396+
///
397+
/// # Examples
398+
/// ```
399+
/// # #![feature(portable_simd)]
400+
/// # #[cfg(feature = "as_crate")] use core_simd::simd;
401+
/// # #[cfg(not(feature = "as_crate"))] use core::simd;
402+
/// # use simd::{Simd, Mask};
403+
/// let vec: Vec<i32> = vec![10, 11, 12, 13, 14, 15, 16, 17, 18];
404+
/// let enable = Mask::from_array([true, true, false, true]);
405+
/// let or = Simd::from_array([-5, -4, -3, -2]);
406+
///
407+
/// let result = Simd::load_select(&vec, enable, or);
408+
/// assert_eq!(result, Simd::from_array([10, 11, -3, 13]));
409+
/// ```
410+
#[must_use]
411+
#[inline]
412+
pub fn load_select(
413+
slice: &[T],
414+
mut enable: Mask<<T as SimdElement>::Mask, N>,
415+
or: Self,
416+
) -> Self {
417+
enable &= mask_up_to(slice.len());
418+
// SAFETY: We performed the bounds check by updating the mask. &[T] is properly aligned to
419+
// the element.
420+
unsafe { Self::load_select_ptr(slice.as_ptr(), enable, or) }
421+
}
422+
423+
/// Reads contiguous elements from `slice`. Each element is read from memory if its
424+
/// corresponding element in `enable` is `true`.
425+
///
426+
/// When the element is disabled, that memory location is not accessed and the corresponding
427+
/// value from `or` is passed through.
428+
#[must_use]
429+
#[inline]
430+
pub unsafe fn load_select_unchecked(
431+
slice: &[T],
432+
enable: Mask<<T as SimdElement>::Mask, N>,
433+
or: Self,
434+
) -> Self {
435+
let ptr = slice.as_ptr();
436+
// SAFETY: The safety of reading elements from `slice` is ensured by the caller.
437+
unsafe { Self::load_select_ptr(ptr, enable, or) }
438+
}
439+
440+
/// Reads contiguous elements starting at `ptr`. Each element is read from memory if its
441+
/// corresponding element in `enable` is `true`.
442+
///
443+
/// When the element is disabled, that memory location is not accessed and the corresponding
444+
/// value from `or` is passed through.
445+
#[must_use]
446+
#[inline]
447+
pub unsafe fn load_select_ptr(
448+
ptr: *const T,
449+
enable: Mask<<T as SimdElement>::Mask, N>,
450+
or: Self,
451+
) -> Self {
452+
// SAFETY: The safety of reading elements through `ptr` is ensured by the caller.
453+
unsafe { core::intrinsics::simd::simd_masked_load(enable.to_int(), ptr, or) }
454+
}
455+
317456
/// Reads from potentially discontiguous indices in `slice` to construct a SIMD vector.
318457
/// If an index is out-of-bounds, the element is instead selected from the `or` vector.
319458
///
@@ -492,6 +631,77 @@ where
492631
unsafe { core::intrinsics::simd::simd_gather(or, source, enable.to_int()) }
493632
}
494633

634+
/// Conditionally write contiguous elements to `slice`. The `enable` mask controls
635+
/// which elements are written, as long as they're in-bounds of the `slice`.
636+
/// If the element is disabled or out of bounds, no memory access to that location
637+
/// is made.
638+
///
639+
/// # Examples
640+
/// ```
641+
/// # #![feature(portable_simd)]
642+
/// # #[cfg(feature = "as_crate")] use core_simd::simd;
643+
/// # #[cfg(not(feature = "as_crate"))] use core::simd;
644+
/// # use simd::{Simd, Mask};
645+
/// let mut arr = [0i32; 4];
646+
/// let write = Simd::from_array([-5, -4, -3, -2]);
647+
/// let enable = Mask::from_array([false, true, true, true]);
648+
///
649+
/// write.store_select(&mut arr[..3], enable);
650+
/// assert_eq!(arr, [0, -4, -3, 0]);
651+
/// ```
652+
#[inline]
653+
pub fn store_select(self, slice: &mut [T], mut enable: Mask<<T as SimdElement>::Mask, N>) {
654+
enable &= mask_up_to(slice.len());
655+
// SAFETY: We performed the bounds check by updating the mask. &[T] is properly aligned to
656+
// the element.
657+
unsafe { self.store_select_ptr(slice.as_mut_ptr(), enable) }
658+
}
659+
660+
/// Conditionally write contiguous elements to `slice`. The `enable` mask controls
661+
/// which elements are written.
662+
///
663+
/// # Safety
664+
///
665+
/// Every enabled element must be in bounds for the `slice`.
666+
///
667+
/// # Examples
668+
/// ```
669+
/// # #![feature(portable_simd)]
670+
/// # #[cfg(feature = "as_crate")] use core_simd::simd;
671+
/// # #[cfg(not(feature = "as_crate"))] use core::simd;
672+
/// # use simd::{Simd, Mask};
673+
/// let mut arr = [0i32; 4];
674+
/// let write = Simd::from_array([-5, -4, -3, -2]);
675+
/// let enable = Mask::from_array([false, true, true, true]);
676+
///
677+
/// unsafe { write.store_select_unchecked(&mut arr, enable) };
678+
/// assert_eq!(arr, [0, -4, -3, -2]);
679+
/// ```
680+
#[inline]
681+
pub unsafe fn store_select_unchecked(
682+
self,
683+
slice: &mut [T],
684+
enable: Mask<<T as SimdElement>::Mask, N>,
685+
) {
686+
let ptr = slice.as_mut_ptr();
687+
// SAFETY: The safety of writing elements in `slice` is ensured by the caller.
688+
unsafe { self.store_select_ptr(ptr, enable) }
689+
}
690+
691+
/// Conditionally write contiguous elements starting from `ptr`.
692+
/// The `enable` mask controls which elements are written.
693+
/// When disabled, the memory location corresponding to that element is not accessed.
694+
///
695+
/// # Safety
696+
///
697+
/// Memory addresses for element are calculated [`core::ptr::wrapping_offset`] and
698+
/// each enabled element must satisfy the same conditions as [`core::ptr::write`].
699+
#[inline]
700+
pub unsafe fn store_select_ptr(self, ptr: *mut T, enable: Mask<<T as SimdElement>::Mask, N>) {
701+
// SAFETY: The safety of writing elements through `ptr` is ensured by the caller.
702+
unsafe { core::intrinsics::simd::simd_masked_store(enable.to_int(), ptr, self) }
703+
}
704+
495705
/// Writes the values in a SIMD vector to potentially discontiguous indices in `slice`.
496706
/// If an index is out-of-bounds, the write is suppressed without panicking.
497707
/// If two elements in the scattered vector would write to the same index
@@ -979,3 +1189,37 @@ where
9791189
{
9801190
type Mask = isize;
9811191
}
1192+
1193+
#[inline]
1194+
fn lane_indices<const N: usize>() -> Simd<usize, N>
1195+
where
1196+
LaneCount<N>: SupportedLaneCount,
1197+
{
1198+
let mut index = [0; N];
1199+
for i in 0..N {
1200+
index[i] = i;
1201+
}
1202+
Simd::from_array(index)
1203+
}
1204+
1205+
#[inline]
1206+
fn mask_up_to<M, const N: usize>(len: usize) -> Mask<M, N>
1207+
where
1208+
LaneCount<N>: SupportedLaneCount,
1209+
M: MaskElement,
1210+
{
1211+
let index = lane_indices::<N>();
1212+
let max_value: u64 = M::max_unsigned();
1213+
macro_rules! case {
1214+
($ty:ty) => {
1215+
if N < <$ty>::MAX as usize && max_value as $ty as u64 == max_value {
1216+
return index.cast().simd_lt(Simd::splat(len.min(N) as $ty)).cast();
1217+
}
1218+
};
1219+
}
1220+
case!(u8);
1221+
case!(u16);
1222+
case!(u32);
1223+
case!(u64);
1224+
index.simd_lt(Simd::splat(len)).cast()
1225+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
#![feature(portable_simd)]
2+
use core_simd::simd::prelude::*;
3+
4+
#[cfg(target_arch = "wasm32")]
5+
use wasm_bindgen_test::*;
6+
7+
#[cfg(target_arch = "wasm32")]
8+
wasm_bindgen_test_configure!(run_in_browser);
9+
10+
#[test]
11+
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
12+
fn masked_load_store() {
13+
let mut arr = [u8::MAX; 7];
14+
15+
u8x4::splat(0).store_select(&mut arr[5..], Mask::from_array([false, true, false, true]));
16+
// write to index 8 is OOB and dropped
17+
assert_eq!(arr, [255u8, 255, 255, 255, 255, 255, 0]);
18+
19+
u8x4::from_array([0, 1, 2, 3]).store_select(&mut arr[1..], Mask::splat(true));
20+
assert_eq!(arr, [255u8, 0, 1, 2, 3, 255, 0]);
21+
22+
// read from index 8 is OOB and dropped
23+
assert_eq!(
24+
u8x4::load_or(&arr[4..], u8x4::splat(42)),
25+
u8x4::from_array([3, 255, 0, 42])
26+
);
27+
assert_eq!(
28+
u8x4::load_select(
29+
&arr[4..],
30+
Mask::from_array([true, false, true, true]),
31+
u8x4::splat(42)
32+
),
33+
u8x4::from_array([3, 42, 0, 42])
34+
);
35+
}

0 commit comments

Comments
 (0)