diff --git a/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td b/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td index a270e69b39410..c1021da0cfb21 100644 --- a/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td +++ b/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td @@ -207,7 +207,9 @@ def cuf_KernelLaunchOp : cuf_Op<"kernel_launch", [CallOpInterface, I32:$block_z, Optional:$bytes, Optional:$stream, - Variadic:$args + Variadic:$args, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs ); let assemblyFormat = [{ diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td index 5f0f0b48e892b..8dbc9df9f553d 100644 --- a/flang/include/flang/Optimizer/Dialect/FIROps.td +++ b/flang/include/flang/Optimizer/Dialect/FIROps.td @@ -2432,6 +2432,8 @@ def fir_CallOp : fir_Op<"call", let arguments = (ins OptionalAttr:$callee, Variadic:$args, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs, OptionalAttr:$procedure_attrs, DefaultValuedAttr:$fastmath @@ -2518,6 +2520,8 @@ def fir_DispatchOp : fir_Op<"dispatch", []> { fir_ClassType:$object, Variadic:$args, OptionalAttr:$pass_arg_pos, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs, OptionalAttr:$procedure_attrs ); diff --git a/flang/lib/Lower/ConvertCall.cpp b/flang/lib/Lower/ConvertCall.cpp index 40cd106e63018..7ca2baf0193ce 100644 --- a/flang/lib/Lower/ConvertCall.cpp +++ b/flang/lib/Lower/ConvertCall.cpp @@ -594,7 +594,8 @@ Fortran::lower::genCallOpAndResult( builder.create( loc, funcType.getResults(), funcSymbolAttr, grid_x, grid_y, grid_z, - block_x, block_y, block_z, bytes, stream, operands); + block_x, block_y, block_z, bytes, stream, operands, + /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr); callNumResults = 0; } else if (caller.requireDispatchCall()) { // Procedure call requiring a dynamic dispatch. Call is created with @@ -621,7 +622,8 @@ Fortran::lower::genCallOpAndResult( dispatch = builder.create( loc, funcType.getResults(), builder.getStringAttr(procName), caller.getInputs()[*passArg], operands, - builder.getI32IntegerAttr(*passArg), procAttrs); + builder.getI32IntegerAttr(*passArg), /*arg_attrs=*/nullptr, + /*res_attrs=*/nullptr, procAttrs); } else { // NOPASS const Fortran::evaluate::Component *component = @@ -636,7 +638,8 @@ Fortran::lower::genCallOpAndResult( passObject = builder.create(loc, passObject); dispatch = builder.create( loc, funcType.getResults(), builder.getStringAttr(procName), - passObject, operands, nullptr, procAttrs); + passObject, operands, nullptr, /*arg_attrs=*/nullptr, + /*res_attrs=*/nullptr, procAttrs); } callNumResults = dispatch.getNumResults(); if (callNumResults != 0) @@ -644,7 +647,8 @@ Fortran::lower::genCallOpAndResult( } else { // Standard procedure call with fir.call. auto call = builder.create( - loc, funcType.getResults(), funcSymbolAttr, operands, procAttrs); + loc, funcType.getResults(), funcSymbolAttr, operands, + /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr, procAttrs); callNumResults = call.getNumResults(); if (callNumResults != 0) diff --git a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp index cc6c2b7df825a..c099a08ffd30a 100644 --- a/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp +++ b/flang/lib/Optimizer/CodeGen/TargetRewrite.cpp @@ -518,6 +518,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase { newOpers.insert(newOpers.end(), trailingOpers.begin(), trailingOpers.end()); llvm::SmallVector newCallResults; + // TODO propagate/update call argument and result attributes. if constexpr (std::is_same_v, mlir::gpu::LaunchFuncOp>) { auto newCall = rewriter->create( loc, callOp.getKernel(), callOp.getGridSizeOperandValues(), @@ -557,6 +558,7 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase { loc, newResTys, rewriter->getStringAttr(callOp.getMethod()), callOp.getOperands()[0], newOpers, rewriter->getI32IntegerAttr(*callOp.getPassArgPos() + passArgShift), + /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr, callOp.getProcedureAttrsAttr()); if (wrap) newCallResults.push_back((*wrap)(dispatchOp.getOperation())); diff --git a/flang/lib/Optimizer/Transforms/AbstractResult.cpp b/flang/lib/Optimizer/Transforms/AbstractResult.cpp index b0327cc10e9de..f8badfa639f94 100644 --- a/flang/lib/Optimizer/Transforms/AbstractResult.cpp +++ b/flang/lib/Optimizer/Transforms/AbstractResult.cpp @@ -147,6 +147,7 @@ class CallConversion : public mlir::OpRewritePattern { newResultTypes.emplace_back(getVoidPtrType(result.getContext())); Op newOp; + // TODO: propagate argument and result attributes (need to be shifted). // fir::CallOp specific handling. if constexpr (std::is_same_v) { if (op.getCallee()) { @@ -189,9 +190,11 @@ class CallConversion : public mlir::OpRewritePattern { if (op.getPassArgPos()) passArgPos = rewriter.getI32IntegerAttr(*op.getPassArgPos() + passArgShift); + // TODO: propagate argument and result attributes (need to be shifted). newOp = rewriter.create( loc, newResultTypes, rewriter.getStringAttr(op.getMethod()), op.getOperands()[0], newOperands, passArgPos, + /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr, op.getProcedureAttrsAttr()); } diff --git a/flang/lib/Optimizer/Transforms/PolymorphicOpConversion.cpp b/flang/lib/Optimizer/Transforms/PolymorphicOpConversion.cpp index 070889a284f48..0c78a878cdc53 100644 --- a/flang/lib/Optimizer/Transforms/PolymorphicOpConversion.cpp +++ b/flang/lib/Optimizer/Transforms/PolymorphicOpConversion.cpp @@ -205,8 +205,9 @@ struct DispatchOpConv : public OpConversionPattern { // Make the call. llvm::SmallVector args{funcPtr}; args.append(dispatch.getArgs().begin(), dispatch.getArgs().end()); - rewriter.replaceOpWithNewOp(dispatch, resTypes, nullptr, args, - dispatch.getProcedureAttrsAttr()); + rewriter.replaceOpWithNewOp( + dispatch, resTypes, nullptr, args, dispatch.getArgAttrsAttr(), + dispatch.getResAttrsAttr(), dispatch.getProcedureAttrsAttr()); return mlir::success(); } diff --git a/mlir/docs/Interfaces.md b/mlir/docs/Interfaces.md index 51747db546bb7..b7e9e64d23d77 100644 --- a/mlir/docs/Interfaces.md +++ b/mlir/docs/Interfaces.md @@ -753,10 +753,15 @@ interface section goes as follows: - (`C++ class` -- `ODS class`(if applicable)) ##### CallInterfaces - * `CallOpInterface` - Used to represent operations like 'call' - `CallInterfaceCallable getCallableForCallee()` - `void setCalleeFromCallable(CallInterfaceCallable)` + - `ArrayAttr getArgAttrsAttr()` + - `ArrayAttr getResAttrsAttr()` + - `void setArgAttrsAttr(ArrayAttr)` + - `void setResAttrsAttr(ArrayAttr)` + - `Attribute removeArgAttrsAttr()` + - `Attribute removeResAttrsAttr()` * `CallableOpInterface` - Used to represent the target callee of call. - `Region * getCallableRegion()` - `ArrayRef getArgumentTypes()` diff --git a/mlir/examples/toy/Ch4/include/toy/Ops.td b/mlir/examples/toy/Ch4/include/toy/Ops.td index 075fd1a9cd473..4441e48ca53c0 100644 --- a/mlir/examples/toy/Ch4/include/toy/Ops.td +++ b/mlir/examples/toy/Ch4/include/toy/Ops.td @@ -215,7 +215,12 @@ def GenericCallOp : Toy_Op<"generic_call", // The generic call operation takes a symbol reference attribute as the // callee, and inputs for the call. - let arguments = (ins FlatSymbolRefAttr:$callee, Variadic:$inputs); + let arguments = (ins + FlatSymbolRefAttr:$callee, + Variadic:$inputs, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs + ); // The generic call operation returns a single value of TensorType. let results = (outs F64Tensor); diff --git a/mlir/examples/toy/Ch5/include/toy/Ops.td b/mlir/examples/toy/Ch5/include/toy/Ops.td index ec6762ff406e8..5b7c966de6f08 100644 --- a/mlir/examples/toy/Ch5/include/toy/Ops.td +++ b/mlir/examples/toy/Ch5/include/toy/Ops.td @@ -214,7 +214,12 @@ def GenericCallOp : Toy_Op<"generic_call", // The generic call operation takes a symbol reference attribute as the // callee, and inputs for the call. - let arguments = (ins FlatSymbolRefAttr:$callee, Variadic:$inputs); + let arguments = (ins + FlatSymbolRefAttr:$callee, + Variadic:$inputs, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs + ); // The generic call operation returns a single value of TensorType. let results = (outs F64Tensor); diff --git a/mlir/examples/toy/Ch6/include/toy/Ops.td b/mlir/examples/toy/Ch6/include/toy/Ops.td index a52bebc8b67b8..fdbc239a171df 100644 --- a/mlir/examples/toy/Ch6/include/toy/Ops.td +++ b/mlir/examples/toy/Ch6/include/toy/Ops.td @@ -214,7 +214,12 @@ def GenericCallOp : Toy_Op<"generic_call", // The generic call operation takes a symbol reference attribute as the // callee, and inputs for the call. - let arguments = (ins FlatSymbolRefAttr:$callee, Variadic:$inputs); + let arguments = (ins + FlatSymbolRefAttr:$callee, + Variadic:$inputs, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs + ); // The generic call operation returns a single value of TensorType. let results = (outs F64Tensor); diff --git a/mlir/examples/toy/Ch7/include/toy/Ops.td b/mlir/examples/toy/Ch7/include/toy/Ops.td index cfd6859eb27bf..71ab7b0aeebb9 100644 --- a/mlir/examples/toy/Ch7/include/toy/Ops.td +++ b/mlir/examples/toy/Ch7/include/toy/Ops.td @@ -237,7 +237,12 @@ def GenericCallOp : Toy_Op<"generic_call", // The generic call operation takes a symbol reference attribute as the // callee, and inputs for the call. - let arguments = (ins FlatSymbolRefAttr:$callee, Variadic:$inputs); + let arguments = (ins + FlatSymbolRefAttr:$callee, + Variadic:$inputs, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs + ); // The generic call operation returns a single value of TensorType or // StructType. @@ -250,7 +255,8 @@ def GenericCallOp : Toy_Op<"generic_call", // Add custom build methods for the generic call operation. let builders = [ - OpBuilder<(ins "StringRef":$callee, "ArrayRef":$arguments)> + OpBuilder<(ins "Type":$result_type, "StringRef":$callee, + "ArrayRef":$arguments)> ]; } diff --git a/mlir/examples/toy/Ch7/mlir/Dialect.cpp b/mlir/examples/toy/Ch7/mlir/Dialect.cpp index 55c44c45e0f00..52881db87d86b 100644 --- a/mlir/examples/toy/Ch7/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch7/mlir/Dialect.cpp @@ -350,9 +350,9 @@ void FuncOp::print(mlir::OpAsmPrinter &p) { //===----------------------------------------------------------------------===// void GenericCallOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, - StringRef callee, ArrayRef arguments) { - // Generic call always returns an unranked Tensor initially. - state.addTypes(UnrankedTensorType::get(builder.getF64Type())); + mlir::Type resultType, StringRef callee, + ArrayRef arguments) { + state.addTypes(resultType); state.addOperands(arguments); state.addAttribute("callee", mlir::SymbolRefAttr::get(builder.getContext(), callee)); diff --git a/mlir/examples/toy/Ch7/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch7/mlir/MLIRGen.cpp index 090e5ff914604..e554e375209f1 100644 --- a/mlir/examples/toy/Ch7/mlir/MLIRGen.cpp +++ b/mlir/examples/toy/Ch7/mlir/MLIRGen.cpp @@ -535,8 +535,7 @@ class MLIRGenImpl { } mlir::toy::FuncOp calledFunc = calledFuncIt->second; return builder.create( - location, calledFunc.getFunctionType().getResult(0), - mlir::SymbolRefAttr::get(builder.getContext(), callee), operands); + location, calledFunc.getFunctionType().getResult(0), callee, operands); } /// Emit a print expression. It emits specific operations for two builtins: diff --git a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td index a08f5d6e714ef..3d29d5bc7dbb6 100644 --- a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td +++ b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td @@ -208,7 +208,13 @@ def Async_CallOp : Async_Op<"call", ``` }]; - let arguments = (ins FlatSymbolRefAttr:$callee, Variadic:$operands); + let arguments = (ins + FlatSymbolRefAttr:$callee, + Variadic:$operands, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs + ); + let results = (outs Variadic); let builders = [ diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td index abcc00feb5816..4fbce995ce5b8 100644 --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td @@ -551,7 +551,13 @@ def EmitC_CallOp : EmitC_Op<"call", %2 = emitc.call @my_add(%0, %1) : (f32, f32) -> f32 ``` }]; - let arguments = (ins FlatSymbolRefAttr:$callee, Variadic:$operands); + let arguments = (ins + FlatSymbolRefAttr:$callee, + Variadic:$operands, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs + ); + let results = (outs Variadic); let builders = [ diff --git a/mlir/include/mlir/Dialect/Func/IR/FuncOps.td b/mlir/include/mlir/Dialect/Func/IR/FuncOps.td index 4da0efcb13ddf..cdaeb6461afb4 100644 --- a/mlir/include/mlir/Dialect/Func/IR/FuncOps.td +++ b/mlir/include/mlir/Dialect/Func/IR/FuncOps.td @@ -49,8 +49,14 @@ def CallOp : Func_Op<"call", ``` }]; - let arguments = (ins FlatSymbolRefAttr:$callee, Variadic:$operands, - UnitAttr:$no_inline); + let arguments = (ins + FlatSymbolRefAttr:$callee, + Variadic:$operands, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs, + UnitAttr:$no_inline + ); + let results = (outs Variadic); let builders = [ @@ -73,6 +79,18 @@ def CallOp : Func_Op<"call", CArg<"ValueRange", "{}">:$operands), [{ build($_builder, $_state, StringAttr::get($_builder.getContext(), callee), results, operands); + }]>, + OpBuilder<(ins "TypeRange":$results, "FlatSymbolRefAttr":$callee, + CArg<"ValueRange", "{}">:$operands), [{ + build($_builder, $_state, callee, results, operands); + }]>, + OpBuilder<(ins "TypeRange":$results, "StringAttr":$callee, + CArg<"ValueRange", "{}">:$operands), [{ + build($_builder, $_state, callee, results, operands); + }]>, + OpBuilder<(ins "TypeRange":$results, "StringRef":$callee, + CArg<"ValueRange", "{}">:$operands), [{ + build($_builder, $_state, callee, results, operands); }]>]; let extraClassDeclaration = [{ @@ -136,8 +154,13 @@ def CallIndirectOp : Func_Op<"call_indirect", [ ``` }]; - let arguments = (ins FunctionType:$callee, - Variadic:$callee_operands); + let arguments = (ins + FunctionType:$callee, + Variadic:$callee_operands, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs + ); + let results = (outs Variadic:$results); let builders = [ diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index b2281536aa40b..ee6e10efed4f1 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -633,6 +633,8 @@ def LLVM_InvokeOp : LLVM_Op<"invoke", [ OptionalAttr>:$var_callee_type, OptionalAttr:$callee, Variadic:$callee_operands, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs, Variadic:$normalDestOperands, Variadic:$unwindDestOperands, OptionalAttr:$branch_weights, @@ -755,7 +757,9 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call", VariadicOfVariadic:$op_bundle_operands, DenseI32ArrayAttr:$op_bundle_sizes, - OptionalAttr:$op_bundle_tags); + OptionalAttr:$op_bundle_tags, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs); // Append the aliasing related attributes defined in LLVM_MemAccessOpBase. let arguments = !con(args, aliasAttrs); let results = (outs Optional:$result); diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td index 991e753d1b359..cc2f0e4962d8a 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td @@ -214,13 +214,24 @@ def SPIRV_FunctionCallOp : SPIRV_Op<"FunctionCall", [ let arguments = (ins FlatSymbolRefAttr:$callee, - Variadic:$arguments + Variadic:$arguments, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs ); let results = (outs Optional:$return_value ); + let builders = [ + OpBuilder<(ins "Type":$returnType, "FlatSymbolRefAttr":$callee, + "ValueRange":$arguments), + [{ + build($_builder, $_state, returnType, callee, arguments, + /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr); + }]> + ]; + let autogenSerialization = 0; let assemblyFormat = [{ diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td index 77ed6b322451e..e4eb67c8e14ce 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -886,7 +886,9 @@ def IncludeOp : TransformDialectOp<"include", let arguments = (ins SymbolRefAttr:$target, FailurePropagationMode:$failure_propagation_mode, - Variadic:$operands); + Variadic:$operands, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs); let results = (outs Variadic:$results); let assemblyFormat = diff --git a/mlir/include/mlir/Interfaces/CallInterfaces.h b/mlir/include/mlir/Interfaces/CallInterfaces.h index 0020c19333d10..2bf3a3ca5f8a8 100644 --- a/mlir/include/mlir/Interfaces/CallInterfaces.h +++ b/mlir/include/mlir/Interfaces/CallInterfaces.h @@ -14,6 +14,7 @@ #ifndef MLIR_INTERFACES_CALLINTERFACES_H #define MLIR_INTERFACES_CALLINTERFACES_H +#include "mlir/IR/OpImplementation.h" #include "mlir/IR/SymbolTable.h" #include "llvm/ADT/PointerUnion.h" @@ -35,6 +36,66 @@ namespace call_interface_impl { Operation *resolveCallable(CallOpInterface call, SymbolTableCollection *symbolTable = nullptr); +/// Parse a function or call result list. +/// +/// function-result-list ::= function-result-list-parens +/// | non-function-type +/// function-result-list-parens ::= `(` `)` +/// | `(` function-result-list-no-parens `)` +/// function-result-list-no-parens ::= function-result (`,` function-result)* +/// function-result ::= type attribute-dict? +/// +ParseResult +parseFunctionResultList(OpAsmParser &parser, SmallVectorImpl &resultTypes, + SmallVectorImpl &resultAttrs); + +/// Parses a function signature using `parser`. This does not deal with function +/// signatures containing SSA region arguments (to parse these signatures, use +/// function_interface_impl::parseFunctionSignature). When +/// `mustParseEmptyResult`, `-> ()` is expected when there is no result type. +/// +/// no-ssa-function-signature ::= `(` no-ssa-function-arg-list `)` +/// -> function-result-list +/// no-ssa-function-arg-list ::= no-ssa-function-arg +/// (`,` no-ssa-function-arg)* +/// no-ssa-function-arg ::= type attribute-dict? +ParseResult parseFunctionSignature(OpAsmParser &parser, + SmallVectorImpl &argTypes, + SmallVectorImpl &argAttrs, + SmallVectorImpl &resultTypes, + SmallVectorImpl &resultAttrs, + bool mustParseEmptyResult = true); + +/// Print a function signature for a call or callable operation. If a body +/// region is provided, the SSA arguments are printed in the signature. When +/// `printEmptyResult` is false, `-> function-result-list` is omitted when +/// `resultTypes` is empty. +/// +/// function-signature ::= ssa-function-signature +/// | no-ssa-function-signature +/// ssa-function-signature ::= `(` ssa-function-arg-list `)` +/// -> function-result-list +/// ssa-function-arg-list ::= ssa-function-arg (`,` ssa-function-arg)* +/// ssa-function-arg ::= `%`name `:` type attribute-dict? +void printFunctionSignature(OpAsmPrinter &p, TypeRange argTypes, + ArrayAttr argAttrs, bool isVariadic, + TypeRange resultTypes, ArrayAttr resultAttrs, + Region *body = nullptr, + bool printEmptyResult = true); + +/// Adds argument and result attributes, provided as `argAttrs` and +/// `resultAttrs` arguments, to the list of operation attributes in `result`. +/// Internally, argument and result attributes are stored as dict attributes +/// with special names given by getResultAttrName, getArgumentAttrName. +void addArgAndResultAttrs(Builder &builder, OperationState &result, + ArrayRef argAttrs, + ArrayRef resultAttrs, + StringAttr argAttrsName, StringAttr resAttrsName); +void addArgAndResultAttrs(Builder &builder, OperationState &result, + ArrayRef args, + ArrayRef resultAttrs, + StringAttr argAttrsName, StringAttr resAttrsName); + } // namespace call_interface_impl } // namespace mlir diff --git a/mlir/include/mlir/Interfaces/CallInterfaces.td b/mlir/include/mlir/Interfaces/CallInterfaces.td index c6002da0d491c..e3c2aec401741 100644 --- a/mlir/include/mlir/Interfaces/CallInterfaces.td +++ b/mlir/include/mlir/Interfaces/CallInterfaces.td @@ -17,6 +17,48 @@ include "mlir/IR/OpBase.td" + +/// Interface for operations with arguments attributes (both call-like +/// and callable operations). +def ArgumentAttributesMethods { + list methods = [ + InterfaceMethod<[{ + Get the array of argument attribute dictionaries. The method should + return an array attribute containing only dictionary attributes equal in + number to the number of arguments. Alternatively, the method can + return null to indicate that there are no argument attributes. + }], + "::mlir::ArrayAttr", "getArgAttrsAttr">, + InterfaceMethod<[{ + Get the array of result attribute dictionaries. The method should return + an array attribute containing only dictionary attributes equal in number + to the number of results. Alternatively, the method can return + null to indicate that there are no result attributes. + }], + "::mlir::ArrayAttr", "getResAttrsAttr">, + InterfaceMethod<[{ + Set the array of argument attribute dictionaries. + }], + "void", "setArgAttrsAttr", (ins "::mlir::ArrayAttr":$attrs)>, + InterfaceMethod<[{ + Set the array of result attribute dictionaries. + }], + "void", "setResAttrsAttr", (ins "::mlir::ArrayAttr":$attrs)>, + InterfaceMethod<[{ + Remove the array of argument attribute dictionaries. This is the same as + setting all argument attributes to an empty dictionary. The method should + return the removed attribute. + }], + "::mlir::Attribute", "removeArgAttrsAttr">, + InterfaceMethod<[{ + Remove the array of result attribute dictionaries. This is the same as + setting all result attributes to an empty dictionary. The method should + return the removed attribute. + }], + "::mlir::Attribute", "removeResAttrsAttr"> + ]; +} + // `CallInterfaceCallable`: This is a type used to represent a single callable // region. A callable is either a symbol, or an SSA value, that is referenced by // a call-like operation. This represents the destination of the call. @@ -81,7 +123,7 @@ def CallOpInterface : OpInterface<"CallOpInterface"> { return ::mlir::call_interface_impl::resolveCallable($_op); }] > - ]; + ] # ArgumentAttributesMethods.methods; } /// Interface for callable operations. @@ -113,48 +155,7 @@ def CallableOpInterface : OpInterface<"CallableOpInterface"> { allow for this method may be called on function declarations). }], "::llvm::ArrayRef<::mlir::Type>", "getResultTypes">, - - InterfaceMethod<[{ - Get the array of argument attribute dictionaries. The method should - return an array attribute containing only dictionary attributes equal in - number to the number of region arguments. Alternatively, the method can - return null to indicate that the region has no argument attributes. - }], - "::mlir::ArrayAttr", "getArgAttrsAttr", (ins), - /*methodBody=*/[{}], /*defaultImplementation=*/[{ return nullptr; }]>, - InterfaceMethod<[{ - Get the array of result attribute dictionaries. The method should return - an array attribute containing only dictionary attributes equal in number - to the number of region results. Alternatively, the method can return - null to indicate that the region has no result attributes. - }], - "::mlir::ArrayAttr", "getResAttrsAttr", (ins), - /*methodBody=*/[{}], /*defaultImplementation=*/[{ return nullptr; }]>, - InterfaceMethod<[{ - Set the array of argument attribute dictionaries. - }], - "void", "setArgAttrsAttr", (ins "::mlir::ArrayAttr":$attrs), - /*methodBody=*/[{}], /*defaultImplementation=*/[{ return; }]>, - InterfaceMethod<[{ - Set the array of result attribute dictionaries. - }], - "void", "setResAttrsAttr", (ins "::mlir::ArrayAttr":$attrs), - /*methodBody=*/[{}], /*defaultImplementation=*/[{ return; }]>, - InterfaceMethod<[{ - Remove the array of argument attribute dictionaries. This is the same as - setting all argument attributes to an empty dictionary. The method should - return the removed attribute. - }], - "::mlir::Attribute", "removeArgAttrsAttr", (ins), - /*methodBody=*/[{}], /*defaultImplementation=*/[{ return nullptr; }]>, - InterfaceMethod<[{ - Remove the array of result attribute dictionaries. This is the same as - setting all result attributes to an empty dictionary. The method should - return the removed attribute. - }], - "::mlir::Attribute", "removeResAttrsAttr", (ins), - /*methodBody=*/[{}], /*defaultImplementation=*/[{ return nullptr; }]>, - ]; + ] # ArgumentAttributesMethods.methods; } #endif // MLIR_INTERFACES_CALLINTERFACES diff --git a/mlir/include/mlir/Interfaces/FunctionImplementation.h b/mlir/include/mlir/Interfaces/FunctionImplementation.h index a5e6963e4e666..374c2c534f87d 100644 --- a/mlir/include/mlir/Interfaces/FunctionImplementation.h +++ b/mlir/include/mlir/Interfaces/FunctionImplementation.h @@ -33,19 +33,6 @@ class VariadicFlag { bool variadic; }; -/// Adds argument and result attributes, provided as `argAttrs` and -/// `resultAttrs` arguments, to the list of operation attributes in `result`. -/// Internally, argument and result attributes are stored as dict attributes -/// with special names given by getResultAttrName, getArgumentAttrName. -void addArgAndResultAttrs(Builder &builder, OperationState &result, - ArrayRef argAttrs, - ArrayRef resultAttrs, - StringAttr argAttrsName, StringAttr resAttrsName); -void addArgAndResultAttrs(Builder &builder, OperationState &result, - ArrayRef args, - ArrayRef resultAttrs, - StringAttr argAttrsName, StringAttr resAttrsName); - /// Callback type for `parseFunctionOp`, the callback should produce the /// type that will be associated with a function-like operation from lists of /// function arguments and results, VariadicFlag indicates whether the function @@ -58,11 +45,11 @@ using FuncTypeBuilder = function_ref &arguments, - bool &isVariadic, SmallVectorImpl &resultTypes, - SmallVectorImpl &resultAttrs); +ParseResult parseFunctionSignatureWithArguments( + OpAsmParser &parser, bool allowVariadic, + SmallVectorImpl &arguments, bool &isVariadic, + SmallVectorImpl &resultTypes, + SmallVectorImpl &resultAttrs); /// Parser implementation for function-like operations. Uses /// `funcTypeBuilder` to construct the custom function type given lists of @@ -84,9 +71,14 @@ void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, /// Prints the signature of the function-like operation `op`. Assumes `op` has /// is a FunctionOpInterface and has passed verification. -void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op, - ArrayRef argTypes, bool isVariadic, - ArrayRef resultTypes); +inline void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op, + ArrayRef argTypes, bool isVariadic, + ArrayRef resultTypes) { + call_interface_impl::printFunctionSignature( + p, argTypes, op.getArgAttrsAttr(), isVariadic, resultTypes, + op.getResAttrsAttr(), &op->getRegion(0), + /*printEmptyResult=*/false); +} /// Prints the list of function prefixed with the "attributes" keyword. The /// attributes with names listed in "elided" as well as those used by the diff --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp index a3e3f80954efc..d3bb250bb8ab9 100644 --- a/mlir/lib/Dialect/Async/IR/Async.cpp +++ b/mlir/lib/Dialect/Async/IR/Async.cpp @@ -308,7 +308,7 @@ void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name, if (argAttrs.empty()) return; assert(type.getNumInputs() == argAttrs.size()); - function_interface_impl::addArgAndResultAttrs( + call_interface_impl::addArgAndResultAttrs( builder, state, argAttrs, /*resultAttrs=*/std::nullopt, getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name)); } diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp index c818dd18a3d24..728a2d33f46e7 100644 --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -529,7 +529,7 @@ void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name, if (argAttrs.empty()) return; assert(type.getNumInputs() == argAttrs.size()); - function_interface_impl::addArgAndResultAttrs( + call_interface_impl::addArgAndResultAttrs( builder, state, argAttrs, /*resultAttrs=*/std::nullopt, getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name)); } diff --git a/mlir/lib/Dialect/Func/IR/FuncOps.cpp b/mlir/lib/Dialect/Func/IR/FuncOps.cpp index a490b4c3c4ab4..ba7b84f27d6a8 100644 --- a/mlir/lib/Dialect/Func/IR/FuncOps.cpp +++ b/mlir/lib/Dialect/Func/IR/FuncOps.cpp @@ -190,7 +190,7 @@ void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name, if (argAttrs.empty()) return; assert(type.getNumInputs() == argAttrs.size()); - function_interface_impl::addArgAndResultAttrs( + call_interface_impl::addArgAndResultAttrs( builder, state, argAttrs, /*resultAttrs=*/std::nullopt, getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name)); } diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp index 301066e7d3e1f..d06f10d3137a1 100644 --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -1467,7 +1467,7 @@ ParseResult GPUFuncOp::parse(OpAsmParser &parser, OperationState &result) { return failure(); auto signatureLocation = parser.getCurrentLocation(); - if (failed(function_interface_impl::parseFunctionSignature( + if (failed(function_interface_impl::parseFunctionSignatureWithArguments( parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes, resultAttrs))) return failure(); @@ -1487,7 +1487,7 @@ ParseResult GPUFuncOp::parse(OpAsmParser &parser, OperationState &result) { result.addAttribute(getFunctionTypeAttrName(result.name), TypeAttr::get(type)); - function_interface_impl::addArgAndResultAttrs( + call_interface_impl::addArgAndResultAttrs( builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index ef5f1b069b40a..a6e996f3fb810 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -1033,6 +1033,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results, /*memory_effects=*/nullptr, /*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr, /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{}, + /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr, /*access_groups=*/nullptr, /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr); } @@ -1060,6 +1061,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state, /*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr, /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{}, + /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr, /*access_groups=*/nullptr, /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr); } @@ -1073,6 +1075,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state, /*CConv=*/nullptr, /*TailCallKind=*/nullptr, /*memory_effects=*/nullptr, /*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr, /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{}, + /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr, /*access_groups=*/nullptr, /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr); } @@ -1087,6 +1090,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func, /*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr, /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{}, /*access_groups=*/nullptr, /*alias_scopes=*/nullptr, + /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr); } @@ -1527,7 +1531,8 @@ void InvokeOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func, auto calleeType = func.getFunctionType(); build(builder, state, getCallOpResultTypes(calleeType), getCallOpVarCalleeType(calleeType), SymbolRefAttr::get(func), ops, - normalOps, unwindOps, nullptr, nullptr, {}, {}, normal, unwind); + /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr, normalOps, unwindOps, + nullptr, nullptr, {}, {}, normal, unwind); } void InvokeOp::build(OpBuilder &builder, OperationState &state, TypeRange tys, @@ -1535,8 +1540,9 @@ void InvokeOp::build(OpBuilder &builder, OperationState &state, TypeRange tys, ValueRange normalOps, Block *unwind, ValueRange unwindOps) { build(builder, state, tys, - /*var_callee_type=*/nullptr, callee, ops, normalOps, unwindOps, nullptr, - nullptr, {}, {}, normal, unwind); + /*var_callee_type=*/nullptr, callee, ops, /*arg_attrs=*/nullptr, + /*res_attrs=*/nullptr, normalOps, unwindOps, nullptr, nullptr, {}, {}, + normal, unwind); } void InvokeOp::build(OpBuilder &builder, OperationState &state, @@ -1544,7 +1550,8 @@ void InvokeOp::build(OpBuilder &builder, OperationState &state, ValueRange ops, Block *normal, ValueRange normalOps, Block *unwind, ValueRange unwindOps) { build(builder, state, getCallOpResultTypes(calleeType), - getCallOpVarCalleeType(calleeType), callee, ops, normalOps, unwindOps, + getCallOpVarCalleeType(calleeType), callee, ops, + /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr, normalOps, unwindOps, nullptr, nullptr, {}, {}, normal, unwind); } @@ -2510,7 +2517,7 @@ void LLVMFuncOp::build(OpBuilder &builder, OperationState &result, assert(llvm::cast(type).getNumParams() == argAttrs.size() && "expected as many argument attribute lists as arguments"); - function_interface_impl::addArgAndResultAttrs( + call_interface_impl::addArgAndResultAttrs( builder, result, argAttrs, /*resultAttrs=*/std::nullopt, getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); } @@ -2595,7 +2602,7 @@ ParseResult LLVMFuncOp::parse(OpAsmParser &parser, OperationState &result) { auto signatureLocation = parser.getCurrentLocation(); if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), result.attributes) || - function_interface_impl::parseFunctionSignature( + function_interface_impl::parseFunctionSignatureWithArguments( parser, /*allowVariadic=*/true, entryArgs, isVariadic, resultTypes, resultAttrs)) return failure(); @@ -2636,7 +2643,7 @@ ParseResult LLVMFuncOp::parse(OpAsmParser &parser, OperationState &result) { if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes))) return failure(); - function_interface_impl::addArgAndResultAttrs( + call_interface_impl::addArgAndResultAttrs( parser.getBuilder(), result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp index f0f03e989cb47..160c264fc32d9 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -917,7 +917,7 @@ ParseResult spirv::FuncOp::parse(OpAsmParser &parser, OperationState &result) { // Parse the function signature. bool isVariadic = false; - if (function_interface_impl::parseFunctionSignature( + if (function_interface_impl::parseFunctionSignatureWithArguments( parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes, resultAttrs)) return failure(); @@ -940,7 +940,7 @@ ParseResult spirv::FuncOp::parse(OpAsmParser &parser, OperationState &result) { // Add the attributes to the function arguments. assert(resultAttrs.size() == resultTypes.size()); - function_interface_impl::addArgAndResultAttrs( + call_interface_impl::addArgAndResultAttrs( builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp index 65efc88e9c403..2200af0f67a86 100644 --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -1297,7 +1297,7 @@ void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name, if (argAttrs.empty()) return; assert(type.getNumInputs() == argAttrs.size()); - function_interface_impl::addArgAndResultAttrs( + call_interface_impl::addArgAndResultAttrs( builder, state, argAttrs, /*resultAttrs=*/std::nullopt, getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name)); } diff --git a/mlir/lib/Interfaces/CallInterfaces.cpp b/mlir/lib/Interfaces/CallInterfaces.cpp index da0ca0e24630f..e8ed4b339a0cb 100644 --- a/mlir/lib/Interfaces/CallInterfaces.cpp +++ b/mlir/lib/Interfaces/CallInterfaces.cpp @@ -7,9 +7,178 @@ //===----------------------------------------------------------------------===// #include "mlir/Interfaces/CallInterfaces.h" +#include "mlir/IR/Builders.h" using namespace mlir; +//===----------------------------------------------------------------------===// +// Argument and result attributes utilities +//===----------------------------------------------------------------------===// + +static ParseResult +parseTypeAndAttrList(OpAsmParser &parser, SmallVectorImpl &types, + SmallVectorImpl &attrs) { + // Parse individual function results. + return parser.parseCommaSeparatedList([&]() -> ParseResult { + types.emplace_back(); + attrs.emplace_back(); + NamedAttrList attrList; + if (parser.parseType(types.back()) || + parser.parseOptionalAttrDict(attrList)) + return failure(); + attrs.back() = attrList.getDictionary(parser.getContext()); + return success(); + }); +} + +ParseResult call_interface_impl::parseFunctionResultList( + OpAsmParser &parser, SmallVectorImpl &resultTypes, + SmallVectorImpl &resultAttrs) { + if (failed(parser.parseOptionalLParen())) { + // We already know that there is no `(`, so parse a type. + // Because there is no `(`, it cannot be a function type. + Type ty; + if (parser.parseType(ty)) + return failure(); + resultTypes.push_back(ty); + resultAttrs.emplace_back(); + return success(); + } + + // Special case for an empty set of parens. + if (succeeded(parser.parseOptionalRParen())) + return success(); + if (parseTypeAndAttrList(parser, resultTypes, resultAttrs)) + return failure(); + return parser.parseRParen(); +} + +ParseResult call_interface_impl::parseFunctionSignature( + OpAsmParser &parser, SmallVectorImpl &argTypes, + SmallVectorImpl &argAttrs, + SmallVectorImpl &resultTypes, + SmallVectorImpl &resultAttrs, bool mustParseEmptyResult) { + // Parse arguments. + if (parser.parseLParen()) + return failure(); + if (failed(parser.parseOptionalRParen())) { + if (parseTypeAndAttrList(parser, argTypes, argAttrs)) + return failure(); + if (parser.parseRParen()) + return failure(); + } + // Parse results. + if (succeeded(parser.parseOptionalArrow())) + return call_interface_impl::parseFunctionResultList(parser, resultTypes, + resultAttrs); + if (mustParseEmptyResult) + return failure(); + return success(); +} + +/// Print a function result list. The provided `attrs` must either be null, or +/// contain a set of DictionaryAttrs of the same arity as `types`. +static void printFunctionResultList(OpAsmPrinter &p, TypeRange types, + ArrayAttr attrs) { + assert(!types.empty() && "Should not be called for empty result list."); + assert((!attrs || attrs.size() == types.size()) && + "Invalid number of attributes."); + + auto &os = p.getStream(); + bool needsParens = types.size() > 1 || llvm::isa(types[0]) || + (attrs && !llvm::cast(attrs[0]).empty()); + if (needsParens) + os << '('; + llvm::interleaveComma(llvm::seq(0, types.size()), os, [&](size_t i) { + p.printType(types[i]); + if (attrs) + p.printOptionalAttrDict(llvm::cast(attrs[i]).getValue()); + }); + if (needsParens) + os << ')'; +} + +void call_interface_impl::printFunctionSignature( + OpAsmPrinter &p, TypeRange argTypes, ArrayAttr argAttrs, bool isVariadic, + TypeRange resultTypes, ArrayAttr resultAttrs, Region *body, + bool printEmptyResult) { + bool isExternal = !body || body->empty(); + if (!isExternal && !isVariadic && !argAttrs && !resultAttrs && + printEmptyResult) { + p.printFunctionalType(argTypes, resultTypes); + return; + } + + p << '('; + for (unsigned i = 0, e = argTypes.size(); i < e; ++i) { + if (i > 0) + p << ", "; + + if (!isExternal) { + ArrayRef attrs; + if (argAttrs) + attrs = llvm::cast(argAttrs[i]).getValue(); + p.printRegionArgument(body->getArgument(i), attrs); + } else { + p.printType(argTypes[i]); + if (argAttrs) + p.printOptionalAttrDict( + llvm::cast(argAttrs[i]).getValue()); + } + } + + if (isVariadic) { + if (!argTypes.empty()) + p << ", "; + p << "..."; + } + + p << ')'; + + if (!resultTypes.empty()) { + p << " -> "; + printFunctionResultList(p, resultTypes, resultAttrs); + } else if (printEmptyResult) { + p << " -> ()"; + } +} + +void call_interface_impl::addArgAndResultAttrs( + Builder &builder, OperationState &result, ArrayRef argAttrs, + ArrayRef resultAttrs, StringAttr argAttrsName, + StringAttr resAttrsName) { + auto nonEmptyAttrsFn = [](DictionaryAttr attrs) { + return attrs && !attrs.empty(); + }; + // Convert the specified array of dictionary attrs (which may have null + // entries) to an ArrayAttr of dictionaries. + auto getArrayAttr = [&](ArrayRef dictAttrs) { + SmallVector attrs; + for (auto &dict : dictAttrs) + attrs.push_back(dict ? dict : builder.getDictionaryAttr({})); + return builder.getArrayAttr(attrs); + }; + + // Add the attributes to the operation arguments. + if (llvm::any_of(argAttrs, nonEmptyAttrsFn)) + result.addAttribute(argAttrsName, getArrayAttr(argAttrs)); + + // Add the attributes to the operation results. + if (llvm::any_of(resultAttrs, nonEmptyAttrsFn)) + result.addAttribute(resAttrsName, getArrayAttr(resultAttrs)); +} + +void call_interface_impl::addArgAndResultAttrs( + Builder &builder, OperationState &result, + ArrayRef args, ArrayRef resultAttrs, + StringAttr argAttrsName, StringAttr resAttrsName) { + SmallVector argAttrs; + for (const auto &arg : args) + argAttrs.push_back(arg.attrs); + addArgAndResultAttrs(builder, result, argAttrs, resultAttrs, argAttrsName, + resAttrsName); +} + //===----------------------------------------------------------------------===// // CallOpInterface //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Interfaces/FunctionImplementation.cpp b/mlir/lib/Interfaces/FunctionImplementation.cpp index 988feee665fea..90f32896e8181 100644 --- a/mlir/lib/Interfaces/FunctionImplementation.cpp +++ b/mlir/lib/Interfaces/FunctionImplementation.cpp @@ -70,50 +70,7 @@ parseFunctionArgumentList(OpAsmParser &parser, bool allowVariadic, }); } -/// Parse a function result list. -/// -/// function-result-list ::= function-result-list-parens -/// | non-function-type -/// function-result-list-parens ::= `(` `)` -/// | `(` function-result-list-no-parens `)` -/// function-result-list-no-parens ::= function-result (`,` function-result)* -/// function-result ::= type attribute-dict? -/// -static ParseResult -parseFunctionResultList(OpAsmParser &parser, SmallVectorImpl &resultTypes, - SmallVectorImpl &resultAttrs) { - if (failed(parser.parseOptionalLParen())) { - // We already know that there is no `(`, so parse a type. - // Because there is no `(`, it cannot be a function type. - Type ty; - if (parser.parseType(ty)) - return failure(); - resultTypes.push_back(ty); - resultAttrs.emplace_back(); - return success(); - } - - // Special case for an empty set of parens. - if (succeeded(parser.parseOptionalRParen())) - return success(); - - // Parse individual function results. - if (parser.parseCommaSeparatedList([&]() -> ParseResult { - resultTypes.emplace_back(); - resultAttrs.emplace_back(); - NamedAttrList attrs; - if (parser.parseType(resultTypes.back()) || - parser.parseOptionalAttrDict(attrs)) - return failure(); - resultAttrs.back() = attrs.getDictionary(parser.getContext()); - return success(); - })) - return failure(); - - return parser.parseRParen(); -} - -ParseResult function_interface_impl::parseFunctionSignature( +ParseResult function_interface_impl::parseFunctionSignatureWithArguments( OpAsmParser &parser, bool allowVariadic, SmallVectorImpl &arguments, bool &isVariadic, SmallVectorImpl &resultTypes, @@ -121,46 +78,11 @@ ParseResult function_interface_impl::parseFunctionSignature( if (parseFunctionArgumentList(parser, allowVariadic, arguments, isVariadic)) return failure(); if (succeeded(parser.parseOptionalArrow())) - return parseFunctionResultList(parser, resultTypes, resultAttrs); + return call_interface_impl::parseFunctionResultList(parser, resultTypes, + resultAttrs); return success(); } -void function_interface_impl::addArgAndResultAttrs( - Builder &builder, OperationState &result, ArrayRef argAttrs, - ArrayRef resultAttrs, StringAttr argAttrsName, - StringAttr resAttrsName) { - auto nonEmptyAttrsFn = [](DictionaryAttr attrs) { - return attrs && !attrs.empty(); - }; - // Convert the specified array of dictionary attrs (which may have null - // entries) to an ArrayAttr of dictionaries. - auto getArrayAttr = [&](ArrayRef dictAttrs) { - SmallVector attrs; - for (auto &dict : dictAttrs) - attrs.push_back(dict ? dict : builder.getDictionaryAttr({})); - return builder.getArrayAttr(attrs); - }; - - // Add the attributes to the function arguments. - if (llvm::any_of(argAttrs, nonEmptyAttrsFn)) - result.addAttribute(argAttrsName, getArrayAttr(argAttrs)); - - // Add the attributes to the function results. - if (llvm::any_of(resultAttrs, nonEmptyAttrsFn)) - result.addAttribute(resAttrsName, getArrayAttr(resultAttrs)); -} - -void function_interface_impl::addArgAndResultAttrs( - Builder &builder, OperationState &result, - ArrayRef args, ArrayRef resultAttrs, - StringAttr argAttrsName, StringAttr resAttrsName) { - SmallVector argAttrs; - for (const auto &arg : args) - argAttrs.push_back(arg.attrs); - addArgAndResultAttrs(builder, result, argAttrs, resultAttrs, argAttrsName, - resAttrsName); -} - ParseResult function_interface_impl::parseFunctionOp( OpAsmParser &parser, OperationState &result, bool allowVariadic, StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder, @@ -182,8 +104,8 @@ ParseResult function_interface_impl::parseFunctionOp( // Parse the function signature. SMLoc signatureLocation = parser.getCurrentLocation(); bool isVariadic = false; - if (parseFunctionSignature(parser, allowVariadic, entryArgs, isVariadic, - resultTypes, resultAttrs)) + if (parseFunctionSignatureWithArguments(parser, allowVariadic, entryArgs, + isVariadic, resultTypes, resultAttrs)) return failure(); std::string errorMessage; @@ -221,8 +143,8 @@ ParseResult function_interface_impl::parseFunctionOp( // Add the attributes to the function arguments. assert(resultAttrs.size() == resultTypes.size()); - addArgAndResultAttrs(builder, result, entryArgs, resultAttrs, argAttrsName, - resAttrsName); + call_interface_impl::addArgAndResultAttrs( + builder, result, entryArgs, resultAttrs, argAttrsName, resAttrsName); // Parse the optional function body. The printer will not print the body if // its empty, so disallow parsing of empty body in the parser. @@ -241,68 +163,6 @@ ParseResult function_interface_impl::parseFunctionOp( return success(); } -/// Print a function result list. The provided `attrs` must either be null, or -/// contain a set of DictionaryAttrs of the same arity as `types`. -static void printFunctionResultList(OpAsmPrinter &p, ArrayRef types, - ArrayAttr attrs) { - assert(!types.empty() && "Should not be called for empty result list."); - assert((!attrs || attrs.size() == types.size()) && - "Invalid number of attributes."); - - auto &os = p.getStream(); - bool needsParens = types.size() > 1 || llvm::isa(types[0]) || - (attrs && !llvm::cast(attrs[0]).empty()); - if (needsParens) - os << '('; - llvm::interleaveComma(llvm::seq(0, types.size()), os, [&](size_t i) { - p.printType(types[i]); - if (attrs) - p.printOptionalAttrDict(llvm::cast(attrs[i]).getValue()); - }); - if (needsParens) - os << ')'; -} - -void function_interface_impl::printFunctionSignature( - OpAsmPrinter &p, FunctionOpInterface op, ArrayRef argTypes, - bool isVariadic, ArrayRef resultTypes) { - Region &body = op->getRegion(0); - bool isExternal = body.empty(); - - p << '('; - ArrayAttr argAttrs = op.getArgAttrsAttr(); - for (unsigned i = 0, e = argTypes.size(); i < e; ++i) { - if (i > 0) - p << ", "; - - if (!isExternal) { - ArrayRef attrs; - if (argAttrs) - attrs = llvm::cast(argAttrs[i]).getValue(); - p.printRegionArgument(body.getArgument(i), attrs); - } else { - p.printType(argTypes[i]); - if (argAttrs) - p.printOptionalAttrDict( - llvm::cast(argAttrs[i]).getValue()); - } - } - - if (isVariadic) { - if (!argTypes.empty()) - p << ", "; - p << "..."; - } - - p << ')'; - - if (!resultTypes.empty()) { - p.getStream() << " -> "; - auto resultAttrs = op.getResAttrsAttr(); - printFunctionResultList(p, resultTypes, resultAttrs); - } -} - void function_interface_impl::printFunctionAttributes( OpAsmPrinter &p, Operation *op, ArrayRef elided) { // Print out function attributes, if present. diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index 79840094686e1..2aa0658ab0e5d 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -557,7 +557,12 @@ def TestCallOp : TEST_Op<"call", [DeclareOpInterfaceMethods { - let arguments = (ins Variadic:$arg_operands, SymbolRefAttr:$callee); + let arguments = (ins + Variadic:$arg_operands, + SymbolRefAttr:$callee, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs + ); let results = (outs Variadic); let extraClassDeclaration = [{ @@ -618,6 +623,10 @@ def ConversionFuncOp : TEST_Op<"conversion_func_op", [FunctionOpInterface]> { def FunctionalRegionOp : TEST_Op<"functional_region_op", [CallableOpInterface]> { let regions = (region AnyRegion:$body); + let arguments = (ins + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs + ); let results = (outs FunctionType); let extraClassDeclaration = [{ @@ -3309,7 +3318,9 @@ def TestCallAndStoreOp : TEST_Op<"call_and_store", SymbolRefAttr:$callee, Arg:$address, Variadic:$callee_operands, - BoolAttr:$store_before_call + BoolAttr:$store_before_call, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs ); let results = (outs Variadic:$results @@ -3324,7 +3335,9 @@ def TestCallOnDeviceOp : TEST_Op<"call_on_device", let arguments = (ins SymbolRefAttr:$callee, Variadic:$forwarded_operands, - AnyType:$non_forwarded_device_operand + AnyType:$non_forwarded_device_operand, + OptionalAttr:$arg_attrs, + OptionalAttr:$res_attrs ); let results = (outs Variadic:$results