diff --git a/include/gc/Dialect/Linalgx/CMakeLists.txt b/include/gc/Dialect/Linalgx/CMakeLists.txt index 1aceb8345..f33061b2d 100644 --- a/include/gc/Dialect/Linalgx/CMakeLists.txt +++ b/include/gc/Dialect/Linalgx/CMakeLists.txt @@ -1,9 +1 @@ -add_mlir_dialect(LinalgxOps linalgx) -set(LLVM_TARGET_DEFINITIONS LinalgxStructuredOps.td) -mlir_tablegen(LinalgxStructuredOps.h.inc -gen-op-decls) -mlir_tablegen(LinalgxStructuredOps.cpp.inc -gen-op-defs) -add_public_tablegen_target(MLIRLinalgxStructuredOpsIncGen) - -add_mlir_doc(LinalgxOps LinalgxOps gc/Dialect/Linalgx/ -gen-op-doc) -add_mlir_doc(LinalgxDialect LinalgxDialect gc/Dialect/Linalgx/ -gen-dialect-doc) -add_mlir_doc(LinalgxStructuredOps LinalgxStructuredOps gc/Dialect/Linalgx/ -gen-dialect-doc) +add_subdirectory(IR) diff --git a/include/gc/Dialect/Linalgx/IR/CMakeLists.txt b/include/gc/Dialect/Linalgx/IR/CMakeLists.txt new file mode 100644 index 000000000..27b123839 --- /dev/null +++ b/include/gc/Dialect/Linalgx/IR/CMakeLists.txt @@ -0,0 +1,9 @@ +add_mlir_dialect(LinalgxOps linalgx) +set(LLVM_TARGET_DEFINITIONS LinalgxStructuredOps.td) +mlir_tablegen(LinalgxStructuredOps.h.inc -gen-op-decls) +mlir_tablegen(LinalgxStructuredOps.cpp.inc -gen-op-defs) +add_public_tablegen_target(MLIRLinalgxStructuredOpsIncGen) + +add_mlir_doc(LinalgxOps LinalgxOps gc/Dialect/Linalgx/IR/ -gen-op-doc) +add_mlir_doc(LinalgxDialect LinalgxDialect gc/Dialect/Linalgx/IR/ -gen-dialect-doc) +add_mlir_doc(LinalgxStructuredOps LinalgxStructuredOps gc/Dialect/Linalgx/IR/ -gen-dialect-doc) diff --git a/include/gc/Dialect/Linalgx/LinalgxDialect.h b/include/gc/Dialect/Linalgx/IR/LinalgxDialect.h similarity index 93% rename from include/gc/Dialect/Linalgx/LinalgxDialect.h rename to include/gc/Dialect/Linalgx/IR/LinalgxDialect.h index 77dedd568..665e39616 100644 --- a/include/gc/Dialect/Linalgx/LinalgxDialect.h +++ b/include/gc/Dialect/Linalgx/IR/LinalgxDialect.h @@ -19,6 +19,6 @@ #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" -#include "gc/Dialect/Linalgx/LinalgxOpsDialect.h.inc" +#include "gc/Dialect/Linalgx/IR/LinalgxOpsDialect.h.inc" #endif // GC_DIALECTS_LINALGXDIALECT_H diff --git a/include/gc/Dialect/Linalgx/LinalgxDialect.td b/include/gc/Dialect/Linalgx/IR/LinalgxDialect.td similarity index 100% rename from include/gc/Dialect/Linalgx/LinalgxDialect.td rename to include/gc/Dialect/Linalgx/IR/LinalgxDialect.td diff --git a/include/gc/Dialect/Linalgx/LinalgxOps.h b/include/gc/Dialect/Linalgx/IR/LinalgxOps.h similarity index 89% rename from include/gc/Dialect/Linalgx/LinalgxOps.h rename to include/gc/Dialect/Linalgx/IR/LinalgxOps.h index 9ea73a91e..a92976ec9 100644 --- a/include/gc/Dialect/Linalgx/LinalgxOps.h +++ b/include/gc/Dialect/Linalgx/IR/LinalgxOps.h @@ -20,9 +20,9 @@ #include "mlir/Interfaces/SideEffectInterfaces.h" #define GET_OP_CLASSES -#include "gc/Dialect/Linalgx/LinalgxOps.h.inc" +#include "gc/Dialect/Linalgx/IR/LinalgxOps.h.inc" #define GET_OP_CLASSES -#include "gc/Dialect/Linalgx/LinalgxStructuredOps.h.inc" +#include "gc/Dialect/Linalgx/IR/LinalgxStructuredOps.h.inc" #endif // GC_DIALECTS_LINALGXOPS_H diff --git a/include/gc/Dialect/Linalgx/LinalgxOps.td b/include/gc/Dialect/Linalgx/IR/LinalgxOps.td similarity index 100% rename from include/gc/Dialect/Linalgx/LinalgxOps.td rename to include/gc/Dialect/Linalgx/IR/LinalgxOps.td diff --git a/include/gc/Dialect/Linalgx/LinalgxStructuredOps.td b/include/gc/Dialect/Linalgx/IR/LinalgxStructuredOps.td similarity index 100% rename from include/gc/Dialect/Linalgx/LinalgxStructuredOps.td rename to include/gc/Dialect/Linalgx/IR/LinalgxStructuredOps.td diff --git a/include/gc/Dialect/Linalgx/Transforms/AllInterfaces.h b/include/gc/Dialect/Linalgx/Transforms/AllInterfaces.h new file mode 100644 index 000000000..da9603054 --- /dev/null +++ b/include/gc/Dialect/Linalgx/Transforms/AllInterfaces.h @@ -0,0 +1,22 @@ + +//===-- AllInterfaces.h - linalgx dialect interfaces ------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef DIALECT_LINALGX_TRANSFORMS_ALLINTERFACES_H +#define DIALECT_LINALGX_TRANSFORMS_ALLINTERFACES_H + +namespace mlir { +class DialectRegistry; + +namespace linalgx { +void registerAllDialectInterfaceImplementations(DialectRegistry ®istry); +} // namespace linalgx + +} // namespace mlir + +#endif // DIALECT_LINALGX_TRANSFORMS_ALLINTERFACES_H diff --git a/include/gc/Dialect/Linalgx/Transforms/BufferizableOpInterfaceImpl.h b/include/gc/Dialect/Linalgx/Transforms/BufferizableOpInterfaceImpl.h new file mode 100644 index 000000000..ab1ff0de4 --- /dev/null +++ b/include/gc/Dialect/Linalgx/Transforms/BufferizableOpInterfaceImpl.h @@ -0,0 +1,20 @@ +//===- BufferizableOpInterfaceImpl.h - linalgx Bufferize --------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef DIALECT_LINALGX_BUFFERIZABLEOPINTERFACEIMPL_H +#define DIALECT_LINALGX_BUFFERIZABLEOPINTERFACEIMPL_H + +namespace mlir { +class DialectRegistry; + +namespace linalgx { +void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry); +} // namespace linalgx +} // namespace mlir + +#endif // DIALECT_LINALGX_BUFFERIZABLEOPINTERFACEIMPL_H diff --git a/include/gc/Dialect/Linalgx/Transforms/TilingInterfaceImpl.h b/include/gc/Dialect/Linalgx/Transforms/TilingInterfaceImpl.h new file mode 100644 index 000000000..078be79e9 --- /dev/null +++ b/include/gc/Dialect/Linalgx/Transforms/TilingInterfaceImpl.h @@ -0,0 +1,20 @@ +//===- TilingInterfaceImpl.h - linalgx Tiling -------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef DIALECT_LINALGX_TILINGINTERFACEIMPL_H +#define DIALECT_LINALGX_TILINGINTERFACEIMPL_H + +namespace mlir { +class DialectRegistry; + +namespace linalgx { +void registerTilingInterfaceExternalModels(DialectRegistry ®istry); +} // namespace linalgx +} // namespace mlir + +#endif // DIALECT_LINALGX_TILINGINTERFACEIMPL_H diff --git a/lib/gc/Dialect/Linalgx/CMakeLists.txt b/lib/gc/Dialect/Linalgx/CMakeLists.txt index 636760331..9f57627c3 100644 --- a/lib/gc/Dialect/Linalgx/CMakeLists.txt +++ b/lib/gc/Dialect/Linalgx/CMakeLists.txt @@ -1,17 +1,2 @@ -gc_set_mlir_link_components(MLIR_LINK_COMPONENTS MLIRIR) - -add_mlir_dialect_library(MLIRLinalgx - LinalgxDialect.cpp - LinalgxOps.cpp - - ADDITIONAL_HEADER_DIRS - ${PROJECT_SOURCE_DIR}/include/gc/Dialect/Linalgx - - DEPENDS - MLIRLinalgxOpsIncGen - MLIRLinalgxStructuredOpsIncGen - - LINK_LIBS PUBLIC - ${MLIR_LINK_COMPONENTS} -) -set_property(GLOBAL APPEND PROPERTY GC_DIALECT_LIBS MLIRLinalgx) \ No newline at end of file +add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/lib/gc/Dialect/Linalgx/IR/CMakeLists.txt b/lib/gc/Dialect/Linalgx/IR/CMakeLists.txt new file mode 100644 index 000000000..cbdf1f892 --- /dev/null +++ b/lib/gc/Dialect/Linalgx/IR/CMakeLists.txt @@ -0,0 +1,17 @@ +gc_set_mlir_link_components(MLIR_LINK_COMPONENTS MLIRIR) + +add_mlir_dialect_library(MLIRLinalgx + LinalgxDialect.cpp + LinalgxOps.cpp + + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/include/gc/Dialect/Linalgx/IR + + DEPENDS + MLIRLinalgxOpsIncGen + MLIRLinalgxStructuredOpsIncGen + + LINK_LIBS PUBLIC + ${MLIR_LINK_COMPONENTS} +) +set_property(GLOBAL APPEND PROPERTY GC_DIALECT_LIBS MLIRLinalgx) \ No newline at end of file diff --git a/lib/gc/Dialect/Linalgx/LinalgOps.cpp.inc b/lib/gc/Dialect/Linalgx/IR/LinalgOps.cpp.inc similarity index 100% rename from lib/gc/Dialect/Linalgx/LinalgOps.cpp.inc rename to lib/gc/Dialect/Linalgx/IR/LinalgOps.cpp.inc diff --git a/lib/gc/Dialect/Linalgx/LinalgxDialect.cpp b/lib/gc/Dialect/Linalgx/IR/LinalgxDialect.cpp similarity index 76% rename from lib/gc/Dialect/Linalgx/LinalgxDialect.cpp rename to lib/gc/Dialect/Linalgx/IR/LinalgxDialect.cpp index d2d7a4389..b624861c8 100644 --- a/lib/gc/Dialect/Linalgx/LinalgxDialect.cpp +++ b/lib/gc/Dialect/Linalgx/IR/LinalgxDialect.cpp @@ -6,8 +6,8 @@ // //===----------------------------------------------------------------------===// -#include "gc/Dialect/Linalgx/LinalgxDialect.h" -#include "gc/Dialect/Linalgx/LinalgxOps.h" +#include "gc/Dialect/Linalgx/IR/LinalgxDialect.h" +#include "gc/Dialect/Linalgx/IR/LinalgxOps.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" @@ -21,15 +21,15 @@ using namespace mlir; using namespace mlir::linalgx; -#include "gc/Dialect/Linalgx/LinalgxOpsDialect.cpp.inc" +#include "gc/Dialect/Linalgx/IR/LinalgxOpsDialect.cpp.inc" void LinalgxDialect::initialize() { addOperations< #define GET_OP_LIST -#include "gc/Dialect/Linalgx/LinalgxOps.cpp.inc" +#include "gc/Dialect/Linalgx/IR/LinalgxOps.cpp.inc" >(); addOperations< #define GET_OP_LIST -#include "gc/Dialect/Linalgx/LinalgxStructuredOps.cpp.inc" +#include "gc/Dialect/Linalgx/IR/LinalgxStructuredOps.cpp.inc" >(); } diff --git a/lib/gc/Dialect/Linalgx/LinalgxOps.cpp b/lib/gc/Dialect/Linalgx/IR/LinalgxOps.cpp similarity index 99% rename from lib/gc/Dialect/Linalgx/LinalgxOps.cpp rename to lib/gc/Dialect/Linalgx/IR/LinalgxOps.cpp index e276ce780..af1d618bf 100644 --- a/lib/gc/Dialect/Linalgx/LinalgxOps.cpp +++ b/lib/gc/Dialect/Linalgx/IR/LinalgxOps.cpp @@ -6,8 +6,8 @@ // //===----------------------------------------------------------------------===// -#include "gc/Dialect/Linalgx/LinalgxOps.h" -#include "gc/Dialect/Linalgx/LinalgxDialect.h" +#include "gc/Dialect/Linalgx/IR/LinalgxOps.h" +#include "gc/Dialect/Linalgx/IR/LinalgxDialect.h" #include "mlir/IR/OpImplementation.h" //===----------------------------------------------------------------------===// @@ -616,7 +616,7 @@ void MultiBatchMatmulOp::getEffects( /////// Operations corresponding to library calls defined with Tablegen //////// #define GET_OP_CLASSES -#include "gc/Dialect/Linalgx/LinalgxOps.cpp.inc" +#include "gc/Dialect/Linalgx/IR/LinalgxOps.cpp.inc" #define GET_OP_CLASSES -#include "gc/Dialect/Linalgx/LinalgxStructuredOps.cpp.inc" +#include "gc/Dialect/Linalgx/IR/LinalgxStructuredOps.cpp.inc" diff --git a/lib/gc/Dialect/Linalgx/Transforms/AllInterfaces.cpp b/lib/gc/Dialect/Linalgx/Transforms/AllInterfaces.cpp new file mode 100644 index 000000000..e906ecf17 --- /dev/null +++ b/lib/gc/Dialect/Linalgx/Transforms/AllInterfaces.cpp @@ -0,0 +1,18 @@ +//===-- AllInterfaces.cpp - linalgx dialect interfaces ----------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "gc/Dialect/Linalgx/Transforms/AllInterfaces.h" + +#include "gc/Dialect/Linalgx/Transforms/BufferizableOpInterfaceImpl.h" +#include "gc/Dialect/Linalgx/Transforms/TilingInterfaceImpl.h" + +void mlir::linalgx::registerAllDialectInterfaceImplementations( + DialectRegistry ®istry) { + registerBufferizableOpInterfaceExternalModels(registry); + registerTilingInterfaceExternalModels(registry); +} diff --git a/lib/gc/Dialect/Linalgx/Transforms/BufferizableOpInterfaceImpl.cpp b/lib/gc/Dialect/Linalgx/Transforms/BufferizableOpInterfaceImpl.cpp new file mode 100644 index 000000000..701dc87d3 --- /dev/null +++ b/lib/gc/Dialect/Linalgx/Transforms/BufferizableOpInterfaceImpl.cpp @@ -0,0 +1,35 @@ +//===-- BufferizableOpInterfaceImpl.cpp - linalgx bufferize -----*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "gc/Dialect/Linalgx/Transforms/BufferizableOpInterfaceImpl.h" +#include "gc/Dialect/Linalgx/IR/LinalgxDialect.h" +#include "gc/Dialect/Linalgx/IR/LinalgxOps.h" + +//===----------------------------------------------------------------------===// +// Builder helper from Linalg/Transforms/BufferizableOpInterfaceImpl.cpp +//===----------------------------------------------------------------------===// + +#include "BufferizableOpInterfaceImpl.cpp.inc" + +using namespace mlir; +using namespace mlir::linalgx; + +void mlir::linalgx::registerBufferizableOpInterfaceExternalModels( + DialectRegistry ®istry) { + registry.addExtension( + +[](MLIRContext *ctx, linalgx::LinalgxDialect *dialect) { + // Register all Linalg structured ops. `LinalgOp` is an interface and it + // is not possible to attach an external interface to an existing + // interface. Therefore, attach the `BufferizableOpInterface` to all ops + // one-by-one. + LinalgOpInterfaceHelper< +#define GET_OP_LIST +#include "gc/Dialect/Linalgx/IR/LinalgxStructuredOps.cpp.inc" + >::registerOpInterface(ctx); + }); +} diff --git a/lib/gc/Dialect/Linalgx/Transforms/BufferizableOpInterfaceImpl.cpp.inc b/lib/gc/Dialect/Linalgx/Transforms/BufferizableOpInterfaceImpl.cpp.inc new file mode 100644 index 000000000..fc53ef5e4 --- /dev/null +++ b/lib/gc/Dialect/Linalgx/Transforms/BufferizableOpInterfaceImpl.cpp.inc @@ -0,0 +1,168 @@ +//===-- BufferizableOpInterfaceImpl.cpp.inc ---------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// This file is a partial copy of +// mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/Operation.h" +#include "mlir/Interfaces/DestinationStyleOpInterface.h" + +using namespace mlir; +using namespace linalg; +using namespace mlir::bufferization; + +namespace { + +/// Generic conversion for any DestinationStyleOpInterface on tensors. +static LogicalResult +bufferizeDestinationStyleOpInterface(RewriterBase &rewriter, + DestinationStyleOpInterface op, + const BufferizationOptions &options) { + // Take a guard before anything else. + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(op); + + // Nothing to do. This op is already bufferized. + if (op.hasPureBufferSemantics()) + return success(); + + // Ensure op has only tensors. Allow mixed tensor-buffer mode on a per-need + // basis. + if (!op.hasPureTensorSemantics()) + return op->emitError() << "op does not have pure tensor semantics"; + + // New input operands for the cloned op. + SmallVector newInputBuffers; + newInputBuffers.reserve(op.getNumDpsInputs()); + for (OpOperand *opOperand : op.getDpsInputOperands()) { + if (op.isScalar(opOperand)) { + newInputBuffers.push_back(opOperand->get()); + continue; + } + FailureOr buffer = getBuffer(rewriter, opOperand->get(), options); + if (failed(buffer)) + return failure(); + newInputBuffers.push_back(*buffer); + } + + // New output operands for the cloned op. + SmallVector newOutputBuffers; + for (OpResult opResult : op->getOpResults()) { + OpOperand *opOperand = op.getDpsInitOperand(opResult.getResultNumber()); + FailureOr resultBuffer = + getBuffer(rewriter, opOperand->get(), options); + if (failed(resultBuffer)) + return failure(); + newOutputBuffers.push_back(*resultBuffer); + } + + // Merge input/output operands. + SmallVector newOperands = newInputBuffers; + newOperands.append(newOutputBuffers.begin(), newOutputBuffers.end()); + + // Set insertion point now that potential alloc/dealloc are introduced. + rewriter.setInsertionPoint(op); + // Clone the op, but use the new operands. Move the existing block into the + // new op. Since the new op does not have any tensor results, it does not + // return anything. + assert(op->getNumRegions() == 1 && "expected that op has 1 region"); + OperationState state(op->getLoc(), op->getName(), newOperands, TypeRange{}, + op->getAttrs()); + state.addRegion(); + Operation *newOp = Operation::create(state); + newOp->getRegion(0).getBlocks().splice(newOp->getRegion(0).begin(), + op->getRegion(0).getBlocks()); + + // We don't want the rewriter tracks an incomplete operation, so insert new + // operation after op was fully constructed. + rewriter.insert(newOp); + + // Replace the results of the old op with the new output buffers. + replaceOpWithBufferizedValues(rewriter, op, newOutputBuffers); + + return success(); +} + +/// Bufferization of linalg.generic. Replace with a new linalg.generic that +/// operates entirely on memrefs. +template +struct LinalgOpInterface + : public DstBufferizableOpInterfaceExternalModel, + OpTy> { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + // Operand is read if it is used in the computation. + auto linalgOp = cast(op); + return linalgOp.payloadUsesValueFromOperand(&opOperand); + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + // Operand is written to if it is not an input/init. + auto dpsOp = cast(op); + return dpsOp.isDpsInit(&opOperand); + } + + bool bufferizesToElementwiseAccess(Operation *op, const AnalysisState &state, + ArrayRef opOperands) const { + auto linalgOp = cast(op); + + // Accesses into sparse data structures are not necessarily elementwise. + if (sparse_tensor::hasAnySparseOperand(linalgOp)) + return false; + + // All loops must be parallel. + if (linalgOp.getNumLoops() != linalgOp.getNumParallelLoops()) + return false; + + // All index maps of tensors must be identity maps. + SmallVector indexingMaps = linalgOp.getIndexingMapsArray(); + assert(linalgOp->getNumOperands() == indexingMaps.size() && + "unexpected number of indexing maps"); + for (auto [operand, map] : + llvm::zip(linalgOp->getOpOperands(), indexingMaps)) { + // Non-tensors do not participate in bufferization, so they can be + // ignored. + if (!isa(operand.get().getType())) + continue; + // Only consider operands in `opOperands`. + if (!llvm::is_contained(opOperands, &operand)) + continue; + // TODO: This could be generalized to other indexing maps. (All indexing + // must be the same.) + if (!map.isIdentity()) + return false; + } + + return true; + } + + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + const BufferizationOptions &options) const { + return bufferizeDestinationStyleOpInterface( + rewriter, cast(op), options); + } +}; + +/// Helper structure that iterates over all LinalgOps in `OpTys` and registers +/// the `BufferizableOpInterface` with each of them. +template +struct LinalgOpInterfaceHelper { + static void registerOpInterface(MLIRContext *ctx) { + (Ops::template attachInterface>(*ctx), ...); + } +}; +} // namespace diff --git a/lib/gc/Dialect/Linalgx/Transforms/CMakeLists.txt b/lib/gc/Dialect/Linalgx/Transforms/CMakeLists.txt new file mode 100644 index 000000000..fa485d3af --- /dev/null +++ b/lib/gc/Dialect/Linalgx/Transforms/CMakeLists.txt @@ -0,0 +1,20 @@ +gc_set_mlir_link_components(MLIR_LINK_COMPONENTS MLIRIR) + +add_mlir_dialect_library(MLIRLinalgxTransforms + AllInterfaces.cpp + BufferizableOpInterfaceImpl.cpp + TilingInterfaceImpl.cpp + + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/include/gc/Dialect/Linalgx/Transforms + + DEPENDS + MLIRLinalgx + MLIRLinalgDialect + MLIRBufferizationDialect + MLIRBufferizationTransforms + + LINK_LIBS PUBLIC + ${MLIR_LINK_COMPONENTS} +) +set_property(GLOBAL APPEND PROPERTY GC_PASS_LIBS MLIRLinalgxTransforms) \ No newline at end of file diff --git a/lib/gc/Dialect/Linalgx/Transforms/TilingInterfaceImpl.cpp b/lib/gc/Dialect/Linalgx/Transforms/TilingInterfaceImpl.cpp new file mode 100644 index 000000000..7ac0c8ff8 --- /dev/null +++ b/lib/gc/Dialect/Linalgx/Transforms/TilingInterfaceImpl.cpp @@ -0,0 +1,43 @@ +//===-- TilingInterfaceImpl.cpp - linalgx TilingInterface -------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "gc/Dialect/Linalgx/Transforms/TilingInterfaceImpl.h" +#include "gc/Dialect/Linalgx/IR/LinalgxDialect.h" +#include "gc/Dialect/Linalgx/IR/LinalgxOps.h" + +//===----------------------------------------------------------------------===// +// Builder helper from Linalg/Transforms/TilingInterfaceImpl.cpp +//===----------------------------------------------------------------------===// + +#include "TilingInterfaceImpl.cpp.inc" + +using namespace mlir; +using namespace mlir::linalgx; + +template static void registerOne(MLIRContext *ctx) { + OpType::template attachInterface>(*ctx); + OpType::template attachInterface>( + *ctx); +} + +/// Variadic helper function. +template static void registerAll(MLIRContext *ctx) { + (registerOne(ctx), ...); +} + +#define GET_OP_LIST + +void mlir::linalgx::registerTilingInterfaceExternalModels( + DialectRegistry ®istry) { + registry.addExtension( + +[](MLIRContext *ctx, linalgx::LinalgxDialect *dialect) { + registerAll< +#include "gc/Dialect/Linalgx/IR/LinalgxStructuredOps.cpp.inc" + >(ctx); + }); +} diff --git a/lib/gc/Dialect/Linalgx/Transforms/TilingInterfaceImpl.cpp.inc b/lib/gc/Dialect/Linalgx/Transforms/TilingInterfaceImpl.cpp.inc new file mode 100644 index 000000000..c8073e521 --- /dev/null +++ b/lib/gc/Dialect/Linalgx/Transforms/TilingInterfaceImpl.cpp.inc @@ -0,0 +1,502 @@ +//===-- TilingInterfaceImpl.cpp.inc -----------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// This file is a partial copy of +// mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/Interfaces/TilingInterface.h" +#include + +using namespace mlir; +using namespace mlir::linalg; + +//===----------------------------------------------------------------------===// +// Utility methods for implementation of Tiling Interface for Linalg ops +//===----------------------------------------------------------------------===// + +/// Return the SSA values that represent the data point accessed using a given +/// `indexingMap` for a given point in the iteration space represented by `ivs`. +static SmallVector getIndicesForAccess(OpBuilder &b, Location loc, + AffineMap indexingMap, + ValueRange ivs) { + SmallVector indices; + indices.reserve(indexingMap.getNumResults()); + for (auto result : indexingMap.getResults()) { + AffineMap m = AffineMap::get(indexingMap.getNumDims(), + indexingMap.getNumSymbols(), result); + Value v = b.create(loc, m, ivs); + indices.push_back(v); + } + return indices; +} + +/// Method to inline the payload of a `linalgOp` given the iteration space +/// point and values for the arguments of the payload. +static LogicalResult inlinePayload(OpBuilder &b, LinalgOp linalgOp, + ValueRange ivs, ValueRange argValues) { + Block *body = linalgOp.getBlock(); + IRMapping map; + map.map(body->getArguments(), argValues); + for (auto &op : body->without_terminator()) { + if (auto indexOp = dyn_cast(&op)) { + map.map(indexOp.getResult(), ivs[indexOp.getDim()]); + continue; + } + b.clone(op, map); + } + + Operation *terminator = body->getTerminator(); + Location loc = terminator->getLoc(); + for (const auto &operand : llvm::enumerate(terminator->getOperands())) { + Value toStore = map.lookupOrDefault(operand.value()); + OpOperand *storeInto = linalgOp.getDpsInitOperand(operand.index()); + auto indices = getIndicesForAccess( + b, loc, linalgOp.getMatchingIndexingMap(storeInto), ivs); + b.create( + loc, toStore, linalgOp.getDpsInitOperand(operand.index())->get(), + indices); + } + return success(); +} + +//===----------------------------------------------------------------------===// +// External Model for implementing `TilingInterface` for `LinalgOp`s. +//===----------------------------------------------------------------------===// + +namespace { +/// External model implementation of TilingInterface for LinalgOps. An external +/// model implementation is used for now till the use of `TilingInterface` is +/// on-par with the current Linalg tiling + fusion patterns. Once it is +/// maybe possible to move this into the op-definition (though there are +/// advantages to leaving it as an external model) +template +struct LinalgOpTilingInterface + : public TilingInterface::ExternalModel, + LinalgOpTy> { + /// Return the loop iterator type. + SmallVector getLoopIteratorTypes(Operation *op) const { + LinalgOpTy concreteOp = cast(op); + return concreteOp.getIteratorTypesArray(); + } + + /// Return the iteration domain range. + SmallVector getIterationDomain(Operation *op, OpBuilder &b) const { + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(op); + Location loc = op->getLoc(); + LinalgOp linalgOp = cast(op); + SmallVector allShapesSizes = + linalgOp.createFlatListOfOperandDims(b, loc); + AffineMap map = linalgOp.getShapesToLoopsMap(); + + return llvm::to_vector( + llvm::map_range(map.getResults(), [&](AffineExpr loopExpr) { + OpFoldResult ofr = affine::makeComposedFoldedAffineApply( + b, loc, loopExpr, allShapesSizes); + return Range{b.getIndexAttr(0), ofr, b.getIndexAttr(1)}; + })); + } + + /// Instantiate the tiled implementation of the operation. + FailureOr + getTiledImplementation(Operation *op, OpBuilder &b, + ArrayRef offsets, + ArrayRef sizes) const { + // Leave the `sizeBounds` value empty. That is only needed when the `sizes` + // specified could lead to out of bounds accesses. + Location loc = op->getLoc(); + LinalgOp linalgOp = cast(op); + SmallVector valuesToTile = linalgOp->getOperands(); + SmallVector tiledOperands = makeTiledShapes( + b, loc, linalgOp, valuesToTile, offsets, sizes, {}, true); + + SmallVector resultTensorTypes = + getTensorOutputTypes(linalgOp, tiledOperands); + + Operation *tiledOp = clone(b, linalgOp, resultTensorTypes, tiledOperands); + offsetIndices(b, cast(tiledOp), offsets); + + return TilingResult{{tiledOp}, SmallVector(tiledOp->getResults())}; + } + + /// Utility to fetch the offsets and sizes when applied as per the indexing + /// map of the linalg op. This helps in fusing the linalg op as a consumer of + /// a given slice op. + void + getMappedOffsetAndSize(LinalgOp linalgOp, OpBuilder &b, AffineMap indexingMap, + ArrayRef offsets, + ArrayRef sizes, + SmallVectorImpl &mappedOffsets, + SmallVectorImpl &mappedSizes) const { + unsigned numLoops = linalgOp.getNumLoops(); + auto tilingInterfaceOp = cast(linalgOp.getOperation()); + mappedOffsets.resize(numLoops); + mappedSizes.resize(numLoops); + if (!indexingMap.isPermutation()) { + SmallVector iterationDomain = + tilingInterfaceOp.getIterationDomain(b); + for (const auto &&[index, value] : llvm::enumerate(iterationDomain)) { + mappedOffsets[index] = value.offset; + mappedSizes[index] = value.size; + } + } + for (const auto &&[index, value] : + llvm::enumerate(indexingMap.getResults())) { + unsigned dimPosition = cast(value).getPosition(); + mappedOffsets[dimPosition] = offsets[index]; + mappedSizes[dimPosition] = sizes[index]; + } + } + + /// Method to return the position of the result tile computed by the tiled + /// operation. + LogicalResult getIterationDomainTileFromOperandTile( + Operation *op, OpBuilder &b, unsigned operandNumber, + ArrayRef offsets, ArrayRef sizes, + SmallVectorImpl &iterDomainOffsets, + SmallVectorImpl &iterDomainSizes) const { + auto linalgOp = cast(op); + + // Check that the indexing map used for the operand is a projected + // permutation. This could be relaxed with a more general approach that can + // map the offsets and sizes from the operand to iteration space tiles + // (filling in full extent for dimensions not used to access the result). + AffineMap indexingMap = + linalgOp.getMatchingIndexingMap(&op->getOpOperand(operandNumber)); + if (!indexingMap.isProjectedPermutation()) { + return op->emitError() + << "unhandled get iter domain position when operand is not " + "accessed using a permuted projection"; + } + + getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes, + iterDomainOffsets, iterDomainSizes); + return success(); + } + + /// Return the details of the output tile generated by the tiled + /// implementation. + LogicalResult + getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber, + ArrayRef offsets, + ArrayRef sizes, + SmallVector &resultOffsets, + SmallVector &resultSizes) const { + Location loc = op->getLoc(); + LinalgOp linalgOp = cast(op); + + AffineExpr d0; + bindDims(b.getContext(), d0); + SmallVector subShapeSizes = + llvm::to_vector(llvm::map_range(sizes, [&](OpFoldResult ofr) { + return affine::makeComposedFoldedAffineApply(b, loc, d0 - 1, ofr); + })); + + OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber); + SliceParameters sliceParams = computeSliceParameters( + b, loc, outOperand->get(), sizes, + linalgOp.getMatchingIndexingMap(outOperand), offsets, + /*ubs*/ {}, subShapeSizes, true); + resultOffsets = sliceParams.offsets; + resultSizes = sliceParams.sizes; + return success(); + } + + FailureOr + generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber, + ArrayRef offsets, + ArrayRef sizes) const { + auto linalgOp = cast(op); + + // Check that the indexing map used for the output is a projected + // permutation. This could be relaxed with a more general approach that can + // map the offsets and sizes from the result to iteration space tiles + // (filling in full extent for dimensions not used to access the result). + AffineMap indexingMap = + linalgOp.getIndexingMapMatchingResult(op->getResult(resultNumber)); + if (!indexingMap.isProjectedPermutation()) { + return op->emitOpError( + "unhandled tiled implementation generation when result is not " + "accessed using a permuted projection"); + } + SmallVector mappedOffsets, mappedSizes; + getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes, + mappedOffsets, mappedSizes); + auto tilingInterfaceOp = cast(op); + FailureOr tilingResult = + tilingInterfaceOp.getTiledImplementation(b, mappedOffsets, mappedSizes); + + if (failed(tilingResult)) + return failure(); + + if (tilingResult->tiledOps.size() != 1) + return op->emitOpError("failed to generate tiled implementation"); + + return TilingResult{ + tilingResult->tiledOps, + SmallVector{tilingResult->tiledValues[resultNumber]}}; + } + + /// Method to generate the tiled implementation of an operation from the tile + /// of the operand. + FailureOr getTiledImplementationFromOperandTile( + Operation *op, OpBuilder &b, unsigned operandNumber, + ArrayRef offsets, ArrayRef sizes) const { + SmallVector mappedOffsets, mappedSizes; + if (failed(getIterationDomainTileFromOperandTile( + op, b, operandNumber, offsets, sizes, mappedOffsets, + mappedSizes))) { + return failure(); + } + return getTiledImplementation(op, b, mappedOffsets, mappedSizes); + } + + LogicalResult generateScalarImplementation(Operation *op, OpBuilder &builder, + Location loc, + ValueRange ivs) const { + auto linalgOp = cast(op); + if (!linalgOp.hasPureBufferSemantics()) + return op->emitOpError("expected operation to have buffer semantics"); + + SmallVector indexedValues; + indexedValues.reserve(linalgOp->getNumOperands()); + Location linalgOpLoc = op->getLoc(); + /// Load the data corresponding to the block arguments that + /// represent input operands. + for (OpOperand &operand : linalgOp->getOpOperands()) { + if (!linalgOp.payloadUsesValueFromOperand(&operand)) { + indexedValues.push_back(nullptr); + continue; + } + if (linalgOp.isScalar(&operand)) { + indexedValues.push_back(operand.get()); + continue; + } + SmallVector indices = getIndicesForAccess( + builder, linalgOpLoc, linalgOp.getMatchingIndexingMap(&operand), ivs); + Value load = + builder.create(linalgOpLoc, operand.get(), indices); + indexedValues.push_back(load); + } + + /// Inline the op payload and store the result. + return inlinePayload(builder, linalgOp, ivs, indexedValues); + } +}; + +//===----------------------------------------------------------------------===// +// External Model for implementing `PartialReductionInterface` for `LinalgOp`s. +//===----------------------------------------------------------------------===// + +/// External model implementation of PartialReductionInterface for LinalgOps. +template +struct LinalgOpPartialReductionInterface + : public PartialReductionOpInterface::ExternalModel< + LinalgOpPartialReductionInterface, LinalgOpTy> { + FailureOr> generateInitialTensorForPartialReduction( + Operation *op, OpBuilder &b, Location loc, ArrayRef sizes, + ArrayRef reductionDims) const { + auto linalgOp = cast(op); + OpBuilder::InsertionGuard guard(b); + + if (linalgOp.hasPureBufferSemantics()) + return op->emitOpError("expected operation to have tensor semantics"); + + SmallVector inits; + for (int initIdx = 0, e = linalgOp.getNumDpsInits(); initIdx < e; + ++initIdx) { + // Insert the new parallel dimension based on the index of the reduction + // loops. This could be controlled by user for more flexibility. + SmallVector combinerOps; + if (!matchReduction(linalgOp.getRegionOutputArgs(), initIdx, + combinerOps) || + combinerOps.size() != 1) + return op->emitOpError("Failed to anaysis the reduction operation."); + + Operation *reductionOp = combinerOps[0]; + std::optional identity = arith::getNeutralElement(reductionOp); + if (!identity.has_value()) + return op->emitOpError( + "Failed to get an identity value for the reduction operation."); + + ArrayRef oldShape = + linalgOp.getShape(linalgOp.getDpsInitOperand(initIdx)); + + // Calculate the new shape, we insert the new dimensions based on the + // index of the reduction dimensions. + SmallVector newOutputShape; + SmallVector dynamicDims; + int64_t currReductionDims = 0; + DenseSet reductionDimsSet(reductionDims.begin(), + reductionDims.end()); + for (int64_t idx : + llvm::seq(0, oldShape.size() + reductionDims.size())) { + if (reductionDimsSet.contains(idx)) { + dispatchIndexOpFoldResults(sizes[idx], dynamicDims, newOutputShape); + currReductionDims++; + continue; + } + int64_t oldIdx = idx - currReductionDims; + int64_t dim = oldShape[oldIdx]; + newOutputShape.push_back(dim); + if (ShapedType::isDynamic(dim)) + dynamicDims.push_back(b.create( + loc, linalgOp.getDpsInitOperand(initIdx)->get(), oldIdx)); + } + Value emptyTensor = b.create( + loc, newOutputShape, + linalgOp.getRegionOutputArgs()[initIdx].getType(), dynamicDims); + Value constantOp = b.create(loc, *identity); + auto identityTensor = + b.create(loc, constantOp, emptyTensor); + inits.push_back(identityTensor.getResult(0)); + } + + return inits; + } + + Operation *tileToPartialReduction(Operation *op, OpBuilder &b, Location loc, + ValueRange init, + ArrayRef offsets, + ArrayRef sizes, + ArrayRef reductionDims) const { + OpBuilder::InsertionGuard guard(b); + auto linalgOp = cast(op); + + // Step 1. Extend init maps to have reduction dimension dims, since we + // are converting them to parallel dimensions. + SmallVector newInitMaps; + newInitMaps.reserve(linalgOp.getNumDpsInits()); + for (int idx : llvm::seq(0, linalgOp.getNumDpsInits())) { + // TODO: linalg::Generic doesn't have getDpsInitOperands. Can replace + // this with a for range loop when we have it. + AffineMap newMap = + linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(idx)); + for (int redPos : reductionDims) { + newMap = newMap.insertResult(b.getAffineDimExpr(redPos), + newMap.getNumResults()); + } + newInitMaps.push_back(newMap); + } + + // Step 2a: Extract a slice of the input operands. + SmallVector tiledInputs = makeTiledShapes( + b, loc, linalgOp, linalgOp.getDpsInputs(), offsets, sizes, {}, true); + + // Step 2b: Extract a slice of the init operands. + SmallVector tiledInits; + for (auto [valueMap, valueToTile] : llvm::zip_equal(newInitMaps, init)) { + int64_t initRank = valueMap.getNumResults(); + SmallVector initOffset(initRank, b.getIndexAttr(0)); + SmallVector initStride(initRank, b.getIndexAttr(1)); + SmallVector initSizes; + for (AffineExpr dimExpr : valueMap.getResults()) { + auto dim = cast(dimExpr); + initSizes.push_back(sizes[dim.getPosition()]); + } + // TODO: Use SubsetExtractOpInterface here once available. + auto extractSlice = b.create( + loc, valueToTile, initOffset, initSizes, initStride); + tiledInits.push_back(extractSlice); + } + + // Update the indexing maps. + SmallVector newMaps = linalgOp.getIndexingMapsArray(); + // Change the init maps. + for (int idx : llvm::seq(0, linalgOp.getNumDpsInits())) { + // TODO: linalg::Generic doesn't have getDpsInitOperands. Can replace + // this with a for range loop when we have it. + OpOperand *initOperand = linalgOp.getDpsInitOperand(idx); + int64_t mapIdx = linalgOp.getIndexingMapIndex(initOperand); + newMaps[mapIdx] = newInitMaps[idx]; + } + + // Step 3. Change the reduction dim iterator types. + SmallVector newIteratorTypes = + linalgOp.getIteratorTypesArray(); + for (int dim : reductionDims) + newIteratorTypes[dim] = utils::IteratorType::parallel; + + // Step 4. Create the new generic op. + auto genericOp = + b.create(loc, ValueRange(tiledInits).getTypes(), tiledInputs, + tiledInits, newMaps, newIteratorTypes); + IRMapping mapping; + op->getRegion(0).cloneInto(&genericOp.getRegion(), + genericOp.getRegion().begin(), mapping); + return genericOp.getOperation(); + } + + Operation *mergeReductions(Operation *op, OpBuilder &b, Location loc, + ValueRange partialReduce, + ArrayRef reductionDims) const { + auto linalgOp = cast(op); + + // Step 1. Recover the dims that actually need to be merged from the + // original operation. We can classify the original iterators as follows: + // + // parallel --> parallel + // reduction + not in reductionDims --> parallel (already reduced) + // reduction + in reductionDims --> reduction (will reduce now) + SmallVector iterators(linalgOp.getNumLoops(), + utils::IteratorType::parallel); + for (int redIdx : reductionDims) + iterators[redIdx] = utils::IteratorType::reduction; + + // Step 2. For each partial result, create a map to index it. This map + // is simply the indexing map for the original result with reductionDims + // appended (as produced in tileToPartialReduction). + int64_t numInits = linalgOp.getNumDpsInits(); + SmallVector indexingMaps(numInits * 2); + for (int idx : llvm::seq(0, numInits)) { + AffineMap &inputMap = indexingMaps[idx]; + AffineMap &outputMap = indexingMaps[numInits + idx]; + + outputMap = + linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(idx)); + inputMap = outputMap; + for (int redPos : reductionDims) { + inputMap = inputMap.insertResult(b.getAffineDimExpr(redPos), + inputMap.getNumResults()); + } + } + + auto reduction = b.create( + loc, op->getResultTypes(), partialReduce, linalgOp.getDpsInits(), + indexingMaps, iterators, + [&linalgOp](OpBuilder &b, Location loc, ValueRange inputs) { + int64_t numInits = linalgOp.getNumDpsInits(); + SmallVector yieldedValues; + for (int idx : llvm::seq(0, numInits)) { + // Get the combiner op. + SmallVector combinerOps; + matchReduction(linalgOp.getRegionOutputArgs(), idx, combinerOps); + Operation *clonedReductionOp = b.clone(*combinerOps[0]); + // Combine the input at idx and output at numInits + idx. + clonedReductionOp->setOperand(0, inputs[idx]); + clonedReductionOp->setOperand(1, inputs[numInits + idx]); + // Yield. + yieldedValues.push_back(clonedReductionOp->getResult(0)); + } + b.create(loc, yieldedValues); + }); + return reduction.getOperation(); + } +}; + +} // namespace diff --git a/lib/gc/Transforms/OneDNNGraphToLinalg.cpp b/lib/gc/Transforms/OneDNNGraphToLinalg.cpp index c472dbe87..2ab5b0527 100644 --- a/lib/gc/Transforms/OneDNNGraphToLinalg.cpp +++ b/lib/gc/Transforms/OneDNNGraphToLinalg.cpp @@ -9,8 +9,8 @@ #include #include -#include "gc/Dialect/Linalgx/LinalgxDialect.h" -#include "gc/Dialect/Linalgx/LinalgxOps.h" +#include "gc/Dialect/Linalgx/IR/LinalgxDialect.h" +#include "gc/Dialect/Linalgx/IR/LinalgxOps.h" #include "gc/Dialect/OneDNNGraph/OneDNNGraphDialect.h" #include "gc/Dialect/OneDNNGraph/OneDNNGraphOps.h" #include "gc/Transforms/Passes.h" diff --git a/lib/gc/Transforms/Pipeline.cpp b/lib/gc/Transforms/Pipeline.cpp index 72003224a..42ffa5149 100644 --- a/lib/gc/Transforms/Pipeline.cpp +++ b/lib/gc/Transforms/Pipeline.cpp @@ -23,7 +23,7 @@ #include "mlir/Transforms/Passes.h" #include "gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.h" -#include "gc/Dialect/Linalgx/LinalgxDialect.h" +#include "gc/Dialect/Linalgx/IR/LinalgxDialect.h" #include "gc/Dialect/OneDNNGraph/OneDNNGraphDialect.h" #include "gc/Transforms/Passes.h" diff --git a/src/gc-opt/gc-opt.cpp b/src/gc-opt/gc-opt.cpp index 809363108..d7c71c7f8 100644 --- a/src/gc-opt/gc-opt.cpp +++ b/src/gc-opt/gc-opt.cpp @@ -18,10 +18,12 @@ */ #include "gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.h" -#include "gc/Dialect/Linalgx/LinalgxDialect.h" +#include "gc/Dialect/Linalgx/IR/LinalgxDialect.h" +#include "gc/Dialect/Linalgx/Transforms/AllInterfaces.h" #include "gc/Dialect/OneDNNGraph/OneDNNGraphDialect.h" #include "gc/Transforms/Passes.h" #include "mlir/InitAllDialects.h" +#include "mlir/InitAllExtensions.h" #include "mlir/InitAllPasses.h" #include "mlir/Tools/mlir-opt/MlirOptMain.h" @@ -34,7 +36,9 @@ int main(int argc, char *argv[]) { registry.insert(); registry.insert(); mlir::registerAllDialects(registry); + mlir::registerAllExtensions(registry); mlir::cpuruntime::registerConvertCPURuntimeToLLVMInterface(registry); + mlir::linalgx::registerAllDialectInterfaceImplementations(registry); return mlir::asMainReturnCode(mlir::MlirOptMain( argc, argv, "Graph Compiler modular optimizer driver\n", registry)); } diff --git a/test/mlir/test/gc/Dialect/Linlagx/linalgx-bufferize.mlir b/test/mlir/test/gc/Dialect/Linlagx/linalgx-bufferize.mlir new file mode 100644 index 000000000..da06bc914 --- /dev/null +++ b/test/mlir/test/gc/Dialect/Linlagx/linalgx-bufferize.mlir @@ -0,0 +1,16 @@ +// RUN: mlir-opt --one-shot-bufferize="dialect-filter=linalgx,bufferization copy-before-write unknown-type-conversion=identity-layout-map" -canonicalize -cse -split-input-file %s | FileCheck %s + +// CHECK-LABEL: @batch_reduce_matmul_vnni +func.func @batch_reduce_matmul_vnni(%arg0: tensor<512x32x64xbf16>, %arg1: tensor<512x32x128x2xbf16>, + %arg2: tensor<32x128xf32>) -> tensor<32x128xf32> { + // CHECK: bufferization.to_memref + // CHECK: bufferization.to_memref + // CHECK: bufferization.to_memref + // CHECK: memref.alloc() + // CHECK: memref.copy + // CHECK: linalgx.batch_reduce_matmul_vnni + // CHECK: bufferization.to_tensor + %0 = linalgx.batch_reduce_matmul_vnni ins(%arg0, %arg1 : tensor<512x32x64xbf16>, tensor<512x32x128x2xbf16>) + outs(%arg2 : tensor<32x128xf32>) -> tensor<32x128xf32> + return %0 : tensor<32x128xf32> +} diff --git a/test/mlir/test/gc/Dialect/Linlagx/linalgx-tile.mlir b/test/mlir/test/gc/Dialect/Linlagx/linalgx-tile.mlir new file mode 100644 index 000000000..aea99ca1c --- /dev/null +++ b/test/mlir/test/gc/Dialect/Linlagx/linalgx-tile.mlir @@ -0,0 +1,35 @@ +// RUN: mlir-opt --split-input-file --transform-interpreter --canonicalize %s | FileCheck %s + +// CHECK-LABEL: @mm2d_vnni +func.func @mm2d_vnni(%arg0: tensor<256x64xi8>, %arg1: tensor<16x2x8x32x4xi8>, + %arg2: tensor<256x512xi32>) -> tensor<256x512xi32> { + // CHECK: linalgx.mm2d_vnni + %0 = linalgx.mm2d_vnni ins(%arg0, %arg1 : tensor<256x64xi8>, tensor<16x2x8x32x4xi8>) + outs(%arg2 : tensor<256x512xi32>) -> tensor<256x512xi32> + return %0 : tensor<256x512xi32> +} +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalgx.mm2d_vnni"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1, %loop:2 = transform.structured.tile_using_for %0 tile_sizes [2, 4] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + transform.yield + } +} + +// ----- + +// CHECK-LABEL: @mm4d_vnni +func.func @mm4d_vnni(%arg0: tensor<2x8x32x32xbf16>, %arg1: tensor<4x8x16x32x2xbf16>, + %arg2: tensor<2x4x32x32xbf16>) -> tensor<2x4x32x32xbf16> { + // CHECK: linalgx.mm4d_vnni + %0 = linalgx.mm4d_vnni ins(%arg0, %arg1 : tensor<2x8x32x32xbf16>, tensor<4x8x16x32x2xbf16>) + outs(%arg2 : tensor<2x4x32x32xbf16>) -> tensor<2x4x32x32xbf16> + return %0 : tensor<2x4x32x32xbf16> +} +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalgx.mm4d_vnni"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1, %loop:2 = transform.structured.tile_using_for %0 tile_sizes [2, 4] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + transform.yield + } +} diff --git a/test/mlir/unittests/Example/Example.cpp b/test/mlir/unittests/Example/Example.cpp index d788ba4ec..d93baf6e0 100644 --- a/test/mlir/unittests/Example/Example.cpp +++ b/test/mlir/unittests/Example/Example.cpp @@ -6,7 +6,7 @@ // //===----------------------------------------------------------------------===// -#include "gc/Dialect/Linalgx/LinalgxDialect.h" +#include "gc/Dialect/Linalgx/IR/LinalgxDialect.h" #include "gtest/gtest.h" TEST(example, HelloWorld) {