@@ -326,6 +326,7 @@ static void setStaticSizeForExtractSliceOp(RewriterBase &rewriter,
326
326
Operation *op, bool isExtract,
327
327
SmallVector<int64_t > size,
328
328
int shrinDimNum = 0 ) {
329
+ llvm::outs () << " ^^^^^^^^^^^^^^setStaticSizeForExtractSliceOp^^^^^^^^^^\n " ;
329
330
OpBuilder::InsertionGuard guard (rewriter);
330
331
rewriter.setInsertionPoint (op);
331
332
if (auto extractSlice = dyn_cast<tensor::ExtractSliceOp>(op)) {
@@ -335,6 +336,23 @@ static void setStaticSizeForExtractSliceOp(RewriterBase &rewriter,
335
336
for (auto i = 0UL ; i < mixedSizes.size (); i++) {
336
337
mixedSizes[i] = getAsIndexOpFoldResult (rewriter.getContext (), size[i]);
337
338
}
339
+ llvm::outs () << " mixedOffsets: " ;
340
+ for (auto t : mixedOffsets) {
341
+ llvm::outs () << t << " , " ;
342
+ }
343
+ llvm::outs () << " \n " ;
344
+
345
+ llvm::outs () << " mixedSizes: " ;
346
+ for (auto t : mixedSizes) {
347
+ llvm::outs () << t << " , " ;
348
+ }
349
+ llvm::outs () << " \n " ;
350
+
351
+ llvm::outs () << " mixedStrides: " ;
352
+ for (auto t : mixedStrides) {
353
+ llvm::outs () << t << " , " ;
354
+ }
355
+ llvm::outs () << " \n " ;
338
356
if (shrinDimNum > 0 ) {
339
357
rewriter.replaceOpWithNewOp <tensor::ExtractSliceOp>(
340
358
extractSlice,
@@ -348,6 +366,7 @@ static void setStaticSizeForExtractSliceOp(RewriterBase &rewriter,
348
366
mixedStrides);
349
367
}
350
368
}
369
+ llvm::outs () << " ^^^^^^^^^^^^^^setStaticSizeForExtractSliceOp^^^^^^^^^^\n " ;
351
370
}
352
371
353
372
static void setStaticSizeForInsertSliceOp (RewriterBase &rewriter, Operation *op,
@@ -398,6 +417,7 @@ struct OuterLoopGenerationResult {
398
417
static FailureOr<OuterLoopGenerationResult>
399
418
generateOuterLoop (RewriterBase &b, linalg::LinalgOp linalgOp,
400
419
const OuterLoopGenerationOption &option) {
420
+ llvm::outs () << " ======================================\n " ;
401
421
// TODO: handle the return value
402
422
OuterLoopGenerationResult result;
403
423
auto nestedTileSizes = option.nestedTileSizes ;
@@ -471,40 +491,82 @@ generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp,
471
491
else
472
492
tileSizes[d] = getAsIndexOpFoldResult (b.getContext (), tile);
473
493
}
494
+
495
+ llvm::outs () << " tileSizes: " ;
496
+ for (auto t : tileSizes) {
497
+ llvm::outs () << t << " , " ;
498
+ }
499
+ llvm::outs () << " \n " ;
500
+
501
+ llvm::outs () << " threads: " ;
502
+ for (auto t : threads) {
503
+ llvm::outs () << t << " , " ;
504
+ }
505
+ llvm::outs () << " \n " ;
506
+
474
507
SmallVector<Range> loopRanges =
475
508
cast<TilingInterface>(currentOp.getOperation ()).getIterationDomain (b);
476
509
OpBuilder::InsertionGuard guard (b);
477
510
b.setInsertionPoint (currentOp);
478
511
if (auto partialInterface =
479
512
dyn_cast<PartialReductionOpInterface>(currentOp.getOperation ())) {
513
+ llvm::outs () << " PartialReductionOpInterface\n " ;
480
514
for (auto [idx, tile] : llvm::enumerate (tileSizes)) {
481
515
if (isConstantIntValue (tile, 0 )) {
482
516
tileSizes[idx] = loopRanges[idx].size ;
483
517
}
484
518
}
485
-
519
+ llvm::outs () << " updated tileSizes: " ;
520
+ for (auto t : tileSizes) {
521
+ llvm::outs () << t << " , " ;
522
+ }
523
+ llvm::outs () << " \n " ;
486
524
SmallVector<OpFoldResult> newParallelDims;
487
525
for (auto i = 0UL ; i < reductionDims.size (); i++) {
488
526
newParallelDims.push_back (getAsIndexOpFoldResult (b.getContext (), i));
489
527
}
490
- auto tilingResult = linalgX::tileAllUsingForall (
491
- b, cast<PartialReductionOpInterface>(currentOp.getOperation ()), {},
492
- tileSizes, newParallelDims, std::nullopt);
493
- if (failed (tilingResult) &&
494
- tilingResult->parallelTiledOps .size () == 1UL )
495
- return failure ();
496
- currentOp =
497
- dyn_cast<linalg::LinalgOp>(tilingResult->parallelTiledOps .back ());
498
- if (!tilingResult->mergeOps .empty ()) {
499
- for (const auto &fn : option.finalReduceCallBacks ) {
500
- auto result = fn (b, currentOp.getLoc (), *tilingResult);
501
- if (succeeded (result)) {
502
- currentOp = *result;
528
+ if (currentTileSize.front () != 16 || true ) {
529
+ auto tilingResult = linalgX::tileAllUsingForall (
530
+ b, cast<PartialReductionOpInterface>(currentOp.getOperation ()),
531
+ {}, tileSizes, newParallelDims, std::nullopt);
532
+ if (failed (tilingResult) &&
533
+ tilingResult->parallelTiledOps .size () == 1UL )
534
+ return failure ();
535
+ currentOp =
536
+ dyn_cast<linalg::LinalgOp>(tilingResult->parallelTiledOps .back ());
537
+ if (!tilingResult->mergeOps .empty ()) {
538
+ llvm::outs () << " has merge ops\n " ;
539
+ for (const auto &fn : option.finalReduceCallBacks ) {
540
+ auto result = fn (b, currentOp.getLoc (), *tilingResult);
541
+ if (succeeded (result)) {
542
+ currentOp = *result;
543
+ }
503
544
}
504
545
}
546
+ } else {
547
+ llvm::outs () << " handle special cases\n " ;
548
+ OpBuilder::InsertionGuard g (b);
549
+
550
+ Location loc = currentOp.getLoc ();
551
+ SmallVector<Value> dest;
552
+ if (failed (tensor::getOrCreateDestinations (b, loc, currentOp, dest)))
553
+ return b.notifyMatchFailure (currentOp,
554
+ " failed to get destination tensors" );
555
+ arith::ConstantIndexOp lb = b.create <arith::ConstantIndexOp>(loc, 0 );
556
+ arith::ConstantIndexOp ub = b.create <arith::ConstantIndexOp>(loc, 2 );
557
+ arith::ConstantIndexOp step =
558
+ b.create <arith::ConstantIndexOp>(loc, 1 );
559
+
560
+ Operation *forallOp = b.create <scf::ForallOp>(
561
+ loc, ArrayRef<OpFoldResult>(lb->getResult (0 )),
562
+ ArrayRef<OpFoldResult>(ub->getResult (0 )),
563
+ ArrayRef<OpFoldResult>(step->getResult (0 )), dest, std::nullopt);
564
+ currentOp = dyn_cast<linalg::LinalgOp>(forallOp);
505
565
}
566
+
506
567
} else if (auto tilingInterface =
507
568
cast<TilingInterface>(currentOp.getOperation ())) {
569
+ llvm::outs () << " TilingInterface\n " ;
508
570
auto tilingResult = linalg::tileToForallOpUsingTileSizes (
509
571
b, tilingInterface, tileSizes, std::nullopt);
510
572
if (failed (tilingResult))
@@ -515,6 +577,7 @@ generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp,
515
577
}
516
578
}
517
579
result.tiledOps .emplace_back (currentOp);
580
+ llvm::outs () << " ======================================\n " ;
518
581
return result;
519
582
}
520
583
@@ -595,6 +658,11 @@ struct deepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
595
658
auto NOuterBlockSize = NDimPos.size () > 1
596
659
? (cfg.NBlock - 1 ) / cfg.innerMostNBlock + 1
597
660
: cfg.NBlock ;
661
+ // Outermost Numa loop
662
+ option.nestedTileSizes .emplace_back (
663
+ SmallVector<size_t >{uint32_t (MFirstDim / 2 )});
664
+ option.loopType .emplace_back (OuterLoopGenerationOption::LoopType::ForallOp);
665
+ option.loopDim .emplace_back (SmallVector<size_t >{MDimPos[0 ]});
598
666
// Outer
599
667
option.nestedTileSizes .emplace_back (SmallVector<size_t >{
600
668
MParallelBlockSize, NParallelBlockSize, KParallelBlockSize});
0 commit comments