Skip to content

Commit 041e43b

Browse files
committed
Implement fallback to smaller vector size for swizzle_dyn
1 parent f6519c5 commit 041e43b

File tree

1 file changed

+139
-12
lines changed

1 file changed

+139
-12
lines changed

crates/core_simd/src/swizzle_dyn.rs

+139-12
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ where
1515
/// A planned compiler improvement will enable using `#[target_feature]` instead.
1616
#[inline]
1717
pub fn swizzle_dyn(self, idxs: Simd<u8, N>) -> Self {
18-
#![allow(unused_imports, unused_unsafe)]
18+
#![allow(unused_imports, unused_unsafe, unreachable_patterns)]
1919
#[cfg(all(
2020
any(target_arch = "aarch64", target_arch = "arm64ec"),
2121
target_endian = "little"
@@ -57,8 +57,6 @@ where
5757
target_endian = "little"
5858
))]
5959
16 => transize(vqtbl1q_u8, self, idxs),
60-
#[cfg(all(target_feature = "avx2", not(target_feature = "avx512vbmi")))]
61-
32 => transize(avx2_pshufb, self, idxs),
6260
#[cfg(all(target_feature = "avx512vl", target_feature = "avx512vbmi"))]
6361
32 => {
6462
// Unlike vpshufb, vpermb doesn't zero out values in the result based on the index high bit
@@ -71,6 +69,8 @@ where
7169
};
7270
transize(swizzler, self, idxs)
7371
}
72+
#[cfg(all(target_feature = "avx2", not(target_feature = "avx512vbmi")))]
73+
32 => transize(avx2_pshufb, self, idxs),
7474
// Notable absence: avx512bw pshufb shuffle
7575
#[cfg(all(target_feature = "avx512vl", target_feature = "avx512vbmi"))]
7676
64 => {
@@ -84,20 +84,147 @@ where
8484
};
8585
transize(swizzler, self, idxs)
8686
}
87-
_ => {
88-
let mut array = [0; N];
89-
for (i, k) in idxs.to_array().into_iter().enumerate() {
90-
if (k as usize) < N {
91-
array[i] = self[k as usize];
92-
};
93-
}
94-
array.into()
95-
}
87+
#[cfg(any(
88+
all(
89+
any(
90+
target_arch = "aarch64",
91+
target_arch = "arm64ec",
92+
all(target_arch = "arm", target_feature = "v7")
93+
),
94+
target_feature = "neon",
95+
target_endian = "little"
96+
),
97+
target_feature = "ssse3",
98+
target_feature = "simd128"
99+
))]
100+
_ => dispatch_compat(self, idxs),
101+
_ => swizzle_dyn_scalar(self, idxs),
96102
}
97103
}
98104
}
99105
}
100106

107+
#[inline(always)]
108+
fn swizzle_dyn_scalar<const N: usize>(bytes: Simd<u8, N>, idxs: Simd<u8, N>) -> Simd<u8, N>
109+
where
110+
LaneCount<N>: SupportedLaneCount,
111+
{
112+
let mut array = [0; N];
113+
for (i, k) in idxs.to_array().into_iter().enumerate() {
114+
if (k as usize) < N {
115+
array[i] = bytes[k as usize];
116+
};
117+
}
118+
array.into()
119+
}
120+
121+
/// Dispatch to swizzle_dyn_compat and swizzle_dyn_zext according to N.
122+
/// Should only be called if there exists some power-of-two size for which
123+
/// the target architecture has a vectorized swizzle_dyn (e.g. pshufb, vqtbl).
124+
#[inline(always)]
125+
fn dispatch_compat<const N: usize>(bytes: Simd<u8, N>, idxs: Simd<u8, N>) -> Simd<u8, N>
126+
where
127+
LaneCount<N>: SupportedLaneCount,
128+
{
129+
#![allow(
130+
dead_code,
131+
unused_unsafe,
132+
unreachable_patterns,
133+
non_contiguous_range_endpoints
134+
)]
135+
136+
// SAFETY: only unsafe usage is transize, see comment on transize
137+
unsafe {
138+
match N {
139+
5..16 => swizzle_dyn_zext::<N, 16>(bytes, idxs),
140+
// only arm actually has 8-byte swizzle_dyn
141+
#[cfg(all(
142+
any(
143+
target_arch = "aarch64",
144+
target_arch = "arm64ec",
145+
all(target_arch = "arm", target_feature = "v7")
146+
),
147+
target_feature = "neon",
148+
target_endian = "little"
149+
))]
150+
16 => transize(swizzle_dyn_compat::<16, 8>, bytes, idxs),
151+
17..32 => swizzle_dyn_zext::<N, 32>(bytes, idxs),
152+
32 => transize(swizzle_dyn_compat::<32, 16>, bytes, idxs),
153+
33..64 => swizzle_dyn_zext::<N, 64>(bytes, idxs),
154+
64 => transize(swizzle_dyn_compat::<64, 32>, bytes, idxs),
155+
_ => swizzle_dyn_scalar(bytes, idxs),
156+
}
157+
}
158+
}
159+
160+
/// Implement swizzle_dyn for N by temporarily zero extending to N_EXT.
161+
#[inline(always)]
162+
#[allow(unused)]
163+
fn swizzle_dyn_zext<const N: usize, const N_EXT: usize>(
164+
bytes: Simd<u8, N>,
165+
idxs: Simd<u8, N>,
166+
) -> Simd<u8, N>
167+
where
168+
LaneCount<N>: SupportedLaneCount,
169+
LaneCount<N_EXT>: SupportedLaneCount,
170+
{
171+
assert!(N_EXT.is_power_of_two(), "N_EXT should be power of two!");
172+
assert!(N < N_EXT, "N_EXT should be larger than N");
173+
Simd::swizzle_dyn(bytes.resize::<N_EXT>(0), idxs.resize::<N_EXT>(0)).resize::<N>(0)
174+
}
175+
176+
/// "Downgrades" a swizzle_dyn op on N lanes to 4 swizzle_dyn ops on N/2 lanes.
177+
///
178+
/// This only makes sense if swizzle_dyn actually has a vectorized implementation for a lower size (N/2, N/4, N/8, etc).
179+
/// e.g. on x86, swizzle_dyn_compat for N=64 can be efficient if we have at least ssse3 for pshufb
180+
///
181+
/// If there is no vectorized implementation for a lower size,
182+
/// this runs in N*logN time and will be slower than the scalar implementation.
183+
#[inline(always)]
184+
#[allow(unused)]
185+
fn swizzle_dyn_compat<const N: usize, const HALF_N: usize>(
186+
bytes: Simd<u8, N>,
187+
idxs: Simd<u8, N>,
188+
) -> Simd<u8, N>
189+
where
190+
LaneCount<N>: SupportedLaneCount,
191+
LaneCount<HALF_N>: SupportedLaneCount,
192+
{
193+
use crate::simd::cmp::SimdPartialOrd;
194+
assert!(N.is_power_of_two(), "doesn't work for non-power-of-two N");
195+
assert!(N < u8::MAX as usize, "doesn't work for N >= 256");
196+
assert_eq!(N / 2, HALF_N, "HALF_N must equal N divided by two");
197+
198+
let mid = Simd::splat(HALF_N as u8);
199+
200+
// unset the "mid" bit from the indices, e.g. 8..15 -> 0..7, 16..31 -> 8..15,
201+
// ensuring that a half-swizzle on the higher half of `bytes` will select the correct indices
202+
// since N is a power of two, any zeroing indices will remain zeroing
203+
let idxs_trunc = idxs & !mid;
204+
205+
let idx_lo = Simd::<u8, HALF_N>::from_slice(&idxs_trunc[..HALF_N]);
206+
let idx_hi = Simd::<u8, HALF_N>::from_slice(&idxs_trunc[HALF_N..]);
207+
208+
let bytes_lo = Simd::<u8, HALF_N>::from_slice(&bytes[..HALF_N]);
209+
let bytes_hi = Simd::<u8, HALF_N>::from_slice(&bytes[HALF_N..]);
210+
211+
macro_rules! half_swizzle {
212+
($bytes:ident) => {{
213+
let lo = Simd::swizzle_dyn($bytes, idx_lo);
214+
let hi = Simd::swizzle_dyn($bytes, idx_hi);
215+
216+
let mut res = [0; N];
217+
res[..HALF_N].copy_from_slice(&lo[..]);
218+
res[HALF_N..].copy_from_slice(&hi[..]);
219+
Simd::from_array(res)
220+
}};
221+
}
222+
223+
let result_lo = half_swizzle!(bytes_lo);
224+
let result_hi = half_swizzle!(bytes_hi);
225+
idxs.simd_lt(mid).select(result_lo, result_hi)
226+
}
227+
101228
/// "vpshufb like it was meant to be" on AVX2
102229
///
103230
/// # Safety

0 commit comments

Comments
 (0)