|
15 | 15 | /// A planned compiler improvement will enable using `#[target_feature]` instead.
|
16 | 16 | #[inline]
|
17 | 17 | pub fn swizzle_dyn(self, idxs: Simd<u8, N>) -> Self {
|
18 |
| - #![allow(unused_imports, unused_unsafe)] |
| 18 | + #![allow(unused_imports, unused_unsafe, unreachable_patterns)] |
19 | 19 | #[cfg(all(
|
20 | 20 | any(target_arch = "aarch64", target_arch = "arm64ec"),
|
21 | 21 | target_endian = "little"
|
@@ -66,20 +66,146 @@ where
|
66 | 66 | // FIXME: initial AVX512VBMI variant didn't actually pass muster
|
67 | 67 | // #[cfg(target_feature = "avx512vbmi")]
|
68 | 68 | // 64 => transize(x86::_mm512_permutexvar_epi8, self, idxs),
|
69 |
| - _ => { |
70 |
| - let mut array = [0; N]; |
71 |
| - for (i, k) in idxs.to_array().into_iter().enumerate() { |
72 |
| - if (k as usize) < N { |
73 |
| - array[i] = self[k as usize]; |
74 |
| - }; |
75 |
| - } |
76 |
| - array.into() |
77 |
| - } |
| 69 | + #[cfg(any( |
| 70 | + all( |
| 71 | + any( |
| 72 | + target_arch = "aarch64", |
| 73 | + target_arch = "arm64ec", |
| 74 | + all(target_arch = "arm", target_feature = "v7") |
| 75 | + ), |
| 76 | + target_feature = "neon", |
| 77 | + target_endian = "little" |
| 78 | + ), |
| 79 | + target_feature = "ssse3", |
| 80 | + target_feature = "simd128" |
| 81 | + ))] |
| 82 | + _ => dispatch_compat(self, idxs), |
| 83 | + _ => swizzle_dyn_scalar(self, idxs), |
78 | 84 | }
|
79 | 85 | }
|
80 | 86 | }
|
81 | 87 | }
|
82 | 88 |
|
| 89 | +#[inline(always)] |
| 90 | +fn swizzle_dyn_scalar<const N: usize>(bytes: Simd<u8, N>, idxs: Simd<u8, N>) -> Simd<u8, N> |
| 91 | +where |
| 92 | + LaneCount<N>: SupportedLaneCount, |
| 93 | +{ |
| 94 | + let mut array = [0; N]; |
| 95 | + for (i, k) in idxs.to_array().into_iter().enumerate() { |
| 96 | + if (k as usize) < N { |
| 97 | + array[i] = bytes[k as usize]; |
| 98 | + }; |
| 99 | + } |
| 100 | + array.into() |
| 101 | +} |
| 102 | + |
| 103 | +/// Dispatch two swizzle_dyn_compat and swizzle_dyn_zext according to N. |
| 104 | +/// Should only be called if the target architecture has a vectorized swizzle_dyn for some power-of-two size (e.g 8, 16). |
| 105 | +#[inline(always)] |
| 106 | +fn dispatch_compat<const N: usize>(bytes: Simd<u8, N>, idxs: Simd<u8, N>) -> Simd<u8, N> |
| 107 | +where |
| 108 | + LaneCount<N>: SupportedLaneCount, |
| 109 | +{ |
| 110 | + #![allow( |
| 111 | + dead_code, |
| 112 | + unused_unsafe, |
| 113 | + unreachable_patterns, |
| 114 | + non_contiguous_range_endpoints |
| 115 | + )] |
| 116 | + |
| 117 | + // SAFETY: only unsafe usage is transize, see comment on transize |
| 118 | + unsafe { |
| 119 | + match N { |
| 120 | + 5..16 => swizzle_dyn_zext::<N, 16>(bytes, idxs), |
| 121 | + // only arm actually has 8-byte swizzle_dyn |
| 122 | + #[cfg(all( |
| 123 | + any( |
| 124 | + target_arch = "aarch64", |
| 125 | + target_arch = "arm64ec", |
| 126 | + all(target_arch = "arm", target_feature = "v7") |
| 127 | + ), |
| 128 | + target_feature = "neon", |
| 129 | + target_endian = "little" |
| 130 | + ))] |
| 131 | + 16 => transize(swizzle_dyn_compat::<16, 8>, bytes, idxs), |
| 132 | + 17..32 => swizzle_dyn_zext::<N, 32>(bytes, idxs), |
| 133 | + 32 => transize(swizzle_dyn_compat::<32, 16>, bytes, idxs), |
| 134 | + 33..64 => swizzle_dyn_zext::<N, 64>(bytes, idxs), |
| 135 | + 64 => transize(swizzle_dyn_compat::<64, 32>, bytes, idxs), |
| 136 | + _ => swizzle_dyn_scalar(bytes, idxs), |
| 137 | + } |
| 138 | + } |
| 139 | +} |
| 140 | + |
| 141 | +/// Implement swizzle_dyn for N by temporarily zero extending to N_EXT. |
| 142 | +#[inline(always)] |
| 143 | +#[allow(unused)] |
| 144 | +fn swizzle_dyn_zext<const N: usize, const N_EXT: usize>( |
| 145 | + bytes: Simd<u8, N>, |
| 146 | + idxs: Simd<u8, N>, |
| 147 | +) -> Simd<u8, N> |
| 148 | +where |
| 149 | + LaneCount<N>: SupportedLaneCount, |
| 150 | + LaneCount<N_EXT>: SupportedLaneCount, |
| 151 | +{ |
| 152 | + assert!(N_EXT.is_power_of_two(), "N_EXT should be power of two!"); |
| 153 | + assert!(N < N_EXT, "N_EXT should be larger than N"); |
| 154 | + Simd::swizzle_dyn(bytes.resize::<N_EXT>(0), idxs.resize::<N_EXT>(0)).resize::<N>(0) |
| 155 | +} |
| 156 | + |
| 157 | +/// "Downgrades" a swizzle_dyn op on N lanes to 4 swizzle_dyn ops on N/2 lanes. |
| 158 | +/// |
| 159 | +/// This only makes sense if swizzle_dyn actually has a vectorized implementation for a lower size (N/2, N/4, N/8, etc). |
| 160 | +/// e.g. on x86, swizzle_dyn_compat for N=64 can be efficient if we have at least ssse3 for pshufb |
| 161 | +/// |
| 162 | +/// If there is no vectorized implementation for a lower size, |
| 163 | +/// this runs in N*logN time and will be slower than the scalar implementation. |
| 164 | +#[inline(always)] |
| 165 | +#[allow(unused)] |
| 166 | +fn swizzle_dyn_compat<const N: usize, const HALF_N: usize>( |
| 167 | + bytes: Simd<u8, N>, |
| 168 | + idxs: Simd<u8, N>, |
| 169 | +) -> Simd<u8, N> |
| 170 | +where |
| 171 | + LaneCount<N>: SupportedLaneCount, |
| 172 | + LaneCount<HALF_N>: SupportedLaneCount, |
| 173 | +{ |
| 174 | + use crate::simd::cmp::SimdPartialOrd; |
| 175 | + assert!(N.is_power_of_two(), "doesn't work for non-power-of-two N"); |
| 176 | + assert!(N < u8::MAX as usize, "doesn't work for N >= 256"); |
| 177 | + assert_eq!(N / 2, HALF_N, "HALF_N must equal N divided by two"); |
| 178 | + |
| 179 | + let mid = Simd::splat(HALF_N as u8); |
| 180 | + |
| 181 | + // unset the "mid" bit from the indices, e.g. 8..15 -> 0..7, 16..31 -> 8..15, |
| 182 | + // ensuring that a half-swizzle on the higher half of `bytes` will select the correct indices |
| 183 | + // since N is a power of two, any zeroing indices will remain zeroing |
| 184 | + let idxs_trunc = idxs & !mid; |
| 185 | + |
| 186 | + let idx_lo = Simd::<u8, HALF_N>::from_slice(&idxs_trunc[..HALF_N]); |
| 187 | + let idx_hi = Simd::<u8, HALF_N>::from_slice(&idxs_trunc[HALF_N..]); |
| 188 | + |
| 189 | + let bytes_lo = Simd::<u8, HALF_N>::from_slice(&bytes[..HALF_N]); |
| 190 | + let bytes_hi = Simd::<u8, HALF_N>::from_slice(&bytes[HALF_N..]); |
| 191 | + |
| 192 | + macro_rules! half_swizzle { |
| 193 | + ($bytes:ident) => {{ |
| 194 | + let lo = Simd::swizzle_dyn($bytes, idx_lo); |
| 195 | + let hi = Simd::swizzle_dyn($bytes, idx_hi); |
| 196 | + |
| 197 | + let mut res = [0; N]; |
| 198 | + res[..HALF_N].copy_from_slice(&lo[..]); |
| 199 | + res[HALF_N..].copy_from_slice(&hi[..]); |
| 200 | + Simd::from_array(res) |
| 201 | + }}; |
| 202 | + } |
| 203 | + |
| 204 | + let result_lo = half_swizzle!(bytes_lo); |
| 205 | + let result_hi = half_swizzle!(bytes_hi); |
| 206 | + idxs.simd_lt(mid).select(result_lo, result_hi) |
| 207 | +} |
| 208 | + |
83 | 209 | /// "vpshufb like it was meant to be" on AVX2
|
84 | 210 | ///
|
85 | 211 | /// # Safety
|
|
0 commit comments