Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index cb4b8c2468d7..4dcf95a0f87a 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -573,6 +573,10 @@ def XeGPU_MemLayoutAttr : XeGPUAttr<"MemLayout", "mem_layout"> {
return getAttrs().getAs<ArrayAttr>("stride");
}

+ ArrayAttr getBlockAttr() {
+ return getAttrs().getAs<ArrayAttr>("block");
+ }
+
}];

}
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index f8b371db498e..93642c2166e1 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -232,7 +232,7 @@ def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "m
return MemDescType::get(getContext(), shape.value_or(getShape()), elementType, getMemLayout());
}

- ArrayAttr getStrides() {
+ ArrayAttr getStridesAttr() {
auto layout = getMemLayout();
if (layout && layout.hasAttr("stride")) {
return layout.getStrides();
@@ -245,6 +245,106 @@ def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "m
Builder builder(getContext());
return builder.getI64ArrayAttr(defaultStrides);
}
+
+ /// Heuristic to determine if the MemDesc uses column-major layout,
+ /// based on the rank and the value of the first stride dimension.
+ bool isColMajor() {
+ auto dim0 = dyn_cast<IntegerAttr>(getStridesAttr()[0]);
+ return getRank() == 2 && dim0 && dim0.getInt() == 1;
+ }
+
+ // get the Blocking shape for a MemDescType, Which is represented
+ // as an attribute in MemDescType. By default it is the shape
+ // of the mdescTy
+ SmallVector<int64_t> getBlockSize() {
+ SmallVector<int64_t> size(getShape());
+ MemLayoutAttr layout = getMemLayout();
+ if (layout && layout.hasAttr("block")) {
+ ArrayAttr attr = layout.getBlockAttr();
+ size.clear();
+ llvm::for_each(attr, [&](Attribute elem) {
+ if (auto intElem = dyn_cast<IntegerAttr>(elem))
+ size.push_back(intElem.getInt());
+ });
+ }
+ return size;
+ }
+
+ // Get strides as vector of integer.
+ // If it contains block attribute, the strides are blocked strides.
+ //
+ // The blocking is applied against the original matrix shape
+ // so that the linear offset is not impacted by the subview.
+ //
+ // It first computes the original matrix shape using the stride info,
+ // then computes the number of blocks in each dimension of original shape,
+ // then compute the outer block shape and stride,
+ // then combines the inner and outer block shape and stride
+ // e.g. for mem_desc<32x256xf16, @block=[16, 8], @strides=[1, 32]>
+ // its memory layout tuple is ([2,32,16,8],[128,256,1,8])
+ // for mem_desc<256x32xf16, @block=[8, 16]> with default @stride[32, 1]
+ // its memory layout tuple is ([32,2,8,16],[256,128,16,1])
+ SmallVector<int64_t> getStrides() {
+
+ SmallVector<int64_t> matrixShape(getShape().begin(),
+ getShape().end());
+
+ ArrayAttr strideAttr = getStridesAttr();
+ SmallVector<int64_t> strides;
+ for (Attribute attr : strideAttr.getValue()) {
+ strides.push_back(cast<IntegerAttr>(attr).getInt());
+ }
+
+ SmallVector<int64_t> innerBlkShape = getBlockSize();
+ if (innerBlkShape.empty())
+ return strides;
+
+ SmallVector<int, 4> perm = llvm::to_vector<4>(
+ llvm::seq<int>(0, strides.size()));
+ llvm::sort(perm, [&](int a, int b) { return strides[a] < strides[b]; });
+
+ assert(strides[perm[0]] == 1 && "inner most dim must have stride 1");
+
+ SmallVector<int64_t> innerBlkStride(innerBlkShape.size());
+ innerBlkStride[perm[0]] = 1;
+ for (size_t i = 1; i < perm.size(); ++i)
+ innerBlkStride[perm[i]] =
+ innerBlkStride[perm[i - 1]] * innerBlkShape[perm[i - 1]];
+
+ // compute the original matrix shape using the stride info
+ // and compute the number of blocks in each dimension
+ // The shape of highest dim can't be derived from stride info,
+ // but doesn't impact the stride computation for blocked layout.
+ SmallVector<int64_t> matrixShapeOrig(matrixShape.size());
+ SmallVector<int64_t> BlkShapeOrig(matrixShape.size());
+ for (size_t i = 0; i < perm.size() - 1; ++i) {
+ matrixShapeOrig[perm[i]] = strides[perm[i + 1]] / strides[perm[i]];
+ BlkShapeOrig[perm[i]] = matrixShapeOrig[perm[i]] / innerBlkShape[perm[i]];
+ }
+
+ int64_t innerBlkSize = 1;
+ for (auto s : innerBlkShape)
+ innerBlkSize *= s;
+
+ SmallVector<int64_t> outerBlkStride(matrixShape.size());
+ outerBlkStride[perm[0]] = innerBlkSize;
+ for (size_t i = 0; i < perm.size() - 1; ++i) {
+ outerBlkStride[perm[i + 1]] =
+ outerBlkStride[perm[i]] * BlkShapeOrig[perm[i]];
+ }
+
+ // combine the inner and outer strides
+ SmallVector<int64_t> blockedStrides;
+ blockedStrides.append(outerBlkStride.begin(), outerBlkStride.end());
+ blockedStrides.append(innerBlkStride.begin(), innerBlkStride.end());
+ return blockedStrides;
+ }
+ /// Generates instructions to compute the linearize offset
+ // if the memory descriptor is blocked, it returns linearize offset based on the blocked layout
+ // the strides of memory descriptor is always considered regardless of blocked or not
+ Value getLinearOffsets(OpBuilder &builder,
+ Location loc, ArrayRef<OpFoldResult> offsets);
+
}];

let hasCustomAssemblyFormat = true;
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 8ea8cb1f4597..808270534459 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -703,6 +703,89 @@ void MemLayoutAttr::print(AsmPrinter &printer) const {
}
printer << ">";
}
+// a helper utility to perform binary operation on OpFoldResult.
+// If both a and b are attributes, it will simply return the result.
+// Otherwise, the corresponding arith op will be generated, and an
+// contant op will be created if one of them is an attribute.
+template <typename ArithOp>
+OpFoldResult genBinOp(OpFoldResult a, OpFoldResult b, Location loc,
+ OpBuilder &builder) {
+ auto aVal = getValueOrCreateConstantIndexOp(builder, loc, a);
+ auto bVal = getValueOrCreateConstantIndexOp(builder, loc, b);
+ return builder.create<ArithOp>(loc, aVal, bVal).getResult();
+}
+
+// a helper utility to perform division operation on OpFoldResult and int64_t.
+#define div(a, b) \
+ genBinOp<arith::DivSIOp>(a, builder.getIndexAttr(b), loc, builder)
+
+// a helper utility to perform reminder operation on OpFoldResult and int64_t.
+#define rem(a, b) \
+ genBinOp<arith::RemSIOp>(a, builder.getIndexAttr(b), loc, builder)
+
+// a helper utility to perform multiply operation on OpFoldResult and int64_t.
+#define mul(a, b) \
+ genBinOp<arith::MulIOp>(a, builder.getIndexAttr(b), loc, builder)
+
+// a helper utility to perform addition operation on two OpFoldResult.
+#define add(a, b) genBinOp<arith::AddIOp>(a, b, loc, builder)
+
+// block the given offsets according to the block shape
+// say the original offset is [y, x], and the block shape is [By, Bx],
+// then the blocked offset is [y/By, x/Bx, y%By, x%Bx]
+SmallVector<OpFoldResult> getBlockedOffsets(OpBuilder &builder, Location loc,
+ ArrayRef<OpFoldResult> offsets,
+ ArrayRef<int64_t> blockShape) {
+
+ assert(offsets.size() == blockShape.size() &&
+ "offsets and blockShape must have the same size");
+ SmallVector<OpFoldResult> blockedOffsets;
+ SmallVector<OpFoldResult> divs, rems;
+
+ for (auto [offset, block] : llvm::zip(offsets, blockShape)) {
+ divs.push_back(div(offset, block));
+ rems.push_back(rem(offset, block));
+ }
+ blockedOffsets.append(divs.begin(), divs.end());
+ blockedOffsets.append(rems.begin(), rems.end());
+
+ return blockedOffsets;
+}
+
+// Calculate the linear offset using the blocked offsets and stride
+Value MemDescType::getLinearOffsets(OpBuilder &builder, Location loc,
+ ArrayRef<OpFoldResult> offsets) {
+
+ SmallVector<int64_t> blockShape = getBlockSize();
+ SmallVector<int64_t> strides = getStrides();
+ if (!blockShape.empty()) {
+ assert(offsets.size() == blockShape.size() &&
+ "offsets and blockShape must have the same size");
+ // say the original offset is [y, x], and the block shape is [By, Bx],
+ // then the blocked offset is [y/By, x/Bx, y%By, x%Bx]
+ SmallVector<OpFoldResult> blockedOffsets;
+ SmallVector<OpFoldResult> divs, rems;
+
+ for (auto [offset, block] : llvm::zip(offsets, blockShape)) {
+ divs.push_back(div(offset, block));
+ rems.push_back(rem(offset, block));
+ }
+ blockedOffsets.append(divs.begin(), divs.end());
+ blockedOffsets.append(rems.begin(), rems.end());
+
+ offsets = blockedOffsets;
+ }
+
+ // Start with initial value as matrix descriptor's base offset.
+ Value linearOffset = arith::ConstantIndexOp::create(builder, loc, 0);
+ for (size_t i = 0; i < offsets.size(); ++i) {
+ OpFoldResult mulResult = mul(offsets[i], strides[i]);
+ Value mulVal = getValueOrCreateConstantIndexOp(builder, loc, mulResult);
+ linearOffset = arith::AddIOp::create(builder, loc, mulVal, linearOffset);
+ }
+
+ return linearOffset;
+}

} // namespace xegpu
} // namespace mlir
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index ecee53c56a54..ba38d74f3c7f 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -1069,7 +1069,7 @@ LogicalResult MemDescSubviewOp::verify() {
[](auto p) { return std::get<0>(p) > std::get<1>(p); }))
return emitOpError("result shape must not exceed source shape.");

- if (srcTy.getStrides() != resTy.getStrides())
+ if (srcTy.getStridesAttr() != resTy.getStridesAttr())
return emitOpError("result must inherit the source strides.");

return success();
115 changes: 115 additions & 0 deletions build_tools/patches/0012-memref-view-lowering-spirv.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index 2e00b42f4a56..15529d4c9b54 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -393,8 +393,65 @@ AllocOpPattern::matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor,
return success();
}

+class ViewOpPattern final : public OpConversionPattern<memref::ViewOp> {
+public:
+ using OpConversionPattern<memref::ViewOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(memref::ViewOp operation, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
+
//===----------------------------------------------------------------------===//
-// AllocOp
+// ViewOp
+// %view = memref.view %alloc[%c0][] : memref<2048xi8, 3> to memref<512xf32, 3>
+// spirv.GlobalVariable @__workgroup_mem__1 : !spirv.ptr<!spirv.array<2048 x i8>, Workgroup>
+// %1 = spirv.Bitcast @__workgroup_mem__1 : !spirv.ptr<!spirv.array<2048 x i8>, Workgroup> to !spirv.ptr<!spirv.array<512 x f32>, Workgroup>
+//
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+ViewOpPattern::matchAndRewrite(memref::ViewOp operation, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ MemRefType ToType = operation.getType();
+
+ // insert spirv.bitcast which cast the pointer type from spirvFromType to spirvToType
+ Type spirvToType = getTypeConverter()->convertType(ToType);
+ if (!spirvToType)
+ return rewriter.notifyMatchFailure(operation, "type conversion failed");
+
+ // need to limit the case where the source is a memref with element type i8
+ // the result memref must have static sizes.
+ MemRefType FromType = cast<MemRefType>(operation.getSource().getType());
+ if (!FromType.getElementType().isInteger(8) || !FromType.hasStaticShape())
+ return rewriter.notifyMatchFailure(operation, "unhandled view type");
+ if (!ToType.hasStaticShape())
+ return rewriter.notifyMatchFailure(operation, "unhandled view type");
+
+ // get base pointer from adaptor.getSource()
+ Value basePtr = adaptor.getSource();
+ // get the offset
+ Value offset = adaptor.getByteShift();
+ if (offset) {
+ Location loc = operation.getLoc();
+ auto *spirvTypeConverter = getTypeConverter<SPIRVTypeConverter>();
+ Type materializedIndexType = spirvTypeConverter->getIndexType();
+ Value basePtrAsInt = rewriter.createOrFold<spirv::ConvertPtrToUOp>(loc, materializedIndexType, basePtr);
+ Value newPtrAsInt = rewriter.createOrFold<spirv::IAddOp>(loc, materializedIndexType, basePtrAsInt, offset);
+ Value newPtr = rewriter.createOrFold<spirv::ConvertUToPtrOp>(loc, basePtr.getType(), newPtrAsInt);
+ basePtr = newPtr;
+ }
+
+ Location loc = operation.getLoc();
+ Value castOp = rewriter.createOrFold<spirv::BitcastOp>(
+ loc, spirvToType, basePtr);
+ rewriter.replaceOp(operation, castOp);
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// AtomicRMWOp
//===----------------------------------------------------------------------===//

LogicalResult
@@ -1071,7 +1128,7 @@ LogicalResult ExtractAlignedPointerAsIndexOpPattern::matchAndRewrite(
namespace mlir {
void populateMemRefToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns) {
- patterns.add<AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern,
+ patterns.add<AllocaOpPattern, AllocOpPattern, ViewOpPattern, AtomicRMWOpPattern,
DeallocOpPattern, IntLoadOpPattern, ImageLoadOpPattern,
IntStoreOpPattern, LoadOpPattern, MemorySpaceCastOpPattern,
StoreOpPattern, ReinterpretCastPattern, CastPattern,
diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
index e6321e99693a..7308f000cdbe 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
@@ -446,6 +446,30 @@ func.func @cast_to_static_zero_elems(%arg: memref<?xf32, #spirv.storage_class<Cr

// -----

+// Check memref.view
+
+module attributes {
+ spirv.target_env = #spirv.target_env<
+ #spirv.vce<v1.0,
+ [
+ Kernel, Addresses, GenericPointer, Int8, Int64, StorageBuffer8BitAccess, Shader], [SPV_KHR_8bit_storage]>, #spirv.resource_limits<>>
+} {
+
+// CHECK-LABEL: func @memory_view
+// CHECK-SAME: (%[[ARG0:.+]]: memref<2048xi8, #spirv.storage_class<Function>>)
+func.func @memory_view(%arg0: memref<2048xi8, #spirv.storage_class<Function>>)
+ -> memref<512xf32, #spirv.storage_class<Function>> {
+// CHECK-DAG: %[[ARG0_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<2048xi8, #spirv.storage_class<Function>> to !spirv.ptr<!spirv.array<2048 x i8>, Function>
+// CHECK: %[[BITCAST:.+]] = spirv.Bitcast %[[ARG0_CAST]] : !spirv.ptr<!spirv.array<2048 x i8>, Function> to !spirv.ptr<!spirv.array<512 x f32>, Function>
+ %c0 = arith.constant 0: index
+ %view = memref.view %arg0[%c0][] : memref<2048xi8, #spirv.storage_class<Function>> to memref<512xf32, #spirv.storage_class<Function>>
+ return %view : memref<512xf32, #spirv.storage_class<Function>>
+}
+
+}
+
+// -----
+
module attributes {
spirv.target_env = #spirv.target_env<#spirv.vce<v1.5, [Kernel, Int64, Addresses, PhysicalStorageBufferAddresses], []>, #spirv.resource_limits<>>
} {
1 change: 1 addition & 0 deletions include/imex/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ std::unique_ptr<mlir::Pass> createHoistTransposePass();
std::unique_ptr<mlir::Pass> createVnniTransformationPass();
std::unique_ptr<mlir::Pass> createEmulateNonNativeBF16Pass();
std::unique_ptr<mlir::Pass> createTileLoopsPass();
std::unique_ptr<mlir::Pass> createMaterializeMatrixOpPass();

#define GEN_PASS_DECL
#include "imex/Transforms/Passes.h.inc"
Expand Down
15 changes: 15 additions & 0 deletions include/imex/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -226,4 +226,19 @@ def TileLoops : Pass<"tile-loops", "::mlir::func::FuncOp"> {
];
}

def MaterializeMatrixOp: Pass<"imex-xegpu-materialize-matrix-op"> {
let summary = "materialize matrix ops for Xe2/Xe3";
let description = [{
Coverts mem_desc operations (load_matrix, store_matrix) into other xegpu memory operations
(load/store chunk, 1d block load) over shared local memory. It computes physical address
using the matrix's layout attributes (@strides, @block) and logical lane coordinates.
}];
let constructor = "imex::createMaterializeMatrixOpPass()";
let dependentDialects = [
"::mlir::xegpu::XeGPUDialect",
"::mlir::vector::VectorDialect",
"::mlir::memref::MemRefDialect"
];
}

#endif // _IMEX_TRANSFORMS_PASSES_TD_INCLUDED_
Loading
Loading