-
Notifications
You must be signed in to change notification settings - Fork 13k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir] Fix conflict of user defined reserved functions with internal prototypes #123378
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-llvm Author: Luohao Wang (Luohaothu) ChangesRelated to #120950 On lowering from Full diff: https://github.com/llvm/llvm-project/pull/123378.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
index 852490cf7428f8..3095c83b90db9e 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
@@ -64,7 +64,8 @@ LLVM::LLVMFuncOp lookupOrCreateMemRefCopyFn(Operation *moduleOp, Type indexType,
/// Create a FuncOp with signature `resultType`(`paramTypes`)` and name `name`.
LLVM::LLVMFuncOp lookupOrCreateFn(Operation *moduleOp, StringRef name,
ArrayRef<Type> paramTypes = {},
- Type resultType = {}, bool isVarArg = false);
+ Type resultType = {}, bool isVarArg = false,
+ bool isReserved = false);
} // namespace LLVM
} // namespace mlir
diff --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
index 88421a16ccf9fb..ecc31df40ea52f 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
@@ -48,13 +48,29 @@ static constexpr llvm::StringRef kMemRefCopy = "memrefCopy";
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFn(Operation *moduleOp,
StringRef name,
ArrayRef<Type> paramTypes,
- Type resultType, bool isVarArg) {
+ Type resultType, bool isVarArg, bool isReserved) {
assert(moduleOp->hasTrait<OpTrait::SymbolTable>() &&
"expected SymbolTable operation");
auto func = llvm::dyn_cast_or_null<LLVM::LLVMFuncOp>(
SymbolTable::lookupSymbolIn(moduleOp, name));
- if (func)
+ auto funcT = LLVMFunctionType::get(resultType, paramTypes, isVarArg);
+ // Assert the signature of the found function is same as expected
+ if (func) {
+ if (funcT != func.getFunctionType()) {
+ if (isReserved) {
+ func.emitError("redefinition of reserved function '" + name + "' of different type ")
+ .append(func.getFunctionType())
+ .append(" is prohibited");
+ exit(0);
+ } else {
+ func.emitError("redefinition of function '" + name + "' of different type ")
+ .append(funcT)
+ .append(" is prohibited");
+ exit(0);
+ }
+ }
return func;
+ }
OpBuilder b(moduleOp->getRegion(0));
return b.create<LLVM::LLVMFuncOp>(
moduleOp->getLoc(), name,
@@ -64,37 +80,37 @@ LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFn(Operation *moduleOp,
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintI64Fn(Operation *moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintI64,
IntegerType::get(moduleOp->getContext(), 64),
- LLVM::LLVMVoidType::get(moduleOp->getContext()));
+ LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintU64Fn(Operation *moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintU64,
IntegerType::get(moduleOp->getContext(), 64),
- LLVM::LLVMVoidType::get(moduleOp->getContext()));
+ LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF16Fn(Operation *moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintF16,
IntegerType::get(moduleOp->getContext(), 16), // bits!
- LLVM::LLVMVoidType::get(moduleOp->getContext()));
+ LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintBF16Fn(Operation *moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintBF16,
IntegerType::get(moduleOp->getContext(), 16), // bits!
- LLVM::LLVMVoidType::get(moduleOp->getContext()));
+ LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF32Fn(Operation *moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintF32,
Float32Type::get(moduleOp->getContext()),
- LLVM::LLVMVoidType::get(moduleOp->getContext()));
+ LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF64Fn(Operation *moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintF64,
Float64Type::get(moduleOp->getContext()),
- LLVM::LLVMVoidType::get(moduleOp->getContext()));
+ LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
}
static LLVM::LLVMPointerType getCharPtr(MLIRContext *context) {
@@ -110,51 +126,51 @@ LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintStringFn(
Operation *moduleOp, std::optional<StringRef> runtimeFunctionName) {
return lookupOrCreateFn(moduleOp, runtimeFunctionName.value_or(kPrintString),
getCharPtr(moduleOp->getContext()),
- LLVM::LLVMVoidType::get(moduleOp->getContext()));
+ LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintOpenFn(Operation *moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintOpen, {},
- LLVM::LLVMVoidType::get(moduleOp->getContext()));
+ LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCloseFn(Operation *moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintClose, {},
- LLVM::LLVMVoidType::get(moduleOp->getContext()));
+ LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCommaFn(Operation *moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintComma, {},
- LLVM::LLVMVoidType::get(moduleOp->getContext()));
+ LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintNewlineFn(Operation *moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintNewline, {},
- LLVM::LLVMVoidType::get(moduleOp->getContext()));
+ LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateMallocFn(Operation *moduleOp,
Type indexType) {
return LLVM::lookupOrCreateFn(moduleOp, kMalloc, indexType,
- getVoidPtr(moduleOp->getContext()));
+ getVoidPtr(moduleOp->getContext()), false, true);
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateAlignedAllocFn(Operation *moduleOp,
Type indexType) {
return LLVM::lookupOrCreateFn(moduleOp, kAlignedAlloc, {indexType, indexType},
- getVoidPtr(moduleOp->getContext()));
+ getVoidPtr(moduleOp->getContext()), false, true);
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFreeFn(Operation *moduleOp) {
return LLVM::lookupOrCreateFn(
moduleOp, kFree, getVoidPtr(moduleOp->getContext()),
- LLVM::LLVMVoidType::get(moduleOp->getContext()));
+ LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericAllocFn(Operation *moduleOp,
Type indexType) {
return LLVM::lookupOrCreateFn(moduleOp, kGenericAlloc, indexType,
- getVoidPtr(moduleOp->getContext()));
+ getVoidPtr(moduleOp->getContext()), false, true);
}
LLVM::LLVMFuncOp
@@ -162,13 +178,13 @@ mlir::LLVM::lookupOrCreateGenericAlignedAllocFn(Operation *moduleOp,
Type indexType) {
return LLVM::lookupOrCreateFn(moduleOp, kGenericAlignedAlloc,
{indexType, indexType},
- getVoidPtr(moduleOp->getContext()));
+ getVoidPtr(moduleOp->getContext()), false, true);
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericFreeFn(Operation *moduleOp) {
return LLVM::lookupOrCreateFn(
moduleOp, kGenericFree, getVoidPtr(moduleOp->getContext()),
- LLVM::LLVMVoidType::get(moduleOp->getContext()));
+ LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
}
LLVM::LLVMFuncOp
@@ -177,5 +193,5 @@ mlir::LLVM::lookupOrCreateMemRefCopyFn(Operation *moduleOp, Type indexType,
return LLVM::lookupOrCreateFn(
moduleOp, kMemRefCopy,
ArrayRef<Type>{indexType, unrankedDescriptorType, unrankedDescriptorType},
- LLVM::LLVMVoidType::get(moduleOp->getContext()));
+ LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
}
diff --git a/mlir/test/Conversion/MemRefToLLVM/issue-120950.mlir b/mlir/test/Conversion/MemRefToLLVM/issue-120950.mlir
new file mode 100644
index 00000000000000..f744e4f7635ea7
--- /dev/null
+++ b/mlir/test/Conversion/MemRefToLLVM/issue-120950.mlir
@@ -0,0 +1,11 @@
+// RUN: mlir-opt %s -finalize-memref-to-llvm 2>&1 | FileCheck %s
+
+#map = affine_map<(d0) -> (d0 + 1)>
+module {
+ // CHECK: redefinition of reserved function 'malloc' of different type '!llvm.func<void (i64)>' is prohibited
+ llvm.func @malloc(i64)
+ func.func @issue_120950() {
+ %alloc = memref.alloc() : memref<1024x64xf32, 1>
+ llvm.return
+ }
+}
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
@wsmoses @jeanPerier can you help review this PR? |
Ping |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dropped a few comments. As described, using exit(0)
is not how we deal with illegal inputs thus this should be changed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the delayed review.
Thank you very much for the improvements. I added one functional comment and a set of nits, but this is soon ready to be shipped 🙂
getTypeConverter(), op->getParentWithTrait<OpTrait::SymbolTable>(), | ||
getIndexType()); | ||
auto results = rewriter.create<LLVM::CallOp>(loc, allocFuncOp, sizeBytes); | ||
if (failed(allocFuncOp)) | ||
return std::make_tuple(Value(), Value()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will this result in a crash later on? If so, this would also require a FailureOr wrapping.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The returned pair of value is populated and finally verified here: AllocLikeConvertion.cpp
// Allocate the underlying buffer.
auto [allocatedPtr, alignedPtr] =
this->allocateBuffer(rewriter, loc, size, op);
if (!allocatedPtr || !alignedPtr)
return rewriter.notifyMatchFailure(loc,
"underlying buffer allocation failed");
Here empty value is assigned as the invalid state and properly handled. I prefer to leave it as is, and a separate PR might be more appropriate to modernize the error handling style if we do have a preference.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fine for me 🙂
getTypeConverter(), op->getParentWithTrait<OpTrait::SymbolTable>(), | ||
getIndexType()); | ||
if (failed(allocFuncOp)) | ||
return Value(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Handled at MemRefToLLVM.cpp
Value ptr = allocateBufferAutoAlign(
rewriter, loc, sizeBytes, op, &defaultLayout,
alignedAllocationGetAlignment(rewriter, loc, cast<memref::AllocOp>(op),
&defaultLayout));
if (!ptr)
return std::make_tuple(Value(), Value());
return std::make_tuple(ptr, ptr);
in the same style.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Can you change this test to use --split-input-file
and --verify-diagnostics
instead of using 2>&1
? This makes the test more robust and modern.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The bad_address_space
test emit error from unknown location, which cannot be handled by --verify-diagnostics
. Do you have any suggestions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Feel free to keep this as is. If this should change, the owners of that test should do so.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for addressing the comments.
Maybe give people another day to also take a look, then you can land this.
getTypeConverter(), op->getParentWithTrait<OpTrait::SymbolTable>(), | ||
getIndexType()); | ||
auto results = rewriter.create<LLVM::CallOp>(loc, allocFuncOp, sizeBytes); | ||
if (failed(allocFuncOp)) | ||
return std::make_tuple(Value(), Value()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fine for me 🙂
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Feel free to keep this as is. If this should change, the owners of that test should do so.
Big thanks for the review @Dinistro ! I can hold it for one day or two. Could you help merge the PR after that? By the way, how to gain write access for the repo? Currently I have to ask for help by comment each time😓 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is outside of my area, but the patch makes sense to me. We did hit similar crashes in flang (like here), and having proper asserts for helpers retrieving libc functions seems right to me.
You have to open an issue with special tags to request commit access, see https://llvm.org/docs/DeveloperPolicy.html#obtaining-commit-access |
I'll add a reminder to land this tomorrow. In case I forget, feel free to ping me 😅 |
Related to #120950
On lowering from
memref
to LLVM,malloc
and other intrinsic functions fromlibc
will be declared in the current module. User's redefinition of these reserved functions will poison the internal analysis with wrong prototype. This patch adds assertion on the found function's type and reports if it mismatch with the intended type.