Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AMDGPU/GlobalISel: add RegBankLegalize rules for bit shifts and sext-inreg #132385

Open
wants to merge 1 commit into
base: users/petar-avramovic/select
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion llvm/lib/Target/AMDGPU/AMDGPURegBankLegalize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ bool AMDGPURegBankLegalize::runOnMachineFunction(MachineFunction &MF) {
// Opcodes that support pretty much all combinations of reg banks and LLTs
// (except S1). There is no point in writing rules for them.
if (Opc == AMDGPU::G_BUILD_VECTOR || Opc == AMDGPU::G_UNMERGE_VALUES ||
Opc == AMDGPU::G_MERGE_VALUES) {
Opc == AMDGPU::G_MERGE_VALUES || Opc == G_BITCAST) {
RBLHelper.applyMappingTrivial(*MI);
continue;
}
Expand Down
109 changes: 109 additions & 0 deletions llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@
#include "AMDGPURegBankLegalizeHelper.h"
#include "AMDGPUGlobalISelUtils.h"
#include "AMDGPUInstrInfo.h"
#include "AMDGPURegBankLegalizeRules.h"
#include "AMDGPURegisterBankInfo.h"
#include "GCNSubtarget.h"
#include "MCTargetDesc/AMDGPUMCTargetDesc.h"
#include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
#include "llvm/CodeGen/MachineInstr.h"
#include "llvm/CodeGen/MachineUniformityAnalysis.h"
#include "llvm/IR/IntrinsicsAMDGPU.h"
#include "llvm/Support/ErrorHandling.h"
Expand Down Expand Up @@ -166,6 +168,60 @@ void RegBankLegalizeHelper::lowerVccExtToSel(MachineInstr &MI) {
return;
}

std::pair<Register, Register> RegBankLegalizeHelper::unpackZExt(Register Reg) {
auto PackedS32 = B.buildBitcast(SgprRB_S32, Reg);
auto Mask = B.buildConstant(SgprRB_S32, 0x0000ffff);
auto Lo = B.buildAnd(SgprRB_S32, PackedS32, Mask);
auto Hi = B.buildLShr(SgprRB_S32, PackedS32, B.buildConstant(SgprRB_S32, 16));
return {Lo.getReg(0), Hi.getReg(0)};
}

std::pair<Register, Register> RegBankLegalizeHelper::unpackSExt(Register Reg) {
auto PackedS32 = B.buildBitcast(SgprRB_S32, Reg);
auto Lo = B.buildSExtInReg(SgprRB_S32, PackedS32, 16);
auto Hi = B.buildAShr(SgprRB_S32, PackedS32, B.buildConstant(SgprRB_S32, 16));
return {Lo.getReg(0), Hi.getReg(0)};
}

std::pair<Register, Register> RegBankLegalizeHelper::unpackAExt(Register Reg) {
auto PackedS32 = B.buildBitcast(SgprRB_S32, Reg);
auto Lo = PackedS32;
auto Hi = B.buildLShr(SgprRB_S32, PackedS32, B.buildConstant(SgprRB_S32, 16));
return {Lo.getReg(0), Hi.getReg(0)};
}

void RegBankLegalizeHelper::lowerUnpack(MachineInstr &MI) {
Register Lo, Hi;
switch (MI.getOpcode()) {
case AMDGPU::G_SHL: {
auto [Val0, Val1] = unpackAExt(MI.getOperand(1).getReg());
auto [Amt0, Amt1] = unpackAExt(MI.getOperand(2).getReg());
Lo = B.buildInstr(MI.getOpcode(), {SgprRB_S32}, {Val0, Amt0}).getReg(0);
Hi = B.buildInstr(MI.getOpcode(), {SgprRB_S32}, {Val1, Amt1}).getReg(0);
break;
}
case AMDGPU::G_LSHR: {
auto [Val0, Val1] = unpackZExt(MI.getOperand(1).getReg());
auto [Amt0, Amt1] = unpackZExt(MI.getOperand(2).getReg());
Lo = B.buildInstr(MI.getOpcode(), {SgprRB_S32}, {Val0, Amt0}).getReg(0);
Hi = B.buildInstr(MI.getOpcode(), {SgprRB_S32}, {Val1, Amt1}).getReg(0);
break;
}
case AMDGPU::G_ASHR: {
auto [Val0, Val1] = unpackSExt(MI.getOperand(1).getReg());
auto [Amt0, Amt1] = unpackSExt(MI.getOperand(2).getReg());
Lo = B.buildAShr(SgprRB_S32, Val0, Amt0).getReg(0);
Hi = B.buildAShr(SgprRB_S32, Val1, Amt1).getReg(0);
break;
}
default:
llvm_unreachable("Unpack lowering not implemented");
}
B.buildBuildVectorTrunc(MI.getOperand(0).getReg(), {Lo, Hi});
MI.eraseFromParent();
return;
}

bool isSignedBFE(MachineInstr &MI) {
unsigned Opc =
isa<GIntrinsic>(MI) ? MI.getOperand(1).getIntrinsicID() : MI.getOpcode();
Expand Down Expand Up @@ -310,6 +366,34 @@ void RegBankLegalizeHelper::lowerSplitTo32Sel(MachineInstr &MI) {
return;
}

void RegBankLegalizeHelper::lowerSplitTo32SExtInReg(MachineInstr &MI) {
auto Op1 = B.buildUnmerge(VgprRB_S32, MI.getOperand(1).getReg());
int Amt = MI.getOperand(2).getImm();
Register Lo, Hi;
// Hi|Lo: s sign bit, ?/x bits changed/not changed by sign-extend
if (Amt <= 32) {
auto Freeze = B.buildFreeze(VgprRB_S32, Op1.getReg(0));
if (Amt == 32) {
// Hi|Lo: ????????|sxxxxxxx -> ssssssss|sxxxxxxx
Lo = Freeze.getReg(0);
} else {
// Hi|Lo: ????????|???sxxxx -> ssssssss|ssssxxxx
Lo = B.buildSExtInReg(VgprRB_S32, Freeze, Amt).getReg(0);
}

auto SignExtCst = B.buildConstant(SgprRB_S32, 31);
Hi = B.buildAShr(VgprRB_S32, Lo, SignExtCst).getReg(0);
} else {
// Hi|Lo: ?????sxx|xxxxxxxx -> ssssssxx|xxxxxxxx
Lo = Op1.getReg(0);
Hi = B.buildSExtInReg(VgprRB_S32, Op1.getReg(1), Amt - 32).getReg(0);
}

B.buildMergeLikeInstr(MI.getOperand(0).getReg(), {Lo, Hi});
MI.eraseFromParent();
return;
}

void RegBankLegalizeHelper::lower(MachineInstr &MI,
const RegBankLLTMapping &Mapping,
SmallSet<Register, 4> &WaterfallSgprs) {
Expand All @@ -332,6 +416,8 @@ void RegBankLegalizeHelper::lower(MachineInstr &MI,
MI.eraseFromParent();
return;
}
case Unpack:
return lowerUnpack(MI);
case Ext32To64: {
const RegisterBank *RB = MRI.getRegBank(MI.getOperand(0).getReg());
MachineInstrBuilder Hi;
Expand Down Expand Up @@ -398,6 +484,8 @@ void RegBankLegalizeHelper::lower(MachineInstr &MI,
return lowerSplitTo32(MI);
case SplitTo32Sel:
return lowerSplitTo32Sel(MI);
case SplitTo32SExtInReg:
return lowerSplitTo32SExtInReg(MI);
case SplitLoad: {
LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
unsigned Size = DstTy.getSizeInBits();
Expand Down Expand Up @@ -487,6 +575,13 @@ LLT RegBankLegalizeHelper::getTyFromID(RegBankLLTMappingApplyID ID) {
case SgprP5:
case VgprP5:
return LLT::pointer(5, 32);
case SgprV2S16:
case VgprV2S16:
case UniInVgprV2S16:
return LLT::fixed_vector(2, 16);
case SgprV2S32:
case VgprV2S32:
return LLT::fixed_vector(2, 32);
case SgprV4S32:
case VgprV4S32:
case UniInVgprV4S32:
Expand Down Expand Up @@ -560,6 +655,8 @@ RegBankLegalizeHelper::getRegBankFromID(RegBankLLTMappingApplyID ID) {
case SgprP3:
case SgprP4:
case SgprP5:
case SgprV2S16:
case SgprV2S32:
case SgprV4S32:
case SgprB32:
case SgprB64:
Expand All @@ -569,6 +666,7 @@ RegBankLegalizeHelper::getRegBankFromID(RegBankLLTMappingApplyID ID) {
case SgprB512:
case UniInVcc:
case UniInVgprS32:
case UniInVgprV2S16:
case UniInVgprV4S32:
case UniInVgprB32:
case UniInVgprB64:
Expand All @@ -590,6 +688,8 @@ RegBankLegalizeHelper::getRegBankFromID(RegBankLLTMappingApplyID ID) {
case VgprP3:
case VgprP4:
case VgprP5:
case VgprV2S16:
case VgprV2S32:
case VgprV4S32:
case VgprB32:
case VgprB64:
Expand Down Expand Up @@ -627,6 +727,8 @@ void RegBankLegalizeHelper::applyMappingDst(
case SgprP3:
case SgprP4:
case SgprP5:
case SgprV2S16:
case SgprV2S32:
case SgprV4S32:
case Vgpr16:
case Vgpr32:
Expand All @@ -636,6 +738,8 @@ void RegBankLegalizeHelper::applyMappingDst(
case VgprP3:
case VgprP4:
case VgprP5:
case VgprV2S16:
case VgprV2S32:
case VgprV4S32: {
assert(Ty == getTyFromID(MethodIDs[OpIdx]));
assert(RB == getRegBankFromID(MethodIDs[OpIdx]));
Expand Down Expand Up @@ -670,6 +774,7 @@ void RegBankLegalizeHelper::applyMappingDst(
break;
}
case UniInVgprS32:
case UniInVgprV2S16:
case UniInVgprV4S32: {
assert(Ty == getTyFromID(MethodIDs[OpIdx]));
assert(RB == SgprRB);
Expand Down Expand Up @@ -743,6 +848,8 @@ void RegBankLegalizeHelper::applyMappingSrc(
case SgprP3:
case SgprP4:
case SgprP5:
case SgprV2S16:
case SgprV2S32:
case SgprV4S32: {
assert(Ty == getTyFromID(MethodIDs[i]));
assert(RB == getRegBankFromID(MethodIDs[i]));
Expand All @@ -768,6 +875,8 @@ void RegBankLegalizeHelper::applyMappingSrc(
case VgprP3:
case VgprP4:
case VgprP5:
case VgprV2S16:
case VgprV2S32:
case VgprV4S32: {
assert(Ty == getTyFromID(MethodIDs[i]));
if (RB != VgprRB) {
Expand Down
5 changes: 5 additions & 0 deletions llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeHelper.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,15 @@ class RegBankLegalizeHelper {
SmallSet<Register, 4> &SgprWaterfallOperandRegs);

void lowerVccExtToSel(MachineInstr &MI);
std::pair<Register, Register> unpackZExt(Register Reg);
std::pair<Register, Register> unpackSExt(Register Reg);
std::pair<Register, Register> unpackAExt(Register Reg);
void lowerUnpack(MachineInstr &MI);
void lowerDiv_BFE(MachineInstr &MI);
void lowerUni_BFE(MachineInstr &MI);
void lowerSplitTo32(MachineInstr &MI);
void lowerSplitTo32Sel(MachineInstr &MI);
void lowerSplitTo32SExtInReg(MachineInstr &MI);
};

} // end namespace AMDGPU
Expand Down
43 changes: 41 additions & 2 deletions llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ bool matchUniformityAndLLT(Register Reg, UniformityLLTOpPredicateID UniID,
return MRI.getType(Reg) == LLT::pointer(4, 64);
case P5:
return MRI.getType(Reg) == LLT::pointer(5, 32);
case V2S32:
return MRI.getType(Reg) == LLT::fixed_vector(2, 32);
case V4S32:
return MRI.getType(Reg) == LLT::fixed_vector(4, 32);
case B32:
Expand Down Expand Up @@ -92,6 +94,8 @@ bool matchUniformityAndLLT(Register Reg, UniformityLLTOpPredicateID UniID,
return MRI.getType(Reg) == LLT::pointer(4, 64) && MUI.isUniform(Reg);
case UniP5:
return MRI.getType(Reg) == LLT::pointer(5, 32) && MUI.isUniform(Reg);
case UniV2S16:
return MRI.getType(Reg) == LLT::fixed_vector(2, 16) && MUI.isUniform(Reg);
case UniB32:
return MRI.getType(Reg).getSizeInBits() == 32 && MUI.isUniform(Reg);
case UniB64:
Expand Down Expand Up @@ -122,6 +126,8 @@ bool matchUniformityAndLLT(Register Reg, UniformityLLTOpPredicateID UniID,
return MRI.getType(Reg) == LLT::pointer(4, 64) && MUI.isDivergent(Reg);
case DivP5:
return MRI.getType(Reg) == LLT::pointer(5, 32) && MUI.isDivergent(Reg);
case DivV2S16:
return MRI.getType(Reg) == LLT::fixed_vector(2, 16) && MUI.isDivergent(Reg);
case DivB32:
return MRI.getType(Reg).getSizeInBits() == 32 && MUI.isDivergent(Reg);
case DivB64:
Expand Down Expand Up @@ -434,7 +440,7 @@ RegBankLegalizeRules::RegBankLegalizeRules(const GCNSubtarget &_ST,
MachineRegisterInfo &_MRI)
: ST(&_ST), MRI(&_MRI) {

addRulesForGOpcs({G_ADD}, Standard)
addRulesForGOpcs({G_ADD, G_SUB}, Standard)
.Uni(S32, {{Sgpr32}, {Sgpr32, Sgpr32}})
.Div(S32, {{Vgpr32}, {Vgpr32, Vgpr32}});

Expand All @@ -451,11 +457,36 @@ RegBankLegalizeRules::RegBankLegalizeRules(const GCNSubtarget &_ST,
.Div(B64, {{VgprB64}, {VgprB64, VgprB64}, SplitTo32});

addRulesForGOpcs({G_SHL}, Standard)
.Uni(S16, {{Sgpr32Trunc}, {Sgpr32AExt, Sgpr32ZExt}})
.Div(S16, {{Vgpr16}, {Vgpr16, Vgpr16}})
.Uni(V2S16, {{SgprV2S16}, {SgprV2S16, SgprV2S16}, Unpack})
.Div(V2S16, {{VgprV2S16}, {VgprV2S16, VgprV2S16}})
.Uni(S32, {{Sgpr32}, {Sgpr32, Sgpr32}})
.Uni(S64, {{Sgpr64}, {Sgpr64, Sgpr32}})
.Div(S32, {{Vgpr32}, {Vgpr32, Vgpr32}})
.Div(S64, {{Vgpr64}, {Vgpr64, Vgpr32}});

addRulesForGOpcs({G_LSHR}, Standard)
.Uni(S16, {{Sgpr32Trunc}, {Sgpr32ZExt, Sgpr32ZExt}})
.Div(S16, {{Vgpr16}, {Vgpr16, Vgpr16}})
.Uni(V2S16, {{SgprV2S16}, {SgprV2S16, SgprV2S16}, Unpack})
.Div(V2S16, {{VgprV2S16}, {VgprV2S16, VgprV2S16}})
.Uni(S32, {{Sgpr32}, {Sgpr32, Sgpr32}})
.Uni(S64, {{Sgpr64}, {Sgpr64, Sgpr32}})
.Div(S32, {{Vgpr32}, {Vgpr32, Vgpr32}})
.Div(S64, {{Vgpr64}, {Vgpr64, Vgpr32}});

addRulesForGOpcs({G_LSHR}, Standard).Uni(S32, {{Sgpr32}, {Sgpr32, Sgpr32}});
addRulesForGOpcs({G_ASHR}, Standard)
.Uni(S16, {{Sgpr32Trunc}, {Sgpr32SExt, Sgpr32ZExt}})
.Div(S16, {{Vgpr16}, {Vgpr16, Vgpr16}})
.Uni(V2S16, {{SgprV2S16}, {SgprV2S16, SgprV2S16}, Unpack})
.Div(V2S16, {{VgprV2S16}, {VgprV2S16, VgprV2S16}})
.Uni(S32, {{Sgpr32}, {Sgpr32, Sgpr32}})
.Uni(S64, {{Sgpr64}, {Sgpr64, Sgpr32}})
.Div(S32, {{Vgpr32}, {Vgpr32, Vgpr32}})
.Div(S64, {{Vgpr64}, {Vgpr64, Vgpr32}});

addRulesForGOpcs({G_FRAME_INDEX}).Any({{UniP5, _}, {{SgprP5}, {None}}});

addRulesForGOpcs({G_UBFX, G_SBFX}, Standard)
.Uni(S32, {{Sgpr32}, {Sgpr32, Sgpr32, Sgpr32}, Uni_BFE})
Expand Down Expand Up @@ -514,6 +545,8 @@ RegBankLegalizeRules::RegBankLegalizeRules(const GCNSubtarget &_ST,
.Any({{DivS16, S32}, {{Vgpr16}, {Vgpr32}}})
.Any({{UniS32, S64}, {{Sgpr32}, {Sgpr64}}})
.Any({{DivS32, S64}, {{Vgpr32}, {Vgpr64}}})
.Any({{UniV2S16, V2S32}, {{SgprV2S16}, {SgprV2S32}}})
.Any({{DivV2S16, V2S32}, {{VgprV2S16}, {VgprV2S32}}})
// This is non-trivial. VgprToVccCopy is done using compare instruction.
.Any({{DivS1, DivS16}, {{Vcc}, {Vgpr16}, VgprToVccCopy}})
.Any({{DivS1, DivS32}, {{Vcc}, {Vgpr32}, VgprToVccCopy}})
Expand Down Expand Up @@ -549,6 +582,12 @@ RegBankLegalizeRules::RegBankLegalizeRules(const GCNSubtarget &_ST,
.Any({{UniS32, S16}, {{Sgpr32}, {Sgpr16}}})
.Any({{DivS32, S16}, {{Vgpr32}, {Vgpr16}}});

addRulesForGOpcs({G_SEXT_INREG})
.Any({{UniS32, S32}, {{Sgpr32}, {Sgpr32}}})
.Any({{DivS32, S32}, {{Vgpr32}, {Vgpr32}}})
.Any({{UniS64, S64}, {{Sgpr64}, {Sgpr64}}})
.Any({{DivS64, S64}, {{Vgpr64}, {Vgpr64}, SplitTo32SExtInReg}});

bool hasUnalignedLoads = ST->getGeneration() >= AMDGPUSubtarget::GFX12;
bool hasSMRDSmall = ST->hasScalarSubwordLoads();

Expand Down
11 changes: 11 additions & 0 deletions llvm/lib/Target/AMDGPU/AMDGPURegBankLegalizeRules.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ enum UniformityLLTOpPredicateID {
V3S32,
V4S32,

UniV2S16,

DivV2S16,

// B types
B32,
B64,
Expand Down Expand Up @@ -117,7 +121,9 @@ enum RegBankLLTMappingApplyID {
SgprP3,
SgprP4,
SgprP5,
SgprV2S16,
SgprV4S32,
SgprV2S32,
SgprB32,
SgprB64,
SgprB96,
Expand All @@ -134,6 +140,8 @@ enum RegBankLLTMappingApplyID {
VgprP3,
VgprP4,
VgprP5,
VgprV2S16,
VgprV2S32,
VgprB32,
VgprB64,
VgprB96,
Expand All @@ -145,6 +153,7 @@ enum RegBankLLTMappingApplyID {
// Dst only modifiers: read-any-lane and truncs
UniInVcc,
UniInVgprS32,
UniInVgprV2S16,
UniInVgprV4S32,
UniInVgprB32,
UniInVgprB64,
Expand Down Expand Up @@ -173,11 +182,13 @@ enum LoweringMethodID {
DoNotLower,
VccExtToSel,
UniExtToSel,
Unpack,
Uni_BFE,
Div_BFE,
VgprToVccCopy,
SplitTo32,
SplitTo32Sel,
SplitTo32SExtInReg,
Ext32To64,
UniCstExt,
SplitLoad,
Expand Down
10 changes: 5 additions & 5 deletions llvm/test/CodeGen/AMDGPU/GlobalISel/ashr.ll
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
; RUN: llc -global-isel -mtriple=amdgcn-amd-amdpal -mcpu=tahiti < %s | FileCheck -check-prefixes=GCN,GFX6 %s
; RUN: llc -global-isel -mtriple=amdgcn-amd-amdpal -mcpu=fiji < %s | FileCheck -check-prefixes=GCN,GFX8 %s
; RUN: llc -global-isel -mtriple=amdgcn-amd-amdpal -mcpu=gfx900 < %s | FileCheck -check-prefixes=GCN,GFX9 %s
; RUN: llc -global-isel -mtriple=amdgcn-amd-amdpal -mcpu=gfx1010 < %s | FileCheck -check-prefixes=GFX10PLUS,GFX10 %s
; RUN: llc -global-isel -mtriple=amdgcn-amd-amdpal -mcpu=gfx1100 -amdgpu-enable-delay-alu=0 < %s | FileCheck -check-prefixes=GFX10PLUS,GFX11 %s
; RUN: llc -global-isel -new-reg-bank-select -mtriple=amdgcn-amd-amdpal -mcpu=tahiti < %s | FileCheck -check-prefixes=GCN,GFX6 %s
; RUN: llc -global-isel -new-reg-bank-select -mtriple=amdgcn-amd-amdpal -mcpu=fiji < %s | FileCheck -check-prefixes=GCN,GFX8 %s
; RUN: llc -global-isel -new-reg-bank-select -mtriple=amdgcn-amd-amdpal -mcpu=gfx900 < %s | FileCheck -check-prefixes=GCN,GFX9 %s
; RUN: llc -global-isel -new-reg-bank-select -mtriple=amdgcn-amd-amdpal -mcpu=gfx1010 < %s | FileCheck -check-prefixes=GFX10PLUS,GFX10 %s
; RUN: llc -global-isel -new-reg-bank-select -mtriple=amdgcn-amd-amdpal -mcpu=gfx1100 -amdgpu-enable-delay-alu=0 < %s | FileCheck -check-prefixes=GFX10PLUS,GFX11 %s

define i8 @v_ashr_i8(i8 %value, i8 %amount) {
; GFX6-LABEL: v_ashr_i8:
Expand Down
Loading
Loading