Skip to content

Commit 384be9f

Browse files
Merge pull request #376 from rust-lang/find-first-set
Find first set element in a mask
2 parents 7e5c03a + 62bbb36 commit 384be9f

File tree

1 file changed

+80
-8
lines changed

1 file changed

+80
-8
lines changed

crates/core_simd/src/masks.rs

+80-8
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
mod mask_impl;
1414

1515
use crate::simd::{
16-
cmp::SimdPartialEq, intrinsics, LaneCount, Simd, SimdElement, SupportedLaneCount,
16+
cmp::SimdPartialEq, intrinsics, LaneCount, Simd, SimdCast, SimdElement, SupportedLaneCount,
1717
};
1818
use core::cmp::Ordering;
1919
use core::{fmt, mem};
@@ -35,6 +35,10 @@ mod sealed {
3535

3636
fn eq(self, other: Self) -> bool;
3737

38+
fn as_usize(self) -> usize;
39+
40+
type Unsigned: SimdElement;
41+
3842
const TRUE: Self;
3943

4044
const FALSE: Self;
@@ -46,10 +50,10 @@ use sealed::Sealed;
4650
///
4751
/// # Safety
4852
/// Type must be a signed integer.
49-
pub unsafe trait MaskElement: SimdElement + Sealed {}
53+
pub unsafe trait MaskElement: SimdElement<Mask = Self> + SimdCast + Sealed {}
5054

5155
macro_rules! impl_element {
52-
{ $ty:ty } => {
56+
{ $ty:ty, $unsigned:ty } => {
5357
impl Sealed for $ty {
5458
#[inline]
5559
fn valid<const N: usize>(value: Simd<Self, N>) -> bool
@@ -62,6 +66,13 @@ macro_rules! impl_element {
6266
#[inline]
6367
fn eq(self, other: Self) -> bool { self == other }
6468

69+
#[inline]
70+
fn as_usize(self) -> usize {
71+
self as usize
72+
}
73+
74+
type Unsigned = $unsigned;
75+
6576
const TRUE: Self = -1;
6677
const FALSE: Self = 0;
6778
}
@@ -71,11 +82,11 @@ macro_rules! impl_element {
7182
}
7283
}
7384

74-
impl_element! { i8 }
75-
impl_element! { i16 }
76-
impl_element! { i32 }
77-
impl_element! { i64 }
78-
impl_element! { isize }
85+
impl_element! { i8, u8 }
86+
impl_element! { i16, u16 }
87+
impl_element! { i32, u32 }
88+
impl_element! { i64, u64 }
89+
impl_element! { isize, usize }
7990

8091
/// A SIMD vector mask for `N` elements of width specified by `Element`.
8192
///
@@ -298,6 +309,67 @@ where
298309
pub fn from_bitmask_vector(bitmask: Simd<u8, N>) -> Self {
299310
Self(mask_impl::Mask::from_bitmask_vector(bitmask))
300311
}
312+
313+
/// Find the index of the first set element.
314+
///
315+
/// ```
316+
/// # #![feature(portable_simd)]
317+
/// # #[cfg(feature = "as_crate")] use core_simd::simd;
318+
/// # #[cfg(not(feature = "as_crate"))] use core::simd;
319+
/// # use simd::mask32x8;
320+
/// assert_eq!(mask32x8::splat(false).first_set(), None);
321+
/// assert_eq!(mask32x8::splat(true).first_set(), Some(0));
322+
///
323+
/// let mask = mask32x8::from_array([false, true, false, false, true, false, false, true]);
324+
/// assert_eq!(mask.first_set(), Some(1));
325+
/// ```
326+
#[inline]
327+
#[must_use = "method returns the index and does not mutate the original value"]
328+
pub fn first_set(self) -> Option<usize> {
329+
// If bitmasks are efficient, using them is better
330+
if cfg!(target_feature = "sse") && N <= 64 {
331+
let tz = self.to_bitmask().trailing_zeros();
332+
return if tz == 64 { None } else { Some(tz as usize) };
333+
}
334+
335+
// To find the first set index:
336+
// * create a vector 0..N
337+
// * replace unset mask elements in that vector with -1
338+
// * perform _unsigned_ reduce-min
339+
// * check if the result is -1 or an index
340+
341+
let index = Simd::from_array(
342+
const {
343+
let mut index = [0; N];
344+
let mut i = 0;
345+
while i < N {
346+
index[i] = i;
347+
i += 1;
348+
}
349+
index
350+
},
351+
);
352+
353+
// Safety: the input and output are integer vectors
354+
let index: Simd<T, N> = unsafe { intrinsics::simd_cast(index) };
355+
356+
let masked_index = self.select(index, Self::splat(true).to_int());
357+
358+
// Safety: the input and output are integer vectors
359+
let masked_index: Simd<T::Unsigned, N> = unsafe { intrinsics::simd_cast(masked_index) };
360+
361+
// Safety: the input is an integer vector
362+
let min_index: T::Unsigned = unsafe { intrinsics::simd_reduce_min(masked_index) };
363+
364+
// Safety: the return value is the unsigned version of T
365+
let min_index: T = unsafe { core::mem::transmute_copy(&min_index) };
366+
367+
if min_index.eq(T::TRUE) {
368+
None
369+
} else {
370+
Some(min_index.as_usize())
371+
}
372+
}
301373
}
302374

303375
// vector/array conversion

0 commit comments

Comments
 (0)