Skip to content

Commit

Permalink
[CIR] Remove return !cir.void from IR and textual representation (#1249)
Browse files Browse the repository at this point in the history
C/C++ functions returning void had an explicit !cir.void return type
while not
having any returned value, which was breaking a lot of MLIR invariants
when the
CIR dialect is used in a greater context, for example with the inliner.

Now, a C/C++ function returning void has no return type and no return
values,
which does not break the MLIR invariant about the same number of return
types
and returned values.

This change does not keeps the same parsing/pretty-printed syntax as
before for
compatibility like in #1203 because
it
requires some new features from the MLIR parser infrastructure itself,
which is
not great.

This uses an optional type for function return type.

The default MLIR parser for optional parameters requires an optional
anchor we
do not have in the syntax, so use a custom FuncType parser to handle the
optional
return type.
  • Loading branch information
keryell authored Jan 10, 2025
1 parent 17cd670 commit 54d48d8
Show file tree
Hide file tree
Showing 15 changed files with 182 additions and 68 deletions.
13 changes: 8 additions & 5 deletions clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -3474,8 +3474,6 @@ def FuncOp : CIR_Op<"func", [
/// Returns the results types that the callable region produces when
/// executed.
llvm::ArrayRef<mlir::Type> getCallableResults() {
if (::llvm::isa<cir::VoidType>(getFunctionType().getReturnType()))
return {};
return getFunctionType().getReturnTypes();
}

Expand All @@ -3492,10 +3490,15 @@ def FuncOp : CIR_Op<"func", [
}

/// Returns the argument types of this function.
llvm::ArrayRef<mlir::Type> getArgumentTypes() { return getFunctionType().getInputs(); }
llvm::ArrayRef<mlir::Type> getArgumentTypes() {
return getFunctionType().getInputs();
}

/// Returns the result types of this function.
llvm::ArrayRef<mlir::Type> getResultTypes() { return getFunctionType().getReturnTypes(); }
/// Returns 0 or 1 result type of this function (0 in the case of a function
/// returing void)
llvm::ArrayRef<mlir::Type> getResultTypes() {
return getFunctionType().getReturnTypes();
}

/// Hook for OpTrait::FunctionOpInterfaceTrait, called after verifying that
/// the 'type' attribute is present and checks if it holds a function type.
Expand Down
25 changes: 19 additions & 6 deletions clang/include/clang/CIR/Dialect/IR/CIRTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -372,29 +372,38 @@ def CIR_VectorType : CIR_Type<"Vector", "vector",
def CIR_FuncType : CIR_Type<"Func", "func"> {
let summary = "CIR function type";
let description = [{
The `!cir.func` is a function type. It consists of a single return type, a
list of parameter types and can optionally be variadic.
The `!cir.func` is a function type. It consists of an optional return type,
a list of parameter types and can optionally be variadic.

Example:

```mlir
!cir.func<()>
!cir.func<!bool ()>
!cir.func<(!s8i, !s8i)>
!cir.func<!s32i (!s8i, !s8i)>
!cir.func<!s32i (!s32i, ...)>
```
}];

let parameters = (ins ArrayRefParameter<"mlir::Type">:$inputs, "mlir::Type":$returnType,
let parameters = (ins ArrayRefParameter<"mlir::Type">:$inputs,
"mlir::Type":$optionalReturnType,
"bool":$varArg);
// Use a custom parser to handle the optional return and argument types
// without an optional anchor.
let assemblyFormat = [{
`<` $returnType ` ` `(` custom<FuncTypeArgs>($inputs, $varArg) `>`
`<` custom<FuncType>($optionalReturnType, $inputs, $varArg) `>`
}];

let builders = [
// Construct with an actual return type or explicit !cir.void
TypeBuilderWithInferredContext<(ins
"llvm::ArrayRef<mlir::Type>":$inputs, "mlir::Type":$returnType,
CArg<"bool", "false">:$isVarArg), [{
return $_get(returnType.getContext(), inputs, returnType, isVarArg);
return $_get(returnType.getContext(), inputs,
mlir::isa<cir::VoidType>(returnType) ? nullptr
: returnType,
isVarArg);
}]>
];

Expand All @@ -408,11 +417,15 @@ def CIR_FuncType : CIR_Type<"Func", "func"> {
/// Returns the number of arguments to the function.
unsigned getNumInputs() const { return getInputs().size(); }

/// Returns the result type of the function as an actual return type or
/// explicit !cir.void
mlir::Type getReturnType() const;

/// Returns the result type of the function as an ArrayRef, enabling better
/// integration with generic MLIR utilities.
llvm::ArrayRef<mlir::Type> getReturnTypes() const;

/// Returns whether the function is returns void.
/// Returns whether the function returns void.
bool isVoid() const;

/// Returns a clone of this function type with the given argument
Expand Down
2 changes: 1 addition & 1 deletion clang/lib/CIR/CodeGen/CIRGenTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ mlir::Type CIRGenTypes::convertFunctionTypeInternal(QualType QFT) {
assert(QFT.isCanonical());
const Type *Ty = QFT.getTypePtr();
const FunctionType *FT = cast<FunctionType>(QFT.getTypePtr());
// First, check whether we can build the full fucntion type. If the function
// First, check whether we can build the full function type. If the function
// type depends on an incomplete type (e.g. a struct or enum), we cannot lower
// the function type.
assert(isFuncTypeConvertible(FT) && "NYI");
Expand Down
14 changes: 7 additions & 7 deletions clang/lib/CIR/Dialect/IR/CIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2490,13 +2490,8 @@ void cir::FuncOp::print(OpAsmPrinter &p) {
p.printSymbolName(getSymName());
auto fnType = getFunctionType();
llvm::SmallVector<Type, 1> resultTypes;
if (!fnType.isVoid())
function_interface_impl::printFunctionSignature(
p, *this, fnType.getInputs(), fnType.isVarArg(),
fnType.getReturnTypes());
else
function_interface_impl::printFunctionSignature(
p, *this, fnType.getInputs(), fnType.isVarArg(), {});
function_interface_impl::printFunctionSignature(
p, *this, fnType.getInputs(), fnType.isVarArg(), fnType.getReturnTypes());

if (mlir::ArrayAttr annotations = getAnnotationsAttr()) {
p << ' ';
Expand Down Expand Up @@ -2565,6 +2560,11 @@ LogicalResult cir::FuncOp::verifyType() {
if (!getNoProto() && type.isVarArg() && type.getNumInputs() == 0)
return emitError()
<< "prototyped function must have at least one non-variadic input";
if (auto rt = type.getReturnTypes();
!rt.empty() && mlir::isa<cir::VoidType>(rt.front()))
return emitOpError("The return type for a function returning void should "
"be empty instead of an explicit !cir.void");

return success();
}

Expand Down
94 changes: 82 additions & 12 deletions clang/lib/CIR/Dialect/IR/CIRTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/MathExtras.h"
#include <cassert>
#include <optional>

using cir::MissingFeatures;
Expand All @@ -41,12 +42,13 @@ using cir::MissingFeatures;
// CIR Custom Parser/Printer Signatures
//===----------------------------------------------------------------------===//

static mlir::ParseResult
parseFuncTypeArgs(mlir::AsmParser &p, llvm::SmallVector<mlir::Type> &params,
bool &isVarArg);
static void printFuncTypeArgs(mlir::AsmPrinter &p,
mlir::ArrayRef<mlir::Type> params, bool isVarArg);
static mlir::ParseResult parseFuncType(mlir::AsmParser &p,
mlir::Type &optionalReturnTypes,
llvm::SmallVector<mlir::Type> &params,
bool &isVarArg);

static void printFuncType(mlir::AsmPrinter &p, mlir::Type optionalReturnTypes,
mlir::ArrayRef<mlir::Type> params, bool isVarArg);
static mlir::ParseResult parsePointerAddrSpace(mlir::AsmParser &p,
mlir::Attribute &addrSpaceAttr);
static void printPointerAddrSpace(mlir::AsmPrinter &p,
Expand Down Expand Up @@ -913,9 +915,38 @@ FuncType FuncType::clone(TypeRange inputs, TypeRange results) const {
return get(llvm::to_vector(inputs), results[0], isVarArg());
}

mlir::ParseResult parseFuncTypeArgs(mlir::AsmParser &p,
llvm::SmallVector<mlir::Type> &params,
bool &isVarArg) {
// A special parser is needed for function returning void to handle the missing
// type.
static mlir::ParseResult parseFuncTypeReturn(mlir::AsmParser &p,
mlir::Type &optionalReturnType) {
if (succeeded(p.parseOptionalLParen())) {
// If we have already a '(', the function has no return type
optionalReturnType = {};
return mlir::success();
}
mlir::Type type;
if (p.parseType(type))
return mlir::failure();
if (isa<cir::VoidType>(type))
// An explicit !cir.void means also no return type.
optionalReturnType = {};
else
// Otherwise use the actual type.
optionalReturnType = type;
return p.parseLParen();
}

// A special pretty-printer for function returning or not a result.
static void printFuncTypeReturn(mlir::AsmPrinter &p,
mlir::Type optionalReturnType) {
if (optionalReturnType)
p << optionalReturnType << ' ';
p << '(';
}

static mlir::ParseResult
parseFuncTypeArgs(mlir::AsmParser &p, llvm::SmallVector<mlir::Type> &params,
bool &isVarArg) {
isVarArg = false;
// `(` `)`
if (succeeded(p.parseOptionalRParen()))
Expand Down Expand Up @@ -945,8 +976,9 @@ mlir::ParseResult parseFuncTypeArgs(mlir::AsmParser &p,
return p.parseRParen();
}

void printFuncTypeArgs(mlir::AsmPrinter &p, mlir::ArrayRef<mlir::Type> params,
bool isVarArg) {
static void printFuncTypeArgs(mlir::AsmPrinter &p,
mlir::ArrayRef<mlir::Type> params,
bool isVarArg) {
llvm::interleaveComma(params, p,
[&p](mlir::Type type) { p.printType(type); });
if (isVarArg) {
Expand All @@ -957,11 +989,49 @@ void printFuncTypeArgs(mlir::AsmPrinter &p, mlir::ArrayRef<mlir::Type> params,
p << ')';
}

// Use a custom parser to handle the optional return and argument types without
// an optional anchor.
static mlir::ParseResult parseFuncType(mlir::AsmParser &p,
mlir::Type &optionalReturnTypes,
llvm::SmallVector<mlir::Type> &params,
bool &isVarArg) {
if (failed(parseFuncTypeReturn(p, optionalReturnTypes)))
return failure();
return parseFuncTypeArgs(p, params, isVarArg);
}

static void printFuncType(mlir::AsmPrinter &p, mlir::Type optionalReturnTypes,
mlir::ArrayRef<mlir::Type> params, bool isVarArg) {
printFuncTypeReturn(p, optionalReturnTypes);
printFuncTypeArgs(p, params, isVarArg);
}

// Return the actual return type or an explicit !cir.void if the function does
// not return anything
mlir::Type FuncType::getReturnType() const {
if (isVoid())
return cir::VoidType::get(getContext());
return static_cast<detail::FuncTypeStorage *>(getImpl())->optionalReturnType;
}

/// Returns the result type of the function as an ArrayRef, enabling better
/// integration with generic MLIR utilities.
llvm::ArrayRef<mlir::Type> FuncType::getReturnTypes() const {
return static_cast<detail::FuncTypeStorage *>(getImpl())->returnType;
if (isVoid())
return {};
return static_cast<detail::FuncTypeStorage *>(getImpl())->optionalReturnType;
}

bool FuncType::isVoid() const { return mlir::isa<VoidType>(getReturnType()); }
// Whether the function returns void
bool FuncType::isVoid() const {
auto rt =
static_cast<detail::FuncTypeStorage *>(getImpl())->optionalReturnType;
assert(!rt ||
!mlir::isa<cir::VoidType>(rt) &&
"The return type for a function returning void should be empty "
"instead of a real !cir.void");
return !rt;
}

//===----------------------------------------------------------------------===//
// MethodType Definitions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ FuncType LowerTypes::getFunctionType(const LowerFunctionInfo &FI) {
}
}

return FuncType::get(getMLIRContext(), ArgTypes, resultType, FI.isVariadic());
return FuncType::get(ArgTypes, resultType, FI.isVariadic());
}

/// Convert a CIR type to its ABI-specific default form.
Expand Down
6 changes: 3 additions & 3 deletions clang/test/CIR/CodeGen/fun-ptr.c
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,10 @@ int foo(Data* d) {
return f(d);
}

// CIR: cir.func private {{@.*test.*}}() -> !cir.ptr<!cir.func<!void ()>>
// CIR: cir.func private {{@.*test.*}}() -> !cir.ptr<!cir.func<()>>
// CIR: cir.func {{@.*bar.*}}()
// CIR: [[RET:%.*]] = cir.call {{@.*test.*}}() : () -> !cir.ptr<!cir.func<!void ()>>
// CIR: cir.call [[RET]]() : (!cir.ptr<!cir.func<!void ()>>) -> ()
// CIR: [[RET:%.*]] = cir.call {{@.*test.*}}() : () -> !cir.ptr<!cir.func<()>>
// CIR: cir.call [[RET]]() : (!cir.ptr<!cir.func<()>>) -> ()
// CIR: cir.return

// LLVM: declare ptr {{@.*test.*}}()
Expand Down
2 changes: 1 addition & 1 deletion clang/test/CIR/CodeGen/gnu-extension.c
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@ void bar(void) {
}

//CHECK: cir.func @bar()
//CHECK: {{.*}} = cir.get_global @bar : !cir.ptr<!cir.func<!void ()>>
//CHECK: {{.*}} = cir.get_global @bar : !cir.ptr<!cir.func<()>>
//CHECK: cir.return
8 changes: 4 additions & 4 deletions clang/test/CIR/CodeGen/member-init-struct.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ C a, b(x), c(0, 2);
// CHECK: %[[VAL_8:.*]] = cir.get_member %[[VAL_2]][2] {name = "d"} : !cir.ptr<!ty_C> -> !cir.ptr<!cir.array<!s32i x 10>>
// CHECK: %[[VAL_9:.*]] = cir.const {{.*}} : !cir.array<!s32i x 10>
// CHECK: cir.store %[[VAL_9]], %[[VAL_8]] : !cir.array<!s32i x 10>, !cir.ptr<!cir.array<!s32i x 10>>
// CHECK: %[[VAL_10:.*]] = cir.get_member %[[VAL_2]][4] {name = "e"} : !cir.ptr<!ty_C> -> !cir.ptr<!cir.method<!cir.func<!void ()> in !ty_C>>
// CHECK: %[[VAL_11:.*]] = cir.const #cir.method<null> : !cir.method<!cir.func<!void ()> in !ty_C>
// CHECK: cir.store %[[VAL_11]], %[[VAL_10]] : !cir.method<!cir.func<!void ()> in !ty_C>, !cir.ptr<!cir.method<!cir.func<!void ()> in !ty_C>>
// CHECK: cir.return
// CHECK: %[[VAL_10:.*]] = cir.get_member %[[VAL_2]][4] {name = "e"} : !cir.ptr<!ty_C> -> !cir.ptr<!cir.method<!cir.func<()> in !ty_C>>
// CHECK: %[[VAL_11:.*]] = cir.const #cir.method<null> : !cir.method<!cir.func<()> in !ty_C>
// CHECK: cir.store %[[VAL_11]], %[[VAL_10]] : !cir.method<!cir.func<()> in !ty_C>, !cir.ptr<!cir.method<!cir.func<()> in !ty_C>>
// CHECK: cir.return
4 changes: 2 additions & 2 deletions clang/test/CIR/CodeGen/multi-vtable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,9 @@ int main() {

// CIR: cir.func @main() -> !s32i extra(#fn_attr) {

// CIR: %{{[0-9]+}} = cir.vtable.address_point( %{{[0-9]+}} : !cir.ptr<!cir.ptr<!cir.func<!void (!cir.ptr<!ty_Mother>)>>>, vtable_index = 0, address_point_index = 0) : !cir.ptr<!cir.ptr<!cir.func<!void (!cir.ptr<!ty_Mother>)>>>
// CIR: %{{[0-9]+}} = cir.vtable.address_point( %{{[0-9]+}} : !cir.ptr<!cir.ptr<!cir.func<(!cir.ptr<!ty_Mother>)>>>, vtable_index = 0, address_point_index = 0) : !cir.ptr<!cir.ptr<!cir.func<(!cir.ptr<!ty_Mother>)>>>

// CIR: %{{[0-9]+}} = cir.vtable.address_point( %{{[0-9]+}} : !cir.ptr<!cir.ptr<!cir.func<!void (!cir.ptr<!ty_Child>)>>>, vtable_index = 0, address_point_index = 0) : !cir.ptr<!cir.ptr<!cir.func<!void (!cir.ptr<!ty_Child>)>>>
// CIR: %{{[0-9]+}} = cir.vtable.address_point( %{{[0-9]+}} : !cir.ptr<!cir.ptr<!cir.func<(!cir.ptr<!ty_Child>)>>>, vtable_index = 0, address_point_index = 0) : !cir.ptr<!cir.ptr<!cir.func<(!cir.ptr<!ty_Child>)>>>

// CIR: }

Expand Down
14 changes: 7 additions & 7 deletions clang/test/CIR/CodeGen/no-proto-fun-ptr.c
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ void check_noproto_ptr() {
}

// CHECK: cir.func no_proto @check_noproto_ptr()
// CHECK: [[ALLOC:%.*]] = cir.alloca !cir.ptr<!cir.func<!void ()>>, !cir.ptr<!cir.ptr<!cir.func<!void ()>>>, ["fun", init] {alignment = 8 : i64}
// CHECK: [[GGO:%.*]] = cir.get_global @empty : !cir.ptr<!cir.func<!void ()>>
// CHECK: cir.store [[GGO]], [[ALLOC]] : !cir.ptr<!cir.func<!void ()>>, !cir.ptr<!cir.ptr<!cir.func<!void ()>>>
// CHECK: [[ALLOC:%.*]] = cir.alloca !cir.ptr<!cir.func<()>>, !cir.ptr<!cir.ptr<!cir.func<()>>>, ["fun", init] {alignment = 8 : i64}
// CHECK: [[GGO:%.*]] = cir.get_global @empty : !cir.ptr<!cir.func<()>>
// CHECK: cir.store [[GGO]], [[ALLOC]] : !cir.ptr<!cir.func<()>>, !cir.ptr<!cir.ptr<!cir.func<()>>>
// CHECK: cir.return

void empty(void) {}
Expand All @@ -20,8 +20,8 @@ void buz() {
}

// CHECK: cir.func no_proto @buz()
// CHECK: [[FNPTR_ALLOC:%.*]] = cir.alloca !cir.ptr<!cir.func<!void (...)>>, !cir.ptr<!cir.ptr<!cir.func<!void (...)>>>, ["func"] {alignment = 8 : i64}
// CHECK: [[FNPTR:%.*]] = cir.load deref [[FNPTR_ALLOC]] : !cir.ptr<!cir.ptr<!cir.func<!void (...)>>>, !cir.ptr<!cir.func<!void (...)>>
// CHECK: [[CAST:%.*]] = cir.cast(bitcast, %1 : !cir.ptr<!cir.func<!void (...)>>), !cir.ptr<!cir.func<!void ()>>
// CHECK: cir.call [[CAST]]() : (!cir.ptr<!cir.func<!void ()>>) -> ()
// CHECK: [[FNPTR_ALLOC:%.*]] = cir.alloca !cir.ptr<!cir.func<(...)>>, !cir.ptr<!cir.ptr<!cir.func<(...)>>>, ["func"] {alignment = 8 : i64}
// CHECK: [[FNPTR:%.*]] = cir.load deref [[FNPTR_ALLOC]] : !cir.ptr<!cir.ptr<!cir.func<(...)>>>, !cir.ptr<!cir.func<(...)>>
// CHECK: [[CAST:%.*]] = cir.cast(bitcast, %1 : !cir.ptr<!cir.func<(...)>>), !cir.ptr<!cir.func<()>>
// CHECK: cir.call [[CAST]]() : (!cir.ptr<!cir.func<()>>) -> ()
// CHECK: cir.return
8 changes: 4 additions & 4 deletions clang/test/CIR/CodeGen/pointer-arith-ext.c
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ void *f4_1(void *a, int b) { return (a -= b); }

FP f5(FP a, int b) { return a + b; }
// CIR-LABEL: f5
// CIR: %[[PTR:.*]] = cir.load {{.*}} : !cir.ptr<!cir.ptr<!cir.func<!void ()>>>, !cir.ptr<!cir.func<!void ()>>
// CIR: %[[PTR:.*]] = cir.load {{.*}} : !cir.ptr<!cir.ptr<!cir.func<()>>>, !cir.ptr<!cir.func<()>>
// CIR: %[[STRIDE:.*]] = cir.load {{.*}} : !cir.ptr<!s32i>, !s32i
// CIR: cir.ptr_stride(%[[PTR]] : !cir.ptr<!cir.func<!void ()>>, %[[STRIDE]] : !s32i)
// CIR: cir.ptr_stride(%[[PTR]] : !cir.ptr<!cir.func<()>>, %[[STRIDE]] : !s32i)

// LLVM-LABEL: f5
// LLVM: %[[PTR:.*]] = load ptr, ptr {{.*}}, align 8
Expand All @@ -67,10 +67,10 @@ FP f6_1(int a, FP b) { return (a += b); }

FP f7(FP a, int b) { return a - b; }
// CIR-LABEL: f7
// CIR: %[[PTR:.*]] = cir.load {{.*}} : !cir.ptr<!cir.ptr<!cir.func<!void ()>>>, !cir.ptr<!cir.func<!void ()>>
// CIR: %[[PTR:.*]] = cir.load {{.*}} : !cir.ptr<!cir.ptr<!cir.func<()>>>, !cir.ptr<!cir.func<()>>
// CIR: %[[STRIDE:.*]] = cir.load {{.*}} : !cir.ptr<!s32i>, !s32i
// CIR: %[[SUB:.*]] = cir.unary(minus, %[[STRIDE]]) : !s32i, !s32i
// CIR: cir.ptr_stride(%[[PTR]] : !cir.ptr<!cir.func<!void ()>>, %[[SUB]] : !s32i)
// CIR: cir.ptr_stride(%[[PTR]] : !cir.ptr<!cir.func<()>>, %[[SUB]] : !s32i)

// LLVM-LABEL: f7
// LLVM: %[[PTR:.*]] = load ptr, ptr {{.*}}, align 8
Expand Down
Loading

0 comments on commit 54d48d8

Please sign in to comment.