diff --git a/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h b/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h index c2742b6fc1d73..33402301115b7 100644 --- a/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h @@ -23,11 +23,10 @@ namespace LLVM { /// Generate IR that prints the given string to stdout. /// If a custom runtime function is defined via `runtimeFunctionName`, it must /// have the signature void(char const*). The default function is `printString`. -void createPrintStrCall(OpBuilder &builder, Location loc, ModuleOp moduleOp, - StringRef symbolName, StringRef string, - const LLVMTypeConverter &typeConverter, - bool addNewline = true, - std::optional runtimeFunctionName = {}); +LogicalResult createPrintStrCall( + OpBuilder &builder, Location loc, ModuleOp moduleOp, StringRef symbolName, + StringRef string, const LLVMTypeConverter &typeConverter, + bool addNewline = true, std::optional runtimeFunctionName = {}); } // namespace LLVM } // namespace mlir diff --git a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h index 852490cf7428f..05e9fe9d58859 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h +++ b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h @@ -16,7 +16,6 @@ #include "mlir/IR/Operation.h" #include "mlir/Support/LLVM.h" -#include namespace mlir { class Location; @@ -29,42 +28,47 @@ class ValueRange; namespace LLVM { class LLVMFuncOp; -/// Helper functions to lookup or create the declaration for commonly used +/// Helper functions to look up or create the declaration for commonly used /// external C function calls. The list of functions provided here must be /// implemented separately (e.g. as part of a support runtime library or as part /// of the libc). -LLVM::LLVMFuncOp lookupOrCreatePrintI64Fn(Operation *moduleOp); -LLVM::LLVMFuncOp lookupOrCreatePrintU64Fn(Operation *moduleOp); -LLVM::LLVMFuncOp lookupOrCreatePrintF16Fn(Operation *moduleOp); -LLVM::LLVMFuncOp lookupOrCreatePrintBF16Fn(Operation *moduleOp); -LLVM::LLVMFuncOp lookupOrCreatePrintF32Fn(Operation *moduleOp); -LLVM::LLVMFuncOp lookupOrCreatePrintF64Fn(Operation *moduleOp); +/// Failure if an unexpected version of function is found. +FailureOr lookupOrCreatePrintI64Fn(Operation *moduleOp); +FailureOr lookupOrCreatePrintU64Fn(Operation *moduleOp); +FailureOr lookupOrCreatePrintF16Fn(Operation *moduleOp); +FailureOr lookupOrCreatePrintBF16Fn(Operation *moduleOp); +FailureOr lookupOrCreatePrintF32Fn(Operation *moduleOp); +FailureOr lookupOrCreatePrintF64Fn(Operation *moduleOp); /// Declares a function to print a C-string. /// If a custom runtime function is defined via `runtimeFunctionName`, it must /// have the signature void(char const*). The default function is `printString`. -LLVM::LLVMFuncOp +FailureOr lookupOrCreatePrintStringFn(Operation *moduleOp, std::optional runtimeFunctionName = {}); -LLVM::LLVMFuncOp lookupOrCreatePrintOpenFn(Operation *moduleOp); -LLVM::LLVMFuncOp lookupOrCreatePrintCloseFn(Operation *moduleOp); -LLVM::LLVMFuncOp lookupOrCreatePrintCommaFn(Operation *moduleOp); -LLVM::LLVMFuncOp lookupOrCreatePrintNewlineFn(Operation *moduleOp); -LLVM::LLVMFuncOp lookupOrCreateMallocFn(Operation *moduleOp, Type indexType); -LLVM::LLVMFuncOp lookupOrCreateAlignedAllocFn(Operation *moduleOp, - Type indexType); -LLVM::LLVMFuncOp lookupOrCreateFreeFn(Operation *moduleOp); -LLVM::LLVMFuncOp lookupOrCreateGenericAllocFn(Operation *moduleOp, - Type indexType); -LLVM::LLVMFuncOp lookupOrCreateGenericAlignedAllocFn(Operation *moduleOp, - Type indexType); -LLVM::LLVMFuncOp lookupOrCreateGenericFreeFn(Operation *moduleOp); -LLVM::LLVMFuncOp lookupOrCreateMemRefCopyFn(Operation *moduleOp, Type indexType, - Type unrankedDescriptorType); +FailureOr lookupOrCreatePrintOpenFn(Operation *moduleOp); +FailureOr lookupOrCreatePrintCloseFn(Operation *moduleOp); +FailureOr lookupOrCreatePrintCommaFn(Operation *moduleOp); +FailureOr lookupOrCreatePrintNewlineFn(Operation *moduleOp); +FailureOr lookupOrCreateMallocFn(Operation *moduleOp, + Type indexType); +FailureOr lookupOrCreateAlignedAllocFn(Operation *moduleOp, + Type indexType); +FailureOr lookupOrCreateFreeFn(Operation *moduleOp); +FailureOr lookupOrCreateGenericAllocFn(Operation *moduleOp, + Type indexType); +FailureOr +lookupOrCreateGenericAlignedAllocFn(Operation *moduleOp, Type indexType); +FailureOr lookupOrCreateGenericFreeFn(Operation *moduleOp); +FailureOr +lookupOrCreateMemRefCopyFn(Operation *moduleOp, Type indexType, + Type unrankedDescriptorType); /// Create a FuncOp with signature `resultType`(`paramTypes`)` and name `name`. -LLVM::LLVMFuncOp lookupOrCreateFn(Operation *moduleOp, StringRef name, - ArrayRef paramTypes = {}, - Type resultType = {}, bool isVarArg = false); +/// Return a failure if the FuncOp found has unexpected signature. +FailureOr +lookupOrCreateFn(Operation *moduleOp, StringRef name, + ArrayRef paramTypes = {}, Type resultType = {}, + bool isVarArg = false, bool isReserved = false); } // namespace LLVM } // namespace mlir diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp index 9b5aeb3fef30b..47d4474a5c28d 100644 --- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp +++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp @@ -396,8 +396,10 @@ class CoroBeginOpConversion : public AsyncOpConversionPattern { // Allocate memory for the coroutine frame. auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn( op->getParentOfType(), rewriter.getI64Type()); + if (failed(allocFuncOp)) + return failure(); auto coroAlloc = rewriter.create( - loc, allocFuncOp, ValueRange{coroAlign, coroSize}); + loc, allocFuncOp.value(), ValueRange{coroAlign, coroSize}); // Begin a coroutine: @llvm.coro.begin. auto coroId = CoroBeginOpAdaptor(adaptor.getOperands()).getId(); @@ -431,7 +433,9 @@ class CoroFreeOpConversion : public AsyncOpConversionPattern { // Free the memory. auto freeFuncOp = LLVM::lookupOrCreateFreeFn(op->getParentOfType()); - rewriter.replaceOpWithNewOp(op, freeFuncOp, + if (failed(freeFuncOp)) + return failure(); + rewriter.replaceOpWithNewOp(op, freeFuncOp.value(), ValueRange(coroMem.getResult())); return success(); diff --git a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp index d0ffb94f3f96a..debfd003bd5b5 100644 --- a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp +++ b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp @@ -61,9 +61,13 @@ struct AssertOpLowering : public ConvertOpToLLVMPattern { // Failed block: Generate IR to print the message and call `abort`. Block *failureBlock = rewriter.createBlock(opBlock->getParent()); - LLVM::createPrintStrCall(rewriter, loc, module, "assert_msg", op.getMsg(), - *getTypeConverter(), /*addNewLine=*/false, - /*runtimeFunctionName=*/"puts"); + auto createResult = LLVM::createPrintStrCall( + rewriter, loc, module, "assert_msg", op.getMsg(), *getTypeConverter(), + /*addNewLine=*/false, + /*runtimeFunctionName=*/"puts"); + if (createResult.failed()) + return failure(); + if (abortOnFailedAssert) { // Insert the `abort` declaration if necessary. auto abortFunc = module.lookupSymbol("abort"); diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp index a47a2872ceb07..840bd3df61a06 100644 --- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp +++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp @@ -276,11 +276,17 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors( // Find the malloc and free, or declare them if necessary. auto module = builder.getInsertionPoint()->getParentOfType(); - LLVM::LLVMFuncOp freeFunc, mallocFunc; - if (toDynamic) + FailureOr freeFunc, mallocFunc; + if (toDynamic) { mallocFunc = LLVM::lookupOrCreateMallocFn(module, indexType); - if (!toDynamic) + if (failed(mallocFunc)) + return failure(); + } + if (!toDynamic) { freeFunc = LLVM::lookupOrCreateFreeFn(module); + if (failed(freeFunc)) + return failure(); + } unsigned unrankedMemrefPos = 0; for (unsigned i = 0, e = operands.size(); i < e; ++i) { @@ -293,7 +299,8 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors( // Allocate memory, copy, and free the source if necessary. Value memory = toDynamic - ? builder.create(loc, mallocFunc, allocationSize) + ? builder + .create(loc, mallocFunc.value(), allocationSize) .getResult() : builder.create(loc, getVoidPtrType(), IntegerType::get(getContext(), 8), @@ -302,7 +309,7 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors( Value source = desc.memRefDescPtr(builder, loc); builder.create(loc, memory, source, allocationSize, false); if (!toDynamic) - builder.create(loc, freeFunc, source); + builder.create(loc, freeFunc.value(), source); // Create a new descriptor. The same descriptor can be returned multiple // times, attempting to modify its pointer can lead to memory leaks diff --git a/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp index bd7b401efec17..337c01f01a7cc 100644 --- a/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp +++ b/mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp @@ -27,7 +27,7 @@ static std::string ensureSymbolNameIsUnique(ModuleOp moduleOp, return uniqueName; } -void mlir::LLVM::createPrintStrCall( +LogicalResult mlir::LLVM::createPrintStrCall( OpBuilder &builder, Location loc, ModuleOp moduleOp, StringRef symbolName, StringRef string, const LLVMTypeConverter &typeConverter, bool addNewline, std::optional runtimeFunctionName) { @@ -59,8 +59,11 @@ void mlir::LLVM::createPrintStrCall( SmallVector indices(1, 0); Value gep = builder.create(loc, ptrTy, arrayTy, msgAddr, indices); - Operation *printer = + FailureOr printer = LLVM::lookupOrCreatePrintStringFn(moduleOp, runtimeFunctionName); - builder.create(loc, TypeRange(), SymbolRefAttr::get(printer), - gep); + if (failed(printer)) + return failure(); + builder.create(loc, TypeRange(), + SymbolRefAttr::get(printer.value()), gep); + return success(); } diff --git a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp b/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp index a6408391b1330..c5b2e83df93dc 100644 --- a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp @@ -14,9 +14,9 @@ using namespace mlir; -namespace { -LLVM::LLVMFuncOp getNotalignedAllocFn(const LLVMTypeConverter *typeConverter, - Operation *module, Type indexType) { +static FailureOr +getNotalignedAllocFn(const LLVMTypeConverter *typeConverter, Operation *module, + Type indexType) { bool useGenericFn = typeConverter->getOptions().useGenericFunctions; if (useGenericFn) return LLVM::lookupOrCreateGenericAllocFn(module, indexType); @@ -24,8 +24,9 @@ LLVM::LLVMFuncOp getNotalignedAllocFn(const LLVMTypeConverter *typeConverter, return LLVM::lookupOrCreateMallocFn(module, indexType); } -LLVM::LLVMFuncOp getAlignedAllocFn(const LLVMTypeConverter *typeConverter, - Operation *module, Type indexType) { +static FailureOr +getAlignedAllocFn(const LLVMTypeConverter *typeConverter, Operation *module, + Type indexType) { bool useGenericFn = typeConverter->getOptions().useGenericFunctions; if (useGenericFn) @@ -34,8 +35,6 @@ LLVM::LLVMFuncOp getAlignedAllocFn(const LLVMTypeConverter *typeConverter, return LLVM::lookupOrCreateAlignedAllocFn(module, indexType); } -} // end namespace - Value AllocationOpLLVMLowering::createAligned( ConversionPatternRewriter &rewriter, Location loc, Value input, Value alignment) { @@ -80,10 +79,13 @@ std::tuple AllocationOpLLVMLowering::allocateBufferManuallyAlign( << " to integer address space " "failed. Consider adding memory space conversions."; } - LLVM::LLVMFuncOp allocFuncOp = getNotalignedAllocFn( + FailureOr allocFuncOp = getNotalignedAllocFn( getTypeConverter(), op->getParentWithTrait(), getIndexType()); - auto results = rewriter.create(loc, allocFuncOp, sizeBytes); + if (failed(allocFuncOp)) + return std::make_tuple(Value(), Value()); + auto results = + rewriter.create(loc, allocFuncOp.value(), sizeBytes); Value allocatedPtr = castAllocFuncResult(rewriter, loc, results.getResult(), memRefType, @@ -146,11 +148,13 @@ Value AllocationOpLLVMLowering::allocateBufferAutoAlign( sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment); Type elementPtrType = this->getElementPtrType(memRefType); - LLVM::LLVMFuncOp allocFuncOp = getAlignedAllocFn( + FailureOr allocFuncOp = getAlignedAllocFn( getTypeConverter(), op->getParentWithTrait(), getIndexType()); + if (failed(allocFuncOp)) + return Value(); auto results = rewriter.create( - loc, allocFuncOp, ValueRange({allocAlignment, sizeBytes})); + loc, allocFuncOp.value(), ValueRange({allocAlignment, sizeBytes})); return castAllocFuncResult(rewriter, loc, results.getResult(), memRefType, elementPtrType, *getTypeConverter()); diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp index f7542b8b3bc5c..af1dba4587dc1 100644 --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -38,12 +38,12 @@ using namespace mlir; namespace { -bool isStaticStrideOrOffset(int64_t strideOrOffset) { +static bool isStaticStrideOrOffset(int64_t strideOrOffset) { return !ShapedType::isDynamic(strideOrOffset); } -LLVM::LLVMFuncOp getFreeFn(const LLVMTypeConverter *typeConverter, - ModuleOp module) { +static FailureOr +getFreeFn(const LLVMTypeConverter *typeConverter, ModuleOp module) { bool useGenericFn = typeConverter->getOptions().useGenericFunctions; if (useGenericFn) @@ -220,8 +220,10 @@ struct DeallocOpLowering : public ConvertOpToLLVMPattern { matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Insert the `free` declaration if it is not already present. - LLVM::LLVMFuncOp freeFunc = + FailureOr freeFunc = getFreeFn(getTypeConverter(), op->getParentOfType()); + if (failed(freeFunc)) + return failure(); Value allocatedPtr; if (auto unrankedTy = llvm::dyn_cast(op.getMemref().getType())) { @@ -236,7 +238,8 @@ struct DeallocOpLowering : public ConvertOpToLLVMPattern { allocatedPtr = MemRefDescriptor(adaptor.getMemref()) .allocatedPtr(rewriter, op.getLoc()); } - rewriter.replaceOpWithNewOp(op, freeFunc, allocatedPtr); + rewriter.replaceOpWithNewOp(op, freeFunc.value(), + allocatedPtr); return success(); } }; @@ -838,7 +841,9 @@ struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern { auto elemSize = getSizeInBytes(loc, srcType.getElementType(), rewriter); auto copyFn = LLVM::lookupOrCreateMemRefCopyFn( op->getParentOfType(), getIndexType(), sourcePtr.getType()); - rewriter.create(loc, copyFn, + if (failed(copyFn)) + return failure(); + rewriter.create(loc, copyFn.value(), ValueRange{elemSize, sourcePtr, targetPtr}); // Restore stack used for descriptors diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index a1e21cb524bd9..baed98c13adc7 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -1546,11 +1546,15 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern { auto punct = printOp.getPunctuation(); if (auto stringLiteral = printOp.getStringLiteral()) { - LLVM::createPrintStrCall(rewriter, loc, parent, "vector_print_str", - *stringLiteral, *getTypeConverter(), - /*addNewline=*/false); + auto createResult = + LLVM::createPrintStrCall(rewriter, loc, parent, "vector_print_str", + *stringLiteral, *getTypeConverter(), + /*addNewline=*/false); + if (createResult.failed()) + return failure(); + } else if (punct != PrintPunctuation::NoPunctuation) { - emitCall(rewriter, printOp->getLoc(), [&] { + FailureOr op = [&]() { switch (punct) { case PrintPunctuation::Close: return LLVM::lookupOrCreatePrintCloseFn(parent); @@ -1563,7 +1567,10 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern { default: llvm_unreachable("unexpected punctuation"); } - }()); + }(); + if (failed(op)) + return failure(); + emitCall(rewriter, printOp->getLoc(), op.value()); } rewriter.eraseOp(printOp); @@ -1588,7 +1595,7 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern { // Make sure element type has runtime support. PrintConversion conversion = PrintConversion::None; - Operation *printer; + FailureOr printer; if (printType.isF32()) { printer = LLVM::lookupOrCreatePrintF32Fn(parent); } else if (printType.isF64()) { @@ -1631,6 +1638,8 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern { } else { return failure(); } + if (failed(printer)) + return failure(); switch (conversion) { case PrintConversion::ZeroExt64: @@ -1648,7 +1657,7 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern { case PrintConversion::None: break; } - emitCall(rewriter, loc, printer, value); + emitCall(rewriter, loc, printer.value(), value); return success(); } diff --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp index 88421a16ccf9f..68d4426e65301 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp @@ -45,56 +45,85 @@ static constexpr llvm::StringRef kGenericFree = "_mlir_memref_to_llvm_free"; static constexpr llvm::StringRef kMemRefCopy = "memrefCopy"; /// Generic print function lookupOrCreate helper. -LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFn(Operation *moduleOp, - StringRef name, - ArrayRef paramTypes, - Type resultType, bool isVarArg) { +FailureOr +mlir::LLVM::lookupOrCreateFn(Operation *moduleOp, StringRef name, + ArrayRef paramTypes, Type resultType, + bool isVarArg, bool isReserved) { assert(moduleOp->hasTrait() && "expected SymbolTable operation"); auto func = llvm::dyn_cast_or_null( SymbolTable::lookupSymbolIn(moduleOp, name)); - if (func) + auto funcT = LLVMFunctionType::get(resultType, paramTypes, isVarArg); + // Assert the signature of the found function is same as expected + if (func) { + if (funcT != func.getFunctionType()) { + if (isReserved) { + func.emitError("redefinition of reserved function '") + << name << "' of different type " << func.getFunctionType() + << " is prohibited"; + } else { + func.emitError("redefinition of function '") + << name << "' of different type " << funcT << " is prohibited"; + } + return failure(); + } return func; + } OpBuilder b(moduleOp->getRegion(0)); return b.create( moduleOp->getLoc(), name, LLVM::LLVMFunctionType::get(resultType, paramTypes, isVarArg)); } -LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintI64Fn(Operation *moduleOp) { - return lookupOrCreateFn(moduleOp, kPrintI64, - IntegerType::get(moduleOp->getContext(), 64), - LLVM::LLVMVoidType::get(moduleOp->getContext())); +static FailureOr +lookupOrCreateReservedFn(Operation *moduleOp, StringRef name, + ArrayRef paramTypes, Type resultType) { + return lookupOrCreateFn(moduleOp, name, paramTypes, resultType, + /*isVarArg=*/false, /*isReserved=*/true); } -LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintU64Fn(Operation *moduleOp) { - return lookupOrCreateFn(moduleOp, kPrintU64, - IntegerType::get(moduleOp->getContext(), 64), - LLVM::LLVMVoidType::get(moduleOp->getContext())); +FailureOr +mlir::LLVM::lookupOrCreatePrintI64Fn(Operation *moduleOp) { + return lookupOrCreateReservedFn( + moduleOp, kPrintI64, IntegerType::get(moduleOp->getContext(), 64), + LLVM::LLVMVoidType::get(moduleOp->getContext())); } -LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF16Fn(Operation *moduleOp) { - return lookupOrCreateFn(moduleOp, kPrintF16, - IntegerType::get(moduleOp->getContext(), 16), // bits! - LLVM::LLVMVoidType::get(moduleOp->getContext())); +FailureOr +mlir::LLVM::lookupOrCreatePrintU64Fn(Operation *moduleOp) { + return lookupOrCreateReservedFn( + moduleOp, kPrintU64, IntegerType::get(moduleOp->getContext(), 64), + LLVM::LLVMVoidType::get(moduleOp->getContext())); } -LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintBF16Fn(Operation *moduleOp) { - return lookupOrCreateFn(moduleOp, kPrintBF16, - IntegerType::get(moduleOp->getContext(), 16), // bits! - LLVM::LLVMVoidType::get(moduleOp->getContext())); +FailureOr +mlir::LLVM::lookupOrCreatePrintF16Fn(Operation *moduleOp) { + return lookupOrCreateReservedFn( + moduleOp, kPrintF16, + IntegerType::get(moduleOp->getContext(), 16), // bits! + LLVM::LLVMVoidType::get(moduleOp->getContext())); } -LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF32Fn(Operation *moduleOp) { - return lookupOrCreateFn(moduleOp, kPrintF32, - Float32Type::get(moduleOp->getContext()), - LLVM::LLVMVoidType::get(moduleOp->getContext())); +FailureOr +mlir::LLVM::lookupOrCreatePrintBF16Fn(Operation *moduleOp) { + return lookupOrCreateReservedFn( + moduleOp, kPrintBF16, + IntegerType::get(moduleOp->getContext(), 16), // bits! + LLVM::LLVMVoidType::get(moduleOp->getContext())); } -LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF64Fn(Operation *moduleOp) { - return lookupOrCreateFn(moduleOp, kPrintF64, - Float64Type::get(moduleOp->getContext()), - LLVM::LLVMVoidType::get(moduleOp->getContext())); +FailureOr +mlir::LLVM::lookupOrCreatePrintF32Fn(Operation *moduleOp) { + return lookupOrCreateReservedFn( + moduleOp, kPrintF32, Float32Type::get(moduleOp->getContext()), + LLVM::LLVMVoidType::get(moduleOp->getContext())); +} + +FailureOr +mlir::LLVM::lookupOrCreatePrintF64Fn(Operation *moduleOp) { + return lookupOrCreateReservedFn( + moduleOp, kPrintF64, Float64Type::get(moduleOp->getContext()), + LLVM::LLVMVoidType::get(moduleOp->getContext())); } static LLVM::LLVMPointerType getCharPtr(MLIRContext *context) { @@ -106,75 +135,87 @@ static LLVM::LLVMPointerType getVoidPtr(MLIRContext *context) { return getCharPtr(context); } -LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintStringFn( +FailureOr mlir::LLVM::lookupOrCreatePrintStringFn( Operation *moduleOp, std::optional runtimeFunctionName) { - return lookupOrCreateFn(moduleOp, runtimeFunctionName.value_or(kPrintString), - getCharPtr(moduleOp->getContext()), - LLVM::LLVMVoidType::get(moduleOp->getContext())); + return lookupOrCreateReservedFn( + moduleOp, runtimeFunctionName.value_or(kPrintString), + getCharPtr(moduleOp->getContext()), + LLVM::LLVMVoidType::get(moduleOp->getContext())); } -LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintOpenFn(Operation *moduleOp) { - return lookupOrCreateFn(moduleOp, kPrintOpen, {}, - LLVM::LLVMVoidType::get(moduleOp->getContext())); +FailureOr +mlir::LLVM::lookupOrCreatePrintOpenFn(Operation *moduleOp) { + return lookupOrCreateReservedFn( + moduleOp, kPrintOpen, {}, + LLVM::LLVMVoidType::get(moduleOp->getContext())); } -LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCloseFn(Operation *moduleOp) { - return lookupOrCreateFn(moduleOp, kPrintClose, {}, - LLVM::LLVMVoidType::get(moduleOp->getContext())); +FailureOr +mlir::LLVM::lookupOrCreatePrintCloseFn(Operation *moduleOp) { + return lookupOrCreateReservedFn( + moduleOp, kPrintClose, {}, + LLVM::LLVMVoidType::get(moduleOp->getContext())); } -LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCommaFn(Operation *moduleOp) { - return lookupOrCreateFn(moduleOp, kPrintComma, {}, - LLVM::LLVMVoidType::get(moduleOp->getContext())); +FailureOr +mlir::LLVM::lookupOrCreatePrintCommaFn(Operation *moduleOp) { + return lookupOrCreateReservedFn( + moduleOp, kPrintComma, {}, + LLVM::LLVMVoidType::get(moduleOp->getContext())); } -LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintNewlineFn(Operation *moduleOp) { - return lookupOrCreateFn(moduleOp, kPrintNewline, {}, - LLVM::LLVMVoidType::get(moduleOp->getContext())); +FailureOr +mlir::LLVM::lookupOrCreatePrintNewlineFn(Operation *moduleOp) { + return lookupOrCreateReservedFn( + moduleOp, kPrintNewline, {}, + LLVM::LLVMVoidType::get(moduleOp->getContext())); } -LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateMallocFn(Operation *moduleOp, - Type indexType) { - return LLVM::lookupOrCreateFn(moduleOp, kMalloc, indexType, - getVoidPtr(moduleOp->getContext())); +FailureOr +mlir::LLVM::lookupOrCreateMallocFn(Operation *moduleOp, Type indexType) { + return lookupOrCreateReservedFn(moduleOp, kMalloc, indexType, + getVoidPtr(moduleOp->getContext())); } -LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateAlignedAllocFn(Operation *moduleOp, - Type indexType) { - return LLVM::lookupOrCreateFn(moduleOp, kAlignedAlloc, {indexType, indexType}, - getVoidPtr(moduleOp->getContext())); +FailureOr +mlir::LLVM::lookupOrCreateAlignedAllocFn(Operation *moduleOp, Type indexType) { + return lookupOrCreateReservedFn(moduleOp, kAlignedAlloc, + {indexType, indexType}, + getVoidPtr(moduleOp->getContext())); } -LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFreeFn(Operation *moduleOp) { - return LLVM::lookupOrCreateFn( +FailureOr +mlir::LLVM::lookupOrCreateFreeFn(Operation *moduleOp) { + return lookupOrCreateReservedFn( moduleOp, kFree, getVoidPtr(moduleOp->getContext()), LLVM::LLVMVoidType::get(moduleOp->getContext())); } -LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericAllocFn(Operation *moduleOp, - Type indexType) { - return LLVM::lookupOrCreateFn(moduleOp, kGenericAlloc, indexType, - getVoidPtr(moduleOp->getContext())); +FailureOr +mlir::LLVM::lookupOrCreateGenericAllocFn(Operation *moduleOp, Type indexType) { + return lookupOrCreateReservedFn(moduleOp, kGenericAlloc, indexType, + getVoidPtr(moduleOp->getContext())); } -LLVM::LLVMFuncOp +FailureOr mlir::LLVM::lookupOrCreateGenericAlignedAllocFn(Operation *moduleOp, Type indexType) { - return LLVM::lookupOrCreateFn(moduleOp, kGenericAlignedAlloc, - {indexType, indexType}, - getVoidPtr(moduleOp->getContext())); + return lookupOrCreateReservedFn(moduleOp, kGenericAlignedAlloc, + {indexType, indexType}, + getVoidPtr(moduleOp->getContext())); } -LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericFreeFn(Operation *moduleOp) { - return LLVM::lookupOrCreateFn( +FailureOr +mlir::LLVM::lookupOrCreateGenericFreeFn(Operation *moduleOp) { + return lookupOrCreateReservedFn( moduleOp, kGenericFree, getVoidPtr(moduleOp->getContext()), LLVM::LLVMVoidType::get(moduleOp->getContext())); } -LLVM::LLVMFuncOp +FailureOr mlir::LLVM::lookupOrCreateMemRefCopyFn(Operation *moduleOp, Type indexType, Type unrankedDescriptorType) { - return LLVM::lookupOrCreateFn( + return lookupOrCreateReservedFn( moduleOp, kMemRefCopy, ArrayRef{indexType, unrankedDescriptorType, unrankedDescriptorType}, LLVM::LLVMVoidType::get(moduleOp->getContext())); diff --git a/mlir/test/Conversion/MemRefToLLVM/invalid.mlir b/mlir/test/Conversion/MemRefToLLVM/invalid.mlir index 40dd75af1dd77..1e12b83a24b5a 100644 --- a/mlir/test/Conversion/MemRefToLLVM/invalid.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/invalid.mlir @@ -2,6 +2,13 @@ // Since the error is at an unknown location, we use FileCheck instead of // -veri-y-diagnostics here +// CHECK: redefinition of reserved function 'malloc' of different type '!llvm.func' is prohibited +llvm.func @malloc(i64) +func.func @redef_reserved() { + %alloc = memref.alloc() : memref<1024x64xf32, 1> + llvm.return +} + // CHECK: conversion of memref memory space "foo" to integer address space failed. Consider adding memory space conversions. // CHECK-LABEL: @bad_address_space func.func @bad_address_space(%a: memref<2xindex, "foo">) {