diff --git a/build_tools/patches/0011-Add-Offset-Computing-Interface-For-MemDescType.patch b/build_tools/patches/0011-Add-Offset-Computing-Interface-For-MemDescType.patch new file mode 100644 index 000000000..36a47735c --- /dev/null +++ b/build_tools/patches/0011-Add-Offset-Computing-Interface-For-MemDescType.patch @@ -0,0 +1,242 @@ +diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td +index cb4b8c2468d7..4dcf95a0f87a 100644 +--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td ++++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td +@@ -573,6 +573,10 @@ def XeGPU_MemLayoutAttr : XeGPUAttr<"MemLayout", "mem_layout"> { + return getAttrs().getAs("stride"); + } + ++ ArrayAttr getBlockAttr() { ++ return getAttrs().getAs("block"); ++ } ++ + }]; + + } +diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td +index f8b371db498e..93642c2166e1 100644 +--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td ++++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td +@@ -232,7 +232,7 @@ def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "m + return MemDescType::get(getContext(), shape.value_or(getShape()), elementType, getMemLayout()); + } + +- ArrayAttr getStrides() { ++ ArrayAttr getStridesAttr() { + auto layout = getMemLayout(); + if (layout && layout.hasAttr("stride")) { + return layout.getStrides(); +@@ -245,6 +245,106 @@ def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "m + Builder builder(getContext()); + return builder.getI64ArrayAttr(defaultStrides); + } ++ ++ /// Heuristic to determine if the MemDesc uses column-major layout, ++ /// based on the rank and the value of the first stride dimension. ++ bool isColMajor() { ++ auto dim0 = dyn_cast(getStridesAttr()[0]); ++ return getRank() == 2 && dim0 && dim0.getInt() == 1; ++ } ++ ++ // get the Blocking shape for a MemDescType, Which is represented ++ // as an attribute in MemDescType. By default it is the shape ++ // of the mdescTy ++ SmallVector getBlockSize() { ++ SmallVector size(getShape()); ++ MemLayoutAttr layout = getMemLayout(); ++ if (layout && layout.hasAttr("block")) { ++ ArrayAttr attr = layout.getBlockAttr(); ++ size.clear(); ++ llvm::for_each(attr, [&](Attribute elem) { ++ if (auto intElem = dyn_cast(elem)) ++ size.push_back(intElem.getInt()); ++ }); ++ } ++ return size; ++ } ++ ++ // Get strides as vector of integer. ++ // If it contains block attribute, the strides are blocked strides. ++ // ++ // The blocking is applied against the original matrix shape ++ // so that the linear offset is not impacted by the subview. ++ // ++ // It first computes the original matrix shape using the stride info, ++ // then computes the number of blocks in each dimension of original shape, ++ // then compute the outer block shape and stride, ++ // then combines the inner and outer block shape and stride ++ // e.g. for mem_desc<32x256xf16, @block=[16, 8], @strides=[1, 32]> ++ // its memory layout tuple is ([2,32,16,8],[128,256,1,8]) ++ // for mem_desc<256x32xf16, @block=[8, 16]> with default @stride[32, 1] ++ // its memory layout tuple is ([32,2,8,16],[256,128,16,1]) ++ SmallVector getStrides() { ++ ++ SmallVector matrixShape(getShape().begin(), ++ getShape().end()); ++ ++ ArrayAttr strideAttr = getStridesAttr(); ++ SmallVector strides; ++ for (Attribute attr : strideAttr.getValue()) { ++ strides.push_back(cast(attr).getInt()); ++ } ++ ++ SmallVector innerBlkShape = getBlockSize(); ++ if (innerBlkShape.empty()) ++ return strides; ++ ++ SmallVector perm = llvm::to_vector<4>( ++ llvm::seq(0, strides.size())); ++ llvm::sort(perm, [&](int a, int b) { return strides[a] < strides[b]; }); ++ ++ assert(strides[perm[0]] == 1 && "inner most dim must have stride 1"); ++ ++ SmallVector innerBlkStride(innerBlkShape.size()); ++ innerBlkStride[perm[0]] = 1; ++ for (size_t i = 1; i < perm.size(); ++i) ++ innerBlkStride[perm[i]] = ++ innerBlkStride[perm[i - 1]] * innerBlkShape[perm[i - 1]]; ++ ++ // compute the original matrix shape using the stride info ++ // and compute the number of blocks in each dimension ++ // The shape of highest dim can't be derived from stride info, ++ // but doesn't impact the stride computation for blocked layout. ++ SmallVector matrixShapeOrig(matrixShape.size()); ++ SmallVector BlkShapeOrig(matrixShape.size()); ++ for (size_t i = 0; i < perm.size() - 1; ++i) { ++ matrixShapeOrig[perm[i]] = strides[perm[i + 1]] / strides[perm[i]]; ++ BlkShapeOrig[perm[i]] = matrixShapeOrig[perm[i]] / innerBlkShape[perm[i]]; ++ } ++ ++ int64_t innerBlkSize = 1; ++ for (auto s : innerBlkShape) ++ innerBlkSize *= s; ++ ++ SmallVector outerBlkStride(matrixShape.size()); ++ outerBlkStride[perm[0]] = innerBlkSize; ++ for (size_t i = 0; i < perm.size() - 1; ++i) { ++ outerBlkStride[perm[i + 1]] = ++ outerBlkStride[perm[i]] * BlkShapeOrig[perm[i]]; ++ } ++ ++ // combine the inner and outer strides ++ SmallVector blockedStrides; ++ blockedStrides.append(outerBlkStride.begin(), outerBlkStride.end()); ++ blockedStrides.append(innerBlkStride.begin(), innerBlkStride.end()); ++ return blockedStrides; ++ } ++ /// Generates instructions to compute the linearize offset ++ // if the memory descriptor is blocked, it returns linearize offset based on the blocked layout ++ // the strides of memory descriptor is always considered regardless of blocked or not ++ Value getLinearOffsets(OpBuilder &builder, ++ Location loc, ArrayRef offsets); ++ + }]; + + let hasCustomAssemblyFormat = true; +diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp +index 8ea8cb1f4597..808270534459 100644 +--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp ++++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp +@@ -703,6 +703,89 @@ void MemLayoutAttr::print(AsmPrinter &printer) const { + } + printer << ">"; + } ++// a helper utility to perform binary operation on OpFoldResult. ++// If both a and b are attributes, it will simply return the result. ++// Otherwise, the corresponding arith op will be generated, and an ++// contant op will be created if one of them is an attribute. ++template ++OpFoldResult genBinOp(OpFoldResult a, OpFoldResult b, Location loc, ++ OpBuilder &builder) { ++ auto aVal = getValueOrCreateConstantIndexOp(builder, loc, a); ++ auto bVal = getValueOrCreateConstantIndexOp(builder, loc, b); ++ return builder.create(loc, aVal, bVal).getResult(); ++} ++ ++// a helper utility to perform division operation on OpFoldResult and int64_t. ++#define div(a, b) \ ++ genBinOp(a, builder.getIndexAttr(b), loc, builder) ++ ++// a helper utility to perform reminder operation on OpFoldResult and int64_t. ++#define rem(a, b) \ ++ genBinOp(a, builder.getIndexAttr(b), loc, builder) ++ ++// a helper utility to perform multiply operation on OpFoldResult and int64_t. ++#define mul(a, b) \ ++ genBinOp(a, builder.getIndexAttr(b), loc, builder) ++ ++// a helper utility to perform addition operation on two OpFoldResult. ++#define add(a, b) genBinOp(a, b, loc, builder) ++ ++// block the given offsets according to the block shape ++// say the original offset is [y, x], and the block shape is [By, Bx], ++// then the blocked offset is [y/By, x/Bx, y%By, x%Bx] ++SmallVector getBlockedOffsets(OpBuilder &builder, Location loc, ++ ArrayRef offsets, ++ ArrayRef blockShape) { ++ ++ assert(offsets.size() == blockShape.size() && ++ "offsets and blockShape must have the same size"); ++ SmallVector blockedOffsets; ++ SmallVector divs, rems; ++ ++ for (auto [offset, block] : llvm::zip(offsets, blockShape)) { ++ divs.push_back(div(offset, block)); ++ rems.push_back(rem(offset, block)); ++ } ++ blockedOffsets.append(divs.begin(), divs.end()); ++ blockedOffsets.append(rems.begin(), rems.end()); ++ ++ return blockedOffsets; ++} ++ ++// Calculate the linear offset using the blocked offsets and stride ++Value MemDescType::getLinearOffsets(OpBuilder &builder, Location loc, ++ ArrayRef offsets) { ++ ++ SmallVector blockShape = getBlockSize(); ++ SmallVector strides = getStrides(); ++ if (!blockShape.empty()) { ++ assert(offsets.size() == blockShape.size() && ++ "offsets and blockShape must have the same size"); ++ // say the original offset is [y, x], and the block shape is [By, Bx], ++ // then the blocked offset is [y/By, x/Bx, y%By, x%Bx] ++ SmallVector blockedOffsets; ++ SmallVector divs, rems; ++ ++ for (auto [offset, block] : llvm::zip(offsets, blockShape)) { ++ divs.push_back(div(offset, block)); ++ rems.push_back(rem(offset, block)); ++ } ++ blockedOffsets.append(divs.begin(), divs.end()); ++ blockedOffsets.append(rems.begin(), rems.end()); ++ ++ offsets = blockedOffsets; ++ } ++ ++ // Start with initial value as matrix descriptor's base offset. ++ Value linearOffset = arith::ConstantIndexOp::create(builder, loc, 0); ++ for (size_t i = 0; i < offsets.size(); ++i) { ++ OpFoldResult mulResult = mul(offsets[i], strides[i]); ++ Value mulVal = getValueOrCreateConstantIndexOp(builder, loc, mulResult); ++ linearOffset = arith::AddIOp::create(builder, loc, mulVal, linearOffset); ++ } ++ ++ return linearOffset; ++} + + } // namespace xegpu + } // namespace mlir +diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +index ecee53c56a54..ba38d74f3c7f 100644 +--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp ++++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +@@ -1069,7 +1069,7 @@ LogicalResult MemDescSubviewOp::verify() { + [](auto p) { return std::get<0>(p) > std::get<1>(p); })) + return emitOpError("result shape must not exceed source shape."); + +- if (srcTy.getStrides() != resTy.getStrides()) ++ if (srcTy.getStridesAttr() != resTy.getStridesAttr()) + return emitOpError("result must inherit the source strides."); + + return success(); diff --git a/build_tools/patches/0012-memref-view-lowering-spirv.patch b/build_tools/patches/0012-memref-view-lowering-spirv.patch new file mode 100644 index 000000000..9d0da360b --- /dev/null +++ b/build_tools/patches/0012-memref-view-lowering-spirv.patch @@ -0,0 +1,115 @@ +diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp +index 2e00b42f4a56..15529d4c9b54 100644 +--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp ++++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp +@@ -393,8 +393,65 @@ AllocOpPattern::matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor, + return success(); + } + ++class ViewOpPattern final : public OpConversionPattern { ++public: ++ using OpConversionPattern::OpConversionPattern; ++ ++ LogicalResult ++ matchAndRewrite(memref::ViewOp operation, OpAdaptor adaptor, ++ ConversionPatternRewriter &rewriter) const override; ++}; ++ ++ + //===----------------------------------------------------------------------===// +-// AllocOp ++// ViewOp ++// %view = memref.view %alloc[%c0][] : memref<2048xi8, 3> to memref<512xf32, 3> ++// spirv.GlobalVariable @__workgroup_mem__1 : !spirv.ptr, Workgroup> ++// %1 = spirv.Bitcast @__workgroup_mem__1 : !spirv.ptr, Workgroup> to !spirv.ptr, Workgroup> ++// ++//===----------------------------------------------------------------------===// ++ ++LogicalResult ++ViewOpPattern::matchAndRewrite(memref::ViewOp operation, OpAdaptor adaptor, ++ ConversionPatternRewriter &rewriter) const { ++ MemRefType ToType = operation.getType(); ++ ++ // insert spirv.bitcast which cast the pointer type from spirvFromType to spirvToType ++ Type spirvToType = getTypeConverter()->convertType(ToType); ++ if (!spirvToType) ++ return rewriter.notifyMatchFailure(operation, "type conversion failed"); ++ ++ // need to limit the case where the source is a memref with element type i8 ++ // the result memref must have static sizes. ++ MemRefType FromType = cast(operation.getSource().getType()); ++ if (!FromType.getElementType().isInteger(8) || !FromType.hasStaticShape()) ++ return rewriter.notifyMatchFailure(operation, "unhandled view type"); ++ if (!ToType.hasStaticShape()) ++ return rewriter.notifyMatchFailure(operation, "unhandled view type"); ++ ++ // get base pointer from adaptor.getSource() ++ Value basePtr = adaptor.getSource(); ++ // get the offset ++ Value offset = adaptor.getByteShift(); ++ if (offset) { ++ Location loc = operation.getLoc(); ++ auto *spirvTypeConverter = getTypeConverter(); ++ Type materializedIndexType = spirvTypeConverter->getIndexType(); ++ Value basePtrAsInt = rewriter.createOrFold(loc, materializedIndexType, basePtr); ++ Value newPtrAsInt = rewriter.createOrFold(loc, materializedIndexType, basePtrAsInt, offset); ++ Value newPtr = rewriter.createOrFold(loc, basePtr.getType(), newPtrAsInt); ++ basePtr = newPtr; ++ } ++ ++ Location loc = operation.getLoc(); ++ Value castOp = rewriter.createOrFold( ++ loc, spirvToType, basePtr); ++ rewriter.replaceOp(operation, castOp); ++ return success(); ++} ++ ++//===----------------------------------------------------------------------===// ++// AtomicRMWOp + //===----------------------------------------------------------------------===// + + LogicalResult +@@ -1071,7 +1128,7 @@ LogicalResult ExtractAlignedPointerAsIndexOpPattern::matchAndRewrite( + namespace mlir { + void populateMemRefToSPIRVPatterns(const SPIRVTypeConverter &typeConverter, + RewritePatternSet &patterns) { +- patterns.add, #spirv.resource_limits<>> ++} { ++ ++// CHECK-LABEL: func @memory_view ++// CHECK-SAME: (%[[ARG0:.+]]: memref<2048xi8, #spirv.storage_class>) ++func.func @memory_view(%arg0: memref<2048xi8, #spirv.storage_class>) ++ -> memref<512xf32, #spirv.storage_class> { ++// CHECK-DAG: %[[ARG0_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<2048xi8, #spirv.storage_class> to !spirv.ptr, Function> ++// CHECK: %[[BITCAST:.+]] = spirv.Bitcast %[[ARG0_CAST]] : !spirv.ptr, Function> to !spirv.ptr, Function> ++ %c0 = arith.constant 0: index ++ %view = memref.view %arg0[%c0][] : memref<2048xi8, #spirv.storage_class> to memref<512xf32, #spirv.storage_class> ++ return %view : memref<512xf32, #spirv.storage_class> ++} ++ ++} ++ ++// ----- ++ + module attributes { + spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> + } { diff --git a/include/imex/Transforms/Passes.h b/include/imex/Transforms/Passes.h index 40b4e5193..182fcb87c 100644 --- a/include/imex/Transforms/Passes.h +++ b/include/imex/Transforms/Passes.h @@ -41,6 +41,7 @@ std::unique_ptr createHoistTransposePass(); std::unique_ptr createVnniTransformationPass(); std::unique_ptr createEmulateNonNativeBF16Pass(); std::unique_ptr createTileLoopsPass(); +std::unique_ptr createMaterializeMatrixOpPass(); #define GEN_PASS_DECL #include "imex/Transforms/Passes.h.inc" diff --git a/include/imex/Transforms/Passes.td b/include/imex/Transforms/Passes.td index 3ae9c0aa0..b59b05088 100644 --- a/include/imex/Transforms/Passes.td +++ b/include/imex/Transforms/Passes.td @@ -226,4 +226,19 @@ def TileLoops : Pass<"tile-loops", "::mlir::func::FuncOp"> { ]; } +def MaterializeMatrixOp: Pass<"imex-xegpu-materialize-matrix-op"> { + let summary = "materialize matrix ops for Xe2/Xe3"; + let description = [{ + Coverts mem_desc operations (load_matrix, store_matrix) into other xegpu memory operations + (load/store chunk, 1d block load) over shared local memory. It computes physical address + using the matrix's layout attributes (@strides, @block) and logical lane coordinates. + }]; + let constructor = "imex::createMaterializeMatrixOpPass()"; + let dependentDialects = [ + "::mlir::xegpu::XeGPUDialect", + "::mlir::vector::VectorDialect", + "::mlir::memref::MemRefDialect" + ]; +} + #endif // _IMEX_TRANSFORMS_PASSES_TD_INCLUDED_ diff --git a/include/imex/Utils/XeCommon.h b/include/imex/Utils/XeCommon.h index 6686a48b8..cf24d70bd 100644 --- a/include/imex/Utils/XeCommon.h +++ b/include/imex/Utils/XeCommon.h @@ -376,6 +376,25 @@ llvm::SmallVector getStridesOrOffsetsOrShapesInValueType( mlir::PatternRewriter &rewriter, ::llvm::SmallVector mixedOSS, mlir::Location loc); +// This method is essentially to insert ops to do vnni transformation +// on the given rank-2 VectorType value, and returns the value after +// transformation. +// The VC lowering path has to write contiguous 32-bit SLM locations +// using chunk stores, which requires the data is loaded in VNNI fashion. +// If the value is only has one use, which is store to +// slm, it is marked as potentialFoldable. Then if value is produced by +// a LoadNdOp, and the loadNdOp doesn't have packedAttr, it will fold +// the vnni transformation with the LoadNdOp, instead of inserting extra ops. +mlir::Value convertToPackedVector(mlir::PatternRewriter &rewriter, + mlir::Location loc, mlir::Value value, + bool potentialFoldable = false); + +// It converts a VectorType value to a 1D vector of 32-bit element type, +// using shapecast and bitcast operations, e.g., vector<4x4xf16> -> +// vector<8xi32>. +mlir::Value convertTo1D32BitVector(mlir::Value value, mlir::Location loc, + mlir::PatternRewriter &rewriter); + } // namespace imex #endif diff --git a/lib/Transforms/CMakeLists.txt b/lib/Transforms/CMakeLists.txt index 706c05b61..8cd3e0d72 100644 --- a/lib/Transforms/CMakeLists.txt +++ b/lib/Transforms/CMakeLists.txt @@ -15,6 +15,7 @@ add_mlir_library(IMEXTransforms OptimizeTranspose.cpp HoistTranspose.cpp TileLoops.cpp + MaterializeMatrixOp.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/imex/Transforms diff --git a/lib/Transforms/MaterializeMatrixOp.cpp b/lib/Transforms/MaterializeMatrixOp.cpp new file mode 100644 index 000000000..05ddd9a2c --- /dev/null +++ b/lib/Transforms/MaterializeMatrixOp.cpp @@ -0,0 +1,340 @@ +//===-- MaterializeMatrixOp.cpp - MaterializeMatrixOpPass ----------*- +// C++-*-===// +// +// Copyright 2025 Intel Corporation +// Part of the IMEX Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains MaterializeMatrixOp pass used for Xe2/Xe3 architecture. +/// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/XeGPU/IR/XeGPU.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/Passes.h" + +#include "imex/Utils/XeCommon.h" + +#include + +namespace imex { +#define GEN_PASS_DEF_MATERIALIZEMATRIXOP +#include "imex/Transforms/Passes.h.inc" +} // namespace imex + +using namespace imex; +using namespace mlir; +using namespace mlir::xegpu; + +namespace { + +// Lower xegpu::CreateMemDescOp to memref::ViewOp. Since SLM access instructions +// on Xe2 and Xe3 operate on 32-bit or 64-bit units, all data types smaller than +// 32 bits will be converted to 32 bits. +class CreateMemDescOpPattern final + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(xegpu::CreateMemDescOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + TypedValue src = op.getSource(); + MemDescType resTy = op.getMemDesc().getType(); + auto *converter = getTypeConverter(); + MemRefType newResTy = converter->convertType(resTy); + Value zero = arith::ConstantIndexOp::create(rewriter, op.getLoc(), 0); + rewriter.replaceOpWithNewOp(op, newResTy, src, zero, + ValueRange()); + return success(); + } +}; + +class MemDescSubviewOpPattern final + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(xegpu::MemDescSubviewOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + return rewriter.notifyMatchFailure( + op, "MemDescSubviewOp are not supported on Xe2/Xe3 architecture."); + } +}; + +class LoadMatrixOpPattern final + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(xegpu::LoadMatrixOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + MLIRContext *context = op.getContext(); + auto loc = op.getLoc(); + MemDescType mdescTy = op.getMemDesc().getType(); + VectorType resTy = op.getRes().getType(); + + SmallVector blockShape = mdescTy.getBlockSize(); + + if (blockShape.empty()) { + // TODO: lowering to regular load/store + // in case the SLM can't be blocked due to some limitation, the lowering + // need to fall back to regular load/store. The inst_data size may be + // bigger than regular load/store so need to be split to multiple regular + // load/store if needed. + return rewriter.notifyMatchFailure( + op, "LoadMatrixOp without blocking layout are not yet supported."); + } + + // TODO: support col-major + if (mdescTy.isColMajor()) + return rewriter.notifyMatchFailure(op, "unsupported memory descriptor"); + + int packSize = getVnniFactor(resTy.getElementType()); + int vecSize = resTy.getNumElements(); + + auto converter = getTypeConverter(); + Type elemTy = converter->convertType(resTy.getElementType()); + Attribute encoding = + BlockTensorDescAttr::get(context, xegpu::MemorySpace::SLM, 1, true); + + SmallVector offsets = op.getMixedOffsets(); + assert(blockShape.size() == 2 && "only support blocking for rank-2 matrix"); + Value linearOffset = mdescTy.getLinearOffsets(rewriter, loc, offsets); + if (packSize > 1) { + vecSize = vecSize / packSize; + Value packSizeScalar = arith::ConstantOp::create( + rewriter, loc, rewriter.getIndexAttr(packSize)); + linearOffset = + arith::DivUIOp::create(rewriter, loc, linearOffset, packSizeScalar); + } + + auto tdescTy = TensorDescType::get(context, vecSize, elemTy, encoding, + /*layout=*/nullptr); + + Value tdesc = xegpu::CreateNdDescOp::create( + rewriter, loc, tdescTy, + dyn_cast>(adaptor.getMemDesc()), + OpFoldResult(linearOffset)); + + auto packAttr = UnitAttr(); + auto transAttr = DenseI64ArrayAttr(); + auto bitWidthAttr = IntegerAttr(); + VectorType newResTy = VectorType::get(vecSize, elemTy); + auto ldOp = rewriter.create( + loc, newResTy, tdesc, ValueRange(), DenseI64ArrayAttr(), packAttr, + transAttr, bitWidthAttr, nullptr, nullptr, nullptr); + + Value result = ldOp.getResult(); + + // cast back + elemTy = resTy.getElementType(); + auto castTy = VectorType::get(resTy.getNumElements(), elemTy); + if (castTy != newResTy) + result = rewriter.create(loc, castTy, result); + if (castTy != resTy) + result = rewriter.create(loc, resTy, result); + + rewriter.replaceOp(op, result); + return success(); + } +}; + +/// Convert xegpu::StoreMatrixOp to xegpu::StoreNdOp if MemDesc is +// row-major or xegpu::StoreScatterOp if MemDesc is col-major. +class StoreMatrixOpPattern final + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(xegpu::StoreMatrixOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value data = adaptor.getData(); + SmallVector offsets = op.getMixedOffsets(); + + VectorType dataTy = op.getData().getType(); + SmallVector dataShape(dataTy.getShape().begin(), + dataTy.getShape().end()); + int packSize = getVnniFactor(dataTy.getElementType()); + + MemDescType mdescTy = op.getMemDesc().getType(); + SmallVector blockShape = mdescTy.getBlockSize(); + + if (blockShape.empty()) { + // TODO: lowering to regular load/store + // in case the SLM can't be blocked due to some limitation, the lowering + // need to fall back to regular load/store. The inst_data size may be + // bigger than regular load/store so need to be split to multiple regular + // load/store if needed. + return rewriter.notifyMatchFailure( + op, + "unblocked StoreMatrixOp are not supported on Xe2/Xe3 architecture."); + } + + assert(blockShape.size() == 2 && "only support blocking for rank-2 matrix"); + Value linearOffset = mdescTy.getLinearOffsets(rewriter, loc, offsets); + if (mdescTy.isColMajor()) { + int64_t vecSize = dataShape[1]; + int64_t stride = blockShape[0]; + + auto indexVecTy = VectorType::get(vecSize, rewriter.getIndexType()); + Value linearOffsetVec = + vector::BroadcastOp::create(rewriter, loc, indexVecTy, linearOffset); + + // Generate a vector of indices [0, 1, ..., vecSize-1] as constant index + // values + SmallVector indexAttrs = + llvm::map_to_vector(llvm::seq(0, vecSize), [&](int64_t i) { + return cast(rewriter.getIndexAttr(i)); + }); + Value indexVec = arith::ConstantOp::create( + rewriter, loc, DenseElementsAttr::get(indexVecTy, indexAttrs)); + Value strideConst = arith::ConstantIndexOp::create(rewriter, loc, stride); + Value strideVec = + vector::BroadcastOp::create(rewriter, loc, indexVecTy, strideConst); + Value mulOp = arith::MulIOp::create(rewriter, loc, indexVec, strideVec); + Value colOffsets = + arith::AddIOp::create(rewriter, loc, linearOffsetVec, mulOp); + + if (packSize > 1) { + Value packSizeScalar = arith::ConstantOp::create( + rewriter, loc, rewriter.getIndexAttr(packSize)); + Value packSizeVec = vector::BroadcastOp::create( + rewriter, loc, colOffsets.getType(), packSizeScalar); + colOffsets = + arith::DivUIOp::create(rewriter, loc, colOffsets, packSizeVec); + } + bool tryFold = op.getData().hasOneUse(); + data = convertToPackedVector(rewriter, loc, data, tryFold); + auto maskTy = VectorType::get(blockShape[1], rewriter.getI1Type()); + auto mask = arith::ConstantOp::create( + rewriter, loc, + DenseElementsAttr::get(maskTy, rewriter.getBoolAttr(true))); + SmallVector permutation = {1, 0}; + data = vector::TransposeOp::create(rewriter, loc, data, permutation); + { // using old style. + MLIRContext *context = op.getContext(); + auto converter = getTypeConverter(); + int64_t chunkSize = dataShape[0] / packSize; + + auto encoding = xegpu::ScatterTensorDescAttr::get( + context, xegpu::MemorySpace::SLM, chunkSize); + Type elemTy = converter->convertType(dataTy.getElementType()); + auto tdescTy = TensorDescType::get(context, {vecSize, chunkSize}, + elemTy, encoding, nullptr); + + Value tdesc = xegpu::CreateDescOp::create( + rewriter, loc, tdescTy, adaptor.getMemDesc(), colOffsets); + rewriter.replaceOpWithNewOp( + op, data, tdesc, mask, nullptr, nullptr, nullptr); + } + { // new style. + // auto chunkSize = rewriter.getI64IntegerAttr(blockShape[0] / + // packSize); rewriter.replaceOpWithNewOp( + // op, data, adaptor.getMemDesc(), linearOffset, mask, + // chunkSize, nullptr, nullptr, nullptr); + } + } else { // lower to 1D block TenssorDesc + MLIRContext *context = op.getContext(); + int vecSize = dataTy.getNumElements(); + auto converter = getTypeConverter(); + Type elemTy = converter->convertType(dataTy.getElementType()); + Attribute encoding = + BlockTensorDescAttr::get(context, xegpu::MemorySpace::SLM, 1, true); + + if (packSize > 1) { + vecSize = vecSize / packSize; + Value packSizeScalar = arith::ConstantOp::create( + rewriter, loc, rewriter.getIndexAttr(packSize)); + linearOffset = + arith::DivUIOp::create(rewriter, loc, linearOffset, packSizeScalar); + } + + auto tdescTy = TensorDescType::get(context, vecSize, elemTy, encoding, + /*layout=*/nullptr); + Value tdesc = xegpu::CreateNdDescOp::create( + rewriter, loc, tdescTy, + dyn_cast>(adaptor.getMemDesc()), + OpFoldResult(linearOffset)); + data = convertTo1D32BitVector(data, loc, rewriter); + rewriter.replaceOpWithNewOp(op, data, tdesc, nullptr, + nullptr, nullptr); + } + return success(); + } +}; + +/// Populate the given list with patterns that convert MemDesc and related ops +void populateMatrixOpConversionPatterns(TypeConverter &converter, + RewritePatternSet &patterns) { + patterns.add( + converter, patterns.getContext()); +} + +struct MaterializeMatrixOpPass + : public imex::impl::MaterializeMatrixOpBase { + void runOnOperation() override { + auto mod = getOperation(); + MLIRContext &ctx = getContext(); + TypeConverter typeConverter; + + // Since SLM access instructions on Xe2 and Xe3 operate on 32-bit or + // 64-bit units, all data types smaller than 32 bits has to be converted + // to 32 bits. + typeConverter.addConversion([&](Type type) -> Type { + if (type.isInteger() && type.getIntOrFloatBitWidth() < 32) + return IntegerType::get(type.getContext(), 32); + if (type.isFloat() && type.getIntOrFloatBitWidth() < 32) + return Float32Type::get(type.getContext()); + return type; + }); + + // Convert MemDescType into flattend 32-bit MemRefType for SLM + typeConverter.addConversion([&](MemDescType type) -> Type { + Type elemTy = type.getElementType(); + int packSize = getVnniFactor(elemTy); + elemTy = typeConverter.convertType(elemTy); + int numElems = type.getNumElements() / packSize; + // TODO: Currently, an I64Attr(3) is assumed to represent the address + // space in memref.alloc. This should be standardized for consistency + // in XeGPU. + return MemRefType::get(numElems, elemTy, AffineMap(), 3); + }); + + ConversionTarget target(ctx); + target.addIllegalOp(); + target.markUnknownOpDynamicallyLegal([](Operation *op) { return true; }); + + RewritePatternSet patterns(&ctx); + populateMatrixOpConversionPatterns(typeConverter, patterns); + + if (failed(applyPartialConversion(mod, target, std::move(patterns)))) + return signalPassFailure(); + + mlir::PassManager pm(&ctx); + pm.addPass(mlir::createCSEPass()); + if (mlir::failed(pm.run(mod))) + signalPassFailure(); + } +}; + +} // namespace + +namespace imex { +std::unique_ptr createMaterializeMatrixOpPass() { + return std::make_unique(); +} +} // namespace imex diff --git a/lib/Utils/XeCommon.cpp b/lib/Utils/XeCommon.cpp index 4c68bf312..22ea5c21a 100644 --- a/lib/Utils/XeCommon.cpp +++ b/lib/Utils/XeCommon.cpp @@ -361,4 +361,80 @@ llvm::SmallVector getStridesOrOffsetsOrShapesInValueType( return valueVec; } +// This method is essentially to insert ops to do vnni transformation +// on the given rank-2 VectorType value, and returns the value after +// transformation. +// The VC lowering path has to write contiguous 32-bit SLM locations +// using chunk stores, which requires the data is loaded in VNNI fashion. +// If the value is only has one use, which is store to +// slm, it is marked as potentialFoldable. Then if value is produced by +// a LoadNdOp, and the loadNdOp doesn't have packedAttr, it will fold +// the vnni transformation with the LoadNdOp, instead of inserting extra ops. +mlir::Value convertToPackedVector(mlir::PatternRewriter &rewriter, + mlir::Location loc, mlir::Value value, + bool potentialFoldable) { + auto vecTy = mlir::dyn_cast(value.getType()); + if (!vecTy) + return value; + + auto packedTy = getPackedType(vecTy); + if (packedTy != vecTy) { + auto defOp = value.getDefiningOp(); + if (defOp && potentialFoldable && !defOp.getPackedAttr()) { + rewriter.startOpModification(defOp); + defOp.setPacked(true); + value = defOp.getResult(); + value.setType(packedTy); + rewriter.finalizeOpModification(defOp); + } else { + auto typedValue = + mlir::dyn_cast>(value); + value = applyVnniTransform(rewriter, typedValue).first; + } + + auto elemTy = vecTy.getElementType(); + + // shape cast packed type (3D vector) to 2D vector, are required by bitcast + auto shape = packedTy.getShape(); + vecTy = mlir::VectorType::get({shape[0], shape[1] * shape[2]}, elemTy); + value = rewriter.create(loc, vecTy, value); + + // cast to 32-bit data, use i32 for intergers and f32 for floats. + elemTy = mlir::isa(elemTy) + ? (mlir::Type)rewriter.getI32Type() + : (mlir::Type)rewriter.getF32Type(); + vecTy = mlir::VectorType::get(packedTy.getShape().take_front(2), elemTy); + if (vecTy != packedTy) + value = rewriter.create(loc, vecTy, value); + } + return value; +} + +// It converts a VectorType value to a 1D vector of 32-bit element type, +// using shapecast and bitcast operations, e.g., vector<4x4xf16> -> +// vector<8xi32>. +mlir::Value convertTo1D32BitVector(mlir::Value value, mlir::Location loc, + mlir::PatternRewriter &rewriter) { + auto vecTy = mlir::dyn_cast(value.getType()); + if (!vecTy) + return value; + + auto elemTy = vecTy.getElementType(); + auto shapecastTy = mlir::VectorType::get(vecTy.getNumElements(), elemTy); + + if (shapecastTy != vecTy) { + value = rewriter.create(loc, shapecastTy, value); + } + + auto vnni = getVnniFactor(elemTy); + if (vnni > 1) { + elemTy = mlir::isa(elemTy) + ? (mlir::Type)rewriter.getI32Type() + : (mlir::Type)rewriter.getF32Type(); + auto castTy = mlir::VectorType::get(vecTy.getNumElements() / vnni, elemTy); + value = rewriter.create(loc, castTy, value); + } + return value; +} + } // namespace imex diff --git a/test/Conversion/MatrixOpMaterialization/lit.local.cfg b/test/Conversion/MatrixOpMaterialization/lit.local.cfg new file mode 100644 index 000000000..17c01e619 --- /dev/null +++ b/test/Conversion/MatrixOpMaterialization/lit.local.cfg @@ -0,0 +1,6 @@ +# it is good for new style (offset in load/store) +excludes_tests = [ + + ] + +config.excludes.update(excludes_tests) diff --git a/test/Conversion/MatrixOpMaterialization/unit_test.mlir b/test/Conversion/MatrixOpMaterialization/unit_test.mlir new file mode 100644 index 000000000..c2a490051 --- /dev/null +++ b/test/Conversion/MatrixOpMaterialization/unit_test.mlir @@ -0,0 +1,88 @@ +// RUN: imex-opt --split-input-file -imex-xegpu-materialize-matrix-op -cse -canonicalize %s -verify-diagnostics -o -| FileCheck %s + +gpu.module @test { + + //CHECK: load_matrix([[m:%.+]]: memref<2048xi8, 3>) + gpu.func @load_matrix(%m: memref<2048xi8, 3>) -> vector<8x16xf16> { + //CHECK: [[offset:%.+]] = arith.constant 192 : index + //CHECK: [[c0:%.+]] = arith.constant 0 : index + //CHECK: [[view:%.+]] = memref.view [[m]][[[c0]]][] : memref<2048xi8, 3> to memref<512xf32, 3> + //CHECK: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[view]][[[offset]]] : memref<512xf32, 3> -> !xegpu.tensor_desc<64xf32, #xegpu.block_tdesc_attr> + //CHECK: [[load:%.+]] = xegpu.load_nd [[tdesc]] : !xegpu.tensor_desc<64xf32, #xegpu.block_tdesc_attr> -> vector<64xf32> + //CHECK: [[cast:%.+]] = vector.bitcast [[load]] : vector<64xf32> to vector<128xf16> + //CHECK: [[result:%.+]] = vector.shape_cast [[cast]] : vector<128xf16> to vector<8x16xf16> + %mem_desc = xegpu.create_mem_desc %m : memref<2048xi8, 3> -> !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> + %load = xegpu.load_matrix %mem_desc[8, 16] : !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> -> vector<8x16xf16> + gpu.return %load: vector<8x16xf16> + } + + + //CHECK: store_matrix([[A:%.+]]: memref<64x64xf16>) + gpu.func @store_matrix(%A: memref<64x64xf16>) { + + //CHECK: [[offset:%.+]] = arith.constant 192 : index + //CHECK: [[c0:%.+]] = arith.constant 0 : index + //CHECK: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[A]][0, 0] : memref<64x64xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK: [[load:%.+]] = xegpu.load_nd [[tdesc]] : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> + //CHECK: [[alloca:%.+]] = memref.alloca() {alignment = 1024 : i64} : memref<2048xi8, 3> + //CHECK: [[view:%.+]] = memref.view [[alloca]][[[c0]]][] : memref<2048xi8, 3> to memref<512xf32, 3> + //CHECK: [[slmtdesc:%.+]] = xegpu.create_nd_tdesc [[view]][[[offset]]] : memref<512xf32, 3> -> !xegpu.tensor_desc<64xf32, #xegpu.block_tdesc_attr> + //CHECK: [[flatten:%.+]] = vector.shape_cast [[load]] : vector<8x16xf16> to vector<128xf16> + //CHECK: [[data:%.+]] = vector.bitcast [[flatten]] : vector<128xf16> to vector<64xf32> + //CHECK: xegpu.store_nd [[data]], [[slmtdesc]] : vector<64xf32>, !xegpu.tensor_desc<64xf32, #xegpu.block_tdesc_attr> + %tdesc = xegpu.create_nd_tdesc %A[0, 0] : memref<64x64xf16> -> !xegpu.tensor_desc<8x16xf16> + %data = xegpu.load_nd %tdesc: !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> + %m = memref.alloca() {alignment = 1024} : memref<2048xi8, 3> + %mem_desc_2 = xegpu.create_mem_desc %m : memref<2048xi8, 3> -> !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> + xegpu.store_matrix %data, %mem_desc_2[8, 16] : vector<8x16xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> + gpu.return + } + + //CHECK: store_matrix_strided([[A:%.+]]: memref<64x64xf16>) + gpu.func @store_matrix_strided(%A: memref<64x64xf16>) { + //CHECK: [[offsets:%.+]] = arith.constant dense<[132, 140, 148, 156, 164, 172, 180, 188, 196, 204, 212, 220, 228, 236, 244, 252]> : vector<16xindex> + //CHECK: [[mask:%.+]] = arith.constant dense : vector<16xi1> + //CHECK: [[c0:%.+]] = arith.constant 0 : index + //CHECK: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[A]][0, 0] : memref<64x64xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK: [[load:%.+]] = xegpu.load_nd [[tdesc]] <{packed}> : !xegpu.tensor_desc<8x16xf16> -> vector<4x16x2xf16> + //CHECK: [[alloca:%.+]] = memref.alloca() {alignment = 1024 : i64} : memref<2048xi8, 3> + //CHECK: [[view:%.+]] = memref.view [[alloca]][[[c0]]][] : memref<2048xi8, 3> to memref<512xf32, 3> + //CHECK: [[shapecast:%.+]] = vector.shape_cast [[load]] : vector<4x16x2xf16> to vector<4x32xf16> + //CHECK: [[bcast:%.+]] = vector.bitcast [[shapecast]] : vector<4x32xf16> to vector<4x16xf32> + //CHECK: [[data:%.+]] = vector.transpose [[bcast]], [1, 0] : vector<4x16xf32> to vector<16x4xf32> + //CHECK: [[tdesc2:%.+]] = xegpu.create_tdesc [[view]], [[offsets]] : memref<512xf32, 3>, vector<16xindex> -> !xegpu.tensor_desc<16x4xf32, #xegpu.scatter_tdesc_attr> + //CHECK: xegpu.store [[data]], [[tdesc2]], [[mask]] : vector<16x4xf32>, !xegpu.tensor_desc<16x4xf32, #xegpu.scatter_tdesc_attr>, vector<16xi1> + %tdesc = xegpu.create_nd_tdesc %A[0, 0] : memref<64x64xf16> -> !xegpu.tensor_desc<8x16xf16> + %data = xegpu.load_nd %tdesc: !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> + %m = memref.alloca() {alignment = 1024} : memref<2048xi8, 3> + %mem_desc_2 = xegpu.create_mem_desc %m : memref<2048xi8, 3> -> !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> + xegpu.store_matrix %data, %mem_desc_2[8, 16] : vector<8x16xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> + gpu.return + } + + //CHECK: store_matrix_strided_v2([[A:%.+]]: memref<64x64xf16>) + gpu.func @store_matrix_strided_v2(%A: memref<64x64xf16>) { + //CHECK: [[offsets:%.+]] = arith.constant dense<[132, 140, 148, 156, 164, 172, 180, 188, 196, 204, 212, 220, 228, 236, 244, 252]> : vector<16xindex> + //CHECK: [[mask:%.+]] = arith.constant dense : vector<16xi1> + //CHECK: [[c0:%.+]] = arith.constant 0 : index + //CHECK: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[A]][0, 0] : memref<64x64xf16> -> !xegpu.tensor_desc<8x16xf16> + //CHECK: [[load:%.+]] = xegpu.load_nd [[tdesc]] : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> + //CHECK: [[alloca:%.+]] = memref.alloca() {alignment = 1024 : i64} : memref<2048xi8, 3> + //CHECK: [[view:%.+]] = memref.view [[alloca]][[[c0]]][] : memref<2048xi8, 3> to memref<512xf32, 3> + //CHECK: [[shapecast1:%.+]] = vector.shape_cast [[load]] {packed} : vector<8x16xf16> to vector<128xf16> + //CHECK: [[shuffle:%.+]] = vector.shuffle [[shapecast1]], [[shapecast1]] + //CHECK: [[shapecast:%.+]] = vector.shape_cast [[shuffle]] : vector<128xf16> to vector<4x32xf16> + //CHECK: [[bcast:%.+]] = vector.bitcast [[shapecast]] : vector<4x32xf16> to vector<4x16xf32> + //CHECK: [[data:%.+]] = vector.transpose [[bcast]], [1, 0] : vector<4x16xf32> to vector<16x4xf32> + //CHECK: [[tdesc2:%.+]] = xegpu.create_tdesc [[view]], [[offsets]] : memref<512xf32, 3>, vector<16xindex> -> !xegpu.tensor_desc<16x4xf32, #xegpu.scatter_tdesc_attr> + //CHECK: xegpu.store [[data]], [[tdesc2]], [[mask]] : vector<16x4xf32>, !xegpu.tensor_desc<16x4xf32, #xegpu.scatter_tdesc_attr>, vector<16xi1> + %tdesc = xegpu.create_nd_tdesc %A[0, 0] : memref<64x64xf16> -> !xegpu.tensor_desc<8x16xf16> + %data = xegpu.load_nd %tdesc: !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> + %m = memref.alloca() {alignment = 1024} : memref<2048xi8, 3> + %mem_desc_2 = xegpu.create_mem_desc %m : memref<2048xi8, 3> -> !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> + xegpu.store_matrix %data, %mem_desc_2[8, 16] : vector<8x16xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout> + xegpu.store_nd %data, %tdesc: vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16> + gpu.return + } + +} diff --git a/test/Integration/Dialect/XeGPU/VC/sg_coop_transpose.mlir b/test/Integration/Dialect/XeGPU/VC/sg_coop_transpose.mlir new file mode 100644 index 000000000..962a75e29 --- /dev/null +++ b/test/Integration/Dialect/XeGPU/VC/sg_coop_transpose.mlir @@ -0,0 +1,87 @@ +// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: --runner mlir-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck +// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc.pp \ +// RUN: --runner mlir-runner -e main \ +// RUN: --entry-point-result=void \ +// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck + +module @gemm attributes {gpu.container_module} { + func.func @test(%A: memref<32x32xf16>) -> memref<32x32xf16> attributes {llvm.emit_c_interface} { + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c4 = arith.constant 4 : index + %A_gpu = gpu.alloc host_shared () : memref<32x32xf16> + memref.copy %A, %A_gpu : memref<32x32xf16> to memref<32x32xf16> + %B_gpu = gpu.alloc host_shared () : memref<32x32xf16> + gpu.launch_func @test_kernel::@test_kernel blocks in (%c1, %c1, %c1) threads in (%c4, %c2, %c1) args(%A_gpu : memref<32x32xf16>, %B_gpu : memref<32x32xf16>) + gpu.dealloc %A_gpu : memref<32x32xf16> + return %B_gpu : memref<32x32xf16> + } + gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce, api=OpenCL, #spirv.resource_limits<>>} { + gpu.func @test_kernel(%A: memref<32x32xf16>, %B: memref<32x32xf16>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + + %tid_x = gpu.thread_id x + %tid_y = gpu.thread_id y + %m = arith.muli %tid_x, %c8 : index + %n = arith.muli %tid_y, %c16 : index + + %a_tile = xegpu.create_nd_tdesc %A[%m, %n] : memref<32x32xf16> -> !xegpu.tensor_desc<8x16xf16> + %a = xegpu.load_nd %a_tile : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16> + + %slm = memref.alloc() : memref<2048xi8, 3> + + %mem_desc_store = xegpu.create_mem_desc %slm : memref<2048xi8, 3> -> !xegpu.mem_desc<32x32xf16, #xegpu.mem_layout> + xegpu.store_matrix %a, %mem_desc_store[%m, %n] : vector<8x16xf16>, !xegpu.mem_desc<32x32xf16, #xegpu.mem_layout>, index, index + + gpu.barrier + + %mem_desc_load = xegpu.create_mem_desc %slm : memref<2048xi8, 3> -> !xegpu.mem_desc<32x32xf16, #xegpu.mem_layout> + %d = xegpu.load_matrix %mem_desc_load[%m, %n] : !xegpu.mem_desc<32x32xf16, #xegpu.mem_layout>, index, index -> vector<8x16xf16> + + %b_tile = xegpu.create_nd_tdesc %B[%m, %n] : memref<32x32xf16> -> !xegpu.tensor_desc<8x16xf16> + xegpu.store_nd %d, %b_tile: vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16> + gpu.return + } + } + func.func @main() attributes {llvm.emit_c_interface} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %cf_0 = arith.constant 0.0 : f16 + %cf_1 = arith.constant 1.0 : f16 + %A = memref.alloc() : memref<32x32xf16> + %Ref = memref.alloc() : memref<32x32xf32> + + // intialize matrix A ; A[i, j] = j + scf.for %i = %c0 to %c32 step %c1 { + scf.for %j = %c0 to %c32 step %c1 { + %mul = arith.muli %i, %c32 : index + %add = arith.addi %mul, %j : index + %t = index.castu %add : index to i16 + %val = arith.uitofp %t : i16 to f16 + memref.store %val, %A[%i, %j] : memref<32x32xf16> + %t32 = index.castu %add : index to i32 + %val32 = arith.uitofp %t32 : i32 to f32 + memref.store %val32, %Ref[%j, %i] : memref<32x32xf32> + } + } + + %B = call @test(%A) : (memref<32x32xf16>) -> memref<32x32xf16> + %cast = memref.cast %B : memref<32x32xf16> to memref<*xf16> + // call @printMemrefF16(%cast) : (memref<*xf16>) -> () + %cast_ref = memref.cast %Ref : memref<32x32xf32> to memref<*xf32> + // CHECK: [ALLCLOSE: TRUE] + call @printAllcloseF16(%cast, %cast_ref) : (memref<*xf16>, memref<*xf32>) -> () + memref.dealloc %A : memref<32x32xf16> + return + } + func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface} + func.func private @printMemrefF16(memref<*xf16>) attributes {llvm.emit_c_interface} + func.func private @printAllcloseF16(memref<*xf16>, memref<*xf32>) attributes {llvm.emit_c_interface} +} diff --git a/test/Integration/Dialect/XeGPU/VC/xegpu-to-func-vc.pp b/test/Integration/Dialect/XeGPU/VC/xegpu-to-func-vc.pp index 1c6e3fbda..8ba93d194 100644 --- a/test/Integration/Dialect/XeGPU/VC/xegpu-to-func-vc.pp +++ b/test/Integration/Dialect/XeGPU/VC/xegpu-to-func-vc.pp @@ -1,13 +1,14 @@ -// gpu dialect with intel intrinsic functions (func dialect) to -// llvm dialect (for host code) and -// spirv dialect (for device code) lowering pipeline. -// Ready for imex runner starting from GPU dialect. builtin.module( - gpu.module(imex-xegpu-hoist-transpose, - imex-xegpu-apply-vnni-transformation, - imex-xegpu-optimize-transpose) cse - gpu.module(convert-math-to-vc{enable-high-precision-interim-calculation=true} + gpu.module( + imex-xegpu-materialize-matrix-op + cse + canonicalize + imex-xegpu-hoist-transpose + imex-xegpu-apply-vnni-transformation + imex-xegpu-optimize-transpose + cse + convert-math-to-vc{enable-high-precision-interim-calculation=true} convert-xegpu-to-vc) cse canonicalize @@ -35,4 +36,3 @@ convert-gpux-to-llvm lower-affine reconcile-unrealized-casts) -// End