@@ -151,10 +151,10 @@ static void buildLinalgRegion(Operation *op, bool createTemporaryOp = false) {
151151//  Check if the linalgOp need to be legalized to f32 accumulation type
152152static  bool  needToLegalizeDtype (linalg::LinalgOp linalgOp) {
153153  mlir::Type dataType =
154-       dyn_cast<mlir::RankedTensorType >(linalgOp.getDpsInputs ()[0 ].getType ())
154+       dyn_cast<mlir::ShapedType >(linalgOp.getDpsInputs ()[0 ].getType ())
155155          .getElementType ();
156156  mlir::Type resultType =
157-       dyn_cast<mlir::RankedTensorType >(linalgOp.getDpsInits ()[0 ].getType ())
157+       dyn_cast<mlir::ShapedType >(linalgOp.getDpsInits ()[0 ].getType ())
158158          .getElementType ();
159159  return  (dataType.isBF16 () || dataType.isF16 ()) && dataType == resultType;
160160}
@@ -372,7 +372,7 @@ generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp,
372372  linalg::LinalgOp currentOp = linalgOp;
373373
374374  bool  hasFullResult = !option.isPartialResult ;
375-   for  (auto  [i, loopType] : llvm::enumerate (loopType)) {
375+   for  (auto  && [i, loopType] : llvm::enumerate (loopType)) {
376376    ArrayRef<size_t > currentDim = loopDim[i];
377377    ArrayRef<size_t > currentTileSize = nestedTileSizes[i];
378378    if  (loopType == OuterLoopGenerationOption::LoopType::ForOp) {
@@ -420,7 +420,7 @@ generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp,
420420          cast<TilingInterface>(currentOp.getOperation ()).getIterationDomain (b);
421421      currentOp.getReductionDims (reductionDims);
422422      bool  tileOnReduction = false ;
423-       for  (auto  [d, tile] : llvm::zip (currentDim, currentTileSize)) {
423+       for  (auto  && [d, tile] : llvm::zip (currentDim, currentTileSize)) {
424424        if  (llvm::find (reductionDims, d) != reductionDims.end () && tile != 0  &&
425425            (!getConstantIntValue (loopRanges[d].size ) ||
426426             tile != static_cast <size_t >(
@@ -438,22 +438,23 @@ generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp,
438438      OpBuilder::InsertionGuard guard (b);
439439      b.setInsertionPoint (currentOp);
440440      if  (tileOnReduction) {
441-         for  (auto  [idx, tile] : llvm::enumerate (tileSizes)) {
441+         for  (auto  && [idx, tile] : llvm::enumerate (tileSizes)) {
442442          if  (isConstantIntValue (tile, 0 ) &&
443443              llvm::find (reductionDims, idx) != reductionDims.end ()) {
444444            tileSizes[idx] = loopRanges[idx].size ;
445445          }
446446        }
447447        SmallVector<OpFoldResult> newParallelDims;
448-         for  (size_t  i = 0UL ; i < reductionDims.size (); i++) {
449-           newParallelDims.push_back (getAsIndexOpFoldResult (b.getContext (), i));
448+         for  (auto  iter : llvm::enumerate (reductionDims)) {
449+           newParallelDims.push_back (
450+               getAsIndexOpFoldResult (b.getContext (), iter.index ()));
450451        }
451452        FailureOr<linalg::ForallReductionTilingResult> tilingResult =
452453            linalgX::tileReductionUsingForall (
453454                b, cast<PartialReductionOpInterface>(currentOp.getOperation ()),
454455                {}, tileSizes, newParallelDims, std::nullopt );
455456        if  (failed (tilingResult) &&
456-             tilingResult->parallelTiledOps . size () ==  1UL )
457+             llvm::hasSingleElement ( tilingResult->parallelTiledOps ) )
457458          return  failure ();
458459        currentOp =
459460            dyn_cast<linalg::LinalgOp>(tilingResult->parallelTiledOps .back ());
@@ -585,7 +586,7 @@ struct DeepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
585586                                 : cfg.NBlock ;
586587
587588    //  Outer loop tile size
588-     for  (auto  [tile, dim] :
589+     for  (auto  && [tile, dim] :
589590         llvm::zip (SmallVector<size_t >{KParallelBlockSize, MParallelBlockSize,
590591                                       NParallelBlockSize},
591592                   SmallVector<size_t >{KDimPos[0 ], MDimPos[0 ], NDimPos[0 ]})) {
@@ -596,27 +597,27 @@ struct DeepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
596597    }
597598
598599    //  Middle loop tile size
599-     for  (auto  [tile, dim] :
600+     for  (auto  && [tile, dim] :
600601         llvm::zip (SmallVector<size_t >{MOuterBlockSize, NOuterBlockSize,
601602                                       KOuterBlockSize},
602603                   SmallVector<size_t >{MDimPos[0 ], NDimPos[0 ], KDimPos[0 ]})) {
603604      option.nestedTileSizes .emplace_back (SmallVector<size_t >{tile});
604605      option.loopType .emplace_back (OuterLoopGenerationOption::LoopType::ForOp);
605606      option.loopDim .emplace_back (SmallVector<size_t >{dim});
606607    }
607-     if  (KDimPos. size () ==  1 ) {
608+     if  (llvm::hasSingleElement (KDimPos) ) {
608609      option.nestedTileSizes .emplace_back (SmallVector<size_t >{cfg.KBlock });
609610      option.loopType .emplace_back (OuterLoopGenerationOption::LoopType::ForOp);
610611      option.loopDim .emplace_back (SmallVector<size_t >{KDimPos.back ()});
611612    }
612613    //  Inner loop tile size
613-     if  (MDimPos. size () ==  1 ) {
614+     if  (llvm::hasSingleElement (MDimPos) ) {
614615      option.nestedTileSizes .emplace_back (
615616          SmallVector<size_t >{cfg.innerMostMBlock });
616617      option.loopType .emplace_back (OuterLoopGenerationOption::LoopType::ForOp);
617618      option.loopDim .emplace_back (SmallVector<size_t >{MDimPos.back ()});
618619    }
619-     if  (NDimPos. size () ==  1 ) {
620+     if  (llvm::hasSingleElement (NDimPos) ) {
620621      option.nestedTileSizes .emplace_back (
621622          SmallVector<size_t >{cfg.innerMostNBlock });
622623      option.loopType .emplace_back (OuterLoopGenerationOption::LoopType::ForOp);
@@ -656,7 +657,7 @@ struct DeepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
656657              const  linalg::ForallReductionTilingResult &result)
657658          -> FailureOr<linalg::LinalgOp> {
658659        ArrayRef<Value> initValue = result.initialValues ;
659-         if  (initValue. size () ==  1  &&
660+         if  (llvm::hasSingleElement (initValue)  &&
660661            isa<linalg::FillOp>(initValue[0 ].getDefiningOp ())) {
661662          rewriter.replaceOp (initValue[0 ].getDefiningOp (),
662663                             dyn_cast<DestinationStyleOpInterface>(
@@ -706,7 +707,7 @@ struct DeepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
706707    SmallVector<int64_t > AInnermostDims, BInnermostDims, CInnermostDims;
707708    bool  firstM = true , firstK = true , firstN = true ;
708709    if  (MDimNum > 1 ) {
709-       for  (auto  [idx, iter] : llvm::enumerate ((*operandDimTypes)[0 ])) {
710+       for  (auto  && [idx, iter] : llvm::enumerate ((*operandDimTypes)[0 ])) {
710711        if  (iter == DimType::M && firstM) {
711712          AInnermostDims.push_back (1 );
712713          firstM = false ;
@@ -721,7 +722,7 @@ struct DeepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
721722      }
722723      firstM = true ;
723724      firstN = true ;
724-       for  (auto  [idx, iter] : llvm::enumerate ((*operandDimTypes)[2 ])) {
725+       for  (auto  && [idx, iter] : llvm::enumerate ((*operandDimTypes)[2 ])) {
725726        if  (iter == DimType::M && firstM) {
726727          CInnermostDims.push_back (1 );
727728          firstM = false ;
@@ -745,7 +746,7 @@ struct DeepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
745746    if  (NDimNum > 1 ) {
746747      firstN = true ;
747748      firstK = true ;
748-       for  (auto  [idx, iter] : llvm::enumerate ((*operandDimTypes)[1 ])) {
749+       for  (auto  && [idx, iter] : llvm::enumerate ((*operandDimTypes)[1 ])) {
749750        if  (iter == DimType::N && firstN) {
750751          BInnermostDims.push_back (1 );
751752          firstN = false ;
@@ -768,13 +769,13 @@ struct DeepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
768769    OpBuilder::InsertionGuard guard (rewriter);
769770    rewriter.setInsertionPoint (currentOp);
770771    mlir::Type dataType =
771-         dyn_cast<mlir::RankedTensorType >(currentOp.getDpsInputs ()[0 ].getType ())
772+         dyn_cast<mlir::ShapedType >(currentOp.getDpsInputs ()[0 ].getType ())
772773            .getElementType ();
773774    mlir::Type weightType =
774-         dyn_cast<mlir::RankedTensorType >(currentOp.getDpsInputs ()[1 ].getType ())
775+         dyn_cast<mlir::ShapedType >(currentOp.getDpsInputs ()[1 ].getType ())
775776            .getElementType ();
776777    mlir::Type resultType =
777-         dyn_cast<mlir::RankedTensorType >(currentOp.getDpsInits ()[0 ].getType ())
778+         dyn_cast<mlir::ShapedType >(currentOp.getDpsInits ()[0 ].getType ())
778779            .getElementType ();
779780
780781    //  update the extractSlice to static size, replace it with
@@ -821,9 +822,8 @@ struct DeepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
821822        currentOp.getDpsInits ()[0 ]);
822823    //  Create the brgemm op and replace the origin linalg op
823824    linalg::LinalgOp matmul;
824-     if  (dyn_cast<mlir::RankedTensorType>(weightOprand.getType ())
825-             .getShape ()
826-             .size () == 3 ) {
825+     if  (dyn_cast<mlir::ShapedType>(weightOprand.getType ()).getShape ().size () ==
826+         3 ) {
827827      matmul = rewriter.create <linalg::BatchReduceMatmulOp>(
828828          loc, resultOprand.getType (), ValueRange{dataOprand, weightOprand},
829829          resultOprand);
@@ -843,7 +843,7 @@ struct DeepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
843843      //  fuse the low precision cast to the innermost body
844844      rewriter.setInsertionPointAfter (currentOp);
845845      Value cond;
846-       for  (LoopLikeOpInterface loop : option.KLoopHandles ) {
846+       for  (LoopLikeOpInterface & loop : option.KLoopHandles ) {
847847        Value induceVar = turnOpFoldResultIntoValue (
848848            rewriter, loc, *loop.getSingleInductionVar ());
849849        Value upBound = turnOpFoldResultIntoValue (rewriter, loc,
@@ -903,7 +903,7 @@ struct DeepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
903903      Value cond;
904904      arith::ConstantIndexOp zeroConst =
905905          rewriter.create <arith::ConstantIndexOp>(loc, 0 );
906-       for  (LoopLikeOpInterface loop : option.KLoopHandles ) {
906+       for  (LoopLikeOpInterface & loop : option.KLoopHandles ) {
907907        Value induceVar = loop.getLoopRegions ().front ()->front ().getArgument (0 );
908908        Value currentCond = rewriter.create <arith::CmpIOp>(
909909            loc, arith::CmpIPredicate::eq, induceVar, zeroConst);
0 commit comments