Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir] Fix conflict of user defined reserved functions with internal prototypes #123378

Merged
merged 9 commits into from
Jan 28, 2025

Conversation

Luohaothu
Copy link
Contributor

Related to #120950

On lowering from memref to LLVM, malloc and other intrinsic functions from libc will be declared in the current module. User's redefinition of these reserved functions will poison the internal analysis with wrong prototype. This patch adds assertion on the found function's type and reports if it mismatch with the intended type.

@llvmbot
Copy link
Member

llvmbot commented Jan 17, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-llvm

Author: Luohao Wang (Luohaothu)

Changes

Related to #120950

On lowering from memref to LLVM, malloc and other intrinsic functions from libc will be declared in the current module. User's redefinition of these reserved functions will poison the internal analysis with wrong prototype. This patch adds assertion on the found function's type and reports if it mismatch with the intended type.


Full diff: https://github.com/llvm/llvm-project/pull/123378.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h (+2-1)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp (+36-20)
  • (added) mlir/test/Conversion/MemRefToLLVM/issue-120950.mlir (+11)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
index 852490cf7428f8..3095c83b90db9e 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
@@ -64,7 +64,8 @@ LLVM::LLVMFuncOp lookupOrCreateMemRefCopyFn(Operation *moduleOp, Type indexType,
 /// Create a FuncOp with signature `resultType`(`paramTypes`)` and name `name`.
 LLVM::LLVMFuncOp lookupOrCreateFn(Operation *moduleOp, StringRef name,
                                   ArrayRef<Type> paramTypes = {},
-                                  Type resultType = {}, bool isVarArg = false);
+                                  Type resultType = {}, bool isVarArg = false,
+                                  bool isReserved = false);
 
 } // namespace LLVM
 } // namespace mlir
diff --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
index 88421a16ccf9fb..ecc31df40ea52f 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
@@ -48,13 +48,29 @@ static constexpr llvm::StringRef kMemRefCopy = "memrefCopy";
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFn(Operation *moduleOp,
                                               StringRef name,
                                               ArrayRef<Type> paramTypes,
-                                              Type resultType, bool isVarArg) {
+                                              Type resultType, bool isVarArg, bool isReserved) {
   assert(moduleOp->hasTrait<OpTrait::SymbolTable>() &&
          "expected SymbolTable operation");
   auto func = llvm::dyn_cast_or_null<LLVM::LLVMFuncOp>(
       SymbolTable::lookupSymbolIn(moduleOp, name));
-  if (func)
+  auto funcT = LLVMFunctionType::get(resultType, paramTypes, isVarArg);
+  // Assert the signature of the found function is same as expected
+  if (func) {
+    if (funcT != func.getFunctionType()) {
+      if (isReserved) {
+        func.emitError("redefinition of reserved function '" + name + "' of different type ")
+        .append(func.getFunctionType())
+        .append(" is prohibited");
+        exit(0);
+      } else {
+        func.emitError("redefinition of function '" + name + "' of different type ")
+        .append(funcT)
+        .append(" is prohibited");
+        exit(0);
+      }
+    }
     return func;
+  }
   OpBuilder b(moduleOp->getRegion(0));
   return b.create<LLVM::LLVMFuncOp>(
       moduleOp->getLoc(), name,
@@ -64,37 +80,37 @@ LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFn(Operation *moduleOp,
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintI64Fn(Operation *moduleOp) {
   return lookupOrCreateFn(moduleOp, kPrintI64,
                           IntegerType::get(moduleOp->getContext(), 64),
-                          LLVM::LLVMVoidType::get(moduleOp->getContext()));
+                          LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintU64Fn(Operation *moduleOp) {
   return lookupOrCreateFn(moduleOp, kPrintU64,
                           IntegerType::get(moduleOp->getContext(), 64),
-                          LLVM::LLVMVoidType::get(moduleOp->getContext()));
+                          LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF16Fn(Operation *moduleOp) {
   return lookupOrCreateFn(moduleOp, kPrintF16,
                           IntegerType::get(moduleOp->getContext(), 16), // bits!
-                          LLVM::LLVMVoidType::get(moduleOp->getContext()));
+                          LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintBF16Fn(Operation *moduleOp) {
   return lookupOrCreateFn(moduleOp, kPrintBF16,
                           IntegerType::get(moduleOp->getContext(), 16), // bits!
-                          LLVM::LLVMVoidType::get(moduleOp->getContext()));
+                          LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF32Fn(Operation *moduleOp) {
   return lookupOrCreateFn(moduleOp, kPrintF32,
                           Float32Type::get(moduleOp->getContext()),
-                          LLVM::LLVMVoidType::get(moduleOp->getContext()));
+                          LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF64Fn(Operation *moduleOp) {
   return lookupOrCreateFn(moduleOp, kPrintF64,
                           Float64Type::get(moduleOp->getContext()),
-                          LLVM::LLVMVoidType::get(moduleOp->getContext()));
+                          LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
 }
 
 static LLVM::LLVMPointerType getCharPtr(MLIRContext *context) {
@@ -110,51 +126,51 @@ LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintStringFn(
     Operation *moduleOp, std::optional<StringRef> runtimeFunctionName) {
   return lookupOrCreateFn(moduleOp, runtimeFunctionName.value_or(kPrintString),
                           getCharPtr(moduleOp->getContext()),
-                          LLVM::LLVMVoidType::get(moduleOp->getContext()));
+                          LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintOpenFn(Operation *moduleOp) {
   return lookupOrCreateFn(moduleOp, kPrintOpen, {},
-                          LLVM::LLVMVoidType::get(moduleOp->getContext()));
+                          LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCloseFn(Operation *moduleOp) {
   return lookupOrCreateFn(moduleOp, kPrintClose, {},
-                          LLVM::LLVMVoidType::get(moduleOp->getContext()));
+                          LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCommaFn(Operation *moduleOp) {
   return lookupOrCreateFn(moduleOp, kPrintComma, {},
-                          LLVM::LLVMVoidType::get(moduleOp->getContext()));
+                          LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintNewlineFn(Operation *moduleOp) {
   return lookupOrCreateFn(moduleOp, kPrintNewline, {},
-                          LLVM::LLVMVoidType::get(moduleOp->getContext()));
+                          LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateMallocFn(Operation *moduleOp,
                                                     Type indexType) {
   return LLVM::lookupOrCreateFn(moduleOp, kMalloc, indexType,
-                                getVoidPtr(moduleOp->getContext()));
+                                getVoidPtr(moduleOp->getContext()), false, true);
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateAlignedAllocFn(Operation *moduleOp,
                                                           Type indexType) {
   return LLVM::lookupOrCreateFn(moduleOp, kAlignedAlloc, {indexType, indexType},
-                                getVoidPtr(moduleOp->getContext()));
+                                getVoidPtr(moduleOp->getContext()), false, true);
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFreeFn(Operation *moduleOp) {
   return LLVM::lookupOrCreateFn(
       moduleOp, kFree, getVoidPtr(moduleOp->getContext()),
-      LLVM::LLVMVoidType::get(moduleOp->getContext()));
+      LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericAllocFn(Operation *moduleOp,
                                                           Type indexType) {
   return LLVM::lookupOrCreateFn(moduleOp, kGenericAlloc, indexType,
-                                getVoidPtr(moduleOp->getContext()));
+                                getVoidPtr(moduleOp->getContext()), false, true);
 }
 
 LLVM::LLVMFuncOp
@@ -162,13 +178,13 @@ mlir::LLVM::lookupOrCreateGenericAlignedAllocFn(Operation *moduleOp,
                                                 Type indexType) {
   return LLVM::lookupOrCreateFn(moduleOp, kGenericAlignedAlloc,
                                 {indexType, indexType},
-                                getVoidPtr(moduleOp->getContext()));
+                                getVoidPtr(moduleOp->getContext()), false, true);
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericFreeFn(Operation *moduleOp) {
   return LLVM::lookupOrCreateFn(
       moduleOp, kGenericFree, getVoidPtr(moduleOp->getContext()),
-      LLVM::LLVMVoidType::get(moduleOp->getContext()));
+      LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
 }
 
 LLVM::LLVMFuncOp
@@ -177,5 +193,5 @@ mlir::LLVM::lookupOrCreateMemRefCopyFn(Operation *moduleOp, Type indexType,
   return LLVM::lookupOrCreateFn(
       moduleOp, kMemRefCopy,
       ArrayRef<Type>{indexType, unrankedDescriptorType, unrankedDescriptorType},
-      LLVM::LLVMVoidType::get(moduleOp->getContext()));
+      LLVM::LLVMVoidType::get(moduleOp->getContext()), false, true);
 }
diff --git a/mlir/test/Conversion/MemRefToLLVM/issue-120950.mlir b/mlir/test/Conversion/MemRefToLLVM/issue-120950.mlir
new file mode 100644
index 00000000000000..f744e4f7635ea7
--- /dev/null
+++ b/mlir/test/Conversion/MemRefToLLVM/issue-120950.mlir
@@ -0,0 +1,11 @@
+// RUN: mlir-opt %s -finalize-memref-to-llvm 2>&1 | FileCheck %s
+
+#map = affine_map<(d0) -> (d0 + 1)>
+module {
+  // CHECK: redefinition of reserved function 'malloc' of different type '!llvm.func<void (i64)>' is prohibited
+  llvm.func @malloc(i64)
+  func.func @issue_120950() {
+    %alloc = memref.alloc() : memref<1024x64xf32, 1>
+    llvm.return
+  }
+}

Copy link

github-actions bot commented Jan 17, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@Luohaothu
Copy link
Contributor Author

@wsmoses @jeanPerier can you help review this PR?

@Luohaothu
Copy link
Contributor Author

Ping

@Dinistro Dinistro self-requested a review January 20, 2025 15:00
Copy link
Contributor

@Dinistro Dinistro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dropped a few comments. As described, using exit(0) is not how we deal with illegal inputs thus this should be changed.

Copy link
Contributor

@Dinistro Dinistro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the delayed review.

Thank you very much for the improvements. I added one functional comment and a set of nits, but this is soon ready to be shipped 🙂

getTypeConverter(), op->getParentWithTrait<OpTrait::SymbolTable>(),
getIndexType());
auto results = rewriter.create<LLVM::CallOp>(loc, allocFuncOp, sizeBytes);
if (failed(allocFuncOp))
return std::make_tuple(Value(), Value());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this result in a crash later on? If so, this would also require a FailureOr wrapping.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The returned pair of value is populated and finally verified here: AllocLikeConvertion.cpp

  // Allocate the underlying buffer.
  auto [allocatedPtr, alignedPtr] =
      this->allocateBuffer(rewriter, loc, size, op);

  if (!allocatedPtr || !alignedPtr)
    return rewriter.notifyMatchFailure(loc,
                                       "underlying buffer allocation failed");

Here empty value is assigned as the invalid state and properly handled. I prefer to leave it as is, and a separate PR might be more appropriate to modernize the error handling style if we do have a preference.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fine for me 🙂

getTypeConverter(), op->getParentWithTrait<OpTrait::SymbolTable>(),
getIndexType());
if (failed(allocFuncOp))
return Value();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As above

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Handled at MemRefToLLVM.cpp

    Value ptr = allocateBufferAutoAlign(
        rewriter, loc, sizeBytes, op, &defaultLayout,
        alignedAllocationGetAlignment(rewriter, loc, cast<memref::AllocOp>(op),
                                      &defaultLayout));
    if (!ptr)
      return std::make_tuple(Value(), Value());
    return std::make_tuple(ptr, ptr);

in the same style.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Can you change this test to use --split-input-file and --verify-diagnostics instead of using 2>&1? This makes the test more robust and modern.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The bad_address_space test emit error from unknown location, which cannot be handled by --verify-diagnostics. Do you have any suggestions?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feel free to keep this as is. If this should change, the owners of that test should do so.

@Luohaothu Luohaothu requested a review from Dinistro January 26, 2025 07:16
Copy link
Contributor

@Dinistro Dinistro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for addressing the comments.

Maybe give people another day to also take a look, then you can land this.

getTypeConverter(), op->getParentWithTrait<OpTrait::SymbolTable>(),
getIndexType());
auto results = rewriter.create<LLVM::CallOp>(loc, allocFuncOp, sizeBytes);
if (failed(allocFuncOp))
return std::make_tuple(Value(), Value());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fine for me 🙂

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feel free to keep this as is. If this should change, the owners of that test should do so.

@Luohaothu
Copy link
Contributor Author

Big thanks for the review @Dinistro ! I can hold it for one day or two. Could you help merge the PR after that?

By the way, how to gain write access for the repo? Currently I have to ask for help by comment each time😓

Copy link
Contributor

@jeanPerier jeanPerier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is outside of my area, but the patch makes sense to me. We did hit similar crashes in flang (like here), and having proper asserts for helpers retrieving libc functions seems right to me.

@jeanPerier
Copy link
Contributor

By the way, how to gain write access for the repo? Currently I have to ask for help by comment each time😓

You have to open an issue with special tags to request commit access, see https://llvm.org/docs/DeveloperPolicy.html#obtaining-commit-access

@Dinistro
Copy link
Contributor

Big thanks for the review @Dinistro ! I can hold it for one day or two. Could you help merge the PR after that?

By the way, how to gain write access for the repo? Currently I have to ask for help by comment each time😓

I'll add a reminder to land this tomorrow. In case I forget, feel free to ping me 😅

@Dinistro Dinistro merged commit e84f6b6 into llvm:main Jan 28, 2025
8 checks passed
@Luohaothu Luohaothu deleted the res-fix branch February 1, 2025 07:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants