Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir] Fix conflict of user defined reserved functions with internal prototypes #123378

Merged
merged 9 commits into from
Jan 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,10 @@ namespace LLVM {
/// Generate IR that prints the given string to stdout.
/// If a custom runtime function is defined via `runtimeFunctionName`, it must
/// have the signature void(char const*). The default function is `printString`.
void createPrintStrCall(OpBuilder &builder, Location loc, ModuleOp moduleOp,
StringRef symbolName, StringRef string,
const LLVMTypeConverter &typeConverter,
bool addNewline = true,
std::optional<StringRef> runtimeFunctionName = {});
LogicalResult createPrintStrCall(
OpBuilder &builder, Location loc, ModuleOp moduleOp, StringRef symbolName,
StringRef string, const LLVMTypeConverter &typeConverter,
bool addNewline = true, std::optional<StringRef> runtimeFunctionName = {});
} // namespace LLVM

} // namespace mlir
Expand Down
58 changes: 31 additions & 27 deletions mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

#include "mlir/IR/Operation.h"
#include "mlir/Support/LLVM.h"
#include <optional>

namespace mlir {
class Location;
Expand All @@ -29,42 +28,47 @@ class ValueRange;
namespace LLVM {
class LLVMFuncOp;

/// Helper functions to lookup or create the declaration for commonly used
/// Helper functions to look up or create the declaration for commonly used
/// external C function calls. The list of functions provided here must be
/// implemented separately (e.g. as part of a support runtime library or as part
/// of the libc).
LLVM::LLVMFuncOp lookupOrCreatePrintI64Fn(Operation *moduleOp);
LLVM::LLVMFuncOp lookupOrCreatePrintU64Fn(Operation *moduleOp);
LLVM::LLVMFuncOp lookupOrCreatePrintF16Fn(Operation *moduleOp);
LLVM::LLVMFuncOp lookupOrCreatePrintBF16Fn(Operation *moduleOp);
LLVM::LLVMFuncOp lookupOrCreatePrintF32Fn(Operation *moduleOp);
LLVM::LLVMFuncOp lookupOrCreatePrintF64Fn(Operation *moduleOp);
/// Failure if an unexpected version of function is found.
FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintI64Fn(Operation *moduleOp);
FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintU64Fn(Operation *moduleOp);
FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintF16Fn(Operation *moduleOp);
FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintBF16Fn(Operation *moduleOp);
FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintF32Fn(Operation *moduleOp);
FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintF64Fn(Operation *moduleOp);
/// Declares a function to print a C-string.
/// If a custom runtime function is defined via `runtimeFunctionName`, it must
/// have the signature void(char const*). The default function is `printString`.
LLVM::LLVMFuncOp
FailureOr<LLVM::LLVMFuncOp>
lookupOrCreatePrintStringFn(Operation *moduleOp,
std::optional<StringRef> runtimeFunctionName = {});
LLVM::LLVMFuncOp lookupOrCreatePrintOpenFn(Operation *moduleOp);
LLVM::LLVMFuncOp lookupOrCreatePrintCloseFn(Operation *moduleOp);
LLVM::LLVMFuncOp lookupOrCreatePrintCommaFn(Operation *moduleOp);
LLVM::LLVMFuncOp lookupOrCreatePrintNewlineFn(Operation *moduleOp);
LLVM::LLVMFuncOp lookupOrCreateMallocFn(Operation *moduleOp, Type indexType);
LLVM::LLVMFuncOp lookupOrCreateAlignedAllocFn(Operation *moduleOp,
Type indexType);
LLVM::LLVMFuncOp lookupOrCreateFreeFn(Operation *moduleOp);
LLVM::LLVMFuncOp lookupOrCreateGenericAllocFn(Operation *moduleOp,
Type indexType);
LLVM::LLVMFuncOp lookupOrCreateGenericAlignedAllocFn(Operation *moduleOp,
Type indexType);
LLVM::LLVMFuncOp lookupOrCreateGenericFreeFn(Operation *moduleOp);
LLVM::LLVMFuncOp lookupOrCreateMemRefCopyFn(Operation *moduleOp, Type indexType,
Type unrankedDescriptorType);
FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintOpenFn(Operation *moduleOp);
FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintCloseFn(Operation *moduleOp);
FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintCommaFn(Operation *moduleOp);
FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintNewlineFn(Operation *moduleOp);
FailureOr<LLVM::LLVMFuncOp> lookupOrCreateMallocFn(Operation *moduleOp,
Type indexType);
FailureOr<LLVM::LLVMFuncOp> lookupOrCreateAlignedAllocFn(Operation *moduleOp,
Type indexType);
FailureOr<LLVM::LLVMFuncOp> lookupOrCreateFreeFn(Operation *moduleOp);
FailureOr<LLVM::LLVMFuncOp> lookupOrCreateGenericAllocFn(Operation *moduleOp,
Type indexType);
FailureOr<LLVM::LLVMFuncOp>
lookupOrCreateGenericAlignedAllocFn(Operation *moduleOp, Type indexType);
FailureOr<LLVM::LLVMFuncOp> lookupOrCreateGenericFreeFn(Operation *moduleOp);
FailureOr<LLVM::LLVMFuncOp>
lookupOrCreateMemRefCopyFn(Operation *moduleOp, Type indexType,
Type unrankedDescriptorType);

/// Create a FuncOp with signature `resultType`(`paramTypes`)` and name `name`.
LLVM::LLVMFuncOp lookupOrCreateFn(Operation *moduleOp, StringRef name,
ArrayRef<Type> paramTypes = {},
Type resultType = {}, bool isVarArg = false);
/// Return a failure if the FuncOp found has unexpected signature.
FailureOr<LLVM::LLVMFuncOp>
lookupOrCreateFn(Operation *moduleOp, StringRef name,
ArrayRef<Type> paramTypes = {}, Type resultType = {},
bool isVarArg = false, bool isReserved = false);

} // namespace LLVM
} // namespace mlir
Expand Down
8 changes: 6 additions & 2 deletions mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -396,8 +396,10 @@ class CoroBeginOpConversion : public AsyncOpConversionPattern<CoroBeginOp> {
// Allocate memory for the coroutine frame.
auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn(
op->getParentOfType<ModuleOp>(), rewriter.getI64Type());
if (failed(allocFuncOp))
return failure();
auto coroAlloc = rewriter.create<LLVM::CallOp>(
loc, allocFuncOp, ValueRange{coroAlign, coroSize});
loc, allocFuncOp.value(), ValueRange{coroAlign, coroSize});

// Begin a coroutine: @llvm.coro.begin.
auto coroId = CoroBeginOpAdaptor(adaptor.getOperands()).getId();
Expand Down Expand Up @@ -431,7 +433,9 @@ class CoroFreeOpConversion : public AsyncOpConversionPattern<CoroFreeOp> {
// Free the memory.
auto freeFuncOp =
LLVM::lookupOrCreateFreeFn(op->getParentOfType<ModuleOp>());
rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFuncOp,
if (failed(freeFuncOp))
return failure();
rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFuncOp.value(),
ValueRange(coroMem.getResult()));

return success();
Expand Down
10 changes: 7 additions & 3 deletions mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,13 @@ struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {

// Failed block: Generate IR to print the message and call `abort`.
Block *failureBlock = rewriter.createBlock(opBlock->getParent());
LLVM::createPrintStrCall(rewriter, loc, module, "assert_msg", op.getMsg(),
*getTypeConverter(), /*addNewLine=*/false,
/*runtimeFunctionName=*/"puts");
auto createResult = LLVM::createPrintStrCall(
rewriter, loc, module, "assert_msg", op.getMsg(), *getTypeConverter(),
/*addNewLine=*/false,
/*runtimeFunctionName=*/"puts");
if (createResult.failed())
return failure();

if (abortOnFailedAssert) {
// Insert the `abort` declaration if necessary.
auto abortFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("abort");
Expand Down
17 changes: 12 additions & 5 deletions mlir/lib/Conversion/LLVMCommon/Pattern.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -276,11 +276,17 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(

// Find the malloc and free, or declare them if necessary.
auto module = builder.getInsertionPoint()->getParentOfType<ModuleOp>();
LLVM::LLVMFuncOp freeFunc, mallocFunc;
if (toDynamic)
FailureOr<LLVM::LLVMFuncOp> freeFunc, mallocFunc;
if (toDynamic) {
mallocFunc = LLVM::lookupOrCreateMallocFn(module, indexType);
if (!toDynamic)
if (failed(mallocFunc))
return failure();
}
if (!toDynamic) {
freeFunc = LLVM::lookupOrCreateFreeFn(module);
if (failed(freeFunc))
return failure();
}

unsigned unrankedMemrefPos = 0;
for (unsigned i = 0, e = operands.size(); i < e; ++i) {
Expand All @@ -293,7 +299,8 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
// Allocate memory, copy, and free the source if necessary.
Value memory =
toDynamic
? builder.create<LLVM::CallOp>(loc, mallocFunc, allocationSize)
? builder
.create<LLVM::CallOp>(loc, mallocFunc.value(), allocationSize)
.getResult()
: builder.create<LLVM::AllocaOp>(loc, getVoidPtrType(),
IntegerType::get(getContext(), 8),
Expand All @@ -302,7 +309,7 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
Value source = desc.memRefDescPtr(builder, loc);
builder.create<LLVM::MemcpyOp>(loc, memory, source, allocationSize, false);
if (!toDynamic)
builder.create<LLVM::CallOp>(loc, freeFunc, source);
builder.create<LLVM::CallOp>(loc, freeFunc.value(), source);

// Create a new descriptor. The same descriptor can be returned multiple
// times, attempting to modify its pointer can lead to memory leaks
Expand Down
11 changes: 7 additions & 4 deletions mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ static std::string ensureSymbolNameIsUnique(ModuleOp moduleOp,
return uniqueName;
}

void mlir::LLVM::createPrintStrCall(
LogicalResult mlir::LLVM::createPrintStrCall(
OpBuilder &builder, Location loc, ModuleOp moduleOp, StringRef symbolName,
StringRef string, const LLVMTypeConverter &typeConverter, bool addNewline,
std::optional<StringRef> runtimeFunctionName) {
Expand Down Expand Up @@ -59,8 +59,11 @@ void mlir::LLVM::createPrintStrCall(
SmallVector<LLVM::GEPArg> indices(1, 0);
Value gep =
builder.create<LLVM::GEPOp>(loc, ptrTy, arrayTy, msgAddr, indices);
Operation *printer =
FailureOr<LLVM::LLVMFuncOp> printer =
LLVM::lookupOrCreatePrintStringFn(moduleOp, runtimeFunctionName);
builder.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(printer),
gep);
if (failed(printer))
return failure();
builder.create<LLVM::CallOp>(loc, TypeRange(),
SymbolRefAttr::get(printer.value()), gep);
return success();
}
26 changes: 15 additions & 11 deletions mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,19 @@

using namespace mlir;

namespace {
LLVM::LLVMFuncOp getNotalignedAllocFn(const LLVMTypeConverter *typeConverter,
Operation *module, Type indexType) {
static FailureOr<LLVM::LLVMFuncOp>
getNotalignedAllocFn(const LLVMTypeConverter *typeConverter, Operation *module,
Type indexType) {
bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
if (useGenericFn)
return LLVM::lookupOrCreateGenericAllocFn(module, indexType);

return LLVM::lookupOrCreateMallocFn(module, indexType);
}

LLVM::LLVMFuncOp getAlignedAllocFn(const LLVMTypeConverter *typeConverter,
Operation *module, Type indexType) {
static FailureOr<LLVM::LLVMFuncOp>
getAlignedAllocFn(const LLVMTypeConverter *typeConverter, Operation *module,
Type indexType) {
bool useGenericFn = typeConverter->getOptions().useGenericFunctions;

if (useGenericFn)
Expand All @@ -34,8 +35,6 @@ LLVM::LLVMFuncOp getAlignedAllocFn(const LLVMTypeConverter *typeConverter,
return LLVM::lookupOrCreateAlignedAllocFn(module, indexType);
}

} // end namespace

Value AllocationOpLLVMLowering::createAligned(
ConversionPatternRewriter &rewriter, Location loc, Value input,
Value alignment) {
Expand Down Expand Up @@ -80,10 +79,13 @@ std::tuple<Value, Value> AllocationOpLLVMLowering::allocateBufferManuallyAlign(
<< " to integer address space "
"failed. Consider adding memory space conversions.";
}
LLVM::LLVMFuncOp allocFuncOp = getNotalignedAllocFn(
FailureOr<LLVM::LLVMFuncOp> allocFuncOp = getNotalignedAllocFn(
getTypeConverter(), op->getParentWithTrait<OpTrait::SymbolTable>(),
getIndexType());
auto results = rewriter.create<LLVM::CallOp>(loc, allocFuncOp, sizeBytes);
if (failed(allocFuncOp))
return std::make_tuple(Value(), Value());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this result in a crash later on? If so, this would also require a FailureOr wrapping.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The returned pair of value is populated and finally verified here: AllocLikeConvertion.cpp

  // Allocate the underlying buffer.
  auto [allocatedPtr, alignedPtr] =
      this->allocateBuffer(rewriter, loc, size, op);

  if (!allocatedPtr || !alignedPtr)
    return rewriter.notifyMatchFailure(loc,
                                       "underlying buffer allocation failed");

Here empty value is assigned as the invalid state and properly handled. I prefer to leave it as is, and a separate PR might be more appropriate to modernize the error handling style if we do have a preference.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fine for me 🙂

auto results =
rewriter.create<LLVM::CallOp>(loc, allocFuncOp.value(), sizeBytes);

Value allocatedPtr =
castAllocFuncResult(rewriter, loc, results.getResult(), memRefType,
Expand Down Expand Up @@ -146,11 +148,13 @@ Value AllocationOpLLVMLowering::allocateBufferAutoAlign(
sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment);

Type elementPtrType = this->getElementPtrType(memRefType);
LLVM::LLVMFuncOp allocFuncOp = getAlignedAllocFn(
FailureOr<LLVM::LLVMFuncOp> allocFuncOp = getAlignedAllocFn(
getTypeConverter(), op->getParentWithTrait<OpTrait::SymbolTable>(),
getIndexType());
if (failed(allocFuncOp))
return Value();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As above

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Handled at MemRefToLLVM.cpp

    Value ptr = allocateBufferAutoAlign(
        rewriter, loc, sizeBytes, op, &defaultLayout,
        alignedAllocationGetAlignment(rewriter, loc, cast<memref::AllocOp>(op),
                                      &defaultLayout));
    if (!ptr)
      return std::make_tuple(Value(), Value());
    return std::make_tuple(ptr, ptr);

in the same style.

auto results = rewriter.create<LLVM::CallOp>(
loc, allocFuncOp, ValueRange({allocAlignment, sizeBytes}));
loc, allocFuncOp.value(), ValueRange({allocAlignment, sizeBytes}));

return castAllocFuncResult(rewriter, loc, results.getResult(), memRefType,
elementPtrType, *getTypeConverter());
Expand Down
17 changes: 11 additions & 6 deletions mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@ using namespace mlir;

namespace {

bool isStaticStrideOrOffset(int64_t strideOrOffset) {
static bool isStaticStrideOrOffset(int64_t strideOrOffset) {
return !ShapedType::isDynamic(strideOrOffset);
}

LLVM::LLVMFuncOp getFreeFn(const LLVMTypeConverter *typeConverter,
ModuleOp module) {
static FailureOr<LLVM::LLVMFuncOp>
getFreeFn(const LLVMTypeConverter *typeConverter, ModuleOp module) {
bool useGenericFn = typeConverter->getOptions().useGenericFunctions;

if (useGenericFn)
Expand Down Expand Up @@ -220,8 +220,10 @@ struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Insert the `free` declaration if it is not already present.
LLVM::LLVMFuncOp freeFunc =
FailureOr<LLVM::LLVMFuncOp> freeFunc =
getFreeFn(getTypeConverter(), op->getParentOfType<ModuleOp>());
if (failed(freeFunc))
return failure();
Value allocatedPtr;
if (auto unrankedTy =
llvm::dyn_cast<UnrankedMemRefType>(op.getMemref().getType())) {
Expand All @@ -236,7 +238,8 @@ struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
allocatedPtr = MemRefDescriptor(adaptor.getMemref())
.allocatedPtr(rewriter, op.getLoc());
}
rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFunc, allocatedPtr);
rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFunc.value(),
allocatedPtr);
return success();
}
};
Expand Down Expand Up @@ -838,7 +841,9 @@ struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
auto elemSize = getSizeInBytes(loc, srcType.getElementType(), rewriter);
auto copyFn = LLVM::lookupOrCreateMemRefCopyFn(
op->getParentOfType<ModuleOp>(), getIndexType(), sourcePtr.getType());
rewriter.create<LLVM::CallOp>(loc, copyFn,
if (failed(copyFn))
return failure();
rewriter.create<LLVM::CallOp>(loc, copyFn.value(),
ValueRange{elemSize, sourcePtr, targetPtr});

// Restore stack used for descriptors
Expand Down
23 changes: 16 additions & 7 deletions mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1546,11 +1546,15 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {

auto punct = printOp.getPunctuation();
if (auto stringLiteral = printOp.getStringLiteral()) {
LLVM::createPrintStrCall(rewriter, loc, parent, "vector_print_str",
*stringLiteral, *getTypeConverter(),
/*addNewline=*/false);
auto createResult =
LLVM::createPrintStrCall(rewriter, loc, parent, "vector_print_str",
*stringLiteral, *getTypeConverter(),
/*addNewline=*/false);
if (createResult.failed())
return failure();

} else if (punct != PrintPunctuation::NoPunctuation) {
emitCall(rewriter, printOp->getLoc(), [&] {
FailureOr<LLVM::LLVMFuncOp> op = [&]() {
switch (punct) {
case PrintPunctuation::Close:
return LLVM::lookupOrCreatePrintCloseFn(parent);
Expand All @@ -1563,7 +1567,10 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
default:
llvm_unreachable("unexpected punctuation");
}
}());
}();
if (failed(op))
return failure();
emitCall(rewriter, printOp->getLoc(), op.value());
}

rewriter.eraseOp(printOp);
Expand All @@ -1588,7 +1595,7 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {

// Make sure element type has runtime support.
PrintConversion conversion = PrintConversion::None;
Operation *printer;
FailureOr<Operation *> printer;
if (printType.isF32()) {
printer = LLVM::lookupOrCreatePrintF32Fn(parent);
} else if (printType.isF64()) {
Expand Down Expand Up @@ -1631,6 +1638,8 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
} else {
return failure();
}
if (failed(printer))
return failure();

switch (conversion) {
case PrintConversion::ZeroExt64:
Expand All @@ -1648,7 +1657,7 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
case PrintConversion::None:
break;
}
emitCall(rewriter, loc, printer, value);
emitCall(rewriter, loc, printer.value(), value);
return success();
}

Expand Down
Loading
Loading