Skip to content

[Draft] Widen X86::FMIN/MAX for FP16 #143298

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 40 additions & 10 deletions llvm/lib/Target/X86/X86ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35424,10 +35424,11 @@ bool X86TargetLowering::isBinOp(unsigned Opcode) const {
switch (Opcode) {
// These are non-commutative binops.
// TODO: Add more X86ISD opcodes once we have test coverage.
case X86ISD::ANDNP:
case X86ISD::PCMPGT:
case X86ISD::FMAX:
case X86ISD::FMIN:
return Subtarget.hasVLX();
case X86ISD::ANDNP:
case X86ISD::PCMPGT:
case X86ISD::FANDN:
case X86ISD::VPSHA:
case X86ISD::VPSHL:
Expand Down Expand Up @@ -44211,6 +44212,12 @@ bool X86TargetLowering::SimplifyDemandedVectorEltsForTargetNode(
SDValue Insert =
insertSubVector(UndefVec, ExtOp, 0, TLO.DAG, DL, ExtSizeInBits);
return TLO.CombineTo(Op, Insert);
}
case X86ISD::FMAX:
case X86ISD::FMIN: {
if (VT.getVectorElementType() == MVT::f16 && !Subtarget.hasVLX())
break;
[[fallthrough]];
}
// Zero upper elements.
case X86ISD::VZEXT_MOVL:
Expand Down Expand Up @@ -44241,8 +44248,6 @@ bool X86TargetLowering::SimplifyDemandedVectorEltsForTargetNode(
case X86ISD::VSRLV:
case X86ISD::VSRAV:
// Float ops.
case X86ISD::FMAX:
case X86ISD::FMIN:
case X86ISD::FMAXC:
case X86ISD::FMINC:
case X86ISD::FRSQRT:
Expand Down Expand Up @@ -55368,25 +55373,46 @@ static SDValue combineFMinNumFMaxNum(SDNode *N, SelectionDAG &DAG,
SDLoc DL(N);
auto MinMaxOp = N->getOpcode() == ISD::FMAXNUM ? X86ISD::FMAX : X86ISD::FMIN;

auto GetNodeOrWiden = [&](SDValue Op0, SDValue Op1) {
if ((VT != MVT::v8f16 && VT != MVT::v16f16) || Subtarget.hasVLX())
return DAG.getNode(MinMaxOp, DL, VT, Op0, Op1, N->getFlags());
Op0 = widenSubVector(MVT::v32f16, Op0, /*ZeroNewElements=*/false, Subtarget,
DAG, DL);
Op1 = widenSubVector(MVT::v32f16, Op1, /*ZeroNewElements=*/false, Subtarget,
DAG, DL);
SDValue Res =
DAG.getNode(MinMaxOp, DL, MVT::v32f16, Op0, Op1, N->getFlags());
return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, Res,
DAG.getVectorIdxConstant(0, DL));
};

// If we don't have to respect NaN inputs, this is a direct translation to x86
// min/max instructions.
if (DAG.getTarget().Options.NoNaNsFPMath || N->getFlags().hasNoNaNs())
return DAG.getNode(MinMaxOp, DL, VT, Op0, Op1, N->getFlags());
return GetNodeOrWiden(Op0, Op1);

// If one of the operands is known non-NaN use the native min/max instructions
// with the non-NaN input as second operand.
if (DAG.isKnownNeverNaN(Op1))
return DAG.getNode(MinMaxOp, DL, VT, Op0, Op1, N->getFlags());
return GetNodeOrWiden(Op0, Op1);
if (DAG.isKnownNeverNaN(Op0))
return DAG.getNode(MinMaxOp, DL, VT, Op1, Op0, N->getFlags());
return GetNodeOrWiden(Op1, Op0);

// If we have to respect NaN inputs, this takes at least 3 instructions.
// Favor a library call when operating on a scalar and minimizing code size.
if (!VT.isVector() && DAG.getMachineFunction().getFunction().hasMinSize())
return SDValue();

EVT WindenVT = VT;
if ((VT == MVT::v8f16 || VT == MVT::v16f16) && !Subtarget.hasVLX()) {
WindenVT = MVT::v32f16;
Op0 = widenSubVector(MVT::v32f16, Op0, /*ZeroNewElements=*/false, Subtarget,
DAG, DL);
Op1 = widenSubVector(MVT::v32f16, Op1, /*ZeroNewElements=*/false, Subtarget,
DAG, DL);
}
EVT SetCCType = TLI.getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(),
VT);
WindenVT);

// There are 4 possibilities involving NaN inputs, and these are the required
// outputs:
Expand All @@ -55407,12 +55433,16 @@ static SDValue combineFMinNumFMaxNum(SDNode *N, SelectionDAG &DAG,
// use those instructions for fmaxnum by selecting away a NaN input.

// If either operand is NaN, the 2nd source operand (Op0) is passed through.
SDValue MinOrMax = DAG.getNode(MinMaxOp, DL, VT, Op1, Op0);
SDValue MinOrMax = DAG.getNode(MinMaxOp, DL, WindenVT, Op1, Op0);
SDValue IsOp0Nan = DAG.getSetCC(DL, SetCCType, Op0, Op0, ISD::SETUO);

// If Op0 is a NaN, select Op1. Otherwise, select the max. If both operands
// are NaN, the NaN value of Op1 is the result.
return DAG.getSelect(DL, VT, IsOp0Nan, Op1, MinOrMax);
SDValue Res = DAG.getSelect(DL, WindenVT, IsOp0Nan, Op1, MinOrMax);
if (VT != WindenVT)
Res = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, Res,
DAG.getVectorIdxConstant(0, DL));
return Res;
}

static SDValue combineX86INT_TO_FP(SDNode *N, SelectionDAG &DAG,
Expand Down
141 changes: 101 additions & 40 deletions llvm/test/CodeGen/X86/avx512fp16-fmaxnum.ll
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
; RUN: llc < %s -verify-machineinstrs --show-mc-encoding -mtriple=x86_64-unknown-unknown -mattr=+avx512fp16,avx512vl | FileCheck %s --check-prefixes=CHECK
; RUN: llc < %s -verify-machineinstrs --show-mc-encoding -mtriple=x86_64-unknown-unknown -mattr=+avx512fp16,avx512vl | FileCheck %s --check-prefixes=CHECK,HasVL
; RUN: llc < %s -verify-machineinstrs --show-mc-encoding -mtriple=x86_64-unknown-unknown -mattr=+avx512fp16 | FileCheck %s --check-prefixes=CHECK,NOVL

declare half @llvm.maxnum.f16(half, half)
declare <2 x half> @llvm.maxnum.v2f16(<2 x half>, <2 x half>)
Expand All @@ -9,61 +10,112 @@ declare <16 x half> @llvm.maxnum.v16f16(<16 x half>, <16 x half>)
declare <32 x half> @llvm.maxnum.v32f16(<32 x half>, <32 x half>)

define half @test_intrinsic_fmaxh(half %x, half %y) {
; CHECK-LABEL: test_intrinsic_fmaxh:
; CHECK: # %bb.0:
; CHECK-NEXT: vmaxsh %xmm0, %xmm1, %xmm2 # encoding: [0x62,0xf5,0x76,0x08,0x5f,0xd0]
; CHECK-NEXT: vcmpunordsh %xmm0, %xmm0, %k1 # encoding: [0x62,0xf3,0x7e,0x08,0xc2,0xc8,0x03]
; CHECK-NEXT: vmovsh %xmm1, %xmm0, %xmm2 {%k1} # encoding: [0x62,0xf5,0x7e,0x09,0x10,0xd1]
; CHECK-NEXT: vmovaps %xmm2, %xmm0 # EVEX TO VEX Compression encoding: [0xc5,0xf8,0x28,0xc2]
; CHECK-NEXT: retq # encoding: [0xc3]
; HasVL-LABEL: test_intrinsic_fmaxh:
; HasVL: # %bb.0:
; HasVL-NEXT: vmaxsh %xmm0, %xmm1, %xmm2 # encoding: [0x62,0xf5,0x76,0x08,0x5f,0xd0]
; HasVL-NEXT: vcmpunordsh %xmm0, %xmm0, %k1 # encoding: [0x62,0xf3,0x7e,0x08,0xc2,0xc8,0x03]
; HasVL-NEXT: vmovsh %xmm1, %xmm0, %xmm2 {%k1} # encoding: [0x62,0xf5,0x7e,0x09,0x10,0xd1]
; HasVL-NEXT: vmovaps %xmm2, %xmm0 # EVEX TO VEX Compression encoding: [0xc5,0xf8,0x28,0xc2]
; HasVL-NEXT: retq # encoding: [0xc3]
;
; NOVL-LABEL: test_intrinsic_fmaxh:
; NOVL: # %bb.0:
; NOVL-NEXT: vmaxsh %xmm0, %xmm1, %xmm2 # encoding: [0x62,0xf5,0x76,0x08,0x5f,0xd0]
; NOVL-NEXT: vcmpunordsh %xmm0, %xmm0, %k1 # encoding: [0x62,0xf3,0x7e,0x08,0xc2,0xc8,0x03]
; NOVL-NEXT: vmovsh %xmm1, %xmm0, %xmm2 {%k1} # encoding: [0x62,0xf5,0x7e,0x09,0x10,0xd1]
; NOVL-NEXT: vmovaps %xmm2, %xmm0 # encoding: [0xc5,0xf8,0x28,0xc2]
; NOVL-NEXT: retq # encoding: [0xc3]
%z = call half @llvm.maxnum.f16(half %x, half %y) readnone
ret half %z
}

define <2 x half> @test_intrinsic_fmax_v2f16(<2 x half> %x, <2 x half> %y) {
; CHECK-LABEL: test_intrinsic_fmax_v2f16:
; CHECK: # %bb.0:
; CHECK-NEXT: vmaxph %xmm0, %xmm1, %xmm2 # encoding: [0x62,0xf5,0x74,0x08,0x5f,0xd0]
; CHECK-NEXT: vcmpunordph %xmm0, %xmm0, %k1 # encoding: [0x62,0xf3,0x7c,0x08,0xc2,0xc8,0x03]
; CHECK-NEXT: vmovdqu16 %xmm1, %xmm2 {%k1} # encoding: [0x62,0xf1,0xff,0x09,0x6f,0xd1]
; CHECK-NEXT: vmovdqa %xmm2, %xmm0 # EVEX TO VEX Compression encoding: [0xc5,0xf9,0x6f,0xc2]
; CHECK-NEXT: retq # encoding: [0xc3]
; HasVL-LABEL: test_intrinsic_fmax_v2f16:
; HasVL: # %bb.0:
; HasVL-NEXT: vmaxph %xmm0, %xmm1, %xmm2 # encoding: [0x62,0xf5,0x74,0x08,0x5f,0xd0]
; HasVL-NEXT: vcmpunordph %xmm0, %xmm0, %k1 # encoding: [0x62,0xf3,0x7c,0x08,0xc2,0xc8,0x03]
; HasVL-NEXT: vmovdqu16 %xmm1, %xmm2 {%k1} # encoding: [0x62,0xf1,0xff,0x09,0x6f,0xd1]
; HasVL-NEXT: vmovdqa %xmm2, %xmm0 # EVEX TO VEX Compression encoding: [0xc5,0xf9,0x6f,0xc2]
; HasVL-NEXT: retq # encoding: [0xc3]
;
; NOVL-LABEL: test_intrinsic_fmax_v2f16:
; NOVL: # %bb.0:
; NOVL-NEXT: # kill: def $xmm1 killed $xmm1 def $zmm1
; NOVL-NEXT: # kill: def $xmm0 killed $xmm0 def $zmm0
; NOVL-NEXT: vmaxph %zmm0, %zmm1, %zmm2 # encoding: [0x62,0xf5,0x74,0x48,0x5f,0xd0]
; NOVL-NEXT: vcmpunordph %zmm0, %zmm0, %k1 # encoding: [0x62,0xf3,0x7c,0x48,0xc2,0xc8,0x03]
; NOVL-NEXT: vmovdqu16 %zmm1, %zmm2 {%k1} # encoding: [0x62,0xf1,0xff,0x49,0x6f,0xd1]
; NOVL-NEXT: vmovdqa %xmm2, %xmm0 # encoding: [0xc5,0xf9,0x6f,0xc2]
; NOVL-NEXT: vzeroupper # encoding: [0xc5,0xf8,0x77]
; NOVL-NEXT: retq # encoding: [0xc3]
%z = call <2 x half> @llvm.maxnum.v2f16(<2 x half> %x, <2 x half> %y) readnone
ret <2 x half> %z
}

define <4 x half> @test_intrinsic_fmax_v4f16(<4 x half> %x, <4 x half> %y) {
; CHECK-LABEL: test_intrinsic_fmax_v4f16:
; CHECK: # %bb.0:
; CHECK-NEXT: vmaxph %xmm0, %xmm1, %xmm2 # encoding: [0x62,0xf5,0x74,0x08,0x5f,0xd0]
; CHECK-NEXT: vcmpunordph %xmm0, %xmm0, %k1 # encoding: [0x62,0xf3,0x7c,0x08,0xc2,0xc8,0x03]
; CHECK-NEXT: vmovdqu16 %xmm1, %xmm2 {%k1} # encoding: [0x62,0xf1,0xff,0x09,0x6f,0xd1]
; CHECK-NEXT: vmovdqa %xmm2, %xmm0 # EVEX TO VEX Compression encoding: [0xc5,0xf9,0x6f,0xc2]
; CHECK-NEXT: retq # encoding: [0xc3]
; HasVL-LABEL: test_intrinsic_fmax_v4f16:
; HasVL: # %bb.0:
; HasVL-NEXT: vmaxph %xmm0, %xmm1, %xmm2 # encoding: [0x62,0xf5,0x74,0x08,0x5f,0xd0]
; HasVL-NEXT: vcmpunordph %xmm0, %xmm0, %k1 # encoding: [0x62,0xf3,0x7c,0x08,0xc2,0xc8,0x03]
; HasVL-NEXT: vmovdqu16 %xmm1, %xmm2 {%k1} # encoding: [0x62,0xf1,0xff,0x09,0x6f,0xd1]
; HasVL-NEXT: vmovdqa %xmm2, %xmm0 # EVEX TO VEX Compression encoding: [0xc5,0xf9,0x6f,0xc2]
; HasVL-NEXT: retq # encoding: [0xc3]
;
; NOVL-LABEL: test_intrinsic_fmax_v4f16:
; NOVL: # %bb.0:
; NOVL-NEXT: # kill: def $xmm1 killed $xmm1 def $zmm1
; NOVL-NEXT: # kill: def $xmm0 killed $xmm0 def $zmm0
; NOVL-NEXT: vmaxph %zmm0, %zmm1, %zmm2 # encoding: [0x62,0xf5,0x74,0x48,0x5f,0xd0]
; NOVL-NEXT: vcmpunordph %zmm0, %zmm0, %k1 # encoding: [0x62,0xf3,0x7c,0x48,0xc2,0xc8,0x03]
; NOVL-NEXT: vmovdqu16 %zmm1, %zmm2 {%k1} # encoding: [0x62,0xf1,0xff,0x49,0x6f,0xd1]
; NOVL-NEXT: vmovdqa %xmm2, %xmm0 # encoding: [0xc5,0xf9,0x6f,0xc2]
; NOVL-NEXT: vzeroupper # encoding: [0xc5,0xf8,0x77]
; NOVL-NEXT: retq # encoding: [0xc3]
%z = call <4 x half> @llvm.maxnum.v4f16(<4 x half> %x, <4 x half> %y) readnone
ret <4 x half> %z
}

define <8 x half> @test_intrinsic_fmax_v8f16(<8 x half> %x, <8 x half> %y) {
; CHECK-LABEL: test_intrinsic_fmax_v8f16:
; CHECK: # %bb.0:
; CHECK-NEXT: vmaxph %xmm0, %xmm1, %xmm2 # encoding: [0x62,0xf5,0x74,0x08,0x5f,0xd0]
; CHECK-NEXT: vcmpunordph %xmm0, %xmm0, %k1 # encoding: [0x62,0xf3,0x7c,0x08,0xc2,0xc8,0x03]
; CHECK-NEXT: vmovdqu16 %xmm1, %xmm2 {%k1} # encoding: [0x62,0xf1,0xff,0x09,0x6f,0xd1]
; CHECK-NEXT: vmovdqa %xmm2, %xmm0 # EVEX TO VEX Compression encoding: [0xc5,0xf9,0x6f,0xc2]
; CHECK-NEXT: retq # encoding: [0xc3]
; HasVL-LABEL: test_intrinsic_fmax_v8f16:
; HasVL: # %bb.0:
; HasVL-NEXT: vmaxph %xmm0, %xmm1, %xmm2 # encoding: [0x62,0xf5,0x74,0x08,0x5f,0xd0]
; HasVL-NEXT: vcmpunordph %xmm0, %xmm0, %k1 # encoding: [0x62,0xf3,0x7c,0x08,0xc2,0xc8,0x03]
; HasVL-NEXT: vmovdqu16 %xmm1, %xmm2 {%k1} # encoding: [0x62,0xf1,0xff,0x09,0x6f,0xd1]
; HasVL-NEXT: vmovdqa %xmm2, %xmm0 # EVEX TO VEX Compression encoding: [0xc5,0xf9,0x6f,0xc2]
; HasVL-NEXT: retq # encoding: [0xc3]
;
; NOVL-LABEL: test_intrinsic_fmax_v8f16:
; NOVL: # %bb.0:
; NOVL-NEXT: # kill: def $xmm1 killed $xmm1 def $zmm1
; NOVL-NEXT: # kill: def $xmm0 killed $xmm0 def $zmm0
; NOVL-NEXT: vmaxph %zmm0, %zmm1, %zmm2 # encoding: [0x62,0xf5,0x74,0x48,0x5f,0xd0]
; NOVL-NEXT: vcmpunordph %zmm0, %zmm0, %k1 # encoding: [0x62,0xf3,0x7c,0x48,0xc2,0xc8,0x03]
; NOVL-NEXT: vmovdqu16 %zmm1, %zmm2 {%k1} # encoding: [0x62,0xf1,0xff,0x49,0x6f,0xd1]
; NOVL-NEXT: vmovdqa %xmm2, %xmm0 # encoding: [0xc5,0xf9,0x6f,0xc2]
; NOVL-NEXT: vzeroupper # encoding: [0xc5,0xf8,0x77]
; NOVL-NEXT: retq # encoding: [0xc3]
%z = call <8 x half> @llvm.maxnum.v8f16(<8 x half> %x, <8 x half> %y) readnone
ret <8 x half> %z
}

define <16 x half> @test_intrinsic_fmax_v16f16(<16 x half> %x, <16 x half> %y) {
; CHECK-LABEL: test_intrinsic_fmax_v16f16:
; CHECK: # %bb.0:
; CHECK-NEXT: vmaxph %ymm0, %ymm1, %ymm2 # encoding: [0x62,0xf5,0x74,0x28,0x5f,0xd0]
; CHECK-NEXT: vcmpunordph %ymm0, %ymm0, %k1 # encoding: [0x62,0xf3,0x7c,0x28,0xc2,0xc8,0x03]
; CHECK-NEXT: vmovdqu16 %ymm1, %ymm2 {%k1} # encoding: [0x62,0xf1,0xff,0x29,0x6f,0xd1]
; CHECK-NEXT: vmovdqa %ymm2, %ymm0 # EVEX TO VEX Compression encoding: [0xc5,0xfd,0x6f,0xc2]
; CHECK-NEXT: retq # encoding: [0xc3]
; HasVL-LABEL: test_intrinsic_fmax_v16f16:
; HasVL: # %bb.0:
; HasVL-NEXT: vmaxph %ymm0, %ymm1, %ymm2 # encoding: [0x62,0xf5,0x74,0x28,0x5f,0xd0]
; HasVL-NEXT: vcmpunordph %ymm0, %ymm0, %k1 # encoding: [0x62,0xf3,0x7c,0x28,0xc2,0xc8,0x03]
; HasVL-NEXT: vmovdqu16 %ymm1, %ymm2 {%k1} # encoding: [0x62,0xf1,0xff,0x29,0x6f,0xd1]
; HasVL-NEXT: vmovdqa %ymm2, %ymm0 # EVEX TO VEX Compression encoding: [0xc5,0xfd,0x6f,0xc2]
; HasVL-NEXT: retq # encoding: [0xc3]
;
; NOVL-LABEL: test_intrinsic_fmax_v16f16:
; NOVL: # %bb.0:
; NOVL-NEXT: # kill: def $ymm1 killed $ymm1 def $zmm1
; NOVL-NEXT: # kill: def $ymm0 killed $ymm0 def $zmm0
; NOVL-NEXT: vmaxph %zmm0, %zmm1, %zmm2 # encoding: [0x62,0xf5,0x74,0x48,0x5f,0xd0]
; NOVL-NEXT: vcmpunordph %zmm0, %zmm0, %k1 # encoding: [0x62,0xf3,0x7c,0x48,0xc2,0xc8,0x03]
; NOVL-NEXT: vmovdqu16 %zmm1, %zmm2 {%k1} # encoding: [0x62,0xf1,0xff,0x49,0x6f,0xd1]
; NOVL-NEXT: vmovdqa %ymm2, %ymm0 # encoding: [0xc5,0xfd,0x6f,0xc2]
; NOVL-NEXT: retq # encoding: [0xc3]
%z = call <16 x half> @llvm.maxnum.v16f16(<16 x half> %x, <16 x half> %y) readnone
ret <16 x half> %z
}
Expand All @@ -81,10 +133,19 @@ define <32 x half> @test_intrinsic_fmax_v32f16(<32 x half> %x, <32 x half> %y) {
}

define <4 x half> @maxnum_intrinsic_nnan_fmf_f432(<4 x half> %a, <4 x half> %b) {
; CHECK-LABEL: maxnum_intrinsic_nnan_fmf_f432:
; CHECK: # %bb.0:
; CHECK-NEXT: vmaxph %xmm1, %xmm0, %xmm0 # encoding: [0x62,0xf5,0x7c,0x08,0x5f,0xc1]
; CHECK-NEXT: retq # encoding: [0xc3]
; HasVL-LABEL: maxnum_intrinsic_nnan_fmf_f432:
; HasVL: # %bb.0:
; HasVL-NEXT: vmaxph %xmm1, %xmm0, %xmm0 # encoding: [0x62,0xf5,0x7c,0x08,0x5f,0xc1]
; HasVL-NEXT: retq # encoding: [0xc3]
;
; NOVL-LABEL: maxnum_intrinsic_nnan_fmf_f432:
; NOVL: # %bb.0:
; NOVL-NEXT: # kill: def $xmm1 killed $xmm1 def $zmm1
; NOVL-NEXT: # kill: def $xmm0 killed $xmm0 def $zmm0
; NOVL-NEXT: vmaxph %zmm1, %zmm0, %zmm0 # encoding: [0x62,0xf5,0x7c,0x48,0x5f,0xc1]
; NOVL-NEXT: # kill: def $xmm0 killed $xmm0 killed $zmm0
; NOVL-NEXT: vzeroupper # encoding: [0xc5,0xf8,0x77]
; NOVL-NEXT: retq # encoding: [0xc3]
%r = tail call nnan <4 x half> @llvm.maxnum.v4f16(<4 x half> %a, <4 x half> %b)
ret <4 x half> %r
}
Expand Down
Loading
Loading