Skip to content

Float to integer casts on x86 produce scalar instructions #325

Open
@rnarubin

Description

@rnarubin

It seems that float-to-int casts fall back to individual cvttss2si and the like on x86_64:

#![feature(portable_simd)]
use std::simd::*;

pub fn cast(x: f32x8) -> i32x8 {
    x.cast()
}
Output with scalar casts and bounds checks
.LCPI0_0:
        .long   0x4effffff
example::cast:
        vmovss  xmm1, dword ptr [rsi]
        vmovss  xmm2, dword ptr [rsi + 4]
        vcvttss2si      r10d, xmm2
        vmovss  xmm0, dword ptr [rip + .LCPI0_0]
        vucomiss        xmm2, xmm0
        mov     r8d, 2147483647
        cmova   r10d, r8d
        xor     r9d, r9d
        vucomiss        xmm2, xmm2
        cmovp   r10d, r9d
        vcvttss2si      eax, xmm1
        vucomiss        xmm1, xmm0
        cmova   eax, r8d
        vucomiss        xmm1, xmm1
        cmovp   eax, r9d
        vmovss  xmm1, dword ptr [rsi + 8]
        vcvttss2si      ecx, xmm1
        vucomiss        xmm1, xmm0
        cmova   ecx, r8d
        vucomiss        xmm1, xmm1
        cmovp   ecx, r9d
        vmovss  xmm1, dword ptr [rsi + 12]
        vcvttss2si      edx, xmm1
        vucomiss        xmm1, xmm0
        vmovd   xmm2, eax
        cmova   edx, r8d
        vpinsrd xmm2, xmm2, r10d, 1
        vucomiss        xmm1, xmm1
        cmovp   edx, r9d
        vmovss  xmm1, dword ptr [rsi + 20]
        vcvttss2si      eax, xmm1
        vucomiss        xmm1, xmm0
        cmova   eax, r8d
        vpinsrd xmm2, xmm2, ecx, 2
        vucomiss        xmm1, xmm1
        cmovp   eax, r9d
        vmovss  xmm1, dword ptr [rsi + 16]
        vcvttss2si      ecx, xmm1
        vucomiss        xmm1, xmm0
        cmova   ecx, r8d
        vpinsrd xmm2, xmm2, edx, 3
        vucomiss        xmm1, xmm1
        cmovp   ecx, r9d
        vmovss  xmm1, dword ptr [rsi + 24]
        vcvttss2si      edx, xmm1
        vucomiss        xmm1, xmm0
        vmovd   xmm3, ecx
        cmova   edx, r8d
        vpinsrd xmm3, xmm3, eax, 1
        vucomiss        xmm1, xmm1
        cmovp   edx, r9d
        vmovss  xmm1, dword ptr [rsi + 28]
        vcvttss2si      eax, xmm1
        vucomiss        xmm1, xmm0
        cmova   eax, r8d
        vpinsrd xmm0, xmm3, edx, 2
        vucomiss        xmm1, xmm1
        cmovp   eax, r9d
        vpinsrd xmm0, xmm0, eax, 3
        mov     rax, rdi
        vmovdqa xmmword ptr [rdi + 16], xmm0
        vmovdqa xmmword ptr [rdi], xmm2
        ret

I would have expected the bounds checking and cast done with vector ops. Here's a rough sketch (might not be right around the edges!)

pub fn manual_cast(x: f32x8) -> i32x8 {
    // check bounds
    let is_nan = x.is_nan();
    let too_high = x.simd_gt(f32x8::splat(i32::MAX as f32));
    let too_low = x.simd_lt(f32x8::splat(i32::MIN as f32));

    // zero-out invalid lanes
    let x = (is_nan | too_high | too_low).select(f32x8::splat(0.0), x);
    
    // cast after bounds adjustment
    let cast = unsafe { x.to_int_unchecked() };

    // populate invalid lanes with saturated values (implicitly leave NaNs as zero)
    let cast = too_high.select(i32x8::splat(i32::MAX), cast);
    let cast = too_low.select(i32x8::splat(i32::MIN), cast);
    cast
}
Output with vector cast and bounds checks
.LCPI1_0:
        .long   0x4f000000
.LCPI1_1:
        .long   0xcf000000
.LCPI1_2:
        .long   2147483647
.LCPI1_3:
        .long   2147483648
example::manual_cast:
        mov     rax, rdi
        vmovaps ymm0, ymmword ptr [rsi]
        vxorps  xmm1, xmm1, xmm1
        vcmpunordps     ymm1, ymm0, ymm1
        vbroadcastss    ymm2, dword ptr [rip + .LCPI1_0]
        vcmpltps        ymm2, ymm2, ymm0
        vbroadcastss    ymm3, dword ptr [rip + .LCPI1_1]
        vcmpltps        ymm3, ymm0, ymm3
        vorps   ymm1, ymm1, ymm2
        vorps   ymm1, ymm1, ymm3
        vcvttps2dq      ymm0, ymm0
        vandnps ymm0, ymm1, ymm0
        vbroadcastss    ymm1, dword ptr [rip + .LCPI1_2]
        vblendvps       ymm0, ymm0, ymm1, ymm2
        vbroadcastss    ymm1, dword ptr [rip + .LCPI1_3]
        vblendvps       ymm0, ymm0, ymm1, ymm3
        vmovaps ymmword ptr [rdi], ymm0
        vzeroupper
        ret

I call out x86 because aarch64 doesn't seem to have this problem. I haven't checked other architectures

Output on aarch64 is vectorized
example::cast:
        ldp     q1, q0, [x0]
        fcvtzs  v1.4s, v1.4s
        fcvtzs  v0.4s, v0.4s
        stp     q1, q0, [x8]
        ret

Meta

rustc 1.68.0-nightly (659e169d3 2023-01-04)

-C opt-level=3 --target x86_64-unknown-linux-gnu -C target-feature=+avx2
-C opt-level=3 --target aarch64-unknown-linux-gnu -C target-feature=+neon

Metadata

Metadata

Assignees

No one assigned

    Labels

    A-LLVMArea: LLVMC-bugCategory: BugI-scalarizeImpact: code that should be vectorized, isn't

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions