Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
Add getScrathMemoryPtr for SLM and GLM

Update
  • Loading branch information
ESI-SYD committed Jan 22, 2025
1 parent d048556 commit 852b848
Show file tree
Hide file tree
Showing 27 changed files with 277 additions and 157 deletions.
5 changes: 5 additions & 0 deletions include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_TARGETINFOBASE_H
#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_TARGETINFOBASE_H

#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "triton/Conversion/MLIRTypes.h"

namespace mlir::triton {
Expand Down Expand Up @@ -89,6 +90,10 @@ class TargetInfoBase {

virtual int getSharedAddressSpace() const = 0;

virtual Value getScrathMemoryPtr(::mlir::gpu::AddressSpace addressSpace,
Location loc, RewriterBase &rewriter,
Operation *op, Value allocOffset = {},
bool getstackptr = false) const = 0;
virtual bool supportVectorizedAtomics() const = 0;

virtual ~TargetInfoBase() {}
Expand Down
136 changes: 71 additions & 65 deletions include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -414,84 +414,90 @@ inline bool isKernel(FunctionOpInterface funcOp) {
return funcOp.getVisibility() == SymbolTable::Visibility::Public;
}

inline Value getStackPointer(RewriterBase &rewriter,
FunctionOpInterface funcOp) {
// See NOTE: [Additional Function Arguments]
if (!isKernel(funcOp)) {
return funcOp.getArgument(funcOp.getNumArguments() - 2);
inline Value getScrathMemoryPtr(::mlir::gpu::AddressSpace addressSpace,
Location loc, RewriterBase &rewriter,
Operation *op, Value allocOffset,
bool getstackptr) {
FunctionOpInterface funcOp = op->getParentOfType<FunctionOpInterface>();
switch (addressSpace) {
case ::mlir::gpu::AddressSpace::Workgroup: {
auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), 3);
auto mod = funcOp->getParentOfType<ModuleOp>();
Value stackPtr, offVal;
if (!isKernel(funcOp)) {
stackPtr = funcOp.getArgument(funcOp.getNumArguments() - 2);
} else {
auto globalBase =
dyn_cast<LLVM::GlobalOp>(mod.lookupSymbol("global_smem"));
assert(globalBase);
stackPtr =
rewriter.create<LLVM::AddressOfOp>(funcOp.getLoc(), globalBase);
}
if (getstackptr) {
return stackPtr;
}
assert(op->hasAttr("allocation.offset"));
size_t offset = cast<IntegerAttr>(op->getAttr("allocation.offset"))
.getValue()
.getZExtValue();
offVal = i32_val(offset);
return gep(ptrTy, i8_ty, stackPtr, offVal);
break;
}
case ::mlir::gpu::AddressSpace::Global: {
if (!isKernel(funcOp)) {
// Base for this function
auto gmemBase = funcOp.getArgument(funcOp.getNumArguments() - 1);
if (!allocOffset) {
return gmemBase;
}

auto mod = funcOp->getParentOfType<ModuleOp>();
auto globalBase = dyn_cast<LLVM::GlobalOp>(mod.lookupSymbol("global_smem"));
assert(globalBase);
return rewriter.create<LLVM::AddressOfOp>(funcOp.getLoc(), globalBase);
}
auto ptrTy = mlir::LLVM::LLVMPointerType::get(rewriter.getContext(), 1);
return gep(ptrTy, i8_ty, gmemBase, allocOffset);
}

inline Value getGlobalScratchPtr(Location loc, RewriterBase &rewriter,
FunctionOpInterface funcOp,
Value allocOffset = {}) {
// See NOTE: [Additional Function Arguments]
if (!isKernel(funcOp)) {
// Base for this function
// Base for entire kernel
auto gmemBase = funcOp.getArgument(funcOp.getNumArguments() - 1);
if (!allocOffset) {

ModuleOp mod = funcOp.getOperation()->getParentOfType<ModuleOp>();
auto allocSizeAttr = mod.getOperation()->getAttrOfType<mlir::IntegerAttr>(
"ttg.global_scratch_memory_size");
if (!allocSizeAttr) {
return gmemBase;
}

auto ptrTy = mlir::LLVM::LLVMPointerType::get(rewriter.getContext(), 1);
return gep(ptrTy, i8_ty, gmemBase, allocOffset);
}
Value gridIdx[3];
Value gridDim[2];
for (int k = 0; k < 3; ++k) {
gridIdx[k] = rewriter.create<GetProgramIdOp>(loc, k);
}
for (int k = 0; k < 2; ++k) {
gridDim[k] = rewriter.create<GetNumProgramsOp>(loc, k);
}

// Base for entire kernel
auto gmemBase = funcOp.getArgument(funcOp.getNumArguments() - 1);
Value linearId = gridIdx[2];
for (int k = 0; k < 2; ++k) {
linearId = add(gridIdx[1 - k], mul(linearId, gridDim[1 - k]));
}

ModuleOp mod = funcOp.getOperation()->getParentOfType<ModuleOp>();
auto allocSizeAttr = mod.getOperation()->getAttrOfType<mlir::IntegerAttr>(
"ttg.global_scratch_memory_size");
if (!allocSizeAttr) {
return gmemBase;
}
auto allocSize = allocSizeAttr.getValue().getZExtValue();

Value gridIdx[3];
Value gridDim[2];
for (int k = 0; k < 3; ++k) {
gridIdx[k] = rewriter.create<GetProgramIdOp>(loc, k);
}
for (int k = 0; k < 2; ++k) {
gridDim[k] = rewriter.create<GetNumProgramsOp>(loc, k);
}
Value offset = mul(linearId, i32_val(allocSize));
if (allocOffset) {
offset = add(offset, allocOffset);
}

Value linearId = gridIdx[2];
for (int k = 0; k < 2; ++k) {
linearId = add(gridIdx[1 - k], mul(linearId, gridDim[1 - k]));
auto *ctx = rewriter.getContext();
auto res =
gep(mlir::LLVM::LLVMPointerType::get(ctx, 1), i8_ty, gmemBase, offset);
return res;
break;
}
default: {
llvm_unreachable("not avaiable addspace type in getScrathMemoryPtr");
break;
}

auto allocSize = allocSizeAttr.getValue().getZExtValue();

Value offset = mul(linearId, i32_val(allocSize));
if (allocOffset) {
offset = add(offset, allocOffset);
}

auto *ctx = rewriter.getContext();
auto res =
gep(mlir::LLVM::LLVMPointerType::get(ctx, 1), i8_ty, gmemBase, offset);
return res;
}

inline Value getSharedMemoryBase(Location loc, RewriterBase &rewriter,
const TargetInfoBase &target, Operation *op) {
auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(),
target.getSharedAddressSpace());
FunctionOpInterface func =
op->template getParentOfType<FunctionOpInterface>();
assert(op->hasAttr("allocation.offset"));
size_t offset = cast<IntegerAttr>(op->getAttr("allocation.offset"))
.getValue()
.getZExtValue();
Value offVal = i32_val(offset);
Value base = gep(ptrTy, i8_ty, LLVM::getStackPointer(rewriter, func), offVal);
return base;
}

// -----------------------------------------------------------------------
Expand Down
12 changes: 8 additions & 4 deletions lib/Conversion/TritonGPUToLLVM/ControlFlowOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,14 @@ struct CallOpConversion : public ConvertOpToLLVMPattern<triton::CallOp> {
callOp.getLoc(), /*opOperands=*/callOp->getOperands(),
adaptor.getOperands(), rewriter);
if (!caller->hasAttr("allocation.offset")) {
auto base = LLVM::getStackPointer(rewriter, caller);
auto base = targetInfo.getScrathMemoryPtr(
::mlir::gpu::AddressSpace::Workgroup, loc, rewriter, callOp, {},
/*getstackptr=*/true);
promotedOperands.push_back(base);

} else {
auto base = LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, callOp);
auto base = targetInfo.getScrathMemoryPtr(
::mlir::gpu::AddressSpace::Workgroup, loc, rewriter, callOp);
promotedOperands.push_back(base);
}

Expand All @@ -98,8 +102,8 @@ struct CallOpConversion : public ConvertOpToLLVMPattern<triton::CallOp> {
opOffsetVal = i32_val(opOffset);
}

promotedOperands.push_back(
LLVM::getGlobalScratchPtr(loc, rewriter, caller, opOffsetVal));
promotedOperands.push_back(targetInfo.getScrathMemoryPtr(
::mlir::gpu::AddressSpace::Global, loc, rewriter, callOp, opOffsetVal));
return promotedOperands;
}

Expand Down
10 changes: 6 additions & 4 deletions lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,9 @@ struct ConvertLayoutOpConversion
return success();
}

Value smemBase =
LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op.getOperation());
Value smemBase = targetInfo.getScrathMemoryPtr(
::mlir::gpu::AddressSpace::Workgroup, loc, rewriter, op.getOperation());

auto shape = dstTy.getShape();
unsigned rank = dstTy.getRank();
SmallVector<unsigned> numReplicates(rank);
Expand Down Expand Up @@ -502,8 +503,9 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
SmallVector<SmallVector<int>> outRegsForIter =
collectRegsForIter(ctx, shmemLoadLayout);

Value smemBase =
LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op.getOperation());
Value smemBase = targetInfo.getScrathMemoryPtr(
::mlir::gpu::AddressSpace::Workgroup, loc, rewriter, op.getOperation());

auto sharedPtrTy = smemBase.getType();
Type elemTy = inVals[0].getType();
auto outSize = shmemLoadLayout.getInDimSize(kRegister);
Expand Down
3 changes: 2 additions & 1 deletion lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ void GatherOpConversion::emitGatherInShared(
assert(srcValues.size() == srcIndices.size());

// Get the base pointer to the scratch memory.
Value smemBase = LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op);
Value smemBase = targetInfo.getScrathMemoryPtr(
::mlir::gpu::AddressSpace::Workgroup, loc, rewriter, op);

// For each src element owned by the thread, index into the scratch memory and
// then store it.
Expand Down
5 changes: 3 additions & 2 deletions lib/Conversion/TritonGPUToLLVM/HistogramOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,9 @@ struct HistogramOpConversion
// TODO: we could skip this for cases with num_warps=1 as long as we can
// generate the right layout. Currently the warp level histogram generates
// data in the default blocked layout.
Value baseSharedMemPtr =
LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op.getOperation());

Value baseSharedMemPtr = targetInfo.getScrathMemoryPtr(
::mlir::gpu::AddressSpace::Workgroup, loc, rewriter, op.getOperation());
auto dstType = op.getType();
Attribute dstEncoding = dstType.getEncoding();
auto indices = emitIndices(op.getLoc(), rewriter, targetInfo, dstEncoding,
Expand Down
18 changes: 12 additions & 6 deletions lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@ void lowerDistributedToShared(
struct GlobalScratchAllocOpConversion
: public ConvertOpToLLVMPattern<triton::gpu::GlobalScratchAllocOp> {
GlobalScratchAllocOpConversion(LLVMTypeConverter &converter,
const TargetInfoBase &targetInfo,
PatternBenefit benefit)
: ConvertOpToLLVMPattern(converter, benefit) {}
: ConvertOpToLLVMPattern(converter, benefit), targetInfo(targetInfo) {}

LogicalResult
matchAndRewrite(triton::gpu::GlobalScratchAllocOp op, OpAdaptor adaptor,
Expand All @@ -51,11 +52,14 @@ struct GlobalScratchAllocOpConversion
return failure();
}
Value ptr =
LLVM::getGlobalScratchPtr(loc, rewriter, funcOp, i32_val(opOffset));

targetInfo.getScrathMemoryPtr(::mlir::gpu::AddressSpace::Global, loc,
rewriter, funcOp, i32_val(opOffset));
rewriter.replaceOp(op, ptr);
return success();
}

private:
const TargetInfoBase &targetInfo;
};

struct LocalAllocOpConversion
Expand All @@ -72,8 +76,9 @@ struct LocalAllocOpConversion
if (!op.isSharedMemoryAlloc())
return failure();
Location loc = op->getLoc();
Value smemBase =
LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op.getOperation());
Value smemBase = targetInfo.getScrathMemoryPtr(
::mlir::gpu::AddressSpace::Workgroup, loc, rewriter, op.getOperation());

auto resultTy = cast<MemDescType>(op.getType());
auto typeConverter = getTypeConverter();
auto sharedLayout =
Expand Down Expand Up @@ -198,7 +203,8 @@ void mlir::triton::populateMemoryOpToLLVMPatterns(
LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo,
RewritePatternSet &patterns, PatternBenefit benefit,
std::optional<BackendCallbacks> backendCallbacks) {
patterns.add<GlobalScratchAllocOpConversion>(typeConverter, benefit);
patterns.add<GlobalScratchAllocOpConversion>(typeConverter, targetInfo,
benefit);
patterns.add<LocalAllocOpConversion>(typeConverter, targetInfo, benefit);
patterns.add<LocalDeallocOpConversion>(typeConverter, benefit);
patterns.add<LocalLoadOpConversion>(typeConverter, targetInfo, benefit);
Expand Down
4 changes: 2 additions & 2 deletions lib/Conversion/TritonGPUToLLVM/ReduceScanCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,8 @@ class ConvertTritonGPUReduceScanToLLVMPattern
});
// Assign base index to each operand in their order in indices
std::map<unsigned, Value> indexToBase;
auto basePtr =
LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op.getOperation());
auto basePtr = targetInfo.getScrathMemoryPtr(
::mlir::gpu::AddressSpace::Workgroup, loc, rewriter, op.getOperation());
indexToBase[indices[0]] = basePtr;
for (unsigned i = 1; i < op.getNumOperands(); ++i) {
indexToBase[indices[i]] =
Expand Down
2 changes: 0 additions & 2 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6816,8 +6816,6 @@ def gather_test_kernel(src_ptr, idx_ptr, out_ptr, axis: tl.constexpr, src_dim0:
([128, 64], [128, 128], 1),
])
def test_gather(src_shape, indices_shape, axis, device):
if is_xpu():
pytest.skip("Fail on XPU")

def triton_gather(src: torch.Tensor, axis: int, indices: torch.Tensor):
output = torch.empty(indices.shape, dtype=src.dtype, device=src.device)
Expand Down
17 changes: 10 additions & 7 deletions third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ using namespace mlir;
using namespace mlir::triton::gpu;

using ::mlir::LLVM::delinearize;
using ::mlir::LLVM::getSharedMemoryBase;
using ::mlir::LLVM::AMD::getVectorSize;
using ::mlir::LLVM::AMD::llLoad;
using ::mlir::LLVM::AMD::llStore;
Expand Down Expand Up @@ -890,8 +889,9 @@ struct AtomicCASOpConversion
if (atomicNeedsSharedMemory(op.getResult())) {
// Extract the new_loaded value from the pair.
Value newLoaded = extract_val(valueElemTy, cmpxchg, 0);
Value atomPtr =
getSharedMemoryBase(loc, rewriter, targetInfo, op.getOperation());
Value atomPtr = targetInfo.getScrathMemoryPtr(
::mlir::gpu::AddressSpace::Workgroup, loc, rewriter,
op.getOperation());
store(newLoaded, atomPtr);
}

Expand All @@ -910,7 +910,8 @@ struct AtomicCASOpConversion
BuilderMemfenceLDS.launch(rewriter, loc, void_ty(ctx));
barrier();
Value atomPtr =
getSharedMemoryBase(loc, rewriter, targetInfo, op.getOperation());
targetInfo.getScrathMemoryPtr(::mlir::gpu::AddressSpace::Workgroup,
loc, rewriter, op.getOperation());
Value ret = load(valueElemTy, atomPtr);
rewriter.replaceOp(op, {ret});
}
Expand Down Expand Up @@ -1132,8 +1133,9 @@ struct AtomicRMWOpConversion

if (!tensorTy) {
if (atomicNeedsSharedMemory(op.getResult())) {
Value atomPtr =
getSharedMemoryBase(loc, rewriter, targetInfo, op.getOperation());
Value atomPtr = targetInfo.getScrathMemoryPtr(
::mlir::gpu::AddressSpace::Workgroup, loc, rewriter,
op.getOperation());
store(atom, atomPtr);
}
}
Expand Down Expand Up @@ -1168,7 +1170,8 @@ struct AtomicRMWOpConversion
return success();
}
Value atomPtr =
getSharedMemoryBase(loc, rewriter, targetInfo, op.getOperation());
targetInfo.getScrathMemoryPtr(::mlir::gpu::AddressSpace::Workgroup,
loc, rewriter, op.getOperation());
barrier();
Value ret = load(valueElemTy, atomPtr);
rewriter.replaceOp(op, {ret});
Expand Down
8 changes: 8 additions & 0 deletions third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,14 @@ void TargetInfo::assertFail(RewriterBase &rewriter, Location loc,

int TargetInfo::getSharedAddressSpace() const { return 3; }

Value TargetInfo::getScrathMemoryPtr(::mlir::gpu::AddressSpace addressSpace,
Location loc, RewriterBase &rewriter,
Operation *op, Value allocOffset,
bool getstackptr) const {
return LLVM::getScrathMemoryPtr(addressSpace, loc, rewriter, op, allocOffset,
getstackptr);
}

bool TargetInfo::supportVectorizedAtomics() const {
// Note: not currently tested or used, but AMD generally supports vectorized
// atomics.
Expand Down
5 changes: 5 additions & 0 deletions third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@ class TargetInfo : public mlir::triton::TargetInfoBase {
StringRef file, StringRef func, int line) const override;
int getSharedAddressSpace() const override;

Value getScrathMemoryPtr(::mlir::gpu::AddressSpace addressSpace, Location loc,
RewriterBase &rewriter, Operation *op,
Value allocOffset = {},
bool getstackptr = false) const override;

bool supportVectorizedAtomics() const override;

private:
Expand Down
Loading

0 comments on commit 852b848

Please sign in to comment.