Skip to content

Commit 7e5c03a

Browse files
Merge pull request #375 from rust-lang/bitmask
Simplify bitmasks
2 parents 8d9bcda + 0ad68db commit 7e5c03a

File tree

6 files changed

+202
-185
lines changed

6 files changed

+202
-185
lines changed

crates/core_simd/src/masks.rs

+39-3
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,6 @@
1212
)]
1313
mod mask_impl;
1414

15-
mod to_bitmask;
16-
pub use to_bitmask::{ToBitMask, ToBitMaskArray};
17-
1815
use crate::simd::{
1916
cmp::SimdPartialEq, intrinsics, LaneCount, Simd, SimdElement, SupportedLaneCount,
2017
};
@@ -262,6 +259,45 @@ where
262259
pub fn all(self) -> bool {
263260
self.0.all()
264261
}
262+
263+
/// Create a bitmask from a mask.
264+
///
265+
/// Each bit is set if the corresponding element in the mask is `true`.
266+
/// If the mask contains more than 64 elements, the bitmask is truncated to the first 64.
267+
#[inline]
268+
#[must_use = "method returns a new integer and does not mutate the original value"]
269+
pub fn to_bitmask(self) -> u64 {
270+
self.0.to_bitmask_integer()
271+
}
272+
273+
/// Create a mask from a bitmask.
274+
///
275+
/// For each bit, if it is set, the corresponding element in the mask is set to `true`.
276+
/// If the mask contains more than 64 elements, the remainder are set to `false`.
277+
#[inline]
278+
#[must_use = "method returns a new mask and does not mutate the original value"]
279+
pub fn from_bitmask(bitmask: u64) -> Self {
280+
Self(mask_impl::Mask::from_bitmask_integer(bitmask))
281+
}
282+
283+
/// Create a bitmask vector from a mask.
284+
///
285+
/// Each bit is set if the corresponding element in the mask is `true`.
286+
/// The remaining bits are unset.
287+
#[inline]
288+
#[must_use = "method returns a new integer and does not mutate the original value"]
289+
pub fn to_bitmask_vector(self) -> Simd<u8, N> {
290+
self.0.to_bitmask_vector()
291+
}
292+
293+
/// Create a mask from a bitmask vector.
294+
///
295+
/// For each bit, if it is set, the corresponding element in the mask is set to `true`.
296+
#[inline]
297+
#[must_use = "method returns a new mask and does not mutate the original value"]
298+
pub fn from_bitmask_vector(bitmask: Simd<u8, N>) -> Self {
299+
Self(mask_impl::Mask::from_bitmask_vector(bitmask))
300+
}
265301
}
266302

267303
// vector/array conversion

crates/core_simd/src/masks/bitmask.rs

+22-24
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#![allow(unused_imports)]
22
use super::MaskElement;
33
use crate::simd::intrinsics;
4-
use crate::simd::{LaneCount, Simd, SupportedLaneCount, ToBitMask};
4+
use crate::simd::{LaneCount, Simd, SupportedLaneCount};
55
use core::marker::PhantomData;
66

77
/// A mask where each lane is represented by a single bit.
@@ -120,39 +120,37 @@ where
120120
}
121121

122122
#[inline]
123-
#[must_use = "method returns a new array and does not mutate the original value"]
124-
pub fn to_bitmask_array<const M: usize>(self) -> [u8; M] {
125-
assert!(core::mem::size_of::<Self>() == M);
126-
127-
// Safety: converting an integer to an array of bytes of the same size is safe
128-
unsafe { core::mem::transmute_copy(&self.0) }
123+
#[must_use = "method returns a new vector and does not mutate the original value"]
124+
pub fn to_bitmask_vector(self) -> Simd<u8, N> {
125+
let mut bitmask = Simd::splat(0);
126+
bitmask.as_mut_array()[..self.0.as_ref().len()].copy_from_slice(self.0.as_ref());
127+
bitmask
129128
}
130129

131130
#[inline]
132131
#[must_use = "method returns a new mask and does not mutate the original value"]
133-
pub fn from_bitmask_array<const M: usize>(bitmask: [u8; M]) -> Self {
134-
assert!(core::mem::size_of::<Self>() == M);
135-
136-
// Safety: converting an array of bytes to an integer of the same size is safe
137-
Self(unsafe { core::mem::transmute_copy(&bitmask) }, PhantomData)
132+
pub fn from_bitmask_vector(bitmask: Simd<u8, N>) -> Self {
133+
let mut bytes = <LaneCount<N> as SupportedLaneCount>::BitMask::default();
134+
let len = bytes.as_ref().len();
135+
bytes.as_mut().copy_from_slice(&bitmask.as_array()[..len]);
136+
Self(bytes, PhantomData)
138137
}
139138

140139
#[inline]
141-
pub fn to_bitmask_integer<U>(self) -> U
142-
where
143-
super::Mask<T, N>: ToBitMask<BitMask = U>,
144-
{
145-
// Safety: these are the same types
146-
unsafe { core::mem::transmute_copy(&self.0) }
140+
pub fn to_bitmask_integer(self) -> u64 {
141+
let mut bitmask = [0u8; 8];
142+
bitmask[..self.0.as_ref().len()].copy_from_slice(self.0.as_ref());
143+
u64::from_ne_bytes(bitmask)
147144
}
148145

149146
#[inline]
150-
pub fn from_bitmask_integer<U>(bitmask: U) -> Self
151-
where
152-
super::Mask<T, N>: ToBitMask<BitMask = U>,
153-
{
154-
// Safety: these are the same types
155-
unsafe { Self(core::mem::transmute_copy(&bitmask), PhantomData) }
147+
pub fn from_bitmask_integer(bitmask: u64) -> Self {
148+
let mut bytes = <LaneCount<N> as SupportedLaneCount>::BitMask::default();
149+
let len = bytes.as_mut().len();
150+
bytes
151+
.as_mut()
152+
.copy_from_slice(&bitmask.to_ne_bytes()[..len]);
153+
Self(bytes, PhantomData)
156154
}
157155

158156
#[inline]

crates/core_simd/src/masks/full_masks.rs

+101-39
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
//! Masks that take up full SIMD vector registers.
22
3-
use super::{to_bitmask::ToBitMaskArray, MaskElement};
43
use crate::simd::intrinsics;
5-
use crate::simd::{LaneCount, Simd, SupportedLaneCount, ToBitMask};
4+
use crate::simd::{LaneCount, MaskElement, Simd, SupportedLaneCount};
65

76
#[repr(transparent)]
87
pub struct Mask<T, const N: usize>(Simd<T, N>)
@@ -143,94 +142,157 @@ where
143142
}
144143

145144
#[inline]
146-
#[must_use = "method returns a new array and does not mutate the original value"]
147-
pub fn to_bitmask_array<const M: usize>(self) -> [u8; M]
148-
where
149-
super::Mask<T, N>: ToBitMaskArray,
150-
{
145+
#[must_use = "method returns a new vector and does not mutate the original value"]
146+
pub fn to_bitmask_vector(self) -> Simd<u8, N> {
147+
let mut bitmask = Simd::splat(0);
148+
151149
// Safety: Bytes is the right size array
152150
unsafe {
153151
// Compute the bitmask
154-
let bitmask: <super::Mask<T, N> as ToBitMaskArray>::BitMaskArray =
152+
let mut bytes: <LaneCount<N> as SupportedLaneCount>::BitMask =
155153
intrinsics::simd_bitmask(self.0);
156154

157-
// Transmute to the return type
158-
let mut bitmask: [u8; M] = core::mem::transmute_copy(&bitmask);
159-
160155
// LLVM assumes bit order should match endianness
161156
if cfg!(target_endian = "big") {
162-
for x in bitmask.as_mut() {
163-
*x = x.reverse_bits();
157+
for x in bytes.as_mut() {
158+
*x = x.reverse_bits()
164159
}
165-
};
160+
}
166161

167-
bitmask
162+
bitmask.as_mut_array()[..bytes.as_ref().len()].copy_from_slice(bytes.as_ref());
168163
}
164+
165+
bitmask
169166
}
170167

171168
#[inline]
172169
#[must_use = "method returns a new mask and does not mutate the original value"]
173-
pub fn from_bitmask_array<const M: usize>(mut bitmask: [u8; M]) -> Self
174-
where
175-
super::Mask<T, N>: ToBitMaskArray,
176-
{
170+
pub fn from_bitmask_vector(bitmask: Simd<u8, N>) -> Self {
171+
let mut bytes = <LaneCount<N> as SupportedLaneCount>::BitMask::default();
172+
177173
// Safety: Bytes is the right size array
178174
unsafe {
175+
let len = bytes.as_ref().len();
176+
bytes.as_mut().copy_from_slice(&bitmask.as_array()[..len]);
177+
179178
// LLVM assumes bit order should match endianness
180179
if cfg!(target_endian = "big") {
181-
for x in bitmask.as_mut() {
180+
for x in bytes.as_mut() {
182181
*x = x.reverse_bits();
183182
}
184183
}
185184

186-
// Transmute to the bitmask
187-
let bitmask: <super::Mask<T, N> as ToBitMaskArray>::BitMaskArray =
188-
core::mem::transmute_copy(&bitmask);
189-
190185
// Compute the regular mask
191186
Self::from_int_unchecked(intrinsics::simd_select_bitmask(
192-
bitmask,
187+
bytes,
193188
Self::splat(true).to_int(),
194189
Self::splat(false).to_int(),
195190
))
196191
}
197192
}
198193

199194
#[inline]
200-
pub(crate) fn to_bitmask_integer<U: ReverseBits>(self) -> U
195+
unsafe fn to_bitmask_impl<U: ReverseBits, const M: usize>(self) -> U
201196
where
202-
super::Mask<T, N>: ToBitMask<BitMask = U>,
197+
LaneCount<M>: SupportedLaneCount,
203198
{
204-
// Safety: U is required to be the appropriate bitmask type
205-
let bitmask: U = unsafe { intrinsics::simd_bitmask(self.0) };
199+
let resized = self.to_int().resize::<M>(T::FALSE);
200+
201+
// Safety: `resized` is an integer vector with length M, which must match T
202+
let bitmask: U = unsafe { intrinsics::simd_bitmask(resized) };
206203

207204
// LLVM assumes bit order should match endianness
208205
if cfg!(target_endian = "big") {
209-
bitmask.reverse_bits(N)
206+
bitmask.reverse_bits(M)
210207
} else {
211208
bitmask
212209
}
213210
}
214211

215212
#[inline]
216-
pub(crate) fn from_bitmask_integer<U: ReverseBits>(bitmask: U) -> Self
213+
unsafe fn from_bitmask_impl<U: ReverseBits, const M: usize>(bitmask: U) -> Self
217214
where
218-
super::Mask<T, N>: ToBitMask<BitMask = U>,
215+
LaneCount<M>: SupportedLaneCount,
219216
{
220217
// LLVM assumes bit order should match endianness
221218
let bitmask = if cfg!(target_endian = "big") {
222-
bitmask.reverse_bits(N)
219+
bitmask.reverse_bits(M)
223220
} else {
224221
bitmask
225222
};
226223

227-
// Safety: U is required to be the appropriate bitmask type
228-
unsafe {
229-
Self::from_int_unchecked(intrinsics::simd_select_bitmask(
224+
// SAFETY: `mask` is the correct bitmask type for a u64 bitmask
225+
let mask: Simd<T, M> = unsafe {
226+
intrinsics::simd_select_bitmask(
230227
bitmask,
231-
Self::splat(true).to_int(),
232-
Self::splat(false).to_int(),
233-
))
228+
Simd::<T, M>::splat(T::TRUE),
229+
Simd::<T, M>::splat(T::FALSE),
230+
)
231+
};
232+
233+
// SAFETY: `mask` only contains `T::TRUE` or `T::FALSE`
234+
unsafe { Self::from_int_unchecked(mask.resize::<N>(T::FALSE)) }
235+
}
236+
237+
#[inline]
238+
pub(crate) fn to_bitmask_integer(self) -> u64 {
239+
// TODO modify simd_bitmask to zero-extend output, making this unnecessary
240+
macro_rules! bitmask {
241+
{ $($ty:ty: $($len:literal),*;)* } => {
242+
match N {
243+
$($(
244+
// Safety: bitmask matches length
245+
$len => unsafe { self.to_bitmask_impl::<$ty, $len>() as u64 },
246+
)*)*
247+
// Safety: bitmask matches length
248+
_ => unsafe { self.to_bitmask_impl::<u64, 64>() },
249+
}
250+
}
251+
}
252+
#[cfg(all_lane_counts)]
253+
bitmask! {
254+
u8: 1, 2, 3, 4, 5, 6, 7, 8;
255+
u16: 9, 10, 11, 12, 13, 14, 15, 16;
256+
u32: 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32;
257+
u64: 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64;
258+
}
259+
#[cfg(not(all_lane_counts))]
260+
bitmask! {
261+
u8: 1, 2, 4, 8;
262+
u16: 16;
263+
u32: 32;
264+
u64: 64;
265+
}
266+
}
267+
268+
#[inline]
269+
pub(crate) fn from_bitmask_integer(bitmask: u64) -> Self {
270+
// TODO modify simd_bitmask_select to truncate input, making this unnecessary
271+
macro_rules! bitmask {
272+
{ $($ty:ty: $($len:literal),*;)* } => {
273+
match N {
274+
$($(
275+
// Safety: bitmask matches length
276+
$len => unsafe { Self::from_bitmask_impl::<$ty, $len>(bitmask as $ty) },
277+
)*)*
278+
// Safety: bitmask matches length
279+
_ => unsafe { Self::from_bitmask_impl::<u64, 64>(bitmask) },
280+
}
281+
}
282+
}
283+
#[cfg(all_lane_counts)]
284+
bitmask! {
285+
u8: 1, 2, 3, 4, 5, 6, 7, 8;
286+
u16: 9, 10, 11, 12, 13, 14, 15, 16;
287+
u32: 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32;
288+
u64: 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64;
289+
}
290+
#[cfg(not(all_lane_counts))]
291+
bitmask! {
292+
u8: 1, 2, 4, 8;
293+
u16: 16;
294+
u32: 32;
295+
u64: 64;
234296
}
235297
}
236298

0 commit comments

Comments
 (0)