@@ -136,13 +136,14 @@ MatmulConfig getDefaultMatmulConfig(linalg::LinalgOp &linalgOp) {
136136 cfg.KBlock = 64 ;
137137 cfg.MThreads = 2 ;
138138 cfg.NThreads = 2 ;
139- cfg.KThreads = 2 ;
139+ cfg.KThreads = 1 ;
140140 return cfg;
141141}
142142
143- static Value tensorViewRankedTensor (RewriterBase &rewriter,
144- RankedTensorType outTensorType,
145- Value value) {
143+ static Value
144+ tensorViewRankedTensor (RewriterBase &rewriter, RankedTensorType outTensorType,
145+ Value value,
146+ ArrayRef<int64_t > permutation = SmallVector<int64_t >{}) {
146147 // TODO: add support for plain layout transpose
147148 Value result, currentValue = value;
148149 auto loc = currentValue.getLoc ();
@@ -175,33 +176,57 @@ static Value tensorViewRankedTensor(RewriterBase &rewriter,
175176
176177 if (outShape.size () < inShape.size ()) {
177178 SmallVector<ReassociationIndices> reassocIndices;
178- ReassociationIndices firstEntry;
179- for (auto i = 0UL ; i < inShape.size () - outShape.size () + 1 ; i++) {
180- firstEntry.push_back (i);
181- }
182- reassocIndices.push_back (firstEntry);
183- for (auto i = inShape.size () - outShape.size () + 1UL ; i < inShape.size ();
184- i++) {
185- reassocIndices.push_back ({(int )i});
179+ uint64_t outIdx = 0UL , inIdx = 0UL ;
180+ while (inIdx < inShape.size () && outIdx < outShape.size ()) {
181+ ReassociationIndices firstEntry;
182+ auto remaining = outShape[outIdx++];
183+ if (remaining == 1 ) {
184+ firstEntry.push_back (inIdx++);
185+ reassocIndices.push_back (firstEntry);
186+ continue ;
187+ }
188+ while (remaining > 1 ) {
189+ remaining /= inShape[inIdx];
190+ firstEntry.push_back (inIdx++);
191+ }
192+ reassocIndices.push_back (firstEntry);
186193 }
187194 result = rewriter.create <tensor::CollapseShapeOp>(
188195 loc, outTensorType, currentValue, reassocIndices);
189196 } else if (outShape.size () > inShape.size ()) {
190197 SmallVector<ReassociationIndices> reassocIndices;
191- ReassociationIndices firstEntry;
192- for (auto i = 0UL ; i < outShape.size () - inShape.size () + 1 ; i++) {
193- firstEntry.push_back ((int )i);
194- }
195- reassocIndices.push_back (firstEntry);
196- for (auto i = outShape.size () - inShape.size () + 1UL ; i < outShape.size ();
197- i++) {
198- reassocIndices.push_back ({(int )i});
198+ uint64_t outIdx = 0UL , inIdx = 0UL ;
199+ while (outIdx < outShape.size () && inIdx < inShape.size ()) {
200+ ReassociationIndices firstEntry;
201+ auto remaining = inShape[inIdx++];
202+ if (remaining == 1 ) {
203+ firstEntry.push_back (outIdx++);
204+ reassocIndices.push_back (firstEntry);
205+ continue ;
206+ }
207+ while (remaining > 1 ) {
208+ remaining /= outShape[outIdx];
209+ firstEntry.push_back (outIdx++);
210+ }
211+ reassocIndices.push_back (firstEntry);
199212 }
200213 result = rewriter.create <tensor::ExpandShapeOp>(
201214 loc, outTensorType, currentValue, reassocIndices);
202215 } else {
203216 result = rewriter.create <tensor::CastOp>(loc, outTensorType, currentValue);
204217 }
218+
219+ if (!permutation.empty ()) {
220+ SmallVector<int64_t > transposeShape;
221+ for (auto idx : permutation) {
222+ transposeShape.push_back (outShape[idx]);
223+ }
224+ auto initOp = rewriter.create <tensor::EmptyOp>(loc, transposeShape,
225+ tensorElementType);
226+ auto transposeOp = rewriter.create <linalg::TransposeOp>(
227+ loc, result, initOp->getResult (0 ), permutation);
228+ result = transposeOp->getResult (0 );
229+ }
205230 return result;
206231}
207232
@@ -345,6 +370,7 @@ generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp,
345370 return b.notifyMatchFailure (
346371 linalgOp, " currentOp should not has pure buffer semantics" );
347372 linalg::LinalgOp currentOp = linalgOp;
373+
348374 for (auto loopTypeIter : llvm::enumerate (loopType)) {
349375 auto [i, loopType] = loopTypeIter;
350376 auto currentDim = loopDim[i];
@@ -486,6 +512,8 @@ static LogicalResult setStaticSizeForExtractSliceOp(RewriterBase &rewriter,
486512 bool isExtract,
487513 SmallVector<int64_t > size,
488514 int shrinDimNum = 0 ) {
515+ OpBuilder::InsertionGuard guard (rewriter);
516+ rewriter.setInsertionPoint (op);
489517 if (auto extractSlice = dyn_cast<tensor::ExtractSliceOp>(op)) {
490518 SmallVector<OpFoldResult> mixedOffsets = extractSlice.getMixedOffsets ();
491519 SmallVector<OpFoldResult> mixedSizes = extractSlice.getMixedSizes ();
@@ -514,6 +542,8 @@ static LogicalResult setStaticSizeForExtractSliceOp(RewriterBase &rewriter,
514542static LogicalResult setStaticSizeForInsertSliceOp (RewriterBase &rewriter,
515543 Operation *op, Value source,
516544 SmallVector<int64_t > size) {
545+ OpBuilder::InsertionGuard guard (rewriter);
546+ rewriter.setInsertionPoint (op);
517547 if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(op)) {
518548 SmallVector<OpFoldResult> mixedOffsets = insertSlice.getMixedOffsets ();
519549 SmallVector<OpFoldResult> mixedSizes = insertSlice.getMixedSizes ();
@@ -575,35 +605,34 @@ struct deepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
575605 linalgOp.getReductionDims (KDimPos);
576606 getMatmulParallelDims (linalgOp, 0 , MDimPos);
577607 getMatmulParallelDims (linalgOp, 1 , NDimPos);
578- bool useBlockedLayout = KDimPos.size () > 1 ;
579608
580609 OuterLoopGenerationOption option;
581610 auto iteratorTypes = linalgOp.getIteratorTypesArray ();
582611 auto KFirstDim = (int )getOprandDim (linalgOp, KDimPos[0 ], 1 );
583612 auto MFirstDim = (int )getOprandDim (linalgOp, MDimPos[0 ], 0 );
584613 auto NFirstDim = (int )getOprandDim (linalgOp, NDimPos[0 ], 1 );
585614 auto KParallelBlockSize =
586- useBlockedLayout
615+ KDimPos. size () > 1
587616 ? divAndCeil (KFirstDim, cfg.KThreads )
588617 : divAndCeil (divAndCeil (KFirstDim, cfg.KBlock ), cfg.KThreads ) *
589618 cfg.KBlock ;
590619 auto MParallelBlockSize =
591- useBlockedLayout
620+ MDimPos. size () > 1
592621 ? divAndCeil (MFirstDim, cfg.MThreads )
593622 : divAndCeil (divAndCeil (MFirstDim, cfg.MBlock ), cfg.MThreads ) *
594623 cfg.MBlock ;
595624 auto NParallelBlockSize =
596- useBlockedLayout
625+ NDimPos. size () > 1
597626 ? divAndCeil (NFirstDim, cfg.NThreads )
598627 : divAndCeil (divAndCeil (NFirstDim, cfg.NBlock ), cfg.NThreads ) *
599628 cfg.NBlock ;
600- auto KOuterBlockSize = useBlockedLayout
629+ auto KOuterBlockSize = KDimPos. size () > 1
601630 ? (cfg.KBlock - 1 ) / cfg.innerMostKBlock + 1
602631 : cfg.KBlock ;
603- auto MOuterBlockSize = useBlockedLayout
632+ auto MOuterBlockSize = MDimPos. size () > 1
604633 ? (cfg.MBlock - 1 ) / cfg.innerMostMBlock + 1
605634 : cfg.MBlock ;
606- auto NOuterBlockSize = useBlockedLayout
635+ auto NOuterBlockSize = NDimPos. size () > 1
607636 ? (cfg.NBlock - 1 ) / cfg.innerMostNBlock + 1
608637 : cfg.NBlock ;
609638 // Outer
@@ -631,11 +660,23 @@ struct deepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
631660 option.loopDim .emplace_back (SmallVector<int >{dim});
632661 }
633662 // Inner
634- if (!useBlockedLayout ) {
663+ if (KDimPos. size () == 1 ) {
635664 option.nestedTileSizes .emplace_back (SmallVector<int >{cfg.KBlock });
636665 option.loopType .emplace_back (OuterLoopGenerationOption::LoopType::ForOp);
637666 option.loopDim .emplace_back (SmallVector<int >{(int )KDimPos.back ()});
638667 }
668+ if (MDimPos.size () == 1 ) {
669+ option.nestedTileSizes .emplace_back (
670+ SmallVector<int >{cfg.innerMostMBlock });
671+ option.loopType .emplace_back (OuterLoopGenerationOption::LoopType::ForOp);
672+ option.loopDim .emplace_back (SmallVector<int >{(int )MDimPos.back ()});
673+ }
674+ if (NDimPos.size () == 1 ) {
675+ option.nestedTileSizes .emplace_back (
676+ SmallVector<int >{cfg.innerMostNBlock });
677+ option.loopType .emplace_back (OuterLoopGenerationOption::LoopType::ForOp);
678+ option.loopDim .emplace_back (SmallVector<int >{(int )NDimPos.back ()});
679+ }
639680 for (auto dim = 0UL ; dim < linalgOp.getNumLoops (); dim++) {
640681 if (dim != MDimPos.back () && dim != NDimPos.back () &&
641682 iteratorTypes[dim] != mlir::utils::IteratorType::reduction) {
@@ -658,17 +699,24 @@ struct deepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
658699 linalg::LinalgOp originOp,
659700 linalg::LinalgOp currentOp,
660701 innerBodyGenerationOption &option) const {
702+
661703 mlir::easybuild::EasyBuilder eb{rewriter, originOp.getLoc ()};
662704 auto operandDimTypes = getOprandDimType (originOp);
663705 MatmulConfig cfg = getDefaultMatmulConfig (originOp);
664706 auto AShape = originOp.getShape (originOp.getDpsInputOperand (0 ));
665707 auto BShape = originOp.getShape (originOp.getDpsInputOperand (1 ));
666708 auto CShape = originOp.getShape (originOp.getDpsInitOperand (0 ));
667- bool useBlockedLayout = BShape.size () > 2 ;
709+
710+ auto MDimNum = std::count_if ((*operandDimTypes)[0 ].begin (),
711+ (*operandDimTypes)[0 ].end (),
712+ [](DimType d) { return d == DimType::M; });
713+ auto NDimNum = std::count_if ((*operandDimTypes)[1 ].begin (),
714+ (*operandDimTypes)[1 ].end (),
715+ [](DimType d) { return d == DimType::N; });
668716 // TODO: support plain in/block out format
669717 SmallVector<int64_t > AInnermostDims, BInnermostDims, CInnermostDims;
670- if (useBlockedLayout) {
671- bool firstM = true , firstK = true , firstN = true ;
718+ bool firstM = true , firstK = true , firstN = true ;
719+ if (MDimNum > 1 ) {
672720 for (auto [idx, iter] : llvm::enumerate ((*operandDimTypes)[0 ])) {
673721 if (iter == DimType::M && firstM) {
674722 AInnermostDims.push_back (1 );
@@ -682,21 +730,6 @@ struct deepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
682730 AInnermostDims.push_back (AShape[idx]);
683731 }
684732 }
685- firstN = true ;
686- firstK = true ;
687- for (auto [idx, iter] : llvm::enumerate ((*operandDimTypes)[1 ])) {
688- if (iter == DimType::N && firstN) {
689- BInnermostDims.push_back (1 );
690- firstN = false ;
691- } else if (iter == DimType::Batch) {
692- BInnermostDims.push_back (1 );
693- } else if (iter == DimType::K && firstK) {
694- BInnermostDims.push_back (cfg.KBlock / cfg.innerMostKBlock );
695- firstK = false ;
696- } else {
697- BInnermostDims.push_back (BShape[idx]);
698- }
699- }
700733 firstM = true ;
701734 firstN = true ;
702735 for (auto [idx, iter] : llvm::enumerate ((*operandDimTypes)[2 ])) {
@@ -716,74 +749,94 @@ struct deepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
716749 AInnermostDims = SmallVector<int64_t >{cfg.innerMostMBlock ,
717750 cfg.KBlock / cfg.innerMostKBlock *
718751 cfg.innerMostKBlock };
752+ CInnermostDims =
753+ SmallVector<int64_t >{cfg.innerMostMBlock , cfg.innerMostNBlock };
754+ }
755+ if (NDimNum > 1 ) {
756+ firstN = true ;
757+ firstK = true ;
758+ for (auto [idx, iter] : llvm::enumerate ((*operandDimTypes)[1 ])) {
759+ if (iter == DimType::N && firstN) {
760+ BInnermostDims.push_back (1 );
761+ firstN = false ;
762+ } else if (iter == DimType::Batch) {
763+ BInnermostDims.push_back (1 );
764+ } else if (iter == DimType::K && firstK) {
765+ BInnermostDims.push_back (cfg.KBlock / cfg.innerMostKBlock );
766+ firstK = false ;
767+ } else {
768+ BInnermostDims.push_back (BShape[idx]);
769+ }
770+ }
771+ } else {
719772 BInnermostDims = SmallVector<int64_t >{cfg.KBlock / cfg.innerMostKBlock *
720773 cfg.innerMostKBlock ,
721774 cfg.innerMostNBlock };
722- CInnermostDims =
723- SmallVector<int64_t >{cfg.innerMostMBlock , cfg.innerMostNBlock };
724775 }
725776
726777 OpBuilder::InsertionGuard guard (rewriter);
727778 rewriter.setInsertionPoint (currentOp);
728779 auto dataType =
729- dyn_cast<mlir::RankedTensorType>(currentOp.getDpsInputs ()[0 ].getType ());
780+ dyn_cast<mlir::RankedTensorType>(currentOp.getDpsInputs ()[0 ].getType ())
781+ .getElementType ();
730782 auto weightType =
731- dyn_cast<mlir::RankedTensorType>(currentOp.getDpsInputs ()[1 ].getType ());
783+ dyn_cast<mlir::RankedTensorType>(currentOp.getDpsInputs ()[1 ].getType ())
784+ .getElementType ();
732785 auto resultType =
733- dyn_cast<mlir::RankedTensorType>(currentOp.getDpsInits ()[0 ].getType ());
734- // use shrink layout when it is able to be converted to brgemm
735- bool useShrinkedLayout = (BInnermostDims.size () == 4 );
786+ dyn_cast<mlir::RankedTensorType>(currentOp.getDpsInits ()[0 ].getType ())
787+ .getElementType ();
736788
737789 // update the extractSlice to static size, replace it with
738790 // useBlockedLayout when
739791 if (failed (setStaticSizeForExtractSliceOp (
740792 rewriter, currentOp.getDpsInits ()[0 ].getDefiningOp (), true ,
741- CInnermostDims, useShrinkedLayout ? 2 : 0 )) ||
793+ CInnermostDims, MDimNum > 1 ? 2 : 0 )) ||
742794 failed (setStaticSizeForExtractSliceOp (
743795 rewriter, currentOp.getDpsInputs ()[1 ].getDefiningOp (), true ,
744- BInnermostDims, useShrinkedLayout )) ||
796+ BInnermostDims, NDimNum > 1 )) ||
745797 failed (setStaticSizeForExtractSliceOp (
746798 rewriter, currentOp.getDpsInputs ()[0 ].getDefiningOp (), true ,
747- AInnermostDims, useShrinkedLayout ))) {
799+ AInnermostDims, MDimNum > 1 ))) {
748800 return failure ();
749801 }
750-
751802 // View the tensor to brgemm required format
752803 Value dataOprand = tensorViewRankedTensor (
753804 rewriter,
754805 mlir::RankedTensorType::get (
755- useBlockedLayout
756- ? SmallVector<int64_t >(AInnermostDims.begin () + 1 ,
757- AInnermostDims.end ())
758- : SmallVector<int64_t >{1 , AInnermostDims[0 ], AInnermostDims[1 ]},
759- dataType.getElementType ()),
760- currentOp.getDpsInputs ()[0 ]);
806+ MDimNum > 1 ? SmallVector<int64_t >(AInnermostDims.begin () + 1 ,
807+ AInnermostDims.end ())
808+ : SmallVector<int64_t >{cfg.innerMostMBlock ,
809+ cfg.KBlock / cfg.innerMostKBlock ,
810+ cfg.innerMostKBlock },
811+ dataType),
812+ currentOp.getDpsInputs ()[0 ],
813+ MDimNum == 1 ? SmallVector<int64_t >{1 , 0 , 2 } : SmallVector<int64_t >{});
761814 Value weightOprand = tensorViewRankedTensor (
762815 rewriter,
763816 mlir::RankedTensorType::get (
764- useBlockedLayout
765- ? SmallVector<int64_t >(BInnermostDims.begin () + 1 ,
766- BInnermostDims.end ())
767- : SmallVector<int64_t >{1 , BInnermostDims[0 ], BInnermostDims[1 ]},
768- weightType.getElementType ()),
817+ NDimNum > 1 ? SmallVector<int64_t >(BInnermostDims.begin () + 1 ,
818+ BInnermostDims.end ())
819+ : SmallVector<int64_t >{cfg.KBlock / cfg.innerMostKBlock ,
820+ cfg.innerMostKBlock ,
821+ cfg.innerMostNBlock },
822+ weightType),
769823 currentOp.getDpsInputs ()[1 ]);
770824 Value resultOprand = tensorViewRankedTensor (
771825 rewriter,
772826 mlir::RankedTensorType::get (
773- SmallVector<int64_t >(CInnermostDims.begin () +
774- (useBlockedLayout ? 2 : 0 ),
827+ SmallVector<int64_t >(CInnermostDims.begin () + (MDimNum > 1 ? 2 : 0 ),
775828 CInnermostDims.end ()),
776- resultType. getElementType () ),
829+ resultType),
777830 currentOp.getDpsInits ()[0 ]);
778-
779831 // Create the brgemm op and replace the origin linalg op
780832 linalg::LinalgOp matmul;
781- if (BInnermostDims.size () == 4 || BInnermostDims.size () == 2 ) {
833+ if (dyn_cast<mlir::RankedTensorType>(weightOprand.getType ())
834+ .getShape ()
835+ .size () == 3 ) {
782836 matmul = rewriter.create <linalg::BatchReduceMatmulOp>(
783837 resultOprand.getLoc (), resultOprand.getType (),
784838 ValueRange{dataOprand, weightOprand}, resultOprand);
785839 } else {
786- IRMapping mapping;
787840 matmul = rewriter.create <linalgx::BatchReduceMatmulVnniOp>(
788841 resultOprand.getLoc (), resultOprand.getType (),
789842 ValueRange{dataOprand, weightOprand}, resultOprand);
0 commit comments