1111// ===----------------------------------------------------------------------===//
1212
1313#include " ./Tiling.hpp"
14+ #include " gc/Dialect/Arith/Utils/EasyBuild.h"
15+ #include " gc/IR/EasyBuild.h"
16+ #include " gc/IR/EasyBuildSCF.h"
1417#include " mlir/AsmParser/AsmParser.h"
1518#include " mlir/Dialect/Affine/IR/AffineOps.h"
1619#include " mlir/Dialect/Func/IR/FuncOps.h"
@@ -179,6 +182,7 @@ struct OuterLoopGenerationResult {
179182 SmallVector<Operation *> tiledOps;
180183 // / The `scf.for` operations that iterate over the tiles.
181184 SmallVector<LoopLikeOpInterface> loops;
185+ SmallVector<LoopLikeOpInterface> reductionLoops;
182186 // / Values to use as replacements for the untiled op. Is the same size as the
183187 // / number of results of the untiled op.
184188 SmallVector<Value> replacements;
@@ -192,6 +196,8 @@ generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp,
192196 auto nestedTileSizes = option.nestedTileSizes ;
193197 auto loopType = option.loopType ;
194198 auto loopDim = option.loopDim ;
199+ SmallVector<mlir::utils::IteratorType> iteratorTypes =
200+ linalgOp.getIteratorTypesArray ();
195201
196202 if (loopType.size () != loopDim.size () ||
197203 loopDim.size () != nestedTileSizes.size ()) {
@@ -228,6 +234,13 @@ generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp,
228234 return failure ();
229235 b.replaceOp (currentOp, tilingResult->replacements );
230236 currentOp = dyn_cast<linalg::LinalgOp>(tilingResult->tiledOps .back ());
237+
238+ for (auto [dim, loop] : llvm::zip (currentDim, tilingResult->loops )) {
239+ if (iteratorTypes[dim] == mlir::utils::IteratorType::reduction) {
240+ result.reductionLoops .push_back (loop);
241+ }
242+ result.loops .push_back (loop);
243+ }
231244 } else if (type == OuterLoopGenerationOption::LoopType::ForallOp) {
232245 SmallVector<OpFoldResult> tileSizes (
233246 currentOp.getNumLoops (), getAsIndexOpFoldResult (b.getContext (), 0 ));
@@ -395,16 +408,16 @@ getOprandDimType(linalg::LinalgOp &linalgOp) {
395408}
396409
397410/*
398- forall([PM, PN]: [MThreads, NThreads) {
399- for(PK : KThreads) {
411+ matmul(A, B) -> C
412+ ---------------->
413+ forall([PM, PN, PK]: [MThreads, NThreads, KThreads]) {
400414 CSlice = [KThreads, PM * MOuterBlock: (PM + 1) * MOuterBlock,
401415 PN * NOuterBlock: (PN + 1) * NOuterBlock]
402416 ASlice = A[PM * MOuterBlock: (PM + 1) * MOuterBlock, PK * KOuterBlock * (PK
403417+ 1) * KOuterBlock]
404418 BSlice = B[PK * KOuterBlock * (PK + 1) * KOuterBlock, PN *
405419NOuterBlock: (PN + 1) * NOuterBlock] CSlice2 = CSlice[PK, PM * MOuterBlock: (PM
406420+ 1) * MOuterBlock, PN * NOuterBlock: (PN + 1) * NOuterBlock]
407-
408421 MNumBlock = MOuterBlock / MBlock
409422 NNumBlock = NOuterBlock / NBlock
410423 KNumBlock = KOuterBlock / KBlovk
@@ -426,9 +439,8 @@ iin_block_: (in + 1) * iin_block_] (init with 0 when ok == 0)
426439A=ASlice3, B=BSlice3, C=CSlice4, onlyUpdate=(ok!=0));
427440 }
428441 }
429- }
430- C = final_reduce(CSlice)
431442}
443+ C = final_reduce(CSlice)
432444*/
433445struct deepTileMatmul : public OpInterfaceRewritePattern <linalg::LinalgOp> {
434446 using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern;
@@ -508,12 +520,14 @@ struct deepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
508520 struct innerBodyGenerationOption {
509521 bool hasFillOp = false ;
510522 Value fillValue;
523+ SmallVector<LoopLikeOpInterface> KLoopHandles;
511524 };
512525
513526 LogicalResult
514527 innerBodyGeneration (RewriterBase &rewriter, linalg::LinalgOp originOp,
515528 linalg::LinalgOp currentOp,
516529 const innerBodyGenerationOption &option) const {
530+ mlir::easybuild::EasyBuilder eb{rewriter, originOp.getLoc ()};
517531 auto operandDimTypes = getOprandDimType (originOp);
518532 MatmulConfig cfg = getDefaultMatmulConfig (originOp);
519533 auto AShape = originOp.getShape (originOp.getDpsInputOperand (0 ));
@@ -656,18 +670,31 @@ struct deepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
656670 currentOp = matmul;
657671
658672 if (option.hasFillOp ) {
659- // TODO: support partial K in sinsngle threads, control flow may need
660- // easy builder support
661673 rewriter.setInsertionPointAfter (currentOp);
662- auto fillOp = rewriter.create <linalg::FillOp>(
663- currentOp->getLoc (), option.fillValue , currentOp.getDpsInits ()[0 ]);
664- IRMapping mapping;
665- mapping.map (currentOp.getDpsInits ()[0 ], fillOp.getResult (0 ));
666- auto res = rewriter.clone (*(currentOp.getOperation ()), mapping);
667- rewriter.replaceOp (currentOp, res);
668- currentOp = dyn_cast<linalg::LinalgOp>(res);
674+
675+ auto cond = eb (true );
676+ for (auto loop : option.KLoopHandles ) {
677+ auto induceVar = eb.wrap <mlir::easybuild::EBUnsigned>(
678+ loop.getLoopRegions ().front ()->front ().getArgument (0 ));
679+ auto currentCond = induceVar == eb.toIndex (0 );
680+ cond = cond & currentCond;
681+ }
682+ EB_scf_if (cond, {currentOp.getDpsInits ()[0 ].getType ()}) {
683+ auto fillOp = rewriter.create <linalg::FillOp>(
684+ currentOp->getLoc (), option.fillValue , currentOp.getDpsInits ()[0 ]);
685+ IRMapping mapping;
686+ mapping.map (currentOp.getDpsInits ()[0 ], fillOp.getResult (0 ));
687+ auto res = rewriter.clone (*(currentOp.getOperation ()), mapping);
688+ eb.yield (res->getResult (0 ));
689+ }
690+ EB_else {
691+ auto res = rewriter.clone (*(currentOp.getOperation ()));
692+ eb.yield (res->getResult (0 ));
693+ }
694+ auto ifOp = eb.getLastOperaion ();
695+ rewriter.replaceOp (currentOp, ifOp);
696+ ifOp->getParentOfType <func::FuncOp>().dump ();
669697 }
670- currentOp.getOperation ()->getParentOfType <func::FuncOp>().dump ();
671698 return success ();
672699 }
673700
@@ -685,7 +712,6 @@ struct deepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
685712 // if-else block)
686713 bool hasFillOp = false ;
687714 Value fillValue;
688- SmallVector<LoopLikeOpInterface> KLoopHandle;
689715 if (auto op = dyn_cast<linalg::FillOp>(
690716 linalgOp.getDpsInits ()[0 ].getDefiningOp ())) {
691717 hasFillOp = true ;
@@ -707,7 +733,8 @@ struct deepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
707733 // Step 3 inner loop generation, convert the linalg.generic to brgemm
708734 if (failed (innerBodyGeneration (
709735 rewriter, matmulOp, linalgOp,
710- innerBodyGenerationOption{hasFillOp, fillValue}))) {
736+ innerBodyGenerationOption{hasFillOp, fillValue,
737+ outerLoopResult->reductionLoops }))) {
711738 return failure ();
712739 }
713740 return success ();
0 commit comments