@@ -151,10 +151,10 @@ static void buildLinalgRegion(Operation *op, bool createTemporaryOp = false) {
151
151
// Check if the linalgOp need to be legalized to f32 accumulation type
152
152
static bool needToLegalizeDtype (linalg::LinalgOp linalgOp) {
153
153
mlir::Type dataType =
154
- dyn_cast<mlir::RankedTensorType >(linalgOp.getDpsInputs ()[0 ].getType ())
154
+ dyn_cast<mlir::ShapedType >(linalgOp.getDpsInputs ()[0 ].getType ())
155
155
.getElementType ();
156
156
mlir::Type resultType =
157
- dyn_cast<mlir::RankedTensorType >(linalgOp.getDpsInits ()[0 ].getType ())
157
+ dyn_cast<mlir::ShapedType >(linalgOp.getDpsInits ()[0 ].getType ())
158
158
.getElementType ();
159
159
return (dataType.isBF16 () || dataType.isF16 ()) && dataType == resultType;
160
160
}
@@ -372,7 +372,7 @@ generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp,
372
372
linalg::LinalgOp currentOp = linalgOp;
373
373
374
374
bool hasFullResult = !option.isPartialResult ;
375
- for (auto [i, loopType] : llvm::enumerate (loopType)) {
375
+ for (auto && [i, loopType] : llvm::enumerate (loopType)) {
376
376
ArrayRef<size_t > currentDim = loopDim[i];
377
377
ArrayRef<size_t > currentTileSize = nestedTileSizes[i];
378
378
if (loopType == OuterLoopGenerationOption::LoopType::ForOp) {
@@ -420,7 +420,7 @@ generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp,
420
420
cast<TilingInterface>(currentOp.getOperation ()).getIterationDomain (b);
421
421
currentOp.getReductionDims (reductionDims);
422
422
bool tileOnReduction = false ;
423
- for (auto [d, tile] : llvm::zip (currentDim, currentTileSize)) {
423
+ for (auto && [d, tile] : llvm::zip (currentDim, currentTileSize)) {
424
424
if (llvm::find (reductionDims, d) != reductionDims.end () && tile != 0 &&
425
425
(!getConstantIntValue (loopRanges[d].size ) ||
426
426
tile != static_cast <size_t >(
@@ -438,22 +438,23 @@ generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp,
438
438
OpBuilder::InsertionGuard guard (b);
439
439
b.setInsertionPoint (currentOp);
440
440
if (tileOnReduction) {
441
- for (auto [idx, tile] : llvm::enumerate (tileSizes)) {
441
+ for (auto && [idx, tile] : llvm::enumerate (tileSizes)) {
442
442
if (isConstantIntValue (tile, 0 ) &&
443
443
llvm::find (reductionDims, idx) != reductionDims.end ()) {
444
444
tileSizes[idx] = loopRanges[idx].size ;
445
445
}
446
446
}
447
447
SmallVector<OpFoldResult> newParallelDims;
448
- for (size_t i = 0UL ; i < reductionDims.size (); i++) {
449
- newParallelDims.push_back (getAsIndexOpFoldResult (b.getContext (), i));
448
+ for (auto iter : llvm::enumerate (reductionDims)) {
449
+ newParallelDims.push_back (
450
+ getAsIndexOpFoldResult (b.getContext (), iter.index ()));
450
451
}
451
452
FailureOr<linalg::ForallReductionTilingResult> tilingResult =
452
453
linalgX::tileReductionUsingForall (
453
454
b, cast<PartialReductionOpInterface>(currentOp.getOperation ()),
454
455
{}, tileSizes, newParallelDims, std::nullopt);
455
456
if (failed (tilingResult) &&
456
- tilingResult->parallelTiledOps . size () == 1UL )
457
+ llvm::hasSingleElement ( tilingResult->parallelTiledOps ) )
457
458
return failure ();
458
459
currentOp =
459
460
dyn_cast<linalg::LinalgOp>(tilingResult->parallelTiledOps .back ());
@@ -585,7 +586,7 @@ struct DeepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
585
586
: cfg.NBlock ;
586
587
587
588
// Outer loop tile size
588
- for (auto [tile, dim] :
589
+ for (auto && [tile, dim] :
589
590
llvm::zip (SmallVector<size_t >{KParallelBlockSize, MParallelBlockSize,
590
591
NParallelBlockSize},
591
592
SmallVector<size_t >{KDimPos[0 ], MDimPos[0 ], NDimPos[0 ]})) {
@@ -596,27 +597,27 @@ struct DeepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
596
597
}
597
598
598
599
// Middle loop tile size
599
- for (auto [tile, dim] :
600
+ for (auto && [tile, dim] :
600
601
llvm::zip (SmallVector<size_t >{MOuterBlockSize, NOuterBlockSize,
601
602
KOuterBlockSize},
602
603
SmallVector<size_t >{MDimPos[0 ], NDimPos[0 ], KDimPos[0 ]})) {
603
604
option.nestedTileSizes .emplace_back (SmallVector<size_t >{tile});
604
605
option.loopType .emplace_back (OuterLoopGenerationOption::LoopType::ForOp);
605
606
option.loopDim .emplace_back (SmallVector<size_t >{dim});
606
607
}
607
- if (KDimPos. size () == 1 ) {
608
+ if (llvm::hasSingleElement (KDimPos) ) {
608
609
option.nestedTileSizes .emplace_back (SmallVector<size_t >{cfg.KBlock });
609
610
option.loopType .emplace_back (OuterLoopGenerationOption::LoopType::ForOp);
610
611
option.loopDim .emplace_back (SmallVector<size_t >{KDimPos.back ()});
611
612
}
612
613
// Inner loop tile size
613
- if (MDimPos. size () == 1 ) {
614
+ if (llvm::hasSingleElement (MDimPos) ) {
614
615
option.nestedTileSizes .emplace_back (
615
616
SmallVector<size_t >{cfg.innerMostMBlock });
616
617
option.loopType .emplace_back (OuterLoopGenerationOption::LoopType::ForOp);
617
618
option.loopDim .emplace_back (SmallVector<size_t >{MDimPos.back ()});
618
619
}
619
- if (NDimPos. size () == 1 ) {
620
+ if (llvm::hasSingleElement (NDimPos) ) {
620
621
option.nestedTileSizes .emplace_back (
621
622
SmallVector<size_t >{cfg.innerMostNBlock });
622
623
option.loopType .emplace_back (OuterLoopGenerationOption::LoopType::ForOp);
@@ -656,7 +657,7 @@ struct DeepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
656
657
const linalg::ForallReductionTilingResult &result)
657
658
-> FailureOr<linalg::LinalgOp> {
658
659
ArrayRef<Value> initValue = result.initialValues ;
659
- if (initValue. size () == 1 &&
660
+ if (llvm::hasSingleElement (initValue) &&
660
661
isa<linalg::FillOp>(initValue[0 ].getDefiningOp ())) {
661
662
rewriter.replaceOp (initValue[0 ].getDefiningOp (),
662
663
dyn_cast<DestinationStyleOpInterface>(
@@ -706,7 +707,7 @@ struct DeepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
706
707
SmallVector<int64_t > AInnermostDims, BInnermostDims, CInnermostDims;
707
708
bool firstM = true , firstK = true , firstN = true ;
708
709
if (MDimNum > 1 ) {
709
- for (auto [idx, iter] : llvm::enumerate ((*operandDimTypes)[0 ])) {
710
+ for (auto && [idx, iter] : llvm::enumerate ((*operandDimTypes)[0 ])) {
710
711
if (iter == DimType::M && firstM) {
711
712
AInnermostDims.push_back (1 );
712
713
firstM = false ;
@@ -721,7 +722,7 @@ struct DeepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
721
722
}
722
723
firstM = true ;
723
724
firstN = true ;
724
- for (auto [idx, iter] : llvm::enumerate ((*operandDimTypes)[2 ])) {
725
+ for (auto && [idx, iter] : llvm::enumerate ((*operandDimTypes)[2 ])) {
725
726
if (iter == DimType::M && firstM) {
726
727
CInnermostDims.push_back (1 );
727
728
firstM = false ;
@@ -745,7 +746,7 @@ struct DeepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
745
746
if (NDimNum > 1 ) {
746
747
firstN = true ;
747
748
firstK = true ;
748
- for (auto [idx, iter] : llvm::enumerate ((*operandDimTypes)[1 ])) {
749
+ for (auto && [idx, iter] : llvm::enumerate ((*operandDimTypes)[1 ])) {
749
750
if (iter == DimType::N && firstN) {
750
751
BInnermostDims.push_back (1 );
751
752
firstN = false ;
@@ -768,13 +769,13 @@ struct DeepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
768
769
OpBuilder::InsertionGuard guard (rewriter);
769
770
rewriter.setInsertionPoint (currentOp);
770
771
mlir::Type dataType =
771
- dyn_cast<mlir::RankedTensorType >(currentOp.getDpsInputs ()[0 ].getType ())
772
+ dyn_cast<mlir::ShapedType >(currentOp.getDpsInputs ()[0 ].getType ())
772
773
.getElementType ();
773
774
mlir::Type weightType =
774
- dyn_cast<mlir::RankedTensorType >(currentOp.getDpsInputs ()[1 ].getType ())
775
+ dyn_cast<mlir::ShapedType >(currentOp.getDpsInputs ()[1 ].getType ())
775
776
.getElementType ();
776
777
mlir::Type resultType =
777
- dyn_cast<mlir::RankedTensorType >(currentOp.getDpsInits ()[0 ].getType ())
778
+ dyn_cast<mlir::ShapedType >(currentOp.getDpsInits ()[0 ].getType ())
778
779
.getElementType ();
779
780
780
781
// update the extractSlice to static size, replace it with
@@ -821,9 +822,8 @@ struct DeepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
821
822
currentOp.getDpsInits ()[0 ]);
822
823
// Create the brgemm op and replace the origin linalg op
823
824
linalg::LinalgOp matmul;
824
- if (dyn_cast<mlir::RankedTensorType>(weightOprand.getType ())
825
- .getShape ()
826
- .size () == 3 ) {
825
+ if (dyn_cast<mlir::ShapedType>(weightOprand.getType ()).getShape ().size () ==
826
+ 3 ) {
827
827
matmul = rewriter.create <linalg::BatchReduceMatmulOp>(
828
828
loc, resultOprand.getType (), ValueRange{dataOprand, weightOprand},
829
829
resultOprand);
@@ -843,7 +843,7 @@ struct DeepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
843
843
// fuse the low precision cast to the innermost body
844
844
rewriter.setInsertionPointAfter (currentOp);
845
845
Value cond;
846
- for (LoopLikeOpInterface loop : option.KLoopHandles ) {
846
+ for (LoopLikeOpInterface & loop : option.KLoopHandles ) {
847
847
Value induceVar = turnOpFoldResultIntoValue (
848
848
rewriter, loc, *loop.getSingleInductionVar ());
849
849
Value upBound = turnOpFoldResultIntoValue (rewriter, loc,
@@ -903,7 +903,7 @@ struct DeepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
903
903
Value cond;
904
904
arith::ConstantIndexOp zeroConst =
905
905
rewriter.create <arith::ConstantIndexOp>(loc, 0 );
906
- for (LoopLikeOpInterface loop : option.KLoopHandles ) {
906
+ for (LoopLikeOpInterface & loop : option.KLoopHandles ) {
907
907
Value induceVar = loop.getLoopRegions ().front ()->front ().getArgument (0 );
908
908
Value currentCond = rewriter.create <arith::CmpIOp>(
909
909
loc, arith::CmpIPredicate::eq, induceVar, zeroConst);
0 commit comments