diff --git a/include/gc/Transforms/Passes.td b/include/gc/Transforms/Passes.td index 1f30b877..141249f7 100644 --- a/include/gc/Transforms/Passes.td +++ b/include/gc/Transforms/Passes.td @@ -261,4 +261,12 @@ def LowerToTileVector : Pass<"lower-to-tile-vector", "func::FuncOp"> { ]; } +def DistributeXeGPU : Pass<"distribute-to-simt", "func::FuncOp"> { + let summary = "Convert simd-like module to SIMT"; + let dependentDialects = [ + "::mlir::vector::VectorDialect", + "::mlir::gpu::GPUDialect", + ]; +} + #endif // GC_DIALECT_GC_PASSES diff --git a/lib/gc/Transforms/GPU/CMakeLists.txt b/lib/gc/Transforms/GPU/CMakeLists.txt index f4b286b9..6cb299ce 100644 --- a/lib/gc/Transforms/GPU/CMakeLists.txt +++ b/lib/gc/Transforms/GPU/CMakeLists.txt @@ -17,6 +17,7 @@ gc_add_mlir_library(GcGpuPasses GpuToGpuOcl.cpp LinalgToXeGPU.cpp Pipeline.cpp + XeGPUDistribute.cpp DEPENDS GraphCompilerPassIncGen diff --git a/lib/gc/Transforms/GPU/XeGPUDistribute.cpp b/lib/gc/Transforms/GPU/XeGPUDistribute.cpp new file mode 100644 index 00000000..6141d737 --- /dev/null +++ b/lib/gc/Transforms/GPU/XeGPUDistribute.cpp @@ -0,0 +1,487 @@ +//===- XeGPUDistribute.cpp - XeGPU ditribute ops to work items --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "gc/Transforms/Passes.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/VectorDistribution.h" +#include "mlir/Dialect/Vector/Utils/VectorUtils.h" +#include "mlir/Dialect/XeGPU/IR/XeGPU.h" +#include "mlir/Dialect/XeGPU/Transforms/Transforms.h" +#include "mlir/IR/Value.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/Support/Debug.h" + +namespace mlir { +namespace gc { +#define GEN_PASS_DEF_DISTRIBUTEXEGPU +#include "gc/Transforms/Passes.h.inc" +} // namespace gc +} // namespace mlir + +#define DEBUG_TYPE "xegpu-distribute" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") + +using namespace mlir; + +namespace { +mlir::OpOperand *getWarpResult(vector::WarpExecuteOnLane0Op warpOp, + const std::function &fn) { + auto yield = cast( + warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); + for (mlir::OpOperand &yieldOperand : yield->getOpOperands()) { + Value yieldValues = yieldOperand.get(); + Operation *definedOp = yieldValues.getDefiningOp(); + if (definedOp && fn(definedOp)) { + if (!warpOp.getResult(yieldOperand.getOperandNumber()).use_empty()) + return &yieldOperand; + } + } + return {}; +} + +vector::WarpExecuteOnLane0Op moveRegionToNewWarpOpAndReplaceReturns( + RewriterBase &rewriter, vector::WarpExecuteOnLane0Op warpOp, + ValueRange newYieldedValues, TypeRange newReturnTypes) { + // Create a new op before the existing one, with the extra operands. + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(warpOp); + auto newWarpOp = rewriter.create( + warpOp.getLoc(), newReturnTypes, warpOp.getLaneid(), warpOp.getWarpSize(), + warpOp.getArgs(), warpOp.getBody()->getArgumentTypes()); + + Region &opBody = warpOp.getBodyRegion(); + Region &newOpBody = newWarpOp.getBodyRegion(); + Block &newOpFirstBlock = newOpBody.front(); + rewriter.inlineRegionBefore(opBody, newOpBody, newOpBody.begin()); + rewriter.eraseBlock(&newOpFirstBlock); + assert(newWarpOp.getWarpRegion().hasOneBlock() && + "expected WarpOp with single block"); + + auto yield = + cast(newOpBody.getBlocks().begin()->getTerminator()); + + rewriter.modifyOpInPlace( + yield, [&]() { yield.getOperandsMutable().assign(newYieldedValues); }); + return newWarpOp; +} + +vector::WarpExecuteOnLane0Op moveRegionToNewWarpOpAndAppendReturns( + RewriterBase &rewriter, vector::WarpExecuteOnLane0Op warpOp, + ValueRange newYieldedValues, TypeRange newReturnTypes, + llvm::SmallVector &indices) { + SmallVector types(warpOp.getResultTypes().begin(), + warpOp.getResultTypes().end()); + auto yield = cast( + warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); + llvm::SmallSetVector yieldValues(yield.getOperands().begin(), + yield.getOperands().end()); + for (auto newRet : llvm::zip(newYieldedValues, newReturnTypes)) { + if (yieldValues.insert(std::get<0>(newRet))) { + types.push_back(std::get<1>(newRet)); + indices.push_back(yieldValues.size() - 1); + } else { + // If the value already exit the region don't create a new output. + for (auto [idx, yieldOperand] : + llvm::enumerate(yieldValues.getArrayRef())) { + if (yieldOperand == std::get<0>(newRet)) { + indices.push_back(idx); + break; + } + } + } + } + yieldValues.insert(newYieldedValues.begin(), newYieldedValues.end()); + vector::WarpExecuteOnLane0Op newWarpOp = + moveRegionToNewWarpOpAndReplaceReturns(rewriter, warpOp, + yieldValues.getArrayRef(), types); + rewriter.replaceOp(warpOp, + newWarpOp.getResults().take_front(warpOp.getNumResults())); + return newWarpOp; +} + +bool divisible(APInt lhs, APInt rhs) { return !lhs.urem(rhs); } + +/// Clone a create_nd_tdesc feeding into vector.yield op for the enclosing +/// `vector.warp_execute_on_lane_0` and put it after the warp op. +/// The warp op will still contain the original op that will not be used by the +/// yield op (and should be cleaned up later with dce). The yield op will bypass +/// the create_nd_tdesc's arguments. +/// The rewrite will create a subview of the size used by a single work item and +/// appropriate offset. The distributed create_nd_tdesc points into the subview +/// without offset. The tensor descriptor types is distributed according to +/// sg_map attribute. +/// +/// Example: +/// +/// ``` +/// #sg_map_8 = #xegpu.sg_map +/// %r = vector.warp_execute_on_lane_0(%laneid) -> +/// (!xegpu.tensor_desc<4x8xf32>) { +/// ... +/// %td = xegpu.create_nd_tdesc %arg0[0, 0] +/// : memref<4x8xf32> -> !xegpu.tensor_desc<4x8xf32> +/// vector.yield %td +/// } +/// ``` +/// To +/// ``` +/// %r:2 = vector.warp_execute_on_lane_0(%laneid) -> () { +/// ... +/// %dead = xegpu.create_nd_tdesc %arg0[0, 0] +/// : memref<4x8xf32> -> !xegpu.tensor_desc<4x8xf32> +/// vector.yield %arg0, %dead +/// } +/// %view = memref.subview %r#0[0, %laneid] [4, 1] [1, 1] +/// : memref<4x8xf32> to memref<4x1xf32> +/// %td = xegpu.create_nd_tdesc %view[0, 0]: memref<4x1xf32> +/// -> !xegpu.tensor_desc<4x1xf32> +/// +/// ``` +struct WarpOpTensorDescOp final + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(vector::WarpExecuteOnLane0Op warpOp, + PatternRewriter &rewriter) const override; +}; + +/// Sink a store_nd feeding into vector.yield op for the enclosing +/// `vector.warp_execute_on_lane_0`. In case arguments for the store are passed +/// through the warp op interface they would be propagated as returned values. +/// Both the stored vector type and tensor descriptor types are distributed +/// according to sg_map attribute. +/// +/// Example: +/// +/// ``` +/// #sg_map_8 = #xegpu.sg_map +/// vector.warp_execute_on_lane_0(%laneid) -> () { +/// ... +/// xegpu.store_nd %arg0, %arg1: vector<4x8xf32>, +/// !xegpu.tensor_desc<4x8xf32> +/// vector.yield +/// } +/// ``` +/// To +/// ``` +/// %r = vector.warp_execute_on_lane_0(%laneid) -> () { +/// ... +/// vector.yield +/// } +/// xegpu.store_nd %arg0, %arg1: vector<4x1xf32>, !xegpu.tensor_desc<4x1xf32> +/// +/// ``` +struct WarpOpStoreNd final + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(vector::WarpExecuteOnLane0Op warpOp, + PatternRewriter &rewriter) const override; +}; + +/// Clone a load_nd feeding into vector.yield op for the enclosing +/// `vector.warp_execute_on_lane_0` and put it after the warp op. +/// The warp op will still contain the original op that will not be used by the +/// yield op (and should be cleaned up later with dce). The yield op will bypass +/// the load's arguments. +/// Both the loaded vector type and tensor descriptor types are distributed +/// according to sg_map attribute. +/// +/// Example: +/// +/// ``` +/// #sg_map_8 = #xegpu.sg_map +/// %r = vector.warp_execute_on_lane_0(%laneid) -> +/// (!xegpu.tensor_desc<4x8xf32>) { +/// ... +/// %ld = xegpu.load_nd %arg0, %arg1: !xegpu.tensor_desc<4x8xf32>, +/// vector<4x8xf32> vector.yield %ld +/// } +/// ``` +/// To +/// ``` +/// %r:2 = vector.warp_execute_on_lane_0(%laneid) -> () { +/// ... +/// %dead = xegpu.load_nd %arg0, %arg1: +/// !xegpu.tensor_desc<4x8xf32>, vector<4x8xf32> +/// vector.yield %arg0, %arg1 +/// } +/// xegpu.store_nd %r#0, %r#1: vector<4x1xf32>, !xegpu.tensor_desc<4x1xf32> +/// +/// ``` +struct WarpOpLoadNd final + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(vector::WarpExecuteOnLane0Op warpOp, + PatternRewriter &rewriter) const override; +}; + +FailureOr getDistributedVectorType(VectorType originalT, + xegpu::SGMapAttr sgMap) { + llvm::SmallVector distributedShape; + auto layout = sgMap.getWiLayout(); + auto shape = originalT.getShape(); + for (const auto [l, o] : llvm::zip_equal(layout, shape)) { + if (!divisible(APInt(64, o), APInt(64, l))) + return failure(); + distributedShape.push_back(o / l); + } + auto newVectorType = + VectorType::get(distributedShape, originalT.getElementType(), + originalT.getScalableDims()); + return newVectorType; +} + +FailureOr +getDistributedTensorDescType(xegpu::TensorDescType originalT, + xegpu::SGMapAttr sgMap, + xegpu::MemorySpace memSpace) { + llvm::SmallVector distributedShape; + auto layout = sgMap.getWiLayout(); + auto shape = originalT.getShape(); + for (const auto [l, o] : llvm::zip_equal(layout, shape)) { + if (!divisible(APInt(64, o), APInt(64, l))) + return failure(); + distributedShape.push_back(o / l); + } + xegpu::TensorDescType distributedDescType; + if (originalT.isScattered()) { + + distributedDescType = xegpu::TensorDescType::get( + distributedShape, originalT.getElementType(), originalT.getChunkSize(), + originalT.getMemorySpace(), originalT.getSGMapAttr()); + } else { + distributedDescType = xegpu::TensorDescType::get( + distributedShape, originalT.getElementType(), + originalT.getBoundaryCheck(), originalT.getArrayLength(), + originalT.getMemorySpace(), originalT.getSGMapAttr()); + } + return distributedDescType; +} +} // namespace + +LogicalResult +WarpOpStoreNd::matchAndRewrite(vector::WarpExecuteOnLane0Op warpOp, + PatternRewriter &rewriter) const { + auto yield = cast( + warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); + Operation *lastNode = yield->getPrevNode(); + auto storeOp = dyn_cast_or_null(lastNode); + if (!storeOp) + return failure(); + + auto origType = storeOp.getTensorDescType(); + xegpu::SGMapAttr sgMap = origType.getSGMapAttr(); + if (!sgMap) + return rewriter.notifyMatchFailure( + storeOp, "the source tensor descriptor lacks sg_map attribute"); + + if (storeOp.getTensorDescType().getShape().size() != 2) + return rewriter.notifyMatchFailure(storeOp, "unsupported shape"); + DBGS() << "Matched store_nd: " << storeOp << "\n"; + + auto distributedTypeOrFailure = + getDistributedVectorType(storeOp.getValueType(), sgMap); + if (failed(distributedTypeOrFailure)) + return rewriter.notifyMatchFailure(storeOp, + "Failed to distribute the type"); + VectorType newVectorType = distributedTypeOrFailure.value(); + + auto distributedDescTypeOrFailure = getDistributedTensorDescType( + storeOp.getTensorDescType(), sgMap, + storeOp.getTensorDescType().getMemorySpace()); + if (failed(distributedDescTypeOrFailure)) + return rewriter.notifyMatchFailure(storeOp, + "Failed to distribute the desc type"); + xegpu::TensorDescType newTDescType = distributedDescTypeOrFailure.value(); + + SmallVector newRetIndices; + vector::WarpExecuteOnLane0Op newWarpOp = + moveRegionToNewWarpOpAndAppendReturns( + rewriter, warpOp, + ValueRange{storeOp.getTensorDesc(), storeOp.getValue()}, + TypeRange{newTDescType, newVectorType}, newRetIndices); + + rewriter.setInsertionPointAfter(newWarpOp); + auto newStoreOp = + cast(rewriter.clone(*storeOp.getOperation())); + rewriter.eraseOp(storeOp); + newStoreOp.getTensorDescMutable().assign( + newWarpOp.getResult(newRetIndices[0])); + newStoreOp.getValueMutable().assign(newWarpOp.getResult(newRetIndices[1])); + + return success(); +} + +LogicalResult WarpOpLoadNd::matchAndRewrite(vector::WarpExecuteOnLane0Op warpOp, + PatternRewriter &rewriter) const { + OpOperand *operand = getWarpResult(warpOp, [](Operation *op) { + return isa(op) && op->hasOneUse(); + }); + + if (!operand) + return rewriter.notifyMatchFailure(warpOp, + "warp result is not a xegpu::LoadNd op"); + + auto loadOp = operand->get().getDefiningOp(); + + if (loadOp.getPacked()) + return rewriter.notifyMatchFailure( + loadOp, "Packed load distribution not supported"); + + xegpu::TensorDescType origType = loadOp.getTensorDescType(); + xegpu::SGMapAttr sgMap = origType.getSGMapAttr(); + if (!sgMap) + return rewriter.notifyMatchFailure( + loadOp, "the source tensor descriptor lacks sg_map attribute"); + + auto origShape = origType.getShape(); + if (origShape.size() != 2) + return rewriter.notifyMatchFailure(loadOp, "unsupported shape"); + + auto distributedTypeOrFailure = + getDistributedVectorType(loadOp.getType(), sgMap); + if (failed(distributedTypeOrFailure)) + return rewriter.notifyMatchFailure(loadOp, "Failed to distribute the type"); + VectorType newVectorType = distributedTypeOrFailure.value(); + + auto distributedDescTypeOrFailure = + getDistributedTensorDescType(loadOp.getTensorDescType(), sgMap, + loadOp.getTensorDescType().getMemorySpace()); + if (failed(distributedDescTypeOrFailure)) + return rewriter.notifyMatchFailure(loadOp, + "Failed to distribute the desc type"); + xegpu::TensorDescType newTDescType = distributedDescTypeOrFailure.value(); + + unsigned operandIdx = operand->getOperandNumber(); + + SmallVector newRetIndices; + vector::WarpExecuteOnLane0Op newWarpOp = + moveRegionToNewWarpOpAndAppendReturns( + rewriter, warpOp, loadOp.getTensorDesc(), TypeRange{newTDescType}, + newRetIndices); + + rewriter.setInsertionPointAfter(newWarpOp); + + auto newLoadOp = rewriter.create( + loadOp.getLoc(), newVectorType, loadOp.getTensorDesc(), + loadOp.getPackedAttr(), loadOp.getTransposeAttr(), + loadOp.getTransposeBitWidthAttr(), loadOp.getL1HintAttr(), + loadOp.getL2HintAttr(), loadOp.getL3HintAttr()); + + newLoadOp.getTensorDescMutable().assign( + newWarpOp.getResult(newRetIndices[0])); + Value distributedVal = newWarpOp.getResult(operandIdx); + rewriter.replaceAllUsesWith(distributedVal, newLoadOp); + + return success(); +} + +LogicalResult +WarpOpTensorDescOp::matchAndRewrite(vector::WarpExecuteOnLane0Op warpOp, + PatternRewriter &rewriter) const { + OpOperand *operand = getWarpResult(warpOp, [](Operation *op) { + return isa(op) && op->hasOneUse(); + }); + + if (!operand) + return rewriter.notifyMatchFailure( + warpOp, "warp result is not a xegpu::CreateNdDesc op"); + auto descOp = operand->get().getDefiningOp(); + assert(descOp && "desc op must be not null"); + unsigned operandIdx = operand->getOperandNumber(); + + // TODO: is memref uniform in the region + rewriter.setInsertionPoint(warpOp); + auto srcTypedVal = dyn_cast>(descOp.getSource()); + assert(srcTypedVal && "source value must be not null"); + + auto descOffsets = descOp.getMixedOffsets(); + if (descOffsets.size() != 2) + return rewriter.notifyMatchFailure(descOp, + "offsets size is expected to be 2"); + + xegpu::SGMapAttr sgMap = descOp.getType().getSGMapAttr(); + if (!sgMap) + return rewriter.notifyMatchFailure( + descOp, "the tensor descriptor lacks sg_map attribute"); + + auto layout = sgMap.getWiLayout(); + + // Calculate the offset within tensor descriptor for the current lane_id. The + // access to proper element for a work item is done through a lane-specific + // subview (tdesc offsets are used as base, lane shift is added on top). + auto laneid = warpOp.getLaneid(); + auto xDim = + rewriter.create(laneid.getLoc(), layout[0]); + auto shiftx = rewriter.create(laneid.getLoc(), laneid, xDim); + auto shifty = rewriter.create(laneid.getLoc(), laneid, xDim); + + auto basex = getValueOrCreateConstantIndexOp(rewriter, laneid.getLoc(), + descOffsets[0]); + auto basey = getValueOrCreateConstantIndexOp(rewriter, laneid.getLoc(), + descOffsets[1]); + auto offsetx = rewriter.create(laneid.getLoc(), shiftx, basex); + auto offsety = rewriter.create(laneid.getLoc(), shifty, basey); + + auto distributedDescTypeOrFailure = getDistributedTensorDescType( + descOp.getType(), sgMap, descOp.getType().getMemorySpace()); + if (failed(distributedDescTypeOrFailure)) + return rewriter.notifyMatchFailure(descOp, + "Failed to distribute the desc type"); + xegpu::TensorDescType newTDescType = distributedDescTypeOrFailure.value(); + auto distributedShape = newTDescType.getShape(); + // use the base memref strides + SmallVector overwriteStrides = + getAsIndexOpFoldResult(rewriter.getContext(), SmallVector{1, 1}); + SmallVector overwriteSizes = + getAsIndexOpFoldResult(rewriter.getContext(), distributedShape); + + SmallVector newRetIndices; + vector::WarpExecuteOnLane0Op newWarpOp = + moveRegionToNewWarpOpAndAppendReturns( + rewriter, warpOp, descOp.getSource(), descOp.getSourceType(), + newRetIndices); + + rewriter.setInsertionPointAfter(newWarpOp); + auto subview = rewriter.create( + newWarpOp.getLoc(), srcTypedVal, getAsOpFoldResult({offsetx, offsety}), + overwriteSizes, overwriteStrides); + subview.getSourceMutable().assign(newWarpOp.getResult(newRetIndices[0])); + + auto zero = rewriter.create(laneid.getLoc(), 0); + auto newDescOp = rewriter.create( + newWarpOp.getLoc(), newTDescType, subview, + getAsOpFoldResult({zero, zero})); + + Value distributedVal = newWarpOp.getResult(operandIdx); + rewriter.replaceAllUsesWith(distributedVal, newDescOp); + + return success(); +} + +void populateXeGPUDistributePatterns(RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); +} + +struct DistributeXeGPU : public gc::impl::DistributeXeGPUBase { + using Base::Base; + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + populateXeGPUDistributePatterns(patterns); + if (failed( + applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + return signalPassFailure(); + }; +}; diff --git a/test/mlir/test/gc/Transforms/GPU/xegpu-distribute.mlir b/test/mlir/test/gc/Transforms/GPU/xegpu-distribute.mlir new file mode 100644 index 00000000..b8c2465a --- /dev/null +++ b/test/mlir/test/gc/Transforms/GPU/xegpu-distribute.mlir @@ -0,0 +1,81 @@ +// RUN: gc-opt -distribute-to-simt -split-input-file %s | FileCheck %s + +#sg_map_16 = #xegpu.sg_map +#blk_tdesc = #xegpu.block_tdesc_attr + +// CHECK-LABEL: test_store_nd_distribution +// CHECK: %[[laneid:.*]] = gpu.lane_id +// CHECK: %[[res:.*]]:2 = vector.warp_execute_on_lane_0(%[[laneid]])[16] args(%{{.*}}, %{{.*}} : vector<24x32xf16>, !xegpu.tensor_desc<24x32xf16, #xegpu.block_tdesc_attr, #xegpu.sg_map>) +// CHECK-SAME: -> (!xegpu.tensor_desc<24x2xf16, #xegpu.block_tdesc_attr, #xegpu.sg_map>, vector<24x2xf16>) +// CHECK: ^bb0(%[[src:.*]]: vector<24x32xf16>, %[[dst:.*]]: !xegpu.tensor_desc<24x32xf16, #xegpu.block_tdesc_attr, #xegpu.sg_map>) +// CHECK: vector.yield %[[dst]], %[[src]] : !xegpu.tensor_desc<24x32xf16, #xegpu.block_tdesc_attr, #xegpu.sg_map>, vector<24x32xf16> +// CHECK: xegpu.store_nd %[[res]]#1, %[[res]]#0 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : +// CHECK-SAME: vector<24x2xf16>, !xegpu.tensor_desc<24x2xf16, #xegpu.block_tdesc_attr, #xegpu.sg_map> + +func.func @test_store_nd_distribution(%src: vector<24x32xf16>, %dst: !xegpu.tensor_desc<24x32xf16, #blk_tdesc, #sg_map_16>) -> () { + %laneid = gpu.lane_id + vector.warp_execute_on_lane_0(%laneid)[16] + args(%src, %dst: vector<24x32xf16>, !xegpu.tensor_desc<24x32xf16, #blk_tdesc, #sg_map_16>) { + ^bb0(%arg0: vector<24x32xf16>, %arg1: !xegpu.tensor_desc<24x32xf16, #blk_tdesc, #sg_map_16>): + xegpu.store_nd %arg0, %arg1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}>: vector<24x32xf16>, !xegpu.tensor_desc<24x32xf16, #blk_tdesc, #sg_map_16> + } + return +} + +// ----- + +#sg_map_16 = #xegpu.sg_map +#blk_tdesc = #xegpu.block_tdesc_attr + +// CHECK-LABEL: test_load_nd_distribution +// CHECK: %[[laneid:.*]] = gpu.lane_id +// CHECK: %[[res:.*]]:2 = vector.warp_execute_on_lane_0(%[[laneid]])[16] args(%{{.*}} : !xegpu.tensor_desc<24x32xf16, #xegpu.block_tdesc_attr, #xegpu.sg_map>) +// CHECK-SAME: -> (vector<24x2xf16>, !xegpu.tensor_desc<24x2xf16, #xegpu.block_tdesc_attr, #xegpu.sg_map>) +// CHECK: ^bb0(%[[dst:.*]]: !xegpu.tensor_desc<24x32xf16, #xegpu.block_tdesc_attr, #xegpu.sg_map>) +// CHECK: %[[dead:.*]] = xegpu.load_nd +// CHECK: vector.yield %[[dead]], %[[dst]] : vector<24x32xf16>, !xegpu.tensor_desc<24x32xf16, #xegpu.block_tdesc_attr, #xegpu.sg_map> +// CHECK: %[[load:.*]] = xegpu.load_nd %[[res]]#1 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> : +// CHECK-SAME: !xegpu.tensor_desc<24x2xf16, #xegpu.block_tdesc_attr, #xegpu.sg_map> -> vector<24x2xf16> +// CHECK: return %[[load]] + +func.func @test_load_nd_distribution(%dst: !xegpu.tensor_desc<24x32xf16, #blk_tdesc, #sg_map_16>) -> (vector<24x2xf16>) { + %laneid = gpu.lane_id + %r = vector.warp_execute_on_lane_0(%laneid)[16] + args(%dst: !xegpu.tensor_desc<24x32xf16, #blk_tdesc, #sg_map_16>) -> (vector<24x2xf16>) { + ^bb0(%arg0: !xegpu.tensor_desc<24x32xf16, #blk_tdesc, #sg_map_16>): + %0 = xegpu.load_nd %arg0 <{l1_hint = #xegpu.cache_hint, l2_hint = #xegpu.cache_hint}> + : !xegpu.tensor_desc<24x32xf16, #blk_tdesc, #sg_map_16> -> vector<24x32xf16> + vector.yield %0 : vector<24x32xf16> + } + return %r : vector<24x2xf16> +} + +// ----- + +#sg_map_16 = #xegpu.sg_map +#blk_tdesc = #xegpu.block_tdesc_attr + +// CHECK-LABEL: test_create_nd_desc_distribution +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[laneid:.*]] = gpu.lane_id +// CHECK: %[[res:.*]]:2 = vector.warp_execute_on_lane_0(%[[laneid]])[16] args(%{{.*}} : memref<24x32xf16>) +// CHECK-SAME: -> (!xegpu.tensor_desc<24x2xf16, #xegpu.block_tdesc_attr, #xegpu.sg_map>, memref<24x32xf16>) +// CHECK: ^bb0(%[[dst:.*]]: memref<24x32xf16>) +// CHECK: %[[dead:.*]] = xegpu.create_nd_tdesc +// CHECK: vector.yield %[[dead]], %[[dst]] : +// CHECK-SAME: !xegpu.tensor_desc<24x32xf16, #xegpu.block_tdesc_attr, #xegpu.sg_map>, memref<24x32xf16> +// CHECK: %[[view:.*]] = memref.subview %[[res]]#1[%[[C0]], %[[laneid]]] [24, 2] [1, 1] : memref<24x32xf16> to memref<24x2xf16, strided<[32, 1], offset: ?>> +// CHECK: %[[desc:.*]] = xegpu.create_nd_tdesc %[[view]][0, 0] : memref<24x2xf16, strided<[32, 1], offset: ?>> +// CHECK-SAME: -> !xegpu.tensor_desc<24x2xf16, #xegpu.block_tdesc_attr, #xegpu.sg_map> +// CHECK: return %[[desc]] + +func.func @test_create_nd_desc_distribution(%dst: memref<24x32xf16>) -> (!xegpu.tensor_desc<24x2xf16, #blk_tdesc, #sg_map_16>) { + %laneid = gpu.lane_id + %r = vector.warp_execute_on_lane_0(%laneid)[16] + args(%dst: memref<24x32xf16>) -> (!xegpu.tensor_desc<24x2xf16, #blk_tdesc, #sg_map_16>) { + ^bb0(%arg0: memref<24x32xf16>): + %0 = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<24x32xf16, #blk_tdesc, #sg_map_16> + vector.yield %0 : !xegpu.tensor_desc<24x32xf16, #blk_tdesc, #sg_map_16> + } + return %r : !xegpu.tensor_desc<24x2xf16, #blk_tdesc, #sg_map_16> +}