Skip to content

Commit 169c541

Browse files
committed
Implement fallback to smaller vector size for swizzle_dyn
1 parent 4697d39 commit 169c541

File tree

1 file changed

+136
-10
lines changed

1 file changed

+136
-10
lines changed

crates/core_simd/src/swizzle_dyn.rs

+136-10
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ where
1515
/// A planned compiler improvement will enable using `#[target_feature]` instead.
1616
#[inline]
1717
pub fn swizzle_dyn(self, idxs: Simd<u8, N>) -> Self {
18-
#![allow(unused_imports, unused_unsafe)]
18+
#![allow(unused_imports, unused_unsafe, unreachable_patterns)]
1919
#[cfg(all(
2020
any(target_arch = "aarch64", target_arch = "arm64ec"),
2121
target_endian = "little"
@@ -66,20 +66,146 @@ where
6666
// FIXME: initial AVX512VBMI variant didn't actually pass muster
6767
// #[cfg(target_feature = "avx512vbmi")]
6868
// 64 => transize(x86::_mm512_permutexvar_epi8, self, idxs),
69-
_ => {
70-
let mut array = [0; N];
71-
for (i, k) in idxs.to_array().into_iter().enumerate() {
72-
if (k as usize) < N {
73-
array[i] = self[k as usize];
74-
};
75-
}
76-
array.into()
77-
}
69+
#[cfg(any(
70+
all(
71+
any(
72+
target_arch = "aarch64",
73+
target_arch = "arm64ec",
74+
all(target_arch = "arm", target_feature = "v7")
75+
),
76+
target_feature = "neon",
77+
target_endian = "little"
78+
),
79+
target_feature = "ssse3",
80+
target_feature = "simd128"
81+
))]
82+
_ => dispatch_compat(self, idxs),
83+
_ => swizzle_dyn_scalar(self, idxs),
7884
}
7985
}
8086
}
8187
}
8288

89+
#[inline(always)]
90+
fn swizzle_dyn_scalar<const N: usize>(bytes: Simd<u8, N>, idxs: Simd<u8, N>) -> Simd<u8, N>
91+
where
92+
LaneCount<N>: SupportedLaneCount,
93+
{
94+
let mut array = [0; N];
95+
for (i, k) in idxs.to_array().into_iter().enumerate() {
96+
if (k as usize) < N {
97+
array[i] = bytes[k as usize];
98+
};
99+
}
100+
array.into()
101+
}
102+
103+
/// Dispatch two swizzle_dyn_compat and swizzle_dyn_zext according to N.
104+
/// Should only be called if the target architecture has a vectorized swizzle_dyn for some power-of-two size (e.g 8, 16).
105+
#[inline(always)]
106+
fn dispatch_compat<const N: usize>(bytes: Simd<u8, N>, idxs: Simd<u8, N>) -> Simd<u8, N>
107+
where
108+
LaneCount<N>: SupportedLaneCount,
109+
{
110+
#![allow(
111+
dead_code,
112+
unused_unsafe,
113+
unreachable_patterns,
114+
non_contiguous_range_endpoints
115+
)]
116+
117+
// SAFETY: only unsafe usage is transize, see comment on transize
118+
unsafe {
119+
match N {
120+
5..16 => swizzle_dyn_zext::<N, 16>(bytes, idxs),
121+
// only arm actually has 8-byte swizzle_dyn
122+
#[cfg(all(
123+
any(
124+
target_arch = "aarch64",
125+
target_arch = "arm64ec",
126+
all(target_arch = "arm", target_feature = "v7")
127+
),
128+
target_feature = "neon",
129+
target_endian = "little"
130+
))]
131+
16 => transize(swizzle_dyn_compat::<16, 8>, bytes, idxs),
132+
17..32 => swizzle_dyn_zext::<N, 32>(bytes, idxs),
133+
32 => transize(swizzle_dyn_compat::<32, 16>, bytes, idxs),
134+
33..64 => swizzle_dyn_zext::<N, 64>(bytes, idxs),
135+
64 => transize(swizzle_dyn_compat::<64, 32>, bytes, idxs),
136+
_ => swizzle_dyn_scalar(bytes, idxs),
137+
}
138+
}
139+
}
140+
141+
/// Implement swizzle_dyn for N by temporarily zero extending to N_EXT.
142+
#[inline(always)]
143+
#[allow(unused)]
144+
fn swizzle_dyn_zext<const N: usize, const N_EXT: usize>(
145+
bytes: Simd<u8, N>,
146+
idxs: Simd<u8, N>,
147+
) -> Simd<u8, N>
148+
where
149+
LaneCount<N>: SupportedLaneCount,
150+
LaneCount<N_EXT>: SupportedLaneCount,
151+
{
152+
assert!(N_EXT.is_power_of_two(), "N_EXT should be power of two!");
153+
assert!(N < N_EXT, "N_EXT should be larger than N");
154+
Simd::swizzle_dyn(bytes.resize::<N_EXT>(0), idxs.resize::<N_EXT>(0)).resize::<N>(0)
155+
}
156+
157+
/// "Downgrades" a swizzle_dyn op on N lanes to 4 swizzle_dyn ops on N/2 lanes.
158+
///
159+
/// This only makes sense if swizzle_dyn actually has a vectorized implementation for a lower size (N/2, N/4, N/8, etc).
160+
/// e.g. on x86, swizzle_dyn_compat for N=64 can be efficient if we have at least ssse3 for pshufb
161+
///
162+
/// If there is no vectorized implementation for a lower size,
163+
/// this runs in N*logN time and will be slower than the scalar implementation.
164+
#[inline(always)]
165+
#[allow(unused)]
166+
fn swizzle_dyn_compat<const N: usize, const HALF_N: usize>(
167+
bytes: Simd<u8, N>,
168+
idxs: Simd<u8, N>,
169+
) -> Simd<u8, N>
170+
where
171+
LaneCount<N>: SupportedLaneCount,
172+
LaneCount<HALF_N>: SupportedLaneCount,
173+
{
174+
use crate::simd::cmp::SimdPartialOrd;
175+
assert!(N.is_power_of_two(), "doesn't work for non-power-of-two N");
176+
assert!(N < u8::MAX as usize, "doesn't work for N >= 256");
177+
assert_eq!(N / 2, HALF_N, "HALF_N must equal N divided by two");
178+
179+
let mid = Simd::splat(HALF_N as u8);
180+
181+
// unset the "mid" bit from the indices, e.g. 8..15 -> 0..7, 16..31 -> 8..15,
182+
// ensuring that a half-swizzle on the higher half of `bytes` will select the correct indices
183+
// since N is a power of two, any zeroing indices will remain zeroing
184+
let idxs_trunc = idxs & !mid;
185+
186+
let idx_lo = Simd::<u8, HALF_N>::from_slice(&idxs_trunc[..HALF_N]);
187+
let idx_hi = Simd::<u8, HALF_N>::from_slice(&idxs_trunc[HALF_N..]);
188+
189+
let bytes_lo = Simd::<u8, HALF_N>::from_slice(&bytes[..HALF_N]);
190+
let bytes_hi = Simd::<u8, HALF_N>::from_slice(&bytes[HALF_N..]);
191+
192+
macro_rules! half_swizzle {
193+
($bytes:ident) => {{
194+
let lo = Simd::swizzle_dyn($bytes, idx_lo);
195+
let hi = Simd::swizzle_dyn($bytes, idx_hi);
196+
197+
let mut res = [0; N];
198+
res[..HALF_N].copy_from_slice(&lo[..]);
199+
res[HALF_N..].copy_from_slice(&hi[..]);
200+
Simd::from_array(res)
201+
}};
202+
}
203+
204+
let result_lo = half_swizzle!(bytes_lo);
205+
let result_hi = half_swizzle!(bytes_hi);
206+
idxs.simd_lt(mid).select(result_lo, result_hi)
207+
}
208+
83209
/// "vpshufb like it was meant to be" on AVX2
84210
///
85211
/// # Safety

0 commit comments

Comments
 (0)