@@ -48,13 +48,29 @@ static constexpr llvm::StringRef kMemRefCopy = "memrefCopy";
48
48
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFn (Operation *moduleOp,
49
49
StringRef name,
50
50
ArrayRef<Type> paramTypes,
51
- Type resultType, bool isVarArg) {
51
+ Type resultType, bool isVarArg, bool isReserved ) {
52
52
assert (moduleOp->hasTrait <OpTrait::SymbolTable>() &&
53
53
" expected SymbolTable operation" );
54
54
auto func = llvm::dyn_cast_or_null<LLVM::LLVMFuncOp>(
55
55
SymbolTable::lookupSymbolIn (moduleOp, name));
56
- if (func)
56
+ auto funcT = LLVMFunctionType::get (resultType, paramTypes, isVarArg);
57
+ // Assert the signature of the found function is same as expected
58
+ if (func) {
59
+ if (funcT != func.getFunctionType ()) {
60
+ if (isReserved) {
61
+ func.emitError (" redefinition of reserved function '" + name + " ' of different type " )
62
+ .append (func.getFunctionType ())
63
+ .append (" is prohibited" );
64
+ exit (0 );
65
+ } else {
66
+ func.emitError (" redefinition of function '" + name + " ' of different type " )
67
+ .append (funcT)
68
+ .append (" is prohibited" );
69
+ exit (0 );
70
+ }
71
+ }
57
72
return func;
73
+ }
58
74
OpBuilder b (moduleOp->getRegion (0 ));
59
75
return b.create <LLVM::LLVMFuncOp>(
60
76
moduleOp->getLoc (), name,
@@ -64,37 +80,37 @@ LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFn(Operation *moduleOp,
64
80
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintI64Fn (Operation *moduleOp) {
65
81
return lookupOrCreateFn (moduleOp, kPrintI64 ,
66
82
IntegerType::get (moduleOp->getContext (), 64 ),
67
- LLVM::LLVMVoidType::get (moduleOp->getContext ()));
83
+ LLVM::LLVMVoidType::get (moduleOp->getContext ()), false , true );
68
84
}
69
85
70
86
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintU64Fn (Operation *moduleOp) {
71
87
return lookupOrCreateFn (moduleOp, kPrintU64 ,
72
88
IntegerType::get (moduleOp->getContext (), 64 ),
73
- LLVM::LLVMVoidType::get (moduleOp->getContext ()));
89
+ LLVM::LLVMVoidType::get (moduleOp->getContext ()), false , true );
74
90
}
75
91
76
92
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF16Fn (Operation *moduleOp) {
77
93
return lookupOrCreateFn (moduleOp, kPrintF16 ,
78
94
IntegerType::get (moduleOp->getContext (), 16 ), // bits!
79
- LLVM::LLVMVoidType::get (moduleOp->getContext ()));
95
+ LLVM::LLVMVoidType::get (moduleOp->getContext ()), false , true );
80
96
}
81
97
82
98
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintBF16Fn (Operation *moduleOp) {
83
99
return lookupOrCreateFn (moduleOp, kPrintBF16 ,
84
100
IntegerType::get (moduleOp->getContext (), 16 ), // bits!
85
- LLVM::LLVMVoidType::get (moduleOp->getContext ()));
101
+ LLVM::LLVMVoidType::get (moduleOp->getContext ()), false , true );
86
102
}
87
103
88
104
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF32Fn (Operation *moduleOp) {
89
105
return lookupOrCreateFn (moduleOp, kPrintF32 ,
90
106
Float32Type::get (moduleOp->getContext ()),
91
- LLVM::LLVMVoidType::get (moduleOp->getContext ()));
107
+ LLVM::LLVMVoidType::get (moduleOp->getContext ()), false , true );
92
108
}
93
109
94
110
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF64Fn (Operation *moduleOp) {
95
111
return lookupOrCreateFn (moduleOp, kPrintF64 ,
96
112
Float64Type::get (moduleOp->getContext ()),
97
- LLVM::LLVMVoidType::get (moduleOp->getContext ()));
113
+ LLVM::LLVMVoidType::get (moduleOp->getContext ()), false , true );
98
114
}
99
115
100
116
static LLVM::LLVMPointerType getCharPtr (MLIRContext *context) {
@@ -110,65 +126,65 @@ LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintStringFn(
110
126
Operation *moduleOp, std::optional<StringRef> runtimeFunctionName) {
111
127
return lookupOrCreateFn (moduleOp, runtimeFunctionName.value_or (kPrintString ),
112
128
getCharPtr (moduleOp->getContext ()),
113
- LLVM::LLVMVoidType::get (moduleOp->getContext ()));
129
+ LLVM::LLVMVoidType::get (moduleOp->getContext ()), false , true );
114
130
}
115
131
116
132
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintOpenFn (Operation *moduleOp) {
117
133
return lookupOrCreateFn (moduleOp, kPrintOpen , {},
118
- LLVM::LLVMVoidType::get (moduleOp->getContext ()));
134
+ LLVM::LLVMVoidType::get (moduleOp->getContext ()), false , true );
119
135
}
120
136
121
137
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCloseFn (Operation *moduleOp) {
122
138
return lookupOrCreateFn (moduleOp, kPrintClose , {},
123
- LLVM::LLVMVoidType::get (moduleOp->getContext ()));
139
+ LLVM::LLVMVoidType::get (moduleOp->getContext ()), false , true );
124
140
}
125
141
126
142
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCommaFn (Operation *moduleOp) {
127
143
return lookupOrCreateFn (moduleOp, kPrintComma , {},
128
- LLVM::LLVMVoidType::get (moduleOp->getContext ()));
144
+ LLVM::LLVMVoidType::get (moduleOp->getContext ()), false , true );
129
145
}
130
146
131
147
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintNewlineFn (Operation *moduleOp) {
132
148
return lookupOrCreateFn (moduleOp, kPrintNewline , {},
133
- LLVM::LLVMVoidType::get (moduleOp->getContext ()));
149
+ LLVM::LLVMVoidType::get (moduleOp->getContext ()), false , true );
134
150
}
135
151
136
152
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateMallocFn (Operation *moduleOp,
137
153
Type indexType) {
138
154
return LLVM::lookupOrCreateFn (moduleOp, kMalloc , indexType,
139
- getVoidPtr (moduleOp->getContext ()));
155
+ getVoidPtr (moduleOp->getContext ()), false , true );
140
156
}
141
157
142
158
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateAlignedAllocFn (Operation *moduleOp,
143
159
Type indexType) {
144
160
return LLVM::lookupOrCreateFn (moduleOp, kAlignedAlloc , {indexType, indexType},
145
- getVoidPtr (moduleOp->getContext ()));
161
+ getVoidPtr (moduleOp->getContext ()), false , true );
146
162
}
147
163
148
164
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFreeFn (Operation *moduleOp) {
149
165
return LLVM::lookupOrCreateFn (
150
166
moduleOp, kFree , getVoidPtr (moduleOp->getContext ()),
151
- LLVM::LLVMVoidType::get (moduleOp->getContext ()));
167
+ LLVM::LLVMVoidType::get (moduleOp->getContext ()), false , true );
152
168
}
153
169
154
170
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericAllocFn (Operation *moduleOp,
155
171
Type indexType) {
156
172
return LLVM::lookupOrCreateFn (moduleOp, kGenericAlloc , indexType,
157
- getVoidPtr (moduleOp->getContext ()));
173
+ getVoidPtr (moduleOp->getContext ()), false , true );
158
174
}
159
175
160
176
LLVM::LLVMFuncOp
161
177
mlir::LLVM::lookupOrCreateGenericAlignedAllocFn (Operation *moduleOp,
162
178
Type indexType) {
163
179
return LLVM::lookupOrCreateFn (moduleOp, kGenericAlignedAlloc ,
164
180
{indexType, indexType},
165
- getVoidPtr (moduleOp->getContext ()));
181
+ getVoidPtr (moduleOp->getContext ()), false , true );
166
182
}
167
183
168
184
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericFreeFn (Operation *moduleOp) {
169
185
return LLVM::lookupOrCreateFn (
170
186
moduleOp, kGenericFree , getVoidPtr (moduleOp->getContext ()),
171
- LLVM::LLVMVoidType::get (moduleOp->getContext ()));
187
+ LLVM::LLVMVoidType::get (moduleOp->getContext ()), false , true );
172
188
}
173
189
174
190
LLVM::LLVMFuncOp
@@ -177,5 +193,5 @@ mlir::LLVM::lookupOrCreateMemRefCopyFn(Operation *moduleOp, Type indexType,
177
193
return LLVM::lookupOrCreateFn (
178
194
moduleOp, kMemRefCopy ,
179
195
ArrayRef<Type>{indexType, unrankedDescriptorType, unrankedDescriptorType},
180
- LLVM::LLVMVoidType::get (moduleOp->getContext ()));
196
+ LLVM::LLVMVoidType::get (moduleOp->getContext ()), false , true );
181
197
}
0 commit comments