Skip to content

Commit 62671f5

Browse files
committed
[SDAG] Fixups required for InferAS change
1 parent 5d91364 commit 62671f5

File tree

5 files changed

+67
-50
lines changed

5 files changed

+67
-50
lines changed

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

+27-21
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include "llvm/Support/ErrorHandling.h"
2626
#include "llvm/Support/FormatVariadic.h"
2727
#include "llvm/Target/TargetIntrinsicInfo.h"
28+
#include <optional>
2829

2930
using namespace llvm;
3031

@@ -341,29 +342,34 @@ bool NVPTXDAGToDAGISel::tryEXTRACT_VECTOR_ELEMENT(SDNode *N) {
341342
return true;
342343
}
343344

344-
static unsigned int getCodeAddrSpace(MemSDNode *N) {
345-
const Value *Src = N->getMemOperand()->getValue();
346-
347-
if (!Src)
345+
static std::optional<unsigned> convertAS(unsigned AS) {
346+
switch (AS) {
347+
case llvm::ADDRESS_SPACE_LOCAL:
348+
return NVPTX::AddressSpace::Local;
349+
case llvm::ADDRESS_SPACE_GLOBAL:
350+
return NVPTX::AddressSpace::Global;
351+
case llvm::ADDRESS_SPACE_SHARED:
352+
return NVPTX::AddressSpace::Shared;
353+
case llvm::ADDRESS_SPACE_GENERIC:
348354
return NVPTX::AddressSpace::Generic;
349-
350-
if (auto *PT = dyn_cast<PointerType>(Src->getType())) {
351-
switch (PT->getAddressSpace()) {
352-
case llvm::ADDRESS_SPACE_LOCAL:
353-
return NVPTX::AddressSpace::Local;
354-
case llvm::ADDRESS_SPACE_GLOBAL:
355-
return NVPTX::AddressSpace::Global;
356-
case llvm::ADDRESS_SPACE_SHARED:
357-
return NVPTX::AddressSpace::Shared;
358-
case llvm::ADDRESS_SPACE_GENERIC:
359-
return NVPTX::AddressSpace::Generic;
360-
case llvm::ADDRESS_SPACE_PARAM:
361-
return NVPTX::AddressSpace::Param;
362-
case llvm::ADDRESS_SPACE_CONST:
363-
return NVPTX::AddressSpace::Const;
364-
default: break;
365-
}
355+
case llvm::ADDRESS_SPACE_PARAM:
356+
return NVPTX::AddressSpace::Param;
357+
case llvm::ADDRESS_SPACE_CONST:
358+
return NVPTX::AddressSpace::Const;
359+
default:
360+
return std::nullopt;
366361
}
362+
}
363+
364+
static unsigned int getCodeAddrSpace(const MemSDNode *N) {
365+
if (const Value *Src = N->getMemOperand()->getValue())
366+
if (auto *PT = dyn_cast<PointerType>(Src->getType()))
367+
if (auto AS = convertAS(PT->getAddressSpace()))
368+
return AS.value();
369+
370+
if (auto AS = convertAS(N->getMemOperand()->getAddrSpace()))
371+
return AS.value();
372+
367373
return NVPTX::AddressSpace::Generic;
368374
}
369375

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

+18-4
Original file line numberDiff line numberDiff line change
@@ -1408,6 +1408,19 @@ static bool shouldConvertToIndirectCall(const CallBase *CB,
14081408
return false;
14091409
}
14101410

1411+
static MachinePointerInfo refinePtrAS(SDValue &Ptr, SelectionDAG &DAG,
1412+
const DataLayout &DL,
1413+
const TargetLowering &TL) {
1414+
if (Ptr->getOpcode() == ISD::FrameIndex) {
1415+
auto Ty = TL.getPointerTy(DL, ADDRESS_SPACE_LOCAL);
1416+
Ptr = DAG.getAddrSpaceCast(SDLoc(), Ty, Ptr, ADDRESS_SPACE_GENERIC,
1417+
ADDRESS_SPACE_LOCAL);
1418+
1419+
return MachinePointerInfo(ADDRESS_SPACE_LOCAL);
1420+
}
1421+
return MachinePointerInfo();
1422+
}
1423+
14111424
SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
14121425
SmallVectorImpl<SDValue> &InVals) const {
14131426

@@ -1572,11 +1585,12 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
15721585
}
15731586

15741587
if (IsByVal) {
1575-
auto PtrVT = getPointerTy(DL);
1576-
SDValue srcAddr = DAG.getNode(ISD::ADD, dl, PtrVT, StVal,
1588+
auto MPI = refinePtrAS(StVal, DAG, DL, *this);
1589+
const EVT PtrVT = StVal.getValueType();
1590+
SDValue SrcAddr = DAG.getNode(ISD::ADD, dl, PtrVT, StVal,
15771591
DAG.getConstant(CurOffset, dl, PtrVT));
1578-
StVal = DAG.getLoad(EltVT, dl, TempChain, srcAddr, MachinePointerInfo(),
1579-
PartAlign);
1592+
1593+
StVal = DAG.getLoad(EltVT, dl, TempChain, SrcAddr, MPI, PartAlign);
15801594
} else if (ExtendIntegerParam) {
15811595
assert(VTs.size() == 1 && "Scalar can't have multiple parts.");
15821596
// zext/sext to i32

llvm/test/CodeGen/NVPTX/indirect_byval.ll

+11-9
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,20 @@ define internal i32 @foo() {
1717
; CHECK-NEXT: .reg .b64 %SPL;
1818
; CHECK-NEXT: .reg .b16 %rs<2>;
1919
; CHECK-NEXT: .reg .b32 %r<3>;
20-
; CHECK-NEXT: .reg .b64 %rd<3>;
20+
; CHECK-NEXT: .reg .b64 %rd<5>;
2121
; CHECK-EMPTY:
2222
; CHECK-NEXT: // %bb.0: // %entry
2323
; CHECK-NEXT: mov.u64 %SPL, __local_depot0;
2424
; CHECK-NEXT: cvta.local.u64 %SP, %SPL;
2525
; CHECK-NEXT: ld.global.u64 %rd1, [ptr];
26-
; CHECK-NEXT: ld.u8 %rs1, [%SP+1];
27-
; CHECK-NEXT: add.u64 %rd2, %SP, 0;
26+
; CHECK-NEXT: add.u64 %rd3, %SPL, 1;
27+
; CHECK-NEXT: ld.local.u8 %rs1, [%rd3];
28+
; CHECK-NEXT: add.u64 %rd4, %SP, 0;
2829
; CHECK-NEXT: { // callseq 0, 0
2930
; CHECK-NEXT: .param .align 1 .b8 param0[1];
3031
; CHECK-NEXT: st.param.b8 [param0], %rs1;
3132
; CHECK-NEXT: .param .b64 param1;
32-
; CHECK-NEXT: st.param.b64 [param1], %rd2;
33+
; CHECK-NEXT: st.param.b64 [param1], %rd4;
3334
; CHECK-NEXT: .param .b32 retval0;
3435
; CHECK-NEXT: prototype_0 : .callprototype (.param .b32 _) _ (.param .align 1 .b8 _[1], .param .b64 _);
3536
; CHECK-NEXT: call (retval0),
@@ -59,19 +60,20 @@ define internal i32 @bar() {
5960
; CHECK-NEXT: .reg .b64 %SP;
6061
; CHECK-NEXT: .reg .b64 %SPL;
6162
; CHECK-NEXT: .reg .b32 %r<3>;
62-
; CHECK-NEXT: .reg .b64 %rd<4>;
63+
; CHECK-NEXT: .reg .b64 %rd<6>;
6364
; CHECK-EMPTY:
6465
; CHECK-NEXT: // %bb.0: // %entry
6566
; CHECK-NEXT: mov.u64 %SPL, __local_depot1;
6667
; CHECK-NEXT: cvta.local.u64 %SP, %SPL;
6768
; CHECK-NEXT: ld.global.u64 %rd1, [ptr];
68-
; CHECK-NEXT: ld.u64 %rd2, [%SP+8];
69-
; CHECK-NEXT: add.u64 %rd3, %SP, 0;
69+
; CHECK-NEXT: add.u64 %rd3, %SPL, 8;
70+
; CHECK-NEXT: ld.local.u64 %rd4, [%rd3];
71+
; CHECK-NEXT: add.u64 %rd5, %SP, 0;
7072
; CHECK-NEXT: { // callseq 1, 0
7173
; CHECK-NEXT: .param .align 8 .b8 param0[8];
72-
; CHECK-NEXT: st.param.b64 [param0], %rd2;
74+
; CHECK-NEXT: st.param.b64 [param0], %rd4;
7375
; CHECK-NEXT: .param .b64 param1;
74-
; CHECK-NEXT: st.param.b64 [param1], %rd3;
76+
; CHECK-NEXT: st.param.b64 [param1], %rd5;
7577
; CHECK-NEXT: .param .b32 retval0;
7678
; CHECK-NEXT: prototype_1 : .callprototype (.param .b32 _) _ (.param .align 8 .b8 _[8], .param .b64 _);
7779
; CHECK-NEXT: call (retval0),

llvm/test/CodeGen/NVPTX/variadics-backend.ll

+2-2
Original file line numberDiff line numberDiff line change
@@ -397,8 +397,8 @@ define dso_local void @qux() {
397397
; CHECK-PTX-NEXT: st.local.u64 [%rd2+8], %rd6;
398398
; CHECK-PTX-NEXT: mov.b64 %rd7, 1;
399399
; CHECK-PTX-NEXT: st.u64 [%SP+16], %rd7;
400-
; CHECK-PTX-NEXT: ld.u64 %rd8, [%SP];
401-
; CHECK-PTX-NEXT: ld.u64 %rd9, [%SP+8];
400+
; CHECK-PTX-NEXT: ld.local.u64 %rd8, [%rd2];
401+
; CHECK-PTX-NEXT: ld.local.u64 %rd9, [%rd2+8];
402402
; CHECK-PTX-NEXT: add.u64 %rd10, %SP, 16;
403403
; CHECK-PTX-NEXT: { // callseq 3, 0
404404
; CHECK-PTX-NEXT: .param .align 8 .b8 param0[16];

llvm/test/tools/UpdateTestChecks/update_llc_test_checks/Inputs/nvptx-basic.ll.expected

+9-14
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,10 @@ define dso_local void @caller_St8x4(ptr nocapture noundef readonly byval(%struct
1010
; CHECK-NEXT: .reg .b32 %SP;
1111
; CHECK-NEXT: .reg .b32 %SPL;
1212
; CHECK-NEXT: .reg .b32 %r<4>;
13-
; CHECK-NEXT: .reg .b64 %rd<17>;
13+
; CHECK-NEXT: .reg .b64 %rd<13>;
1414
; CHECK-EMPTY:
1515
; CHECK-NEXT: // %bb.0:
1616
; CHECK-NEXT: mov.u32 %SPL, __local_depot0;
17-
; CHECK-NEXT: cvta.local.u32 %SP, %SPL;
1817
; CHECK-NEXT: ld.param.u32 %r1, [caller_St8x4_param_1];
1918
; CHECK-NEXT: add.u32 %r3, %SPL, 0;
2019
; CHECK-NEXT: ld.param.u64 %rd1, [caller_St8x4_param_0+24];
@@ -25,27 +24,23 @@ define dso_local void @caller_St8x4(ptr nocapture noundef readonly byval(%struct
2524
; CHECK-NEXT: st.local.u64 [%r3+8], %rd3;
2625
; CHECK-NEXT: ld.param.u64 %rd4, [caller_St8x4_param_0];
2726
; CHECK-NEXT: st.local.u64 [%r3], %rd4;
28-
; CHECK-NEXT: ld.u64 %rd5, [%SP+8];
29-
; CHECK-NEXT: ld.u64 %rd6, [%SP];
30-
; CHECK-NEXT: ld.u64 %rd7, [%SP+24];
31-
; CHECK-NEXT: ld.u64 %rd8, [%SP+16];
3227
; CHECK-NEXT: { // callseq 0, 0
3328
; CHECK-NEXT: .param .align 16 .b8 param0[32];
34-
; CHECK-NEXT: st.param.v2.b64 [param0], {%rd6, %rd5};
35-
; CHECK-NEXT: st.param.v2.b64 [param0+16], {%rd8, %rd7};
29+
; CHECK-NEXT: st.param.v2.b64 [param0], {%rd4, %rd3};
30+
; CHECK-NEXT: st.param.v2.b64 [param0+16], {%rd2, %rd1};
3631
; CHECK-NEXT: .param .align 16 .b8 retval0[32];
3732
; CHECK-NEXT: call.uni (retval0),
3833
; CHECK-NEXT: callee_St8x4,
3934
; CHECK-NEXT: (
4035
; CHECK-NEXT: param0
4136
; CHECK-NEXT: );
42-
; CHECK-NEXT: ld.param.v2.b64 {%rd9, %rd10}, [retval0];
43-
; CHECK-NEXT: ld.param.v2.b64 {%rd11, %rd12}, [retval0+16];
37+
; CHECK-NEXT: ld.param.v2.b64 {%rd5, %rd6}, [retval0];
38+
; CHECK-NEXT: ld.param.v2.b64 {%rd7, %rd8}, [retval0+16];
4439
; CHECK-NEXT: } // callseq 0
45-
; CHECK-NEXT: st.u64 [%r1], %rd9;
46-
; CHECK-NEXT: st.u64 [%r1+8], %rd10;
47-
; CHECK-NEXT: st.u64 [%r1+16], %rd11;
48-
; CHECK-NEXT: st.u64 [%r1+24], %rd12;
40+
; CHECK-NEXT: st.u64 [%r1], %rd5;
41+
; CHECK-NEXT: st.u64 [%r1+8], %rd6;
42+
; CHECK-NEXT: st.u64 [%r1+16], %rd7;
43+
; CHECK-NEXT: st.u64 [%r1+24], %rd8;
4944
; CHECK-NEXT: ret;
5045
%call = tail call fastcc [4 x i64] @callee_St8x4(ptr noundef nonnull byval(%struct.St8x4) align 8 %in) #2
5146
%.fca.0.extract = extractvalue [4 x i64] %call, 0

0 commit comments

Comments
 (0)