|
1 | 1 | //! Masks that take up full SIMD vector registers.
|
2 | 2 |
|
3 |
| -use super::{to_bitmask::ToBitMaskArray, MaskElement}; |
4 | 3 | use crate::simd::intrinsics;
|
5 |
| -use crate::simd::{LaneCount, Simd, SupportedLaneCount, ToBitMask}; |
| 4 | +use crate::simd::{LaneCount, MaskElement, Simd, SupportedLaneCount}; |
6 | 5 |
|
7 | 6 | #[repr(transparent)]
|
8 | 7 | pub struct Mask<T, const N: usize>(Simd<T, N>)
|
@@ -143,94 +142,157 @@ where
|
143 | 142 | }
|
144 | 143 |
|
145 | 144 | #[inline]
|
146 |
| - #[must_use = "method returns a new array and does not mutate the original value"] |
147 |
| - pub fn to_bitmask_array<const M: usize>(self) -> [u8; M] |
148 |
| - where |
149 |
| - super::Mask<T, N>: ToBitMaskArray, |
150 |
| - { |
| 145 | + #[must_use = "method returns a new vector and does not mutate the original value"] |
| 146 | + pub fn to_bitmask_vector(self) -> Simd<u8, N> { |
| 147 | + let mut bitmask = Simd::splat(0); |
| 148 | + |
151 | 149 | // Safety: Bytes is the right size array
|
152 | 150 | unsafe {
|
153 | 151 | // Compute the bitmask
|
154 |
| - let bitmask: <super::Mask<T, N> as ToBitMaskArray>::BitMaskArray = |
| 152 | + let mut bytes: <LaneCount<N> as SupportedLaneCount>::BitMask = |
155 | 153 | intrinsics::simd_bitmask(self.0);
|
156 | 154 |
|
157 |
| - // Transmute to the return type |
158 |
| - let mut bitmask: [u8; M] = core::mem::transmute_copy(&bitmask); |
159 |
| - |
160 | 155 | // LLVM assumes bit order should match endianness
|
161 | 156 | if cfg!(target_endian = "big") {
|
162 |
| - for x in bitmask.as_mut() { |
163 |
| - *x = x.reverse_bits(); |
| 157 | + for x in bytes.as_mut() { |
| 158 | + *x = x.reverse_bits() |
164 | 159 | }
|
165 |
| - }; |
| 160 | + } |
166 | 161 |
|
167 |
| - bitmask |
| 162 | + bitmask.as_mut_array()[..bytes.as_ref().len()].copy_from_slice(bytes.as_ref()); |
168 | 163 | }
|
| 164 | + |
| 165 | + bitmask |
169 | 166 | }
|
170 | 167 |
|
171 | 168 | #[inline]
|
172 | 169 | #[must_use = "method returns a new mask and does not mutate the original value"]
|
173 |
| - pub fn from_bitmask_array<const M: usize>(mut bitmask: [u8; M]) -> Self |
174 |
| - where |
175 |
| - super::Mask<T, N>: ToBitMaskArray, |
176 |
| - { |
| 170 | + pub fn from_bitmask_vector(bitmask: Simd<u8, N>) -> Self { |
| 171 | + let mut bytes = <LaneCount<N> as SupportedLaneCount>::BitMask::default(); |
| 172 | + |
177 | 173 | // Safety: Bytes is the right size array
|
178 | 174 | unsafe {
|
| 175 | + let len = bytes.as_ref().len(); |
| 176 | + bytes.as_mut().copy_from_slice(&bitmask.as_array()[..len]); |
| 177 | + |
179 | 178 | // LLVM assumes bit order should match endianness
|
180 | 179 | if cfg!(target_endian = "big") {
|
181 |
| - for x in bitmask.as_mut() { |
| 180 | + for x in bytes.as_mut() { |
182 | 181 | *x = x.reverse_bits();
|
183 | 182 | }
|
184 | 183 | }
|
185 | 184 |
|
186 |
| - // Transmute to the bitmask |
187 |
| - let bitmask: <super::Mask<T, N> as ToBitMaskArray>::BitMaskArray = |
188 |
| - core::mem::transmute_copy(&bitmask); |
189 |
| - |
190 | 185 | // Compute the regular mask
|
191 | 186 | Self::from_int_unchecked(intrinsics::simd_select_bitmask(
|
192 |
| - bitmask, |
| 187 | + bytes, |
193 | 188 | Self::splat(true).to_int(),
|
194 | 189 | Self::splat(false).to_int(),
|
195 | 190 | ))
|
196 | 191 | }
|
197 | 192 | }
|
198 | 193 |
|
199 | 194 | #[inline]
|
200 |
| - pub(crate) fn to_bitmask_integer<U: ReverseBits>(self) -> U |
| 195 | + unsafe fn to_bitmask_impl<U: ReverseBits, const M: usize>(self) -> U |
201 | 196 | where
|
202 |
| - super::Mask<T, N>: ToBitMask<BitMask = U>, |
| 197 | + LaneCount<M>: SupportedLaneCount, |
203 | 198 | {
|
204 |
| - // Safety: U is required to be the appropriate bitmask type |
205 |
| - let bitmask: U = unsafe { intrinsics::simd_bitmask(self.0) }; |
| 199 | + let resized = self.to_int().resize::<M>(T::FALSE); |
| 200 | + |
| 201 | + // Safety: `resized` is an integer vector with length M, which must match T |
| 202 | + let bitmask: U = unsafe { intrinsics::simd_bitmask(resized) }; |
206 | 203 |
|
207 | 204 | // LLVM assumes bit order should match endianness
|
208 | 205 | if cfg!(target_endian = "big") {
|
209 |
| - bitmask.reverse_bits(N) |
| 206 | + bitmask.reverse_bits(M) |
210 | 207 | } else {
|
211 | 208 | bitmask
|
212 | 209 | }
|
213 | 210 | }
|
214 | 211 |
|
215 | 212 | #[inline]
|
216 |
| - pub(crate) fn from_bitmask_integer<U: ReverseBits>(bitmask: U) -> Self |
| 213 | + unsafe fn from_bitmask_impl<U: ReverseBits, const M: usize>(bitmask: U) -> Self |
217 | 214 | where
|
218 |
| - super::Mask<T, N>: ToBitMask<BitMask = U>, |
| 215 | + LaneCount<M>: SupportedLaneCount, |
219 | 216 | {
|
220 | 217 | // LLVM assumes bit order should match endianness
|
221 | 218 | let bitmask = if cfg!(target_endian = "big") {
|
222 |
| - bitmask.reverse_bits(N) |
| 219 | + bitmask.reverse_bits(M) |
223 | 220 | } else {
|
224 | 221 | bitmask
|
225 | 222 | };
|
226 | 223 |
|
227 |
| - // Safety: U is required to be the appropriate bitmask type |
228 |
| - unsafe { |
229 |
| - Self::from_int_unchecked(intrinsics::simd_select_bitmask( |
| 224 | + // SAFETY: `mask` is the correct bitmask type for a u64 bitmask |
| 225 | + let mask: Simd<T, M> = unsafe { |
| 226 | + intrinsics::simd_select_bitmask( |
230 | 227 | bitmask,
|
231 |
| - Self::splat(true).to_int(), |
232 |
| - Self::splat(false).to_int(), |
233 |
| - )) |
| 228 | + Simd::<T, M>::splat(T::TRUE), |
| 229 | + Simd::<T, M>::splat(T::FALSE), |
| 230 | + ) |
| 231 | + }; |
| 232 | + |
| 233 | + // SAFETY: `mask` only contains `T::TRUE` or `T::FALSE` |
| 234 | + unsafe { Self::from_int_unchecked(mask.resize::<N>(T::FALSE)) } |
| 235 | + } |
| 236 | + |
| 237 | + #[inline] |
| 238 | + pub(crate) fn to_bitmask_integer(self) -> u64 { |
| 239 | + // TODO modify simd_bitmask to zero-extend output, making this unnecessary |
| 240 | + macro_rules! bitmask { |
| 241 | + { $($ty:ty: $($len:literal),*;)* } => { |
| 242 | + match N { |
| 243 | + $($( |
| 244 | + // Safety: bitmask matches length |
| 245 | + $len => unsafe { self.to_bitmask_impl::<$ty, $len>() as u64 }, |
| 246 | + )*)* |
| 247 | + // Safety: bitmask matches length |
| 248 | + _ => unsafe { self.to_bitmask_impl::<u64, 64>() }, |
| 249 | + } |
| 250 | + } |
| 251 | + } |
| 252 | + #[cfg(all_lane_counts)] |
| 253 | + bitmask! { |
| 254 | + u8: 1, 2, 3, 4, 5, 6, 7, 8; |
| 255 | + u16: 9, 10, 11, 12, 13, 14, 15, 16; |
| 256 | + u32: 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32; |
| 257 | + u64: 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64; |
| 258 | + } |
| 259 | + #[cfg(not(all_lane_counts))] |
| 260 | + bitmask! { |
| 261 | + u8: 1, 2, 4, 8; |
| 262 | + u16: 16; |
| 263 | + u32: 32; |
| 264 | + u64: 64; |
| 265 | + } |
| 266 | + } |
| 267 | + |
| 268 | + #[inline] |
| 269 | + pub(crate) fn from_bitmask_integer(bitmask: u64) -> Self { |
| 270 | + // TODO modify simd_bitmask_select to truncate input, making this unnecessary |
| 271 | + macro_rules! bitmask { |
| 272 | + { $($ty:ty: $($len:literal),*;)* } => { |
| 273 | + match N { |
| 274 | + $($( |
| 275 | + // Safety: bitmask matches length |
| 276 | + $len => unsafe { Self::from_bitmask_impl::<$ty, $len>(bitmask as $ty) }, |
| 277 | + )*)* |
| 278 | + // Safety: bitmask matches length |
| 279 | + _ => unsafe { Self::from_bitmask_impl::<u64, 64>(bitmask) }, |
| 280 | + } |
| 281 | + } |
| 282 | + } |
| 283 | + #[cfg(all_lane_counts)] |
| 284 | + bitmask! { |
| 285 | + u8: 1, 2, 3, 4, 5, 6, 7, 8; |
| 286 | + u16: 9, 10, 11, 12, 13, 14, 15, 16; |
| 287 | + u32: 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32; |
| 288 | + u64: 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64; |
| 289 | + } |
| 290 | + #[cfg(not(all_lane_counts))] |
| 291 | + bitmask! { |
| 292 | + u8: 1, 2, 4, 8; |
| 293 | + u16: 16; |
| 294 | + u32: 32; |
| 295 | + u64: 64; |
234 | 296 | }
|
235 | 297 | }
|
236 | 298 |
|
|
0 commit comments