From 327d627066e6452b081365b595657d17f2690a3b Mon Sep 17 00:00:00 2001 From: jeanPerier Date: Mon, 3 Feb 2025 11:27:14 +0100 Subject: [PATCH] [mlir] share argument attributes interface between calls and callables (#123176) This patch shares core interface methods dealing with argument and result attributes from CallableOpInterface with the CallOpInterface and makes them mandatory to gives more consistent guarantees about concrete operations using these interfaces. This allows adding argument attributes on call like operations, which is sometimes required to get proper ABI, like with llvm.call (and llvm.invoke). The patch adds optional `arg_attrs` and `res_attrs` attributes to operations using these interfaces that did not have that already. They can then re-use the common "rich function signature" printing/parsing helpers if they want (for the LLVM dialect, this is done in the next patch). Part of RFC: https://discourse.llvm.org/t/mlir-rfc-adding-argument-and-result-attributes-to-llvm-call/84107 --- .../flang/Optimizer/Dialect/CUF/CUFOps.td | 4 +- .../include/flang/Optimizer/Dialect/FIROps.td | 4 + flang/lib/Lower/ConvertCall.cpp | 12 +- flang/lib/Optimizer/CodeGen/TargetRewrite.cpp | 2 + .../Optimizer/Transforms/AbstractResult.cpp | 3 + .../Transforms/PolymorphicOpConversion.cpp | 5 +- mlir/docs/Interfaces.md | 7 +- mlir/examples/toy/Ch4/include/toy/Ops.td | 7 +- mlir/examples/toy/Ch5/include/toy/Ops.td | 7 +- mlir/examples/toy/Ch6/include/toy/Ops.td | 7 +- mlir/examples/toy/Ch7/include/toy/Ops.td | 10 +- mlir/examples/toy/Ch7/mlir/Dialect.cpp | 6 +- mlir/examples/toy/Ch7/mlir/MLIRGen.cpp | 3 +- .../include/mlir/Dialect/Async/IR/AsyncOps.td | 8 +- mlir/include/mlir/Dialect/EmitC/IR/EmitC.td | 8 +- mlir/include/mlir/Dialect/Func/IR/FuncOps.td | 31 +++- mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td | 6 +- .../Dialect/SPIRV/IR/SPIRVControlFlowOps.td | 13 +- .../mlir/Dialect/Transform/IR/TransformOps.td | 4 +- mlir/include/mlir/Interfaces/CallInterfaces.h | 61 +++++++ .../include/mlir/Interfaces/CallInterfaces.td | 87 ++++----- .../mlir/Interfaces/FunctionImplementation.h | 34 ++-- mlir/lib/Dialect/Async/IR/Async.cpp | 2 +- mlir/lib/Dialect/EmitC/IR/EmitC.cpp | 2 +- mlir/lib/Dialect/Func/IR/FuncOps.cpp | 2 +- mlir/lib/Dialect/GPU/IR/GPUDialect.cpp | 4 +- mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 21 ++- mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp | 4 +- mlir/lib/Dialect/Shape/IR/Shape.cpp | 2 +- mlir/lib/Interfaces/CallInterfaces.cpp | 169 ++++++++++++++++++ .../lib/Interfaces/FunctionImplementation.cpp | 154 +--------------- mlir/test/lib/Dialect/Test/TestOps.td | 19 +- 32 files changed, 452 insertions(+), 256 deletions(-) diff --git a/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td b/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td index a270e69b39410..c1021da0cfb21 100644 --- a/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td +++ b/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td @@ -207,7 +207,9 @@ def cuf_KernelLaunchOp : cuf_Op<"kernel_launch", [CallOpInterface, I32:$block_z, Optional:$bytes, Optional:$stream, - Variadic:$args + Variadic:$args, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs ); let assemblyFormat = [{ diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td index 5f0f0b48e892b..8dbc9df9f553d 100644 --- a/flang/include/flang/Optimizer/Dialect/FIROps.td +++ b/flang/include/flang/Optimizer/Dialect/FIROps.td @@ -2432,6 +2432,8 @@ def fir_CallOp : fir_Op<"call", let arguments = (ins OptionalAttr:$callee, Variadic:$args, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs, OptionalAttr:$procedure_attrs, DefaultValuedAttr:$fastmath @@ -2518,6 +2520,8 @@ def fir_DispatchOp : fir_Op<"dispatch", []> { fir_ClassType:$object, Variadic:$args, OptionalAttr:$pass_arg_pos, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs, OptionalAttr:$procedure_attrs ); diff --git a/flang/lib/Lower/ConvertCall.cpp b/flang/lib/Lower/ConvertCall.cpp index 40cd106e63018..7ca2baf0193ce 100644 --- a/flang/lib/Lower/ConvertCall.cpp +++ b/flang/lib/Lower/ConvertCall.cpp @@ -594,7 +594,8 @@ Fortran::lower::genCallOpAndResult( builder.create( loc, funcType.getResults(), funcSymbolAttr, grid_x, grid_y, grid_z, - block_x, block_y, block_z, bytes, stream, operands); + block_x, block_y, block_z, bytes, stream, operands, + /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr); callNumResults = 0; } else if (caller.requireDispatchCall()) { // Procedure call requiring a dynamic dispatch. Call is created with @@ -621,7 +622,8 @@ Fortran::lower::genCallOpAndResult( dispatch = builder.create( loc, funcType.getResults(), builder.getStringAttr(procName), caller.getInputs()[*passArg], operands, - builder.getI32IntegerAttr(*passArg), procAttrs); + builder.getI32IntegerAttr(*passArg), /*arg_attrs=*/nullptr, + /*res_attrs=*/nullptr, procAttrs); } else { // NOPASS const Fortran::evaluate::Component *component = @@ -636,7 +638,8 @@ Fortran::lower::genCallOpAndResult( passObject = builder.create(loc, passObject); dispatch = builder.create( loc, funcType.getResults(), builder.getStringAttr(procName), - passObject, operands, nullptr, procAttrs); + passObject, operands, nullptr, /*arg_attrs=*/nullptr, + /*res_attrs=*/nullptr, procAttrs); } callNumResults = dispatch.getNumResults(); if (callNumResults != 0) @@ -644,7 +647,8 @@ Fortran::lower::genCallOpAndResult( } else { // Standard procedure call with fir.call. auto call = builder.create( - loc, funcType.getResults(), funcSymbolAttr, operands, procAttrs); + loc, funcType.getResults(), funcSymbolAttr, operands, + /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr, procAttrs); callNumResults = call.getNumResults(); if (callNumResults != 0) diff --git a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp index cc6c2b7df825a..c099a08ffd30a 100644 --- a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp +++ b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp @@ -518,6 +518,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase { newOpers.insert(newOpers.end(), trailingOpers.begin(), trailingOpers.end()); llvm::SmallVector newCallResults; + // TODO propagate/update call argument and result attributes. if constexpr (std::is_same_v, mlir::gpu::LaunchFuncOp>) { auto newCall = rewriter->create( loc, callOp.getKernel(), callOp.getGridSizeOperandValues(), @@ -557,6 +558,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase { loc, newResTys, rewriter->getStringAttr(callOp.getMethod()), callOp.getOperands()[0], newOpers, rewriter->getI32IntegerAttr(*callOp.getPassArgPos() + passArgShift), + /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr, callOp.getProcedureAttrsAttr()); if (wrap) newCallResults.push_back((*wrap)(dispatchOp.getOperation())); diff --git a/flang/lib/Optimizer/Transforms/AbstractResult.cpp b/flang/lib/Optimizer/Transforms/AbstractResult.cpp index b0327cc10e9de..f8badfa639f94 100644 --- a/flang/lib/Optimizer/Transforms/AbstractResult.cpp +++ b/flang/lib/Optimizer/Transforms/AbstractResult.cpp @@ -147,6 +147,7 @@ class CallConversion : public mlir::OpRewritePattern { newResultTypes.emplace_back(getVoidPtrType(result.getContext())); Op newOp; + // TODO: propagate argument and result attributes (need to be shifted). // fir::CallOp specific handling. if constexpr (std::is_same_v) { if (op.getCallee()) { @@ -189,9 +190,11 @@ class CallConversion : public mlir::OpRewritePattern { if (op.getPassArgPos()) passArgPos = rewriter.getI32IntegerAttr(*op.getPassArgPos() + passArgShift); + // TODO: propagate argument and result attributes (need to be shifted). newOp = rewriter.create( loc, newResultTypes, rewriter.getStringAttr(op.getMethod()), op.getOperands()[0], newOperands, passArgPos, + /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr, op.getProcedureAttrsAttr()); } diff --git a/flang/lib/Optimizer/Transforms/PolymorphicOpConversion.cpp b/flang/lib/Optimizer/Transforms/PolymorphicOpConversion.cpp index 070889a284f48..0c78a878cdc53 100644 --- a/flang/lib/Optimizer/Transforms/PolymorphicOpConversion.cpp +++ b/flang/lib/Optimizer/Transforms/PolymorphicOpConversion.cpp @@ -205,8 +205,9 @@ struct DispatchOpConv : public OpConversionPattern { // Make the call. llvm::SmallVector args{funcPtr}; args.append(dispatch.getArgs().begin(), dispatch.getArgs().end()); - rewriter.replaceOpWithNewOp(dispatch, resTypes, nullptr, args, - dispatch.getProcedureAttrsAttr()); + rewriter.replaceOpWithNewOp( + dispatch, resTypes, nullptr, args, dispatch.getArgAttrsAttr(), + dispatch.getResAttrsAttr(), dispatch.getProcedureAttrsAttr()); return mlir::success(); } diff --git a/mlir/docs/Interfaces.md b/mlir/docs/Interfaces.md index 51747db546bb7..b7e9e64d23d77 100644 --- a/mlir/docs/Interfaces.md +++ b/mlir/docs/Interfaces.md @@ -753,10 +753,15 @@ interface section goes as follows: - (`C++ class` -- `ODS class`(if applicable)) ##### CallInterfaces - * `CallOpInterface` - Used to represent operations like 'call' - `CallInterfaceCallable getCallableForCallee()` - `void setCalleeFromCallable(CallInterfaceCallable)` + - `ArrayAttr getArgAttrsAttr()` + - `ArrayAttr getResAttrsAttr()` + - `void setArgAttrsAttr(ArrayAttr)` + - `void setResAttrsAttr(ArrayAttr)` + - `Attribute removeArgAttrsAttr()` + - `Attribute removeResAttrsAttr()` * `CallableOpInterface` - Used to represent the target callee of call. - `Region * getCallableRegion()` - `ArrayRef getArgumentTypes()` diff --git a/mlir/examples/toy/Ch4/include/toy/Ops.td b/mlir/examples/toy/Ch4/include/toy/Ops.td index 075fd1a9cd473..4441e48ca53c0 100644 --- a/mlir/examples/toy/Ch4/include/toy/Ops.td +++ b/mlir/examples/toy/Ch4/include/toy/Ops.td @@ -215,7 +215,12 @@ def GenericCallOp : Toy_Op<"generic_call", // The generic call operation takes a symbol reference attribute as the // callee, and inputs for the call. - let arguments = (ins FlatSymbolRefAttr:$callee, Variadic:$inputs); + let arguments = (ins + FlatSymbolRefAttr:$callee, + Variadic:$inputs, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs + ); // The generic call operation returns a single value of TensorType. let results = (outs F64Tensor); diff --git a/mlir/examples/toy/Ch5/include/toy/Ops.td b/mlir/examples/toy/Ch5/include/toy/Ops.td index ec6762ff406e8..5b7c966de6f08 100644 --- a/mlir/examples/toy/Ch5/include/toy/Ops.td +++ b/mlir/examples/toy/Ch5/include/toy/Ops.td @@ -214,7 +214,12 @@ def GenericCallOp : Toy_Op<"generic_call", // The generic call operation takes a symbol reference attribute as the // callee, and inputs for the call. - let arguments = (ins FlatSymbolRefAttr:$callee, Variadic:$inputs); + let arguments = (ins + FlatSymbolRefAttr:$callee, + Variadic:$inputs, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs + ); // The generic call operation returns a single value of TensorType. let results = (outs F64Tensor); diff --git a/mlir/examples/toy/Ch6/include/toy/Ops.td b/mlir/examples/toy/Ch6/include/toy/Ops.td index a52bebc8b67b8..fdbc239a171df 100644 --- a/mlir/examples/toy/Ch6/include/toy/Ops.td +++ b/mlir/examples/toy/Ch6/include/toy/Ops.td @@ -214,7 +214,12 @@ def GenericCallOp : Toy_Op<"generic_call", // The generic call operation takes a symbol reference attribute as the // callee, and inputs for the call. - let arguments = (ins FlatSymbolRefAttr:$callee, Variadic:$inputs); + let arguments = (ins + FlatSymbolRefAttr:$callee, + Variadic:$inputs, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs + ); // The generic call operation returns a single value of TensorType. let results = (outs F64Tensor); diff --git a/mlir/examples/toy/Ch7/include/toy/Ops.td b/mlir/examples/toy/Ch7/include/toy/Ops.td index cfd6859eb27bf..71ab7b0aeebb9 100644 --- a/mlir/examples/toy/Ch7/include/toy/Ops.td +++ b/mlir/examples/toy/Ch7/include/toy/Ops.td @@ -237,7 +237,12 @@ def GenericCallOp : Toy_Op<"generic_call", // The generic call operation takes a symbol reference attribute as the // callee, and inputs for the call. - let arguments = (ins FlatSymbolRefAttr:$callee, Variadic:$inputs); + let arguments = (ins + FlatSymbolRefAttr:$callee, + Variadic:$inputs, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs + ); // The generic call operation returns a single value of TensorType or // StructType. @@ -250,7 +255,8 @@ def GenericCallOp : Toy_Op<"generic_call", // Add custom build methods for the generic call operation. let builders = [ - OpBuilder<(ins "StringRef":$callee, "ArrayRef":$arguments)> + OpBuilder<(ins "Type":$result_type, "StringRef":$callee, + "ArrayRef":$arguments)> ]; } diff --git a/mlir/examples/toy/Ch7/mlir/Dialect.cpp b/mlir/examples/toy/Ch7/mlir/Dialect.cpp index 55c44c45e0f00..52881db87d86b 100644 --- a/mlir/examples/toy/Ch7/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch7/mlir/Dialect.cpp @@ -350,9 +350,9 @@ void FuncOp::print(mlir::OpAsmPrinter &p) { //===----------------------------------------------------------------------===// void GenericCallOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, - StringRef callee, ArrayRef arguments) { - // Generic call always returns an unranked Tensor initially. - state.addTypes(UnrankedTensorType::get(builder.getF64Type())); + mlir::Type resultType, StringRef callee, + ArrayRef arguments) { + state.addTypes(resultType); state.addOperands(arguments); state.addAttribute("callee", mlir::SymbolRefAttr::get(builder.getContext(), callee)); diff --git a/mlir/examples/toy/Ch7/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch7/mlir/MLIRGen.cpp index 090e5ff914604..e554e375209f1 100644 --- a/mlir/examples/toy/Ch7/mlir/MLIRGen.cpp +++ b/mlir/examples/toy/Ch7/mlir/MLIRGen.cpp @@ -535,8 +535,7 @@ class MLIRGenImpl { } mlir::toy::FuncOp calledFunc = calledFuncIt->second; return builder.create( - location, calledFunc.getFunctionType().getResult(0), - mlir::SymbolRefAttr::get(builder.getContext(), callee), operands); + location, calledFunc.getFunctionType().getResult(0), callee, operands); } /// Emit a print expression. It emits specific operations for two builtins: diff --git a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td index a08f5d6e714ef..3d29d5bc7dbb6 100644 --- a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td +++ b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td @@ -208,7 +208,13 @@ def Async_CallOp : Async_Op<"call", ``` }]; - let arguments = (ins FlatSymbolRefAttr:$callee, Variadic:$operands); + let arguments = (ins + FlatSymbolRefAttr:$callee, + Variadic:$operands, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs + ); + let results = (outs Variadic); let builders = [ diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td index abcc00feb5816..4fbce995ce5b8 100644 --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td @@ -551,7 +551,13 @@ def EmitC_CallOp : EmitC_Op<"call", %2 = emitc.call @my_add(%0, %1) : (f32, f32) -> f32 ``` }]; - let arguments = (ins FlatSymbolRefAttr:$callee, Variadic:$operands); + let arguments = (ins + FlatSymbolRefAttr:$callee, + Variadic:$operands, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs + ); + let results = (outs Variadic); let builders = [ diff --git a/mlir/include/mlir/Dialect/Func/IR/FuncOps.td b/mlir/include/mlir/Dialect/Func/IR/FuncOps.td index 4da0efcb13ddf..cdaeb6461afb4 100644 --- a/mlir/include/mlir/Dialect/Func/IR/FuncOps.td +++ b/mlir/include/mlir/Dialect/Func/IR/FuncOps.td @@ -49,8 +49,14 @@ def CallOp : Func_Op<"call", ``` }]; - let arguments = (ins FlatSymbolRefAttr:$callee, Variadic:$operands, - UnitAttr:$no_inline); + let arguments = (ins + FlatSymbolRefAttr:$callee, + Variadic:$operands, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs, + UnitAttr:$no_inline + ); + let results = (outs Variadic); let builders = [ @@ -73,6 +79,18 @@ def CallOp : Func_Op<"call", CArg<"ValueRange", "{}">:$operands), [{ build($_builder, $_state, StringAttr::get($_builder.getContext(), callee), results, operands); + }]>, + OpBuilder<(ins "TypeRange":$results, "FlatSymbolRefAttr":$callee, + CArg<"ValueRange", "{}">:$operands), [{ + build($_builder, $_state, callee, results, operands); + }]>, + OpBuilder<(ins "TypeRange":$results, "StringAttr":$callee, + CArg<"ValueRange", "{}">:$operands), [{ + build($_builder, $_state, callee, results, operands); + }]>, + OpBuilder<(ins "TypeRange":$results, "StringRef":$callee, + CArg<"ValueRange", "{}">:$operands), [{ + build($_builder, $_state, callee, results, operands); }]>]; let extraClassDeclaration = [{ @@ -136,8 +154,13 @@ def CallIndirectOp : Func_Op<"call_indirect", [ ``` }]; - let arguments = (ins FunctionType:$callee, - Variadic:$callee_operands); + let arguments = (ins + FunctionType:$callee, + Variadic:$callee_operands, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs + ); + let results = (outs Variadic:$results); let builders = [ diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index b2281536aa40b..ee6e10efed4f1 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -633,6 +633,8 @@ def LLVM_InvokeOp : LLVM_Op<"invoke", [ OptionalAttr>:$var_callee_type, OptionalAttr:$callee, Variadic:$callee_operands, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs, Variadic:$normalDestOperands, Variadic:$unwindDestOperands, OptionalAttr:$branch_weights, @@ -755,7 +757,9 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call", VariadicOfVariadic:$op_bundle_operands, DenseI32ArrayAttr:$op_bundle_sizes, - OptionalAttr:$op_bundle_tags); + OptionalAttr:$op_bundle_tags, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs); // Append the aliasing related attributes defined in LLVM_MemAccessOpBase. let arguments = !con(args, aliasAttrs); let results = (outs Optional:$result); diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td index 991e753d1b359..cc2f0e4962d8a 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td @@ -214,13 +214,24 @@ def SPIRV_FunctionCallOp : SPIRV_Op<"FunctionCall", [ let arguments = (ins FlatSymbolRefAttr:$callee, - Variadic:$arguments + Variadic:$arguments, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs ); let results = (outs Optional:$return_value ); + let builders = [ + OpBuilder<(ins "Type":$returnType, "FlatSymbolRefAttr":$callee, + "ValueRange":$arguments), + [{ + build($_builder, $_state, returnType, callee, arguments, + /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr); + }]> + ]; + let autogenSerialization = 0; let assemblyFormat = [{ diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td index 77ed6b322451e..e4eb67c8e14ce 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -886,7 +886,9 @@ def IncludeOp : TransformDialectOp<"include", let arguments = (ins SymbolRefAttr:$target, FailurePropagationMode:$failure_propagation_mode, - Variadic:$operands); + Variadic:$operands, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs); let results = (outs Variadic:$results); let assemblyFormat = diff --git a/mlir/include/mlir/Interfaces/CallInterfaces.h b/mlir/include/mlir/Interfaces/CallInterfaces.h index 0020c19333d10..2bf3a3ca5f8a8 100644 --- a/mlir/include/mlir/Interfaces/CallInterfaces.h +++ b/mlir/include/mlir/Interfaces/CallInterfaces.h @@ -14,6 +14,7 @@ #ifndef MLIR_INTERFACES_CALLINTERFACES_H #define MLIR_INTERFACES_CALLINTERFACES_H +#include "mlir/IR/OpImplementation.h" #include "mlir/IR/SymbolTable.h" #include "llvm/ADT/PointerUnion.h" @@ -35,6 +36,66 @@ namespace call_interface_impl { Operation *resolveCallable(CallOpInterface call, SymbolTableCollection *symbolTable = nullptr); +/// Parse a function or call result list. +/// +/// function-result-list ::= function-result-list-parens +/// | non-function-type +/// function-result-list-parens ::= `(` `)` +/// | `(` function-result-list-no-parens `)` +/// function-result-list-no-parens ::= function-result (`,` function-result)* +/// function-result ::= type attribute-dict? +/// +ParseResult +parseFunctionResultList(OpAsmParser &parser, SmallVectorImpl &resultTypes, + SmallVectorImpl &resultAttrs); + +/// Parses a function signature using `parser`. This does not deal with function +/// signatures containing SSA region arguments (to parse these signatures, use +/// function_interface_impl::parseFunctionSignature). When +/// `mustParseEmptyResult`, `-> ()` is expected when there is no result type. +/// +/// no-ssa-function-signature ::= `(` no-ssa-function-arg-list `)` +/// -> function-result-list +/// no-ssa-function-arg-list ::= no-ssa-function-arg +/// (`,` no-ssa-function-arg)* +/// no-ssa-function-arg ::= type attribute-dict? +ParseResult parseFunctionSignature(OpAsmParser &parser, + SmallVectorImpl &argTypes, + SmallVectorImpl &argAttrs, + SmallVectorImpl &resultTypes, + SmallVectorImpl &resultAttrs, + bool mustParseEmptyResult = true); + +/// Print a function signature for a call or callable operation. If a body +/// region is provided, the SSA arguments are printed in the signature. When +/// `printEmptyResult` is false, `-> function-result-list` is omitted when +/// `resultTypes` is empty. +/// +/// function-signature ::= ssa-function-signature +/// | no-ssa-function-signature +/// ssa-function-signature ::= `(` ssa-function-arg-list `)` +/// -> function-result-list +/// ssa-function-arg-list ::= ssa-function-arg (`,` ssa-function-arg)* +/// ssa-function-arg ::= `%`name `:` type attribute-dict? +void printFunctionSignature(OpAsmPrinter &p, TypeRange argTypes, + ArrayAttr argAttrs, bool isVariadic, + TypeRange resultTypes, ArrayAttr resultAttrs, + Region *body = nullptr, + bool printEmptyResult = true); + +/// Adds argument and result attributes, provided as `argAttrs` and +/// `resultAttrs` arguments, to the list of operation attributes in `result`. +/// Internally, argument and result attributes are stored as dict attributes +/// with special names given by getResultAttrName, getArgumentAttrName. +void addArgAndResultAttrs(Builder &builder, OperationState &result, + ArrayRef argAttrs, + ArrayRef resultAttrs, + StringAttr argAttrsName, StringAttr resAttrsName); +void addArgAndResultAttrs(Builder &builder, OperationState &result, + ArrayRef args, + ArrayRef resultAttrs, + StringAttr argAttrsName, StringAttr resAttrsName); + } // namespace call_interface_impl } // namespace mlir diff --git a/mlir/include/mlir/Interfaces/CallInterfaces.td b/mlir/include/mlir/Interfaces/CallInterfaces.td index c6002da0d491c..e3c2aec401741 100644 --- a/mlir/include/mlir/Interfaces/CallInterfaces.td +++ b/mlir/include/mlir/Interfaces/CallInterfaces.td @@ -17,6 +17,48 @@ include "mlir/IR/OpBase.td" + +/// Interface for operations with arguments attributes (both call-like +/// and callable operations). +def ArgumentAttributesMethods { + list methods = [ + InterfaceMethod<[{ + Get the array of argument attribute dictionaries. The method should + return an array attribute containing only dictionary attributes equal in + number to the number of arguments. Alternatively, the method can + return null to indicate that there are no argument attributes. + }], + "::mlir::ArrayAttr", "getArgAttrsAttr">, + InterfaceMethod<[{ + Get the array of result attribute dictionaries. The method should return + an array attribute containing only dictionary attributes equal in number + to the number of results. Alternatively, the method can return + null to indicate that there are no result attributes. + }], + "::mlir::ArrayAttr", "getResAttrsAttr">, + InterfaceMethod<[{ + Set the array of argument attribute dictionaries. + }], + "void", "setArgAttrsAttr", (ins "::mlir::ArrayAttr":$attrs)>, + InterfaceMethod<[{ + Set the array of result attribute dictionaries. + }], + "void", "setResAttrsAttr", (ins "::mlir::ArrayAttr":$attrs)>, + InterfaceMethod<[{ + Remove the array of argument attribute dictionaries. This is the same as + setting all argument attributes to an empty dictionary. The method should + return the removed attribute. + }], + "::mlir::Attribute", "removeArgAttrsAttr">, + InterfaceMethod<[{ + Remove the array of result attribute dictionaries. This is the same as + setting all result attributes to an empty dictionary. The method should + return the removed attribute. + }], + "::mlir::Attribute", "removeResAttrsAttr"> + ]; +} + // `CallInterfaceCallable`: This is a type used to represent a single callable // region. A callable is either a symbol, or an SSA value, that is referenced by // a call-like operation. This represents the destination of the call. @@ -81,7 +123,7 @@ def CallOpInterface : OpInterface<"CallOpInterface"> { return ::mlir::call_interface_impl::resolveCallable($_op); }] > - ]; + ] # ArgumentAttributesMethods.methods; } /// Interface for callable operations. @@ -113,48 +155,7 @@ def CallableOpInterface : OpInterface<"CallableOpInterface"> { allow for this method may be called on function declarations). }], "::llvm::ArrayRef<::mlir::Type>", "getResultTypes">, - - InterfaceMethod<[{ - Get the array of argument attribute dictionaries. The method should - return an array attribute containing only dictionary attributes equal in - number to the number of region arguments. Alternatively, the method can - return null to indicate that the region has no argument attributes. - }], - "::mlir::ArrayAttr", "getArgAttrsAttr", (ins), - /*methodBody=*/[{}], /*defaultImplementation=*/[{ return nullptr; }]>, - InterfaceMethod<[{ - Get the array of result attribute dictionaries. The method should return - an array attribute containing only dictionary attributes equal in number - to the number of region results. Alternatively, the method can return - null to indicate that the region has no result attributes. - }], - "::mlir::ArrayAttr", "getResAttrsAttr", (ins), - /*methodBody=*/[{}], /*defaultImplementation=*/[{ return nullptr; }]>, - InterfaceMethod<[{ - Set the array of argument attribute dictionaries. - }], - "void", "setArgAttrsAttr", (ins "::mlir::ArrayAttr":$attrs), - /*methodBody=*/[{}], /*defaultImplementation=*/[{ return; }]>, - InterfaceMethod<[{ - Set the array of result attribute dictionaries. - }], - "void", "setResAttrsAttr", (ins "::mlir::ArrayAttr":$attrs), - /*methodBody=*/[{}], /*defaultImplementation=*/[{ return; }]>, - InterfaceMethod<[{ - Remove the array of argument attribute dictionaries. This is the same as - setting all argument attributes to an empty dictionary. The method should - return the removed attribute. - }], - "::mlir::Attribute", "removeArgAttrsAttr", (ins), - /*methodBody=*/[{}], /*defaultImplementation=*/[{ return nullptr; }]>, - InterfaceMethod<[{ - Remove the array of result attribute dictionaries. This is the same as - setting all result attributes to an empty dictionary. The method should - return the removed attribute. - }], - "::mlir::Attribute", "removeResAttrsAttr", (ins), - /*methodBody=*/[{}], /*defaultImplementation=*/[{ return nullptr; }]>, - ]; + ] # ArgumentAttributesMethods.methods; } #endif // MLIR_INTERFACES_CALLINTERFACES diff --git a/mlir/include/mlir/Interfaces/FunctionImplementation.h b/mlir/include/mlir/Interfaces/FunctionImplementation.h index a5e6963e4e666..374c2c534f87d 100644 --- a/mlir/include/mlir/Interfaces/FunctionImplementation.h +++ b/mlir/include/mlir/Interfaces/FunctionImplementation.h @@ -33,19 +33,6 @@ class VariadicFlag { bool variadic; }; -/// Adds argument and result attributes, provided as `argAttrs` and -/// `resultAttrs` arguments, to the list of operation attributes in `result`. -/// Internally, argument and result attributes are stored as dict attributes -/// with special names given by getResultAttrName, getArgumentAttrName. -void addArgAndResultAttrs(Builder &builder, OperationState &result, - ArrayRef argAttrs, - ArrayRef resultAttrs, - StringAttr argAttrsName, StringAttr resAttrsName); -void addArgAndResultAttrs(Builder &builder, OperationState &result, - ArrayRef args, - ArrayRef resultAttrs, - StringAttr argAttrsName, StringAttr resAttrsName); - /// Callback type for `parseFunctionOp`, the callback should produce the /// type that will be associated with a function-like operation from lists of /// function arguments and results, VariadicFlag indicates whether the function @@ -58,11 +45,11 @@ using FuncTypeBuilder = function_ref &arguments, - bool &isVariadic, SmallVectorImpl &resultTypes, - SmallVectorImpl &resultAttrs); +ParseResult parseFunctionSignatureWithArguments( + OpAsmParser &parser, bool allowVariadic, + SmallVectorImpl &arguments, bool &isVariadic, + SmallVectorImpl &resultTypes, + SmallVectorImpl &resultAttrs); /// Parser implementation for function-like operations. Uses /// `funcTypeBuilder` to construct the custom function type given lists of @@ -84,9 +71,14 @@ void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, /// Prints the signature of the function-like operation `op`. Assumes `op` has /// is a FunctionOpInterface and has passed verification. -void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op, - ArrayRef argTypes, bool isVariadic, - ArrayRef resultTypes); +inline void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op, + ArrayRef argTypes, bool isVariadic, + ArrayRef resultTypes) { + call_interface_impl::printFunctionSignature( + p, argTypes, op.getArgAttrsAttr(), isVariadic, resultTypes, + op.getResAttrsAttr(), &op->getRegion(0), + /*printEmptyResult=*/false); +} /// Prints the list of function prefixed with the "attributes" keyword. The /// attributes with names listed in "elided" as well as those used by the diff --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp index a3e3f80954efc..d3bb250bb8ab9 100644 --- a/mlir/lib/Dialect/Async/IR/Async.cpp +++ b/mlir/lib/Dialect/Async/IR/Async.cpp @@ -308,7 +308,7 @@ void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name, if (argAttrs.empty()) return; assert(type.getNumInputs() == argAttrs.size()); - function_interface_impl::addArgAndResultAttrs( + call_interface_impl::addArgAndResultAttrs( builder, state, argAttrs, /*resultAttrs=*/std::nullopt, getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name)); } diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp index c818dd18a3d24..728a2d33f46e7 100644 --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -529,7 +529,7 @@ void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name, if (argAttrs.empty()) return; assert(type.getNumInputs() == argAttrs.size()); - function_interface_impl::addArgAndResultAttrs( + call_interface_impl::addArgAndResultAttrs( builder, state, argAttrs, /*resultAttrs=*/std::nullopt, getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name)); } diff --git a/mlir/lib/Dialect/Func/IR/FuncOps.cpp b/mlir/lib/Dialect/Func/IR/FuncOps.cpp index a490b4c3c4ab4..ba7b84f27d6a8 100644 --- a/mlir/lib/Dialect/Func/IR/FuncOps.cpp +++ b/mlir/lib/Dialect/Func/IR/FuncOps.cpp @@ -190,7 +190,7 @@ void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name, if (argAttrs.empty()) return; assert(type.getNumInputs() == argAttrs.size()); - function_interface_impl::addArgAndResultAttrs( + call_interface_impl::addArgAndResultAttrs( builder, state, argAttrs, /*resultAttrs=*/std::nullopt, getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name)); } diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp index 301066e7d3e1f..d06f10d3137a1 100644 --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -1467,7 +1467,7 @@ ParseResult GPUFuncOp::parse(OpAsmParser &parser, OperationState &result) { return failure(); auto signatureLocation = parser.getCurrentLocation(); - if (failed(function_interface_impl::parseFunctionSignature( + if (failed(function_interface_impl::parseFunctionSignatureWithArguments( parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes, resultAttrs))) return failure(); @@ -1487,7 +1487,7 @@ ParseResult GPUFuncOp::parse(OpAsmParser &parser, OperationState &result) { result.addAttribute(getFunctionTypeAttrName(result.name), TypeAttr::get(type)); - function_interface_impl::addArgAndResultAttrs( + call_interface_impl::addArgAndResultAttrs( builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index ef5f1b069b40a..a6e996f3fb810 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -1033,6 +1033,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results, /*memory_effects=*/nullptr, /*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr, /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{}, + /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr, /*access_groups=*/nullptr, /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr); } @@ -1060,6 +1061,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state, /*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr, /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{}, + /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr, /*access_groups=*/nullptr, /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr); } @@ -1073,6 +1075,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state, /*CConv=*/nullptr, /*TailCallKind=*/nullptr, /*memory_effects=*/nullptr, /*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr, /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{}, + /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr, /*access_groups=*/nullptr, /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr); } @@ -1087,6 +1090,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func, /*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr, /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{}, /*access_groups=*/nullptr, /*alias_scopes=*/nullptr, + /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr); } @@ -1527,7 +1531,8 @@ void InvokeOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func, auto calleeType = func.getFunctionType(); build(builder, state, getCallOpResultTypes(calleeType), getCallOpVarCalleeType(calleeType), SymbolRefAttr::get(func), ops, - normalOps, unwindOps, nullptr, nullptr, {}, {}, normal, unwind); + /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr, normalOps, unwindOps, + nullptr, nullptr, {}, {}, normal, unwind); } void InvokeOp::build(OpBuilder &builder, OperationState &state, TypeRange tys, @@ -1535,8 +1540,9 @@ void InvokeOp::build(OpBuilder &builder, OperationState &state, TypeRange tys, ValueRange normalOps, Block *unwind, ValueRange unwindOps) { build(builder, state, tys, - /*var_callee_type=*/nullptr, callee, ops, normalOps, unwindOps, nullptr, - nullptr, {}, {}, normal, unwind); + /*var_callee_type=*/nullptr, callee, ops, /*arg_attrs=*/nullptr, + /*res_attrs=*/nullptr, normalOps, unwindOps, nullptr, nullptr, {}, {}, + normal, unwind); } void InvokeOp::build(OpBuilder &builder, OperationState &state, @@ -1544,7 +1550,8 @@ void InvokeOp::build(OpBuilder &builder, OperationState &state, ValueRange ops, Block *normal, ValueRange normalOps, Block *unwind, ValueRange unwindOps) { build(builder, state, getCallOpResultTypes(calleeType), - getCallOpVarCalleeType(calleeType), callee, ops, normalOps, unwindOps, + getCallOpVarCalleeType(calleeType), callee, ops, + /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr, normalOps, unwindOps, nullptr, nullptr, {}, {}, normal, unwind); } @@ -2510,7 +2517,7 @@ void LLVMFuncOp::build(OpBuilder &builder, OperationState &result, assert(llvm::cast(type).getNumParams() == argAttrs.size() && "expected as many argument attribute lists as arguments"); - function_interface_impl::addArgAndResultAttrs( + call_interface_impl::addArgAndResultAttrs( builder, result, argAttrs, /*resultAttrs=*/std::nullopt, getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); } @@ -2595,7 +2602,7 @@ ParseResult LLVMFuncOp::parse(OpAsmParser &parser, OperationState &result) { auto signatureLocation = parser.getCurrentLocation(); if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), result.attributes) || - function_interface_impl::parseFunctionSignature( + function_interface_impl::parseFunctionSignatureWithArguments( parser, /*allowVariadic=*/true, entryArgs, isVariadic, resultTypes, resultAttrs)) return failure(); @@ -2636,7 +2643,7 @@ ParseResult LLVMFuncOp::parse(OpAsmParser &parser, OperationState &result) { if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes))) return failure(); - function_interface_impl::addArgAndResultAttrs( + call_interface_impl::addArgAndResultAttrs( parser.getBuilder(), result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp index f0f03e989cb47..160c264fc32d9 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -917,7 +917,7 @@ ParseResult spirv::FuncOp::parse(OpAsmParser &parser, OperationState &result) { // Parse the function signature. bool isVariadic = false; - if (function_interface_impl::parseFunctionSignature( + if (function_interface_impl::parseFunctionSignatureWithArguments( parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes, resultAttrs)) return failure(); @@ -940,7 +940,7 @@ ParseResult spirv::FuncOp::parse(OpAsmParser &parser, OperationState &result) { // Add the attributes to the function arguments. assert(resultAttrs.size() == resultTypes.size()); - function_interface_impl::addArgAndResultAttrs( + call_interface_impl::addArgAndResultAttrs( builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp index 65efc88e9c403..2200af0f67a86 100644 --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -1297,7 +1297,7 @@ void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name, if (argAttrs.empty()) return; assert(type.getNumInputs() == argAttrs.size()); - function_interface_impl::addArgAndResultAttrs( + call_interface_impl::addArgAndResultAttrs( builder, state, argAttrs, /*resultAttrs=*/std::nullopt, getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name)); } diff --git a/mlir/lib/Interfaces/CallInterfaces.cpp b/mlir/lib/Interfaces/CallInterfaces.cpp index da0ca0e24630f..e8ed4b339a0cb 100644 --- a/mlir/lib/Interfaces/CallInterfaces.cpp +++ b/mlir/lib/Interfaces/CallInterfaces.cpp @@ -7,9 +7,178 @@ //===----------------------------------------------------------------------===// #include "mlir/Interfaces/CallInterfaces.h" +#include "mlir/IR/Builders.h" using namespace mlir; +//===----------------------------------------------------------------------===// +// Argument and result attributes utilities +//===----------------------------------------------------------------------===// + +static ParseResult +parseTypeAndAttrList(OpAsmParser &parser, SmallVectorImpl &types, + SmallVectorImpl &attrs) { + // Parse individual function results. + return parser.parseCommaSeparatedList([&]() -> ParseResult { + types.emplace_back(); + attrs.emplace_back(); + NamedAttrList attrList; + if (parser.parseType(types.back()) || + parser.parseOptionalAttrDict(attrList)) + return failure(); + attrs.back() = attrList.getDictionary(parser.getContext()); + return success(); + }); +} + +ParseResult call_interface_impl::parseFunctionResultList( + OpAsmParser &parser, SmallVectorImpl &resultTypes, + SmallVectorImpl &resultAttrs) { + if (failed(parser.parseOptionalLParen())) { + // We already know that there is no `(`, so parse a type. + // Because there is no `(`, it cannot be a function type. + Type ty; + if (parser.parseType(ty)) + return failure(); + resultTypes.push_back(ty); + resultAttrs.emplace_back(); + return success(); + } + + // Special case for an empty set of parens. + if (succeeded(parser.parseOptionalRParen())) + return success(); + if (parseTypeAndAttrList(parser, resultTypes, resultAttrs)) + return failure(); + return parser.parseRParen(); +} + +ParseResult call_interface_impl::parseFunctionSignature( + OpAsmParser &parser, SmallVectorImpl &argTypes, + SmallVectorImpl &argAttrs, + SmallVectorImpl &resultTypes, + SmallVectorImpl &resultAttrs, bool mustParseEmptyResult) { + // Parse arguments. + if (parser.parseLParen()) + return failure(); + if (failed(parser.parseOptionalRParen())) { + if (parseTypeAndAttrList(parser, argTypes, argAttrs)) + return failure(); + if (parser.parseRParen()) + return failure(); + } + // Parse results. + if (succeeded(parser.parseOptionalArrow())) + return call_interface_impl::parseFunctionResultList(parser, resultTypes, + resultAttrs); + if (mustParseEmptyResult) + return failure(); + return success(); +} + +/// Print a function result list. The provided `attrs` must either be null, or +/// contain a set of DictionaryAttrs of the same arity as `types`. +static void printFunctionResultList(OpAsmPrinter &p, TypeRange types, + ArrayAttr attrs) { + assert(!types.empty() && "Should not be called for empty result list."); + assert((!attrs || attrs.size() == types.size()) && + "Invalid number of attributes."); + + auto &os = p.getStream(); + bool needsParens = types.size() > 1 || llvm::isa(types[0]) || + (attrs && !llvm::cast(attrs[0]).empty()); + if (needsParens) + os << '('; + llvm::interleaveComma(llvm::seq(0, types.size()), os, [&](size_t i) { + p.printType(types[i]); + if (attrs) + p.printOptionalAttrDict(llvm::cast(attrs[i]).getValue()); + }); + if (needsParens) + os << ')'; +} + +void call_interface_impl::printFunctionSignature( + OpAsmPrinter &p, TypeRange argTypes, ArrayAttr argAttrs, bool isVariadic, + TypeRange resultTypes, ArrayAttr resultAttrs, Region *body, + bool printEmptyResult) { + bool isExternal = !body || body->empty(); + if (!isExternal && !isVariadic && !argAttrs && !resultAttrs && + printEmptyResult) { + p.printFunctionalType(argTypes, resultTypes); + return; + } + + p << '('; + for (unsigned i = 0, e = argTypes.size(); i < e; ++i) { + if (i > 0) + p << ", "; + + if (!isExternal) { + ArrayRef attrs; + if (argAttrs) + attrs = llvm::cast(argAttrs[i]).getValue(); + p.printRegionArgument(body->getArgument(i), attrs); + } else { + p.printType(argTypes[i]); + if (argAttrs) + p.printOptionalAttrDict( + llvm::cast(argAttrs[i]).getValue()); + } + } + + if (isVariadic) { + if (!argTypes.empty()) + p << ", "; + p << "..."; + } + + p << ')'; + + if (!resultTypes.empty()) { + p << " -> "; + printFunctionResultList(p, resultTypes, resultAttrs); + } else if (printEmptyResult) { + p << " -> ()"; + } +} + +void call_interface_impl::addArgAndResultAttrs( + Builder &builder, OperationState &result, ArrayRef argAttrs, + ArrayRef resultAttrs, StringAttr argAttrsName, + StringAttr resAttrsName) { + auto nonEmptyAttrsFn = [](DictionaryAttr attrs) { + return attrs && !attrs.empty(); + }; + // Convert the specified array of dictionary attrs (which may have null + // entries) to an ArrayAttr of dictionaries. + auto getArrayAttr = [&](ArrayRef dictAttrs) { + SmallVector attrs; + for (auto &dict : dictAttrs) + attrs.push_back(dict ? dict : builder.getDictionaryAttr({})); + return builder.getArrayAttr(attrs); + }; + + // Add the attributes to the operation arguments. + if (llvm::any_of(argAttrs, nonEmptyAttrsFn)) + result.addAttribute(argAttrsName, getArrayAttr(argAttrs)); + + // Add the attributes to the operation results. + if (llvm::any_of(resultAttrs, nonEmptyAttrsFn)) + result.addAttribute(resAttrsName, getArrayAttr(resultAttrs)); +} + +void call_interface_impl::addArgAndResultAttrs( + Builder &builder, OperationState &result, + ArrayRef args, ArrayRef resultAttrs, + StringAttr argAttrsName, StringAttr resAttrsName) { + SmallVector argAttrs; + for (const auto &arg : args) + argAttrs.push_back(arg.attrs); + addArgAndResultAttrs(builder, result, argAttrs, resultAttrs, argAttrsName, + resAttrsName); +} + //===----------------------------------------------------------------------===// // CallOpInterface //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Interfaces/FunctionImplementation.cpp b/mlir/lib/Interfaces/FunctionImplementation.cpp index 988feee665fea..90f32896e8181 100644 --- a/mlir/lib/Interfaces/FunctionImplementation.cpp +++ b/mlir/lib/Interfaces/FunctionImplementation.cpp @@ -70,50 +70,7 @@ parseFunctionArgumentList(OpAsmParser &parser, bool allowVariadic, }); } -/// Parse a function result list. -/// -/// function-result-list ::= function-result-list-parens -/// | non-function-type -/// function-result-list-parens ::= `(` `)` -/// | `(` function-result-list-no-parens `)` -/// function-result-list-no-parens ::= function-result (`,` function-result)* -/// function-result ::= type attribute-dict? -/// -static ParseResult -parseFunctionResultList(OpAsmParser &parser, SmallVectorImpl &resultTypes, - SmallVectorImpl &resultAttrs) { - if (failed(parser.parseOptionalLParen())) { - // We already know that there is no `(`, so parse a type. - // Because there is no `(`, it cannot be a function type. - Type ty; - if (parser.parseType(ty)) - return failure(); - resultTypes.push_back(ty); - resultAttrs.emplace_back(); - return success(); - } - - // Special case for an empty set of parens. - if (succeeded(parser.parseOptionalRParen())) - return success(); - - // Parse individual function results. - if (parser.parseCommaSeparatedList([&]() -> ParseResult { - resultTypes.emplace_back(); - resultAttrs.emplace_back(); - NamedAttrList attrs; - if (parser.parseType(resultTypes.back()) || - parser.parseOptionalAttrDict(attrs)) - return failure(); - resultAttrs.back() = attrs.getDictionary(parser.getContext()); - return success(); - })) - return failure(); - - return parser.parseRParen(); -} - -ParseResult function_interface_impl::parseFunctionSignature( +ParseResult function_interface_impl::parseFunctionSignatureWithArguments( OpAsmParser &parser, bool allowVariadic, SmallVectorImpl &arguments, bool &isVariadic, SmallVectorImpl &resultTypes, @@ -121,46 +78,11 @@ ParseResult function_interface_impl::parseFunctionSignature( if (parseFunctionArgumentList(parser, allowVariadic, arguments, isVariadic)) return failure(); if (succeeded(parser.parseOptionalArrow())) - return parseFunctionResultList(parser, resultTypes, resultAttrs); + return call_interface_impl::parseFunctionResultList(parser, resultTypes, + resultAttrs); return success(); } -void function_interface_impl::addArgAndResultAttrs( - Builder &builder, OperationState &result, ArrayRef argAttrs, - ArrayRef resultAttrs, StringAttr argAttrsName, - StringAttr resAttrsName) { - auto nonEmptyAttrsFn = [](DictionaryAttr attrs) { - return attrs && !attrs.empty(); - }; - // Convert the specified array of dictionary attrs (which may have null - // entries) to an ArrayAttr of dictionaries. - auto getArrayAttr = [&](ArrayRef dictAttrs) { - SmallVector attrs; - for (auto &dict : dictAttrs) - attrs.push_back(dict ? dict : builder.getDictionaryAttr({})); - return builder.getArrayAttr(attrs); - }; - - // Add the attributes to the function arguments. - if (llvm::any_of(argAttrs, nonEmptyAttrsFn)) - result.addAttribute(argAttrsName, getArrayAttr(argAttrs)); - - // Add the attributes to the function results. - if (llvm::any_of(resultAttrs, nonEmptyAttrsFn)) - result.addAttribute(resAttrsName, getArrayAttr(resultAttrs)); -} - -void function_interface_impl::addArgAndResultAttrs( - Builder &builder, OperationState &result, - ArrayRef args, ArrayRef resultAttrs, - StringAttr argAttrsName, StringAttr resAttrsName) { - SmallVector argAttrs; - for (const auto &arg : args) - argAttrs.push_back(arg.attrs); - addArgAndResultAttrs(builder, result, argAttrs, resultAttrs, argAttrsName, - resAttrsName); -} - ParseResult function_interface_impl::parseFunctionOp( OpAsmParser &parser, OperationState &result, bool allowVariadic, StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder, @@ -182,8 +104,8 @@ ParseResult function_interface_impl::parseFunctionOp( // Parse the function signature. SMLoc signatureLocation = parser.getCurrentLocation(); bool isVariadic = false; - if (parseFunctionSignature(parser, allowVariadic, entryArgs, isVariadic, - resultTypes, resultAttrs)) + if (parseFunctionSignatureWithArguments(parser, allowVariadic, entryArgs, + isVariadic, resultTypes, resultAttrs)) return failure(); std::string errorMessage; @@ -221,8 +143,8 @@ ParseResult function_interface_impl::parseFunctionOp( // Add the attributes to the function arguments. assert(resultAttrs.size() == resultTypes.size()); - addArgAndResultAttrs(builder, result, entryArgs, resultAttrs, argAttrsName, - resAttrsName); + call_interface_impl::addArgAndResultAttrs( + builder, result, entryArgs, resultAttrs, argAttrsName, resAttrsName); // Parse the optional function body. The printer will not print the body if // its empty, so disallow parsing of empty body in the parser. @@ -241,68 +163,6 @@ ParseResult function_interface_impl::parseFunctionOp( return success(); } -/// Print a function result list. The provided `attrs` must either be null, or -/// contain a set of DictionaryAttrs of the same arity as `types`. -static void printFunctionResultList(OpAsmPrinter &p, ArrayRef types, - ArrayAttr attrs) { - assert(!types.empty() && "Should not be called for empty result list."); - assert((!attrs || attrs.size() == types.size()) && - "Invalid number of attributes."); - - auto &os = p.getStream(); - bool needsParens = types.size() > 1 || llvm::isa(types[0]) || - (attrs && !llvm::cast(attrs[0]).empty()); - if (needsParens) - os << '('; - llvm::interleaveComma(llvm::seq(0, types.size()), os, [&](size_t i) { - p.printType(types[i]); - if (attrs) - p.printOptionalAttrDict(llvm::cast(attrs[i]).getValue()); - }); - if (needsParens) - os << ')'; -} - -void function_interface_impl::printFunctionSignature( - OpAsmPrinter &p, FunctionOpInterface op, ArrayRef argTypes, - bool isVariadic, ArrayRef resultTypes) { - Region &body = op->getRegion(0); - bool isExternal = body.empty(); - - p << '('; - ArrayAttr argAttrs = op.getArgAttrsAttr(); - for (unsigned i = 0, e = argTypes.size(); i < e; ++i) { - if (i > 0) - p << ", "; - - if (!isExternal) { - ArrayRef attrs; - if (argAttrs) - attrs = llvm::cast(argAttrs[i]).getValue(); - p.printRegionArgument(body.getArgument(i), attrs); - } else { - p.printType(argTypes[i]); - if (argAttrs) - p.printOptionalAttrDict( - llvm::cast(argAttrs[i]).getValue()); - } - } - - if (isVariadic) { - if (!argTypes.empty()) - p << ", "; - p << "..."; - } - - p << ')'; - - if (!resultTypes.empty()) { - p.getStream() << " -> "; - auto resultAttrs = op.getResAttrsAttr(); - printFunctionResultList(p, resultTypes, resultAttrs); - } -} - void function_interface_impl::printFunctionAttributes( OpAsmPrinter &p, Operation *op, ArrayRef elided) { // Print out function attributes, if present. diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index 79840094686e1..2aa0658ab0e5d 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -557,7 +557,12 @@ def TestCallOp : TEST_Op<"call", [DeclareOpInterfaceMethods { - let arguments = (ins Variadic:$arg_operands, SymbolRefAttr:$callee); + let arguments = (ins + Variadic:$arg_operands, + SymbolRefAttr:$callee, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs + ); let results = (outs Variadic); let extraClassDeclaration = [{ @@ -618,6 +623,10 @@ def ConversionFuncOp : TEST_Op<"conversion_func_op", [FunctionOpInterface]> { def FunctionalRegionOp : TEST_Op<"functional_region_op", [CallableOpInterface]> { let regions = (region AnyRegion:$body); + let arguments = (ins + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs + ); let results = (outs FunctionType); let extraClassDeclaration = [{ @@ -3309,7 +3318,9 @@ def TestCallAndStoreOp : TEST_Op<"call_and_store", SymbolRefAttr:$callee, Arg:$address, Variadic:$callee_operands, - BoolAttr:$store_before_call + BoolAttr:$store_before_call, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs ); let results = (outs Variadic:$results @@ -3324,7 +3335,9 @@ def TestCallOnDeviceOp : TEST_Op<"call_on_device", let arguments = (ins SymbolRefAttr:$callee, Variadic:$forwarded_operands, - AnyType:$non_forwarded_device_operand + AnyType:$non_forwarded_device_operand, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs ); let results = (outs Variadic:$results