Skip to content

Commit 18aaec5

Browse files
committed
[SDAG] Fixups required for InferAS change
1 parent c57002f commit 18aaec5

File tree

6 files changed

+72
-50
lines changed

6 files changed

+72
-50
lines changed

llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp

+6
Original file line numberDiff line numberDiff line change
@@ -954,6 +954,12 @@ static void AddNodeIDCustom(FoldingSetNodeID &ID, const SDNode *N) {
954954
ID.AddInteger(M);
955955
break;
956956
}
957+
case ISD::ADDRSPACECAST: {
958+
const AddrSpaceCastSDNode *ASC = cast<AddrSpaceCastSDNode>(N);
959+
ID.AddInteger(ASC->getSrcAddressSpace());
960+
ID.AddInteger(ASC->getDestAddressSpace());
961+
break;
962+
}
957963
case ISD::TargetBlockAddress:
958964
case ISD::BlockAddress: {
959965
const BlockAddressSDNode *BA = cast<BlockAddressSDNode>(N);

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

+27-21
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "llvm/Support/ErrorHandling.h"
2525
#include "llvm/Support/FormatVariadic.h"
2626
#include "llvm/Target/TargetIntrinsicInfo.h"
27+
#include <optional>
2728

2829
using namespace llvm;
2930

@@ -334,29 +335,34 @@ bool NVPTXDAGToDAGISel::tryEXTRACT_VECTOR_ELEMENT(SDNode *N) {
334335
return true;
335336
}
336337

337-
static unsigned int getCodeAddrSpace(MemSDNode *N) {
338-
const Value *Src = N->getMemOperand()->getValue();
339-
340-
if (!Src)
338+
static std::optional<unsigned> convertAS(unsigned AS) {
339+
switch (AS) {
340+
case llvm::ADDRESS_SPACE_LOCAL:
341+
return NVPTX::AddressSpace::Local;
342+
case llvm::ADDRESS_SPACE_GLOBAL:
343+
return NVPTX::AddressSpace::Global;
344+
case llvm::ADDRESS_SPACE_SHARED:
345+
return NVPTX::AddressSpace::Shared;
346+
case llvm::ADDRESS_SPACE_GENERIC:
341347
return NVPTX::AddressSpace::Generic;
342-
343-
if (auto *PT = dyn_cast<PointerType>(Src->getType())) {
344-
switch (PT->getAddressSpace()) {
345-
case llvm::ADDRESS_SPACE_LOCAL:
346-
return NVPTX::AddressSpace::Local;
347-
case llvm::ADDRESS_SPACE_GLOBAL:
348-
return NVPTX::AddressSpace::Global;
349-
case llvm::ADDRESS_SPACE_SHARED:
350-
return NVPTX::AddressSpace::Shared;
351-
case llvm::ADDRESS_SPACE_GENERIC:
352-
return NVPTX::AddressSpace::Generic;
353-
case llvm::ADDRESS_SPACE_PARAM:
354-
return NVPTX::AddressSpace::Param;
355-
case llvm::ADDRESS_SPACE_CONST:
356-
return NVPTX::AddressSpace::Const;
357-
default: break;
358-
}
348+
case llvm::ADDRESS_SPACE_PARAM:
349+
return NVPTX::AddressSpace::Param;
350+
case llvm::ADDRESS_SPACE_CONST:
351+
return NVPTX::AddressSpace::Const;
352+
default:
353+
return std::nullopt;
359354
}
355+
}
356+
357+
static unsigned int getCodeAddrSpace(const MemSDNode *N) {
358+
if (const Value *Src = N->getMemOperand()->getValue())
359+
if (auto *PT = dyn_cast<PointerType>(Src->getType()))
360+
if (auto AS = convertAS(PT->getAddressSpace()))
361+
return AS.value();
362+
363+
if (auto AS = convertAS(N->getMemOperand()->getAddrSpace()))
364+
return AS.value();
365+
360366
return NVPTX::AddressSpace::Generic;
361367
}
362368

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

+17-4
Original file line numberDiff line numberDiff line change
@@ -1389,6 +1389,18 @@ static bool shouldConvertToIndirectCall(const CallBase *CB,
13891389
return false;
13901390
}
13911391

1392+
static MachinePointerInfo refinePtrAS(SDValue &Ptr, SelectionDAG &DAG, const DataLayout &DL,
1393+
const TargetLowering &TL) {
1394+
if (Ptr->getOpcode() == ISD::FrameIndex) {
1395+
auto Ty = TL.getPointerTy(DL, ADDRESS_SPACE_GENERIC);
1396+
Ptr = DAG.getAddrSpaceCast(SDLoc(), Ty, Ptr, ADDRESS_SPACE_GENERIC,
1397+
ADDRESS_SPACE_LOCAL);
1398+
1399+
return MachinePointerInfo(ADDRESS_SPACE_LOCAL);
1400+
}
1401+
return MachinePointerInfo();
1402+
}
1403+
13921404
SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
13931405
SmallVectorImpl<SDValue> &InVals) const {
13941406

@@ -1553,11 +1565,12 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
15531565
}
15541566

15551567
if (IsByVal) {
1556-
auto PtrVT = getPointerTy(DL);
1557-
SDValue srcAddr = DAG.getNode(ISD::ADD, dl, PtrVT, StVal,
1568+
auto MPI = refinePtrAS(StVal, DAG, DL, *this);
1569+
const EVT PtrVT = StVal.getValueType();
1570+
SDValue SrcAddr = DAG.getNode(ISD::ADD, dl, PtrVT, StVal,
15581571
DAG.getConstant(CurOffset, dl, PtrVT));
1559-
StVal = DAG.getLoad(EltVT, dl, TempChain, srcAddr, MachinePointerInfo(),
1560-
PartAlign);
1572+
1573+
StVal = DAG.getLoad(EltVT, dl, TempChain, SrcAddr, MPI, PartAlign);
15611574
} else if (ExtendIntegerParam) {
15621575
assert(VTs.size() == 1 && "Scalar can't have multiple parts.");
15631576
// zext/sext to i32

llvm/test/CodeGen/NVPTX/indirect_byval.ll

+11-9
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,20 @@ define internal i32 @foo() {
1717
; CHECK-NEXT: .reg .b64 %SPL;
1818
; CHECK-NEXT: .reg .b16 %rs<2>;
1919
; CHECK-NEXT: .reg .b32 %r<3>;
20-
; CHECK-NEXT: .reg .b64 %rd<3>;
20+
; CHECK-NEXT: .reg .b64 %rd<5>;
2121
; CHECK-EMPTY:
2222
; CHECK-NEXT: // %bb.0: // %entry
2323
; CHECK-NEXT: mov.u64 %SPL, __local_depot0;
2424
; CHECK-NEXT: cvta.local.u64 %SP, %SPL;
2525
; CHECK-NEXT: ld.global.u64 %rd1, [ptr];
26-
; CHECK-NEXT: ld.u8 %rs1, [%SP+1];
27-
; CHECK-NEXT: add.u64 %rd2, %SP, 0;
26+
; CHECK-NEXT: add.u64 %rd3, %SPL, 1;
27+
; CHECK-NEXT: ld.local.u8 %rs1, [%rd3];
28+
; CHECK-NEXT: add.u64 %rd4, %SP, 0;
2829
; CHECK-NEXT: { // callseq 0, 0
2930
; CHECK-NEXT: .param .align 1 .b8 param0[1];
3031
; CHECK-NEXT: st.param.b8 [param0], %rs1;
3132
; CHECK-NEXT: .param .b64 param1;
32-
; CHECK-NEXT: st.param.b64 [param1], %rd2;
33+
; CHECK-NEXT: st.param.b64 [param1], %rd4;
3334
; CHECK-NEXT: .param .b32 retval0;
3435
; CHECK-NEXT: prototype_0 : .callprototype (.param .b32 _) _ (.param .align 1 .b8 _[1], .param .b64 _);
3536
; CHECK-NEXT: call (retval0),
@@ -59,19 +60,20 @@ define internal i32 @bar() {
5960
; CHECK-NEXT: .reg .b64 %SP;
6061
; CHECK-NEXT: .reg .b64 %SPL;
6162
; CHECK-NEXT: .reg .b32 %r<3>;
62-
; CHECK-NEXT: .reg .b64 %rd<4>;
63+
; CHECK-NEXT: .reg .b64 %rd<6>;
6364
; CHECK-EMPTY:
6465
; CHECK-NEXT: // %bb.0: // %entry
6566
; CHECK-NEXT: mov.u64 %SPL, __local_depot1;
6667
; CHECK-NEXT: cvta.local.u64 %SP, %SPL;
6768
; CHECK-NEXT: ld.global.u64 %rd1, [ptr];
68-
; CHECK-NEXT: ld.u64 %rd2, [%SP+8];
69-
; CHECK-NEXT: add.u64 %rd3, %SP, 0;
69+
; CHECK-NEXT: add.u64 %rd3, %SPL, 8;
70+
; CHECK-NEXT: ld.local.u64 %rd4, [%rd3];
71+
; CHECK-NEXT: add.u64 %rd5, %SP, 0;
7072
; CHECK-NEXT: { // callseq 1, 0
7173
; CHECK-NEXT: .param .align 8 .b8 param0[8];
72-
; CHECK-NEXT: st.param.b64 [param0], %rd2;
74+
; CHECK-NEXT: st.param.b64 [param0], %rd4;
7375
; CHECK-NEXT: .param .b64 param1;
74-
; CHECK-NEXT: st.param.b64 [param1], %rd3;
76+
; CHECK-NEXT: st.param.b64 [param1], %rd5;
7577
; CHECK-NEXT: .param .b32 retval0;
7678
; CHECK-NEXT: prototype_1 : .callprototype (.param .b32 _) _ (.param .align 8 .b8 _[8], .param .b64 _);
7779
; CHECK-NEXT: call (retval0),

llvm/test/CodeGen/NVPTX/variadics-backend.ll

+2-2
Original file line numberDiff line numberDiff line change
@@ -397,8 +397,8 @@ define dso_local void @qux() {
397397
; CHECK-PTX-NEXT: st.local.u64 [%rd2+8], %rd6;
398398
; CHECK-PTX-NEXT: mov.b64 %rd7, 1;
399399
; CHECK-PTX-NEXT: st.u64 [%SP+16], %rd7;
400-
; CHECK-PTX-NEXT: ld.u64 %rd8, [%SP];
401-
; CHECK-PTX-NEXT: ld.u64 %rd9, [%SP+8];
400+
; CHECK-PTX-NEXT: ld.local.u64 %rd8, [%rd2];
401+
; CHECK-PTX-NEXT: ld.local.u64 %rd9, [%rd2+8];
402402
; CHECK-PTX-NEXT: add.u64 %rd10, %SP, 16;
403403
; CHECK-PTX-NEXT: { // callseq 3, 0
404404
; CHECK-PTX-NEXT: .param .align 8 .b8 param0[16];

llvm/test/tools/UpdateTestChecks/update_llc_test_checks/Inputs/nvptx-basic.ll.expected

+9-14
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,10 @@ define dso_local void @caller_St8x4(ptr nocapture noundef readonly byval(%struct
1010
; CHECK-NEXT: .reg .b32 %SP;
1111
; CHECK-NEXT: .reg .b32 %SPL;
1212
; CHECK-NEXT: .reg .b32 %r<4>;
13-
; CHECK-NEXT: .reg .b64 %rd<17>;
13+
; CHECK-NEXT: .reg .b64 %rd<13>;
1414
; CHECK-EMPTY:
1515
; CHECK-NEXT: // %bb.0:
1616
; CHECK-NEXT: mov.u32 %SPL, __local_depot0;
17-
; CHECK-NEXT: cvta.local.u32 %SP, %SPL;
1817
; CHECK-NEXT: ld.param.u32 %r1, [caller_St8x4_param_1];
1918
; CHECK-NEXT: add.u32 %r3, %SPL, 0;
2019
; CHECK-NEXT: ld.param.u64 %rd1, [caller_St8x4_param_0+24];
@@ -25,27 +24,23 @@ define dso_local void @caller_St8x4(ptr nocapture noundef readonly byval(%struct
2524
; CHECK-NEXT: st.local.u64 [%r3+8], %rd3;
2625
; CHECK-NEXT: ld.param.u64 %rd4, [caller_St8x4_param_0];
2726
; CHECK-NEXT: st.local.u64 [%r3], %rd4;
28-
; CHECK-NEXT: ld.u64 %rd5, [%SP+8];
29-
; CHECK-NEXT: ld.u64 %rd6, [%SP];
30-
; CHECK-NEXT: ld.u64 %rd7, [%SP+24];
31-
; CHECK-NEXT: ld.u64 %rd8, [%SP+16];
3227
; CHECK-NEXT: { // callseq 0, 0
3328
; CHECK-NEXT: .param .align 16 .b8 param0[32];
34-
; CHECK-NEXT: st.param.v2.b64 [param0], {%rd6, %rd5};
35-
; CHECK-NEXT: st.param.v2.b64 [param0+16], {%rd8, %rd7};
29+
; CHECK-NEXT: st.param.v2.b64 [param0], {%rd4, %rd3};
30+
; CHECK-NEXT: st.param.v2.b64 [param0+16], {%rd2, %rd1};
3631
; CHECK-NEXT: .param .align 16 .b8 retval0[32];
3732
; CHECK-NEXT: call.uni (retval0),
3833
; CHECK-NEXT: callee_St8x4,
3934
; CHECK-NEXT: (
4035
; CHECK-NEXT: param0
4136
; CHECK-NEXT: );
42-
; CHECK-NEXT: ld.param.v2.b64 {%rd9, %rd10}, [retval0];
43-
; CHECK-NEXT: ld.param.v2.b64 {%rd11, %rd12}, [retval0+16];
37+
; CHECK-NEXT: ld.param.v2.b64 {%rd5, %rd6}, [retval0];
38+
; CHECK-NEXT: ld.param.v2.b64 {%rd7, %rd8}, [retval0+16];
4439
; CHECK-NEXT: } // callseq 0
45-
; CHECK-NEXT: st.u64 [%r1], %rd9;
46-
; CHECK-NEXT: st.u64 [%r1+8], %rd10;
47-
; CHECK-NEXT: st.u64 [%r1+16], %rd11;
48-
; CHECK-NEXT: st.u64 [%r1+24], %rd12;
40+
; CHECK-NEXT: st.u64 [%r1], %rd5;
41+
; CHECK-NEXT: st.u64 [%r1+8], %rd6;
42+
; CHECK-NEXT: st.u64 [%r1+16], %rd7;
43+
; CHECK-NEXT: st.u64 [%r1+24], %rd8;
4944
; CHECK-NEXT: ret;
5045
%call = tail call fastcc [4 x i64] @callee_St8x4(ptr noundef nonnull byval(%struct.St8x4) align 8 %in) #2
5146
%.fca.0.extract = extractvalue [4 x i64] %call, 0

0 commit comments

Comments
 (0)