Skip to content

Commit

Permalink
[AArch64][CostModel] Alter sdiv/srem cost where the divisor is consta…
Browse files Browse the repository at this point in the history
…nt (llvm#123552)

This patch revises the cost model for sdiv/srem and draws its inspiration from the udiv/urem patch llvm#122236

The typical codegen for the different scenarios has been mentioned as notes/comments in the code itself( this is done owing to lot of scenarios such that it would be difficult to mention them here in the patch description).
  • Loading branch information
sushgokh authored Mar 10, 2025
1 parent 58fc4b1 commit c480874
Show file tree
Hide file tree
Showing 8 changed files with 608 additions and 568 deletions.
138 changes: 105 additions & 33 deletions llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "llvm/CodeGen/BasicTTIImpl.h"
#include "llvm/CodeGen/CostTable.h"
#include "llvm/CodeGen/TargetLowering.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/IntrinsicsAArch64.h"
Expand Down Expand Up @@ -3531,23 +3532,111 @@ InstructionCost AArch64TTIImpl::getArithmeticInstrCost(
default:
return BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Op1Info,
Op2Info);
case ISD::SREM:
case ISD::SDIV:
if (Op2Info.isConstant() && Op2Info.isUniform() && Op2Info.isPowerOf2()) {
// On AArch64, scalar signed division by constants power-of-two are
// normally expanded to the sequence ADD + CMP + SELECT + SRA.
// The OperandValue properties many not be same as that of previous
// operation; conservatively assume OP_None.
InstructionCost Cost = getArithmeticInstrCost(
Instruction::Add, Ty, CostKind,
Op1Info.getNoProps(), Op2Info.getNoProps());
Cost += getArithmeticInstrCost(Instruction::Sub, Ty, CostKind,
Op1Info.getNoProps(), Op2Info.getNoProps());
Cost += getArithmeticInstrCost(
Instruction::Select, Ty, CostKind,
Op1Info.getNoProps(), Op2Info.getNoProps());
Cost += getArithmeticInstrCost(Instruction::AShr, Ty, CostKind,
Op1Info.getNoProps(), Op2Info.getNoProps());
return Cost;
/*
Notes for sdiv/srem specific costs:
1. This only considers the cases where the divisor is constant, uniform and
(pow-of-2/non-pow-of-2). Other cases are not important since they either
result in some form of (ldr + adrp), corresponding to constant vectors, or
scalarization of the division operation.
2. Constant divisors, either negative in whole or partially, don't result in
significantly different codegen as compared to positive constant divisors.
So, we don't consider negative divisors seperately.
3. If the codegen is significantly different with SVE, it has been indicated
using comments at appropriate places.
sdiv specific cases:
-----------------------------------------------------------------------
codegen | pow-of-2 | Type
-----------------------------------------------------------------------
add + cmp + csel + asr | Y | i64
add + cmp + csel + asr | Y | i32
-----------------------------------------------------------------------
srem specific cases:
-----------------------------------------------------------------------
codegen | pow-of-2 | Type
-----------------------------------------------------------------------
negs + and + and + csneg | Y | i64
negs + and + and + csneg | Y | i32
-----------------------------------------------------------------------
other sdiv/srem cases:
-------------------------------------------------------------------------
commom codegen | + srem | + sdiv | pow-of-2 | Type
-------------------------------------------------------------------------
smulh + asr + add + add | - | - | N | i64
smull + lsr + add + add | - | - | N | i32
usra | and + sub | sshr | Y | <2 x i64>
2 * (scalar code) | - | - | N | <2 x i64>
usra | bic + sub | sshr + neg | Y | <4 x i32>
smull2 + smull + uzp2 | mls | - | N | <4 x i32>
+ sshr + usra | | | |
-------------------------------------------------------------------------
*/
if (Op2Info.isConstant() && Op2Info.isUniform()) {
InstructionCost AddCost =
getArithmeticInstrCost(Instruction::Add, Ty, CostKind,
Op1Info.getNoProps(), Op2Info.getNoProps());
InstructionCost AsrCost =
getArithmeticInstrCost(Instruction::AShr, Ty, CostKind,
Op1Info.getNoProps(), Op2Info.getNoProps());
InstructionCost MulCost =
getArithmeticInstrCost(Instruction::Mul, Ty, CostKind,
Op1Info.getNoProps(), Op2Info.getNoProps());
// add/cmp/csel/csneg should have similar cost while asr/negs/and should
// have similar cost.
auto VT = TLI->getValueType(DL, Ty);
if (LT.second.isScalarInteger() && VT.getSizeInBits() <= 64) {
if (Op2Info.isPowerOf2()) {
return ISD == ISD::SDIV ? (3 * AddCost + AsrCost)
: (3 * AsrCost + AddCost);
} else {
return MulCost + AsrCost + 2 * AddCost;
}
} else if (VT.isVector()) {
InstructionCost UsraCost = 2 * AsrCost;
if (Op2Info.isPowerOf2()) {
// Division with scalable types corresponds to native 'asrd'
// instruction when SVE is available.
// e.g. %1 = sdiv <vscale x 4 x i32> %a, splat (i32 8)
if (Ty->isScalableTy() && ST->hasSVE())
return 2 * AsrCost;
return UsraCost +
(ISD == ISD::SDIV
? (LT.second.getScalarType() == MVT::i64 ? 1 : 2) *
AsrCost
: 2 * AddCost);
} else if (LT.second == MVT::v2i64) {
return VT.getVectorNumElements() *
getArithmeticInstrCost(Opcode, Ty->getScalarType(), CostKind,
Op1Info.getNoProps(),
Op2Info.getNoProps());
} else {
// When SVE is available, we get:
// smulh + lsr + add/sub + asr + add/sub.
if (Ty->isScalableTy() && ST->hasSVE())
return MulCost /*smulh cost*/ + 2 * AddCost + 2 * AsrCost;
return 2 * MulCost + AddCost /*uzp2 cost*/ + AsrCost + UsraCost;
}
}
}
if (Op2Info.isConstant() && !Op2Info.isUniform() &&
LT.second.isFixedLengthVector()) {
// FIXME: When the constant vector is non-uniform, this may result in
// loading the vector from constant pool or in some cases, may also result
// in scalarization. For now, we are approximating this with the
// scalarization cost.
auto ExtractCost = 2 * getVectorInstrCost(Instruction::ExtractElement, Ty,
CostKind, -1, nullptr, nullptr);
auto InsertCost = getVectorInstrCost(Instruction::InsertElement, Ty,
CostKind, -1, nullptr, nullptr);
unsigned NElts = cast<FixedVectorType>(Ty)->getNumElements();
return ExtractCost + InsertCost +
NElts * getArithmeticInstrCost(Opcode, Ty->getScalarType(),
CostKind, Op1Info.getNoProps(),
Op2Info.getNoProps());
}
[[fallthrough]];
case ISD::UDIV:
Expand Down Expand Up @@ -3587,23 +3676,6 @@ InstructionCost AArch64TTIImpl::getArithmeticInstrCost(
AddCost * 2 + ShrCost;
return DivCost + (ISD == ISD::UREM ? MulCost + AddCost : 0);
}

// TODO: Fix SDIV and SREM costs, similar to the above.
if (TLI->isOperationLegalOrCustom(ISD::MULHU, VT) &&
Op2Info.isUniform() && !VT.isScalableVector()) {
// Vector signed division by constant are expanded to the
// sequence MULHS + ADD/SUB + SRA + SRL + ADD.
InstructionCost MulCost =
getArithmeticInstrCost(Instruction::Mul, Ty, CostKind,
Op1Info.getNoProps(), Op2Info.getNoProps());
InstructionCost AddCost =
getArithmeticInstrCost(Instruction::Add, Ty, CostKind,
Op1Info.getNoProps(), Op2Info.getNoProps());
InstructionCost ShrCost =
getArithmeticInstrCost(Instruction::AShr, Ty, CostKind,
Op1Info.getNoProps(), Op2Info.getNoProps());
return MulCost * 2 + AddCost * 2 + ShrCost * 2 + 1;
}
}

// div i128's are lowered as libcalls. Pass nullptr as (u)divti3 calls are
Expand Down
Loading

0 comments on commit c480874

Please sign in to comment.