Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 5ec38d6

Browse files
committedJan 21, 2025
[mlir] Wrapped return value of function lookup in FailureOr for error handling
1 parent d26a77d commit 5ec38d6

File tree

10 files changed

+196
-144
lines changed

10 files changed

+196
-144
lines changed
 

‎mlir/include/mlir/Conversion/LLVMCommon/PrintCallHelper.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ namespace LLVM {
2323
/// Generate IR that prints the given string to stdout.
2424
/// If a custom runtime function is defined via `runtimeFunctionName`, it must
2525
/// have the signature void(char const*). The default function is `printString`.
26-
void createPrintStrCall(OpBuilder &builder, Location loc, ModuleOp moduleOp,
26+
LogicalResult createPrintStrCall(OpBuilder &builder, Location loc, ModuleOp moduleOp,
2727
StringRef symbolName, StringRef string,
2828
const LLVMTypeConverter &typeConverter,
2929
bool addNewline = true,

‎mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h

+22-21
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
#include "mlir/IR/Operation.h"
1818
#include "mlir/Support/LLVM.h"
19-
#include <optional>
2019

2120
namespace mlir {
2221
class Location;
@@ -29,40 +28,42 @@ class ValueRange;
2928
namespace LLVM {
3029
class LLVMFuncOp;
3130

32-
/// Helper functions to lookup or create the declaration for commonly used
31+
/// Helper functions to look up or create the declaration for commonly used
3332
/// external C function calls. The list of functions provided here must be
3433
/// implemented separately (e.g. as part of a support runtime library or as part
3534
/// of the libc).
36-
LLVM::LLVMFuncOp lookupOrCreatePrintI64Fn(Operation *moduleOp);
37-
LLVM::LLVMFuncOp lookupOrCreatePrintU64Fn(Operation *moduleOp);
38-
LLVM::LLVMFuncOp lookupOrCreatePrintF16Fn(Operation *moduleOp);
39-
LLVM::LLVMFuncOp lookupOrCreatePrintBF16Fn(Operation *moduleOp);
40-
LLVM::LLVMFuncOp lookupOrCreatePrintF32Fn(Operation *moduleOp);
41-
LLVM::LLVMFuncOp lookupOrCreatePrintF64Fn(Operation *moduleOp);
35+
/// Failure if an unexpected version of function is found.
36+
FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintI64Fn(Operation *moduleOp);
37+
FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintU64Fn(Operation *moduleOp);
38+
FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintF16Fn(Operation *moduleOp);
39+
FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintBF16Fn(Operation *moduleOp);
40+
FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintF32Fn(Operation *moduleOp);
41+
FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintF64Fn(Operation *moduleOp);
4242
/// Declares a function to print a C-string.
4343
/// If a custom runtime function is defined via `runtimeFunctionName`, it must
4444
/// have the signature void(char const*). The default function is `printString`.
45-
LLVM::LLVMFuncOp
45+
FailureOr<LLVM::LLVMFuncOp>
4646
lookupOrCreatePrintStringFn(Operation *moduleOp,
4747
std::optional<StringRef> runtimeFunctionName = {});
48-
LLVM::LLVMFuncOp lookupOrCreatePrintOpenFn(Operation *moduleOp);
49-
LLVM::LLVMFuncOp lookupOrCreatePrintCloseFn(Operation *moduleOp);
50-
LLVM::LLVMFuncOp lookupOrCreatePrintCommaFn(Operation *moduleOp);
51-
LLVM::LLVMFuncOp lookupOrCreatePrintNewlineFn(Operation *moduleOp);
52-
LLVM::LLVMFuncOp lookupOrCreateMallocFn(Operation *moduleOp, Type indexType);
53-
LLVM::LLVMFuncOp lookupOrCreateAlignedAllocFn(Operation *moduleOp,
48+
FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintOpenFn(Operation *moduleOp);
49+
FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintCloseFn(Operation *moduleOp);
50+
FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintCommaFn(Operation *moduleOp);
51+
FailureOr<LLVM::LLVMFuncOp> lookupOrCreatePrintNewlineFn(Operation *moduleOp);
52+
FailureOr<LLVM::LLVMFuncOp> lookupOrCreateMallocFn(Operation *moduleOp, Type indexType);
53+
FailureOr<LLVM::LLVMFuncOp> lookupOrCreateAlignedAllocFn(Operation *moduleOp,
5454
Type indexType);
55-
LLVM::LLVMFuncOp lookupOrCreateFreeFn(Operation *moduleOp);
56-
LLVM::LLVMFuncOp lookupOrCreateGenericAllocFn(Operation *moduleOp,
55+
FailureOr<LLVM::LLVMFuncOp> lookupOrCreateFreeFn(Operation *moduleOp);
56+
FailureOr<LLVM::LLVMFuncOp> lookupOrCreateGenericAllocFn(Operation *moduleOp,
5757
Type indexType);
58-
LLVM::LLVMFuncOp lookupOrCreateGenericAlignedAllocFn(Operation *moduleOp,
58+
FailureOr<LLVM::LLVMFuncOp> lookupOrCreateGenericAlignedAllocFn(Operation *moduleOp,
5959
Type indexType);
60-
LLVM::LLVMFuncOp lookupOrCreateGenericFreeFn(Operation *moduleOp);
61-
LLVM::LLVMFuncOp lookupOrCreateMemRefCopyFn(Operation *moduleOp, Type indexType,
60+
FailureOr<LLVM::LLVMFuncOp> lookupOrCreateGenericFreeFn(Operation *moduleOp);
61+
FailureOr<LLVM::LLVMFuncOp> lookupOrCreateMemRefCopyFn(Operation *moduleOp, Type indexType,
6262
Type unrankedDescriptorType);
6363

6464
/// Create a FuncOp with signature `resultType`(`paramTypes`)` and name `name`.
65-
LLVM::LLVMFuncOp lookupOrCreateFn(Operation *moduleOp, StringRef name,
65+
/// Return a failure if the FuncOp found has unexpected signature.
66+
FailureOr<LLVM::LLVMFuncOp> lookupOrCreateFn(Operation *moduleOp, StringRef name,
6667
ArrayRef<Type> paramTypes = {},
6768
Type resultType = {}, bool isVarArg = false,
6869
bool isReserved = false);

‎mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp

+6-2
Original file line numberDiff line numberDiff line change
@@ -396,8 +396,10 @@ class CoroBeginOpConversion : public AsyncOpConversionPattern<CoroBeginOp> {
396396
// Allocate memory for the coroutine frame.
397397
auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn(
398398
op->getParentOfType<ModuleOp>(), rewriter.getI64Type());
399+
if (failed(allocFuncOp))
400+
return failure();
399401
auto coroAlloc = rewriter.create<LLVM::CallOp>(
400-
loc, allocFuncOp, ValueRange{coroAlign, coroSize});
402+
loc, allocFuncOp.value(), ValueRange{coroAlign, coroSize});
401403

402404
// Begin a coroutine: @llvm.coro.begin.
403405
auto coroId = CoroBeginOpAdaptor(adaptor.getOperands()).getId();
@@ -431,7 +433,9 @@ class CoroFreeOpConversion : public AsyncOpConversionPattern<CoroFreeOp> {
431433
// Free the memory.
432434
auto freeFuncOp =
433435
LLVM::lookupOrCreateFreeFn(op->getParentOfType<ModuleOp>());
434-
rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFuncOp,
436+
if (failed(freeFuncOp))
437+
return failure();
438+
rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFuncOp.value(),
435439
ValueRange(coroMem.getResult()));
436440

437441
return success();

‎mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,11 @@ struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
6161

6262
// Failed block: Generate IR to print the message and call `abort`.
6363
Block *failureBlock = rewriter.createBlock(opBlock->getParent());
64-
LLVM::createPrintStrCall(rewriter, loc, module, "assert_msg", op.getMsg(),
64+
if (LLVM::createPrintStrCall(rewriter, loc, module, "assert_msg", op.getMsg(),
6565
*getTypeConverter(), /*addNewLine=*/false,
66-
/*runtimeFunctionName=*/"puts");
66+
/*runtimeFunctionName=*/"puts").failed()) {
67+
return failure();
68+
}
6769
if (abortOnFailedAssert) {
6870
// Insert the `abort` declaration if necessary.
6971
auto abortFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("abort");

‎mlir/lib/Conversion/LLVMCommon/Pattern.cpp

+11-5
Original file line numberDiff line numberDiff line change
@@ -276,11 +276,17 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
276276

277277
// Find the malloc and free, or declare them if necessary.
278278
auto module = builder.getInsertionPoint()->getParentOfType<ModuleOp>();
279-
LLVM::LLVMFuncOp freeFunc, mallocFunc;
280-
if (toDynamic)
279+
FailureOr<LLVM::LLVMFuncOp> freeFunc, mallocFunc;
280+
if (toDynamic) {
281281
mallocFunc = LLVM::lookupOrCreateMallocFn(module, indexType);
282-
if (!toDynamic)
282+
if (failed(mallocFunc))
283+
return failure();
284+
}
285+
if (!toDynamic) {
283286
freeFunc = LLVM::lookupOrCreateFreeFn(module);
287+
if (failed(freeFunc))
288+
return failure();
289+
}
284290

285291
unsigned unrankedMemrefPos = 0;
286292
for (unsigned i = 0, e = operands.size(); i < e; ++i) {
@@ -293,7 +299,7 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
293299
// Allocate memory, copy, and free the source if necessary.
294300
Value memory =
295301
toDynamic
296-
? builder.create<LLVM::CallOp>(loc, mallocFunc, allocationSize)
302+
? builder.create<LLVM::CallOp>(loc, mallocFunc.value(), allocationSize)
297303
.getResult()
298304
: builder.create<LLVM::AllocaOp>(loc, getVoidPtrType(),
299305
IntegerType::get(getContext(), 8),
@@ -302,7 +308,7 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
302308
Value source = desc.memRefDescPtr(builder, loc);
303309
builder.create<LLVM::MemcpyOp>(loc, memory, source, allocationSize, false);
304310
if (!toDynamic)
305-
builder.create<LLVM::CallOp>(loc, freeFunc, source);
311+
builder.create<LLVM::CallOp>(loc, freeFunc.value(), source);
306312

307313
// Create a new descriptor. The same descriptor can be returned multiple
308314
// times, attempting to modify its pointer can lead to memory leaks

‎mlir/lib/Conversion/LLVMCommon/PrintCallHelper.cpp

+9-5
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ static std::string ensureSymbolNameIsUnique(ModuleOp moduleOp,
2727
return uniqueName;
2828
}
2929

30-
void mlir::LLVM::createPrintStrCall(
30+
LogicalResult mlir::LLVM::createPrintStrCall(
3131
OpBuilder &builder, Location loc, ModuleOp moduleOp, StringRef symbolName,
3232
StringRef string, const LLVMTypeConverter &typeConverter, bool addNewline,
3333
std::optional<StringRef> runtimeFunctionName) {
@@ -59,8 +59,12 @@ void mlir::LLVM::createPrintStrCall(
5959
SmallVector<LLVM::GEPArg> indices(1, 0);
6060
Value gep =
6161
builder.create<LLVM::GEPOp>(loc, ptrTy, arrayTy, msgAddr, indices);
62-
Operation *printer =
63-
LLVM::lookupOrCreatePrintStringFn(moduleOp, runtimeFunctionName);
64-
builder.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(printer),
65-
gep);
62+
if (auto printer =
63+
LLVM::lookupOrCreatePrintStringFn(moduleOp, runtimeFunctionName); succeeded(printer)) {
64+
builder.create<LLVM::CallOp>(loc, TypeRange(),
65+
SymbolRefAttr::get(printer.value()), gep);
66+
} else {
67+
return failure();
68+
}
69+
return success();
6670
}

‎mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp

+8-6
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
using namespace mlir;
1616

1717
namespace {
18-
LLVM::LLVMFuncOp getNotalignedAllocFn(const LLVMTypeConverter *typeConverter,
18+
FailureOr<LLVM::LLVMFuncOp> getNotalignedAllocFn(const LLVMTypeConverter *typeConverter,
1919
Operation *module, Type indexType) {
2020
bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
2121
if (useGenericFn)
@@ -24,7 +24,7 @@ LLVM::LLVMFuncOp getNotalignedAllocFn(const LLVMTypeConverter *typeConverter,
2424
return LLVM::lookupOrCreateMallocFn(module, indexType);
2525
}
2626

27-
LLVM::LLVMFuncOp getAlignedAllocFn(const LLVMTypeConverter *typeConverter,
27+
FailureOr<LLVM::LLVMFuncOp> getAlignedAllocFn(const LLVMTypeConverter *typeConverter,
2828
Operation *module, Type indexType) {
2929
bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
3030

@@ -80,10 +80,11 @@ std::tuple<Value, Value> AllocationOpLLVMLowering::allocateBufferManuallyAlign(
8080
<< " to integer address space "
8181
"failed. Consider adding memory space conversions.";
8282
}
83-
LLVM::LLVMFuncOp allocFuncOp = getNotalignedAllocFn(
83+
FailureOr<LLVM::LLVMFuncOp> allocFuncOp = getNotalignedAllocFn(
8484
getTypeConverter(), op->getParentWithTrait<OpTrait::SymbolTable>(),
8585
getIndexType());
86-
auto results = rewriter.create<LLVM::CallOp>(loc, allocFuncOp, sizeBytes);
86+
if (failed(allocFuncOp)) return std::make_tuple(Value(), Value());
87+
auto results = rewriter.create<LLVM::CallOp>(loc, allocFuncOp.value(), sizeBytes);
8788

8889
Value allocatedPtr =
8990
castAllocFuncResult(rewriter, loc, results.getResult(), memRefType,
@@ -146,11 +147,12 @@ Value AllocationOpLLVMLowering::allocateBufferAutoAlign(
146147
sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment);
147148

148149
Type elementPtrType = this->getElementPtrType(memRefType);
149-
LLVM::LLVMFuncOp allocFuncOp = getAlignedAllocFn(
150+
FailureOr<LLVM::LLVMFuncOp> allocFuncOp = getAlignedAllocFn(
150151
getTypeConverter(), op->getParentWithTrait<OpTrait::SymbolTable>(),
151152
getIndexType());
153+
if (failed(allocFuncOp)) return Value();
152154
auto results = rewriter.create<LLVM::CallOp>(
153-
loc, allocFuncOp, ValueRange({allocAlignment, sizeBytes}));
155+
loc, allocFuncOp.value(), ValueRange({allocAlignment, sizeBytes}));
154156

155157
return castAllocFuncResult(rewriter, loc, results.getResult(), memRefType,
156158
elementPtrType, *getTypeConverter());

‎mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp

+10-5
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ bool isStaticStrideOrOffset(int64_t strideOrOffset) {
4242
return !ShapedType::isDynamic(strideOrOffset);
4343
}
4444

45-
LLVM::LLVMFuncOp getFreeFn(const LLVMTypeConverter *typeConverter,
46-
ModuleOp module) {
45+
FailureOr<LLVM::LLVMFuncOp> getFreeFn(const LLVMTypeConverter *typeConverter,
46+
ModuleOp module) {
4747
bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
4848

4949
if (useGenericFn)
@@ -220,8 +220,10 @@ struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
220220
matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor,
221221
ConversionPatternRewriter &rewriter) const override {
222222
// Insert the `free` declaration if it is not already present.
223-
LLVM::LLVMFuncOp freeFunc =
223+
auto freeFunc =
224224
getFreeFn(getTypeConverter(), op->getParentOfType<ModuleOp>());
225+
if (failed(freeFunc))
226+
return failure();
225227
Value allocatedPtr;
226228
if (auto unrankedTy =
227229
llvm::dyn_cast<UnrankedMemRefType>(op.getMemref().getType())) {
@@ -236,7 +238,8 @@ struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
236238
allocatedPtr = MemRefDescriptor(adaptor.getMemref())
237239
.allocatedPtr(rewriter, op.getLoc());
238240
}
239-
rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFunc, allocatedPtr);
241+
rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFunc.value(),
242+
allocatedPtr);
240243
return success();
241244
}
242245
};
@@ -838,7 +841,9 @@ struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
838841
auto elemSize = getSizeInBytes(loc, srcType.getElementType(), rewriter);
839842
auto copyFn = LLVM::lookupOrCreateMemRefCopyFn(
840843
op->getParentOfType<ModuleOp>(), getIndexType(), sourcePtr.getType());
841-
rewriter.create<LLVM::CallOp>(loc, copyFn,
844+
if (failed(copyFn))
845+
return failure();
846+
rewriter.create<LLVM::CallOp>(loc, copyFn.value(),
842847
ValueRange{elemSize, sourcePtr, targetPtr});
843848

844849
// Restore stack used for descriptors

‎mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

+29-19
Original file line numberDiff line numberDiff line change
@@ -1546,24 +1546,32 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
15461546

15471547
auto punct = printOp.getPunctuation();
15481548
if (auto stringLiteral = printOp.getStringLiteral()) {
1549-
LLVM::createPrintStrCall(rewriter, loc, parent, "vector_print_str",
1550-
*stringLiteral, *getTypeConverter(),
1551-
/*addNewline=*/false);
1549+
if (LLVM::createPrintStrCall(rewriter, loc, parent, "vector_print_str",
1550+
*stringLiteral, *getTypeConverter(),
1551+
/*addNewline=*/false)
1552+
.failed()) {
1553+
return failure();
1554+
}
15521555
} else if (punct != PrintPunctuation::NoPunctuation) {
1553-
emitCall(rewriter, printOp->getLoc(), [&] {
1554-
switch (punct) {
1555-
case PrintPunctuation::Close:
1556-
return LLVM::lookupOrCreatePrintCloseFn(parent);
1557-
case PrintPunctuation::Open:
1558-
return LLVM::lookupOrCreatePrintOpenFn(parent);
1559-
case PrintPunctuation::Comma:
1560-
return LLVM::lookupOrCreatePrintCommaFn(parent);
1561-
case PrintPunctuation::NewLine:
1562-
return LLVM::lookupOrCreatePrintNewlineFn(parent);
1563-
default:
1564-
llvm_unreachable("unexpected punctuation");
1565-
}
1566-
}());
1556+
if (auto op = [&] -> FailureOr<LLVM::LLVMFuncOp> {
1557+
switch (punct) {
1558+
case PrintPunctuation::Close:
1559+
return LLVM::lookupOrCreatePrintCloseFn(parent);
1560+
case PrintPunctuation::Open:
1561+
return LLVM::lookupOrCreatePrintOpenFn(parent);
1562+
case PrintPunctuation::Comma:
1563+
return LLVM::lookupOrCreatePrintCommaFn(parent);
1564+
case PrintPunctuation::NewLine:
1565+
return LLVM::lookupOrCreatePrintNewlineFn(parent);
1566+
default:
1567+
llvm_unreachable("unexpected punctuation");
1568+
}
1569+
}();
1570+
succeeded(op))
1571+
emitCall(rewriter, printOp->getLoc(), op.value());
1572+
else {
1573+
return failure();
1574+
}
15671575
}
15681576

15691577
rewriter.eraseOp(printOp);
@@ -1588,7 +1596,7 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
15881596

15891597
// Make sure element type has runtime support.
15901598
PrintConversion conversion = PrintConversion::None;
1591-
Operation *printer;
1599+
FailureOr<Operation *> printer;
15921600
if (printType.isF32()) {
15931601
printer = LLVM::lookupOrCreatePrintF32Fn(parent);
15941602
} else if (printType.isF64()) {
@@ -1631,6 +1639,8 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
16311639
} else {
16321640
return failure();
16331641
}
1642+
if (failed(printer))
1643+
return failure();
16341644

16351645
switch (conversion) {
16361646
case PrintConversion::ZeroExt64:
@@ -1648,7 +1658,7 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
16481658
case PrintConversion::None:
16491659
break;
16501660
}
1651-
emitCall(rewriter, loc, printer, value);
1661+
emitCall(rewriter, loc, printer.value(), value);
16521662
return success();
16531663
}
16541664

‎mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp

+96-78
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,10 @@ static constexpr llvm::StringRef kGenericFree = "_mlir_memref_to_llvm_free";
4545
static constexpr llvm::StringRef kMemRefCopy = "memrefCopy";
4646

4747
/// Generic print function lookupOrCreate helper.
48-
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFn(Operation *moduleOp,
49-
StringRef name,
50-
ArrayRef<Type> paramTypes,
51-
Type resultType, bool isVarArg,
52-
bool isReserved) {
48+
FailureOr<LLVM::LLVMFuncOp>
49+
mlir::LLVM::lookupOrCreateFn(Operation *moduleOp, StringRef name,
50+
ArrayRef<Type> paramTypes, Type resultType,
51+
bool isVarArg, bool isReserved) {
5352
assert(moduleOp->hasTrait<OpTrait::SymbolTable>() &&
5453
"expected SymbolTable operation");
5554
auto func = llvm::dyn_cast_or_null<LLVM::LLVMFuncOp>(
@@ -63,14 +62,13 @@ LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFn(Operation *moduleOp,
6362
"' of different type ")
6463
.append(func.getFunctionType())
6564
.append(" is prohibited");
66-
exit(0);
6765
} else {
6866
func.emitError("redefinition of function '" + name +
6967
"' of different type ")
7068
.append(funcT)
7169
.append(" is prohibited");
72-
exit(0);
7370
}
71+
return failure();
7472
}
7573
return func;
7674
}
@@ -80,42 +78,58 @@ LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFn(Operation *moduleOp,
8078
LLVM::LLVMFunctionType::get(resultType, paramTypes, isVarArg));
8179
}
8280

83-
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintI64Fn(Operation *moduleOp) {
84-
return lookupOrCreateFn(
81+
namespace {
82+
FailureOr<LLVM::LLVMFuncOp> lookupOrCreateReservedFn(Operation *moduleOp,
83+
StringRef name,
84+
ArrayRef<Type> paramTypes,
85+
Type resultType) {
86+
return lookupOrCreateFn(moduleOp, name, paramTypes, resultType,
87+
/*isVarArg=*/false, /*isReserved=*/true);
88+
}
89+
} // namespace
90+
91+
FailureOr<LLVM::LLVMFuncOp>
92+
mlir::LLVM::lookupOrCreatePrintI64Fn(Operation *moduleOp) {
93+
return lookupOrCreateReservedFn(
8594
moduleOp, kPrintI64, IntegerType::get(moduleOp->getContext(), 64),
86-
LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
95+
LLVM::LLVMVoidType::get(moduleOp->getContext()));
8796
}
8897

89-
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintU64Fn(Operation *moduleOp) {
90-
return lookupOrCreateFn(
98+
FailureOr<LLVM::LLVMFuncOp>
99+
mlir::LLVM::lookupOrCreatePrintU64Fn(Operation *moduleOp) {
100+
return lookupOrCreateReservedFn(
91101
moduleOp, kPrintU64, IntegerType::get(moduleOp->getContext(), 64),
92-
LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
102+
LLVM::LLVMVoidType::get(moduleOp->getContext()));
93103
}
94104

95-
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF16Fn(Operation *moduleOp) {
96-
return lookupOrCreateFn(moduleOp, kPrintF16,
97-
IntegerType::get(moduleOp->getContext(), 16), // bits!
98-
LLVM::LLVMVoidType::get(moduleOp->getContext()),
99-
false, true);
105+
FailureOr<LLVM::LLVMFuncOp>
106+
mlir::LLVM::lookupOrCreatePrintF16Fn(Operation *moduleOp) {
107+
return lookupOrCreateReservedFn(
108+
moduleOp, kPrintF16,
109+
IntegerType::get(moduleOp->getContext(), 16), // bits!
110+
LLVM::LLVMVoidType::get(moduleOp->getContext()));
100111
}
101112

102-
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintBF16Fn(Operation *moduleOp) {
103-
return lookupOrCreateFn(moduleOp, kPrintBF16,
104-
IntegerType::get(moduleOp->getContext(), 16), // bits!
105-
LLVM::LLVMVoidType::get(moduleOp->getContext()),
106-
false, true);
113+
FailureOr<LLVM::LLVMFuncOp>
114+
mlir::LLVM::lookupOrCreatePrintBF16Fn(Operation *moduleOp) {
115+
return lookupOrCreateReservedFn(
116+
moduleOp, kPrintBF16,
117+
IntegerType::get(moduleOp->getContext(), 16), // bits!
118+
LLVM::LLVMVoidType::get(moduleOp->getContext()));
107119
}
108120

109-
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF32Fn(Operation *moduleOp) {
110-
return lookupOrCreateFn(
121+
FailureOr<LLVM::LLVMFuncOp>
122+
mlir::LLVM::lookupOrCreatePrintF32Fn(Operation *moduleOp) {
123+
return lookupOrCreateReservedFn(
111124
moduleOp, kPrintF32, Float32Type::get(moduleOp->getContext()),
112-
LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
125+
LLVM::LLVMVoidType::get(moduleOp->getContext()));
113126
}
114127

115-
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF64Fn(Operation *moduleOp) {
116-
return lookupOrCreateFn(
128+
FailureOr<LLVM::LLVMFuncOp>
129+
mlir::LLVM::lookupOrCreatePrintF64Fn(Operation *moduleOp) {
130+
return lookupOrCreateReservedFn(
117131
moduleOp, kPrintF64, Float64Type::get(moduleOp->getContext()),
118-
LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
132+
LLVM::LLVMVoidType::get(moduleOp->getContext()));
119133
}
120134

121135
static LLVM::LLVMPointerType getCharPtr(MLIRContext *context) {
@@ -127,84 +141,88 @@ static LLVM::LLVMPointerType getVoidPtr(MLIRContext *context) {
127141
return getCharPtr(context);
128142
}
129143

130-
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintStringFn(
144+
FailureOr<LLVM::LLVMFuncOp> mlir::LLVM::lookupOrCreatePrintStringFn(
131145
Operation *moduleOp, std::optional<StringRef> runtimeFunctionName) {
132-
return lookupOrCreateFn(moduleOp, runtimeFunctionName.value_or(kPrintString),
133-
getCharPtr(moduleOp->getContext()),
134-
LLVM::LLVMVoidType::get(moduleOp->getContext()),
135-
false, true);
146+
return lookupOrCreateReservedFn(
147+
moduleOp, runtimeFunctionName.value_or(kPrintString),
148+
getCharPtr(moduleOp->getContext()),
149+
LLVM::LLVMVoidType::get(moduleOp->getContext()));
136150
}
137151

138-
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintOpenFn(Operation *moduleOp) {
139-
return lookupOrCreateFn(moduleOp, kPrintOpen, {},
140-
LLVM::LLVMVoidType::get(moduleOp->getContext()),
141-
false, true);
152+
FailureOr<LLVM::LLVMFuncOp>
153+
mlir::LLVM::lookupOrCreatePrintOpenFn(Operation *moduleOp) {
154+
return lookupOrCreateReservedFn(
155+
moduleOp, kPrintOpen, {},
156+
LLVM::LLVMVoidType::get(moduleOp->getContext()));
142157
}
143158

144-
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCloseFn(Operation *moduleOp) {
145-
return lookupOrCreateFn(moduleOp, kPrintClose, {},
146-
LLVM::LLVMVoidType::get(moduleOp->getContext()),
147-
false, true);
159+
FailureOr<LLVM::LLVMFuncOp>
160+
mlir::LLVM::lookupOrCreatePrintCloseFn(Operation *moduleOp) {
161+
return lookupOrCreateReservedFn(
162+
moduleOp, kPrintClose, {},
163+
LLVM::LLVMVoidType::get(moduleOp->getContext()));
148164
}
149165

150-
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCommaFn(Operation *moduleOp) {
151-
return lookupOrCreateFn(moduleOp, kPrintComma, {},
152-
LLVM::LLVMVoidType::get(moduleOp->getContext()),
153-
false, true);
166+
FailureOr<LLVM::LLVMFuncOp>
167+
mlir::LLVM::lookupOrCreatePrintCommaFn(Operation *moduleOp) {
168+
return lookupOrCreateReservedFn(
169+
moduleOp, kPrintComma, {},
170+
LLVM::LLVMVoidType::get(moduleOp->getContext()));
154171
}
155172

156-
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintNewlineFn(Operation *moduleOp) {
157-
return lookupOrCreateFn(moduleOp, kPrintNewline, {},
158-
LLVM::LLVMVoidType::get(moduleOp->getContext()),
159-
false, true);
173+
FailureOr<LLVM::LLVMFuncOp>
174+
mlir::LLVM::lookupOrCreatePrintNewlineFn(Operation *moduleOp) {
175+
return lookupOrCreateReservedFn(
176+
moduleOp, kPrintNewline, {},
177+
LLVM::LLVMVoidType::get(moduleOp->getContext()));
160178
}
161179

162-
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateMallocFn(Operation *moduleOp,
163-
Type indexType) {
164-
return LLVM::lookupOrCreateFn(moduleOp, kMalloc, indexType,
165-
getVoidPtr(moduleOp->getContext()), false,
166-
true);
180+
FailureOr<LLVM::LLVMFuncOp>
181+
mlir::LLVM::lookupOrCreateMallocFn(Operation *moduleOp, Type indexType) {
182+
return lookupOrCreateReservedFn(moduleOp, kMalloc, indexType,
183+
getVoidPtr(moduleOp->getContext()));
167184
}
168185

169-
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateAlignedAllocFn(Operation *moduleOp,
170-
Type indexType) {
171-
return LLVM::lookupOrCreateFn(moduleOp, kAlignedAlloc, {indexType, indexType},
172-
getVoidPtr(moduleOp->getContext()), false,
173-
true);
186+
FailureOr<LLVM::LLVMFuncOp>
187+
mlir::LLVM::lookupOrCreateAlignedAllocFn(Operation *moduleOp, Type indexType) {
188+
return lookupOrCreateReservedFn(moduleOp, kAlignedAlloc,
189+
{indexType, indexType},
190+
getVoidPtr(moduleOp->getContext()));
174191
}
175192

176-
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFreeFn(Operation *moduleOp) {
177-
return LLVM::lookupOrCreateFn(
193+
FailureOr<LLVM::LLVMFuncOp>
194+
mlir::LLVM::lookupOrCreateFreeFn(Operation *moduleOp) {
195+
return lookupOrCreateReservedFn(
178196
moduleOp, kFree, getVoidPtr(moduleOp->getContext()),
179-
LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
197+
LLVM::LLVMVoidType::get(moduleOp->getContext()));
180198
}
181199

182-
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericAllocFn(Operation *moduleOp,
183-
Type indexType) {
184-
return LLVM::lookupOrCreateFn(moduleOp, kGenericAlloc, indexType,
185-
getVoidPtr(moduleOp->getContext()), false,
186-
true);
200+
FailureOr<LLVM::LLVMFuncOp>
201+
mlir::LLVM::lookupOrCreateGenericAllocFn(Operation *moduleOp, Type indexType) {
202+
return lookupOrCreateReservedFn(moduleOp, kGenericAlloc, indexType,
203+
getVoidPtr(moduleOp->getContext()));
187204
}
188205

189-
LLVM::LLVMFuncOp
206+
FailureOr<LLVM::LLVMFuncOp>
190207
mlir::LLVM::lookupOrCreateGenericAlignedAllocFn(Operation *moduleOp,
191208
Type indexType) {
192-
return LLVM::lookupOrCreateFn(
193-
moduleOp, kGenericAlignedAlloc, {indexType, indexType},
194-
getVoidPtr(moduleOp->getContext()), false, true);
209+
return lookupOrCreateReservedFn(moduleOp, kGenericAlignedAlloc,
210+
{indexType, indexType},
211+
getVoidPtr(moduleOp->getContext()));
195212
}
196213

197-
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericFreeFn(Operation *moduleOp) {
198-
return LLVM::lookupOrCreateFn(
214+
FailureOr<LLVM::LLVMFuncOp>
215+
mlir::LLVM::lookupOrCreateGenericFreeFn(Operation *moduleOp) {
216+
return lookupOrCreateReservedFn(
199217
moduleOp, kGenericFree, getVoidPtr(moduleOp->getContext()),
200-
LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
218+
LLVM::LLVMVoidType::get(moduleOp->getContext()));
201219
}
202220

203-
LLVM::LLVMFuncOp
221+
FailureOr<LLVM::LLVMFuncOp>
204222
mlir::LLVM::lookupOrCreateMemRefCopyFn(Operation *moduleOp, Type indexType,
205223
Type unrankedDescriptorType) {
206-
return LLVM::lookupOrCreateFn(
224+
return lookupOrCreateReservedFn(
207225
moduleOp, kMemRefCopy,
208226
ArrayRef<Type>{indexType, unrankedDescriptorType, unrankedDescriptorType},
209-
LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
227+
LLVM::LLVMVoidType::get(moduleOp->getContext()));
210228
}

0 commit comments

Comments
 (0)
Please sign in to comment.