Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NVPTX][InferAS] assume alloca instructions are in local AS #121710

Merged
merged 4 commits into from
Feb 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 21 additions & 22 deletions llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/FormatVariadic.h"
#include <optional>

using namespace llvm;

Expand Down Expand Up @@ -342,30 +343,28 @@ bool NVPTXDAGToDAGISel::tryEXTRACT_VECTOR_ELEMENT(SDNode *N) {
return true;
}

static unsigned int getCodeAddrSpace(MemSDNode *N) {
const Value *Src = N->getMemOperand()->getValue();

if (!Src)
static std::optional<unsigned> convertAS(unsigned AS) {
switch (AS) {
case llvm::ADDRESS_SPACE_LOCAL:
return NVPTX::AddressSpace::Local;
case llvm::ADDRESS_SPACE_GLOBAL:
return NVPTX::AddressSpace::Global;
case llvm::ADDRESS_SPACE_SHARED:
return NVPTX::AddressSpace::Shared;
case llvm::ADDRESS_SPACE_GENERIC:
return NVPTX::AddressSpace::Generic;

if (auto *PT = dyn_cast<PointerType>(Src->getType())) {
switch (PT->getAddressSpace()) {
case llvm::ADDRESS_SPACE_LOCAL:
return NVPTX::AddressSpace::Local;
case llvm::ADDRESS_SPACE_GLOBAL:
return NVPTX::AddressSpace::Global;
case llvm::ADDRESS_SPACE_SHARED:
return NVPTX::AddressSpace::Shared;
case llvm::ADDRESS_SPACE_GENERIC:
return NVPTX::AddressSpace::Generic;
case llvm::ADDRESS_SPACE_PARAM:
return NVPTX::AddressSpace::Param;
case llvm::ADDRESS_SPACE_CONST:
return NVPTX::AddressSpace::Const;
default: break;
}
case llvm::ADDRESS_SPACE_PARAM:
return NVPTX::AddressSpace::Param;
case llvm::ADDRESS_SPACE_CONST:
return NVPTX::AddressSpace::Const;
default:
return std::nullopt;
}
return NVPTX::AddressSpace::Generic;
}

static unsigned int getCodeAddrSpace(const MemSDNode *N) {
return convertAS(N->getMemOperand()->getAddrSpace())
.value_or(NVPTX::AddressSpace::Generic);
}

namespace {
Expand Down
22 changes: 18 additions & 4 deletions llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1405,6 +1405,19 @@ static bool shouldConvertToIndirectCall(const CallBase *CB,
return false;
}

static MachinePointerInfo refinePtrAS(SDValue &Ptr, SelectionDAG &DAG,
const DataLayout &DL,
const TargetLowering &TL) {
if (Ptr->getOpcode() == ISD::FrameIndex) {
auto Ty = TL.getPointerTy(DL, ADDRESS_SPACE_LOCAL);
Ptr = DAG.getAddrSpaceCast(SDLoc(), Ty, Ptr, ADDRESS_SPACE_GENERIC,
ADDRESS_SPACE_LOCAL);

return MachinePointerInfo(ADDRESS_SPACE_LOCAL);
}
return MachinePointerInfo();
}

SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
SmallVectorImpl<SDValue> &InVals) const {

Expand Down Expand Up @@ -1564,11 +1577,12 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
}

if (IsByVal) {
auto PtrVT = getPointerTy(DL);
SDValue srcAddr = DAG.getNode(ISD::ADD, dl, PtrVT, StVal,
auto MPI = refinePtrAS(StVal, DAG, DL, *this);
const EVT PtrVT = StVal.getValueType();
SDValue SrcAddr = DAG.getNode(ISD::ADD, dl, PtrVT, StVal,
DAG.getConstant(CurOffset, dl, PtrVT));
StVal = DAG.getLoad(EltVT, dl, TempChain, srcAddr, MachinePointerInfo(),
PartAlign);

StVal = DAG.getLoad(EltVT, dl, TempChain, SrcAddr, MPI, PartAlign);
} else if (ExtendIntegerParam) {
assert(VTs.size() == 1 && "Scalar can't have multiple parts.");
// zext/sext to i32
Expand Down
8 changes: 8 additions & 0 deletions llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "llvm/IR/Value.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/NVPTXAddrSpace.h"
#include "llvm/Transforms/InstCombine/InstCombiner.h"
#include <optional>
using namespace llvm;
Expand Down Expand Up @@ -564,6 +565,13 @@ Value *NVPTXTTIImpl::rewriteIntrinsicWithAddressSpace(IntrinsicInst *II,
return nullptr;
}

unsigned NVPTXTTIImpl::getAssumedAddrSpace(const Value *V) const {
if (isa<AllocaInst>(V))
return ADDRESS_SPACE_LOCAL;

Comment on lines +569 to +571
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is worse than just changing the alloca addrspace

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While I tend to agree that converting all allocas to local addrspace would be a good idea, it's also a deep can of worms. As @Artem-B points out here #106127 (comment), optimizations may not handle allocas in specific AS as well, and supporting local allocas would likely require changes in the backend and datalayout.

I do hope to pursue this in the future, but in the meantime this change is both simple and correct. Even if we were to address the above it still seems preferable to handle generic allocas here in case InferAS happens to encounter one in the IR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only optimization using the non-0 addrspace really breaks folding compares to null out, which rarely matters if you can see the original alloca. We should also make the non-0 address space null handling datalayout configurable

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Artem-B were there other concerns you had with moving allocas to non-0 AS?

Even if we want to go this route I fear it will be pretty difficult and invasive. getAssumedAddrSpace is supposed to return the true AS of a given value in the AS 0. Handling this case there doesn't preclude moving allocas to non-0 AS in the future. I think it is good to add this check regardless to make InferAS make robust to different possible IRs.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ping @Artem-B / @arsenm for any further thoughts on the feasibility of switching allocas to the local addrspace and whether this change is acceptable as a complimentary/intermediate-term solution.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@arsenm would it be alright if I proceeded with this change despite your point about switching to specific-AS allocas? I agree this is worth investigating but it will take a fair bit of work I think, and this solution is small, shows immediate improvements, and won't hamper any efforts to handle specific-AS allocas in the future.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@arsenm ping on this discussion.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since I haven't heard back on this discussion for a while now, and this seems more like a possible future enhancement than a blocking issue, I'll plan to land this change at the end of the week unless I hear back otherwise. CC @arsenm

return -1;
}

void NVPTXTTIImpl::collectKernelLaunchBounds(
const Function &F,
SmallVectorImpl<std::pair<StringRef, int64_t>> &LB) const {
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ class NVPTXTTIImpl : public BasicTTIImplBase<NVPTXTTIImpl> {

Value *rewriteIntrinsicWithAddressSpace(IntrinsicInst *II, Value *OldV,
Value *NewV) const;
unsigned getAssumedAddrSpace(const Value *V) const;

void collectKernelLaunchBounds(
const Function &F,
Expand Down
20 changes: 11 additions & 9 deletions llvm/test/CodeGen/NVPTX/indirect_byval.ll
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,20 @@ define internal i32 @foo() {
; CHECK-NEXT: .reg .b64 %SPL;
; CHECK-NEXT: .reg .b16 %rs<2>;
; CHECK-NEXT: .reg .b32 %r<3>;
; CHECK-NEXT: .reg .b64 %rd<3>;
; CHECK-NEXT: .reg .b64 %rd<5>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0: // %entry
; CHECK-NEXT: mov.u64 %SPL, __local_depot0;
; CHECK-NEXT: cvta.local.u64 %SP, %SPL;
; CHECK-NEXT: ld.global.u64 %rd1, [ptr];
; CHECK-NEXT: ld.u8 %rs1, [%SP+1];
; CHECK-NEXT: add.u64 %rd2, %SP, 0;
; CHECK-NEXT: add.u64 %rd3, %SPL, 1;
; CHECK-NEXT: ld.local.u8 %rs1, [%rd3];
; CHECK-NEXT: add.u64 %rd4, %SP, 0;
; CHECK-NEXT: { // callseq 0, 0
; CHECK-NEXT: .param .align 1 .b8 param0[1];
; CHECK-NEXT: st.param.b8 [param0], %rs1;
; CHECK-NEXT: .param .b64 param1;
; CHECK-NEXT: st.param.b64 [param1], %rd2;
; CHECK-NEXT: st.param.b64 [param1], %rd4;
; CHECK-NEXT: .param .b32 retval0;
; CHECK-NEXT: prototype_0 : .callprototype (.param .b32 _) _ (.param .align 1 .b8 _[1], .param .b64 _);
; CHECK-NEXT: call (retval0),
Expand Down Expand Up @@ -59,19 +60,20 @@ define internal i32 @bar() {
; CHECK-NEXT: .reg .b64 %SP;
; CHECK-NEXT: .reg .b64 %SPL;
; CHECK-NEXT: .reg .b32 %r<3>;
; CHECK-NEXT: .reg .b64 %rd<4>;
; CHECK-NEXT: .reg .b64 %rd<6>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0: // %entry
; CHECK-NEXT: mov.u64 %SPL, __local_depot1;
; CHECK-NEXT: cvta.local.u64 %SP, %SPL;
; CHECK-NEXT: ld.global.u64 %rd1, [ptr];
; CHECK-NEXT: ld.u64 %rd2, [%SP+8];
; CHECK-NEXT: add.u64 %rd3, %SP, 0;
; CHECK-NEXT: add.u64 %rd3, %SPL, 8;
; CHECK-NEXT: ld.local.u64 %rd4, [%rd3];
; CHECK-NEXT: add.u64 %rd5, %SP, 0;
; CHECK-NEXT: { // callseq 1, 0
; CHECK-NEXT: .param .align 8 .b8 param0[8];
; CHECK-NEXT: st.param.b64 [param0], %rd2;
; CHECK-NEXT: st.param.b64 [param0], %rd4;
; CHECK-NEXT: .param .b64 param1;
; CHECK-NEXT: st.param.b64 [param1], %rd3;
; CHECK-NEXT: st.param.b64 [param1], %rd5;
; CHECK-NEXT: .param .b32 retval0;
; CHECK-NEXT: prototype_1 : .callprototype (.param .b32 _) _ (.param .align 8 .b8 _[8], .param .b64 _);
; CHECK-NEXT: call (retval0),
Expand Down
Loading