Skip to content

Commit 3b892e6

Browse files
committed
[mlir] Add assertion on reserved function's type
1 parent 84fa175 commit 3b892e6

File tree

2 files changed

+38
-21
lines changed

2 files changed

+38
-21
lines changed

mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,8 @@ LLVM::LLVMFuncOp lookupOrCreateMemRefCopyFn(Operation *moduleOp, Type indexType,
6464
/// Create a FuncOp with signature `resultType`(`paramTypes`)` and name `name`.
6565
LLVM::LLVMFuncOp lookupOrCreateFn(Operation *moduleOp, StringRef name,
6666
ArrayRef<Type> paramTypes = {},
67-
Type resultType = {}, bool isVarArg = false);
67+
Type resultType = {}, bool isVarArg = false,
68+
bool isReserved = false);
6869

6970
} // namespace LLVM
7071
} // namespace mlir

mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp

Lines changed: 36 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,29 @@ static constexpr llvm::StringRef kMemRefCopy = "memrefCopy";
4848
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFn(Operation *moduleOp,
4949
StringRef name,
5050
ArrayRef<Type> paramTypes,
51-
Type resultType, bool isVarArg) {
51+
Type resultType, bool isVarArg, bool isReserved) {
5252
assert(moduleOp->hasTrait<OpTrait::SymbolTable>() &&
5353
"expected SymbolTable operation");
5454
auto func = llvm::dyn_cast_or_null<LLVM::LLVMFuncOp>(
5555
SymbolTable::lookupSymbolIn(moduleOp, name));
56-
if (func)
56+
auto funcT = LLVMFunctionType::get(resultType, paramTypes, isVarArg);
57+
// Assert the signature of the found function is same as expected
58+
if (func) {
59+
if (funcT != func.getFunctionType()) {
60+
if (isReserved) {
61+
func.emitError("redefinition of reserved function '" + name + "' of different type ")
62+
.append(func.getFunctionType())
63+
.append(" is prohibited");
64+
exit(0);
65+
} else {
66+
func.emitError("redefinition of function '" + name + "' of different type ")
67+
.append(funcT)
68+
.append(" is prohibited");
69+
exit(0);
70+
}
71+
}
5772
return func;
73+
}
5874
OpBuilder b(moduleOp->getRegion(0));
5975
return b.create<LLVM::LLVMFuncOp>(
6076
moduleOp->getLoc(), name,
@@ -64,37 +80,37 @@ LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFn(Operation *moduleOp,
6480
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintI64Fn(Operation *moduleOp) {
6581
return lookupOrCreateFn(moduleOp, kPrintI64,
6682
IntegerType::get(moduleOp->getContext(), 64),
67-
LLVM::LLVMVoidType::get(moduleOp->getContext()));
83+
LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
6884
}
6985

7086
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintU64Fn(Operation *moduleOp) {
7187
return lookupOrCreateFn(moduleOp, kPrintU64,
7288
IntegerType::get(moduleOp->getContext(), 64),
73-
LLVM::LLVMVoidType::get(moduleOp->getContext()));
89+
LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
7490
}
7591

7692
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF16Fn(Operation *moduleOp) {
7793
return lookupOrCreateFn(moduleOp, kPrintF16,
7894
IntegerType::get(moduleOp->getContext(), 16), // bits!
79-
LLVM::LLVMVoidType::get(moduleOp->getContext()));
95+
LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
8096
}
8197

8298
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintBF16Fn(Operation *moduleOp) {
8399
return lookupOrCreateFn(moduleOp, kPrintBF16,
84100
IntegerType::get(moduleOp->getContext(), 16), // bits!
85-
LLVM::LLVMVoidType::get(moduleOp->getContext()));
101+
LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
86102
}
87103

88104
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF32Fn(Operation *moduleOp) {
89105
return lookupOrCreateFn(moduleOp, kPrintF32,
90106
Float32Type::get(moduleOp->getContext()),
91-
LLVM::LLVMVoidType::get(moduleOp->getContext()));
107+
LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
92108
}
93109

94110
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF64Fn(Operation *moduleOp) {
95111
return lookupOrCreateFn(moduleOp, kPrintF64,
96112
Float64Type::get(moduleOp->getContext()),
97-
LLVM::LLVMVoidType::get(moduleOp->getContext()));
113+
LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
98114
}
99115

100116
static LLVM::LLVMPointerType getCharPtr(MLIRContext *context) {
@@ -110,65 +126,65 @@ LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintStringFn(
110126
Operation *moduleOp, std::optional<StringRef> runtimeFunctionName) {
111127
return lookupOrCreateFn(moduleOp, runtimeFunctionName.value_or(kPrintString),
112128
getCharPtr(moduleOp->getContext()),
113-
LLVM::LLVMVoidType::get(moduleOp->getContext()));
129+
LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
114130
}
115131

116132
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintOpenFn(Operation *moduleOp) {
117133
return lookupOrCreateFn(moduleOp, kPrintOpen, {},
118-
LLVM::LLVMVoidType::get(moduleOp->getContext()));
134+
LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
119135
}
120136

121137
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCloseFn(Operation *moduleOp) {
122138
return lookupOrCreateFn(moduleOp, kPrintClose, {},
123-
LLVM::LLVMVoidType::get(moduleOp->getContext()));
139+
LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
124140
}
125141

126142
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCommaFn(Operation *moduleOp) {
127143
return lookupOrCreateFn(moduleOp, kPrintComma, {},
128-
LLVM::LLVMVoidType::get(moduleOp->getContext()));
144+
LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
129145
}
130146

131147
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintNewlineFn(Operation *moduleOp) {
132148
return lookupOrCreateFn(moduleOp, kPrintNewline, {},
133-
LLVM::LLVMVoidType::get(moduleOp->getContext()));
149+
LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
134150
}
135151

136152
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateMallocFn(Operation *moduleOp,
137153
Type indexType) {
138154
return LLVM::lookupOrCreateFn(moduleOp, kMalloc, indexType,
139-
getVoidPtr(moduleOp->getContext()));
155+
getVoidPtr(moduleOp->getContext()), false, true);
140156
}
141157

142158
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateAlignedAllocFn(Operation *moduleOp,
143159
Type indexType) {
144160
return LLVM::lookupOrCreateFn(moduleOp, kAlignedAlloc, {indexType, indexType},
145-
getVoidPtr(moduleOp->getContext()));
161+
getVoidPtr(moduleOp->getContext()), false, true);
146162
}
147163

148164
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFreeFn(Operation *moduleOp) {
149165
return LLVM::lookupOrCreateFn(
150166
moduleOp, kFree, getVoidPtr(moduleOp->getContext()),
151-
LLVM::LLVMVoidType::get(moduleOp->getContext()));
167+
LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
152168
}
153169

154170
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericAllocFn(Operation *moduleOp,
155171
Type indexType) {
156172
return LLVM::lookupOrCreateFn(moduleOp, kGenericAlloc, indexType,
157-
getVoidPtr(moduleOp->getContext()));
173+
getVoidPtr(moduleOp->getContext()), false, true);
158174
}
159175

160176
LLVM::LLVMFuncOp
161177
mlir::LLVM::lookupOrCreateGenericAlignedAllocFn(Operation *moduleOp,
162178
Type indexType) {
163179
return LLVM::lookupOrCreateFn(moduleOp, kGenericAlignedAlloc,
164180
{indexType, indexType},
165-
getVoidPtr(moduleOp->getContext()));
181+
getVoidPtr(moduleOp->getContext()), false, true);
166182
}
167183

168184
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericFreeFn(Operation *moduleOp) {
169185
return LLVM::lookupOrCreateFn(
170186
moduleOp, kGenericFree, getVoidPtr(moduleOp->getContext()),
171-
LLVM::LLVMVoidType::get(moduleOp->getContext()));
187+
LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
172188
}
173189

174190
LLVM::LLVMFuncOp
@@ -177,5 +193,5 @@ mlir::LLVM::lookupOrCreateMemRefCopyFn(Operation *moduleOp, Type indexType,
177193
return LLVM::lookupOrCreateFn(
178194
moduleOp, kMemRefCopy,
179195
ArrayRef<Type>{indexType, unrankedDescriptorType, unrankedDescriptorType},
180-
LLVM::LLVMVoidType::get(moduleOp->getContext()));
196+
LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
181197
}

0 commit comments

Comments
 (0)