@@ -794,11 +794,10 @@ FailureOr<TiledLinalgOp> static tileLinalgOpImpl(
794794 return tileLinalgOpImpl<LoopTy>(b, op, tileSizeVector, options);
795795}
796796
797- FailureOr<linalg::ForallReductionTilingResult>
798- tileAllUsingForall (RewriterBase &b, PartialReductionOpInterface op,
799- ArrayRef<OpFoldResult> numThreads,
800- ArrayRef<OpFoldResult> tileSizes,
801- std::optional<ArrayAttr> mapping) {
797+ FailureOr<linalg::ForallReductionTilingResult> tileAllUsingForall (
798+ RewriterBase &b, PartialReductionOpInterface op,
799+ ArrayRef<OpFoldResult> threadNums, ArrayRef<OpFoldResult> tileSizes,
800+ ArrayRef<OpFoldResult> newParallelDims, std::optional<ArrayAttr> mapping) {
802801 Location loc = op.getLoc ();
803802 OpBuilder::InsertionGuard g (b);
804803
@@ -834,6 +833,24 @@ tileAllUsingForall(RewriterBase &b, PartialReductionOpInterface op,
834833 if (iteratorType == utils::IteratorType::reduction)
835834 redDims.push_back (idx);
836835 }
836+
837+ SmallVector<OpFoldResult> numThreads (threadNums.begin (), threadNums.end ());
838+ if (numThreads.empty ()) {
839+ SmallVector<Range> loopRanges = tilingInterfaceOp.getIterationDomain (b);
840+ unsigned nLoops = loopRanges.size ();
841+ numThreads.reserve (nLoops);
842+ AffineExpr s0, s1;
843+ bindSymbols (b.getContext (), s0, s1);
844+ AffineExpr divExpr = s0.ceilDiv (s1);
845+ for (const auto &it : llvm::zip (tileSizes, loopRanges)) {
846+ OpFoldResult numTiles = std::get<0 >(it);
847+ if (!isConstantIntValue (numTiles, 0 ))
848+ numTiles = makeComposedFoldedAffineApply (
849+ b, op.getLoc (), divExpr, {std::get<1 >(it).size , std::get<0 >(it)});
850+ numThreads.push_back (numTiles);
851+ }
852+ }
853+
837854 bool hasReductionThreads = false ;
838855 for (auto dim : redDims) {
839856 if (!isConstantIntValue (numThreads[dim], 0 ) &&
@@ -850,13 +867,24 @@ tileAllUsingForall(RewriterBase &b, PartialReductionOpInterface op,
850867 if ((unsigned )redDims.front () >= numThreads.size ())
851868 return b.notifyMatchFailure (
852869 op, " reduction dimension must be mapped to threads" );
853-
870+ SmallVector<int > constantNewParallelDims;
871+ for (auto dim : newParallelDims) {
872+ if (getConstantIntValue (dim) == std::nullopt )
873+ return b.notifyMatchFailure (
874+ op, " Expected new parallel dims to be constant integers." );
875+ constantNewParallelDims.push_back (*getConstantIntValue (dim));
876+ }
877+ if (newParallelDims.empty ())
878+ constantNewParallelDims = redDims;
879+ if (constantNewParallelDims.size () != redDims.size ())
880+ return b.notifyMatchFailure (
881+ op, " reduction dimension must be mapped to new parallel dims" );
854882 // 1. Create the inital tensor value.
855883 FailureOr<Operation *> identityTensor = nullptr ;
856884 if (hasReductionThreads) {
857885 identityTensor = LinalgOpPartialReductionInterface::
858- generateInitialTensorForPartialReduction (op, b, loc, numThreads,
859- redDims, {} );
886+ generateInitialTensorForPartialReduction (
887+ op, b, loc, numThreads, redDims, constantNewParallelDims );
860888 }
861889 if (failed (identityTensor))
862890 return b.notifyMatchFailure (op,
@@ -866,7 +894,6 @@ tileAllUsingForall(RewriterBase &b, PartialReductionOpInterface op,
866894 SmallVector<Value> dest;
867895 if (failed (tensor::getOrCreateDestinations (b, loc, op, dest)))
868896 return b.notifyMatchFailure (op, " failed to get destination tensors" );
869-
870897 Operation *tiledOp = nullptr ;
871898
872899 SmallVector<OpFoldResult> nonZeroNumThreads =
@@ -875,20 +902,21 @@ tileAllUsingForall(RewriterBase &b, PartialReductionOpInterface op,
875902 }));
876903 SmallVector<Value> materializedNonZeroNumThreads =
877904 getValueOrCreateConstantIndexOp (b, loc, nonZeroNumThreads);
878-
879905 // 2. Create the ForallOp with an empty region.
880906 scf::ForallOp forallOp = b.create <scf::ForallOp>(
881907 loc, getAsOpFoldResult (materializedNonZeroNumThreads),
882908 hasReductionThreads ? (*identityTensor)->getResults () : dest, mapping);
883-
884909 // 3. Calculate the tile offsets and sizes for the subsequent loop that will
885910 // be nested under `forallOp`.
886911 SmallVector<OpFoldResult> tiledOffsets, tiledSizes;
912+ std::optional<ArrayRef<OpFoldResult>> nominalTileSizes = std::nullopt ;
913+ if (!tileSizes.empty () && threadNums.empty ()) {
914+ nominalTileSizes = tileSizes;
915+ }
887916 calculateTileOffsetsAndSizes (b, loc, forallOp, numThreads, iterationDomain,
888917 /* omitTileOffsetBoundsCheck =*/ false ,
889- /* nominalTileSizes=*/ tileSizes, tiledOffsets,
890- tiledSizes);
891-
918+ /* nominalTileSizes=*/ nominalTileSizes,
919+ tiledOffsets, tiledSizes);
892920 // 4. Clone the tileable op and update its destination operands to use the
893921 // output bbArgs of the ForallOp.
894922 SmallVector<Value> tilingResults;
@@ -907,20 +935,26 @@ tileAllUsingForall(RewriterBase &b, PartialReductionOpInterface op,
907935 SmallVector<OpFoldResult> strides (numThreads.size (), b.getIndexAttr (1 ));
908936 SmallVector<OpFoldResult> outOffsets (numThreads.size (),
909937 b.getIndexAttr (0 ));
910- SmallVector<OpFoldResult> sizes;
911- for (auto s :
912- cast<RankedTensorType>(destBbArgs[destNum].getType ()).getShape ()) {
913- sizes.emplace_back (getAsIndexOpFoldResult (b.getContext (), (int )s));
914- }
915- for (auto dim : redDims) {
916- sizes[dim] = b.getIndexAttr (1 );
938+ SmallVector<OpFoldResult> sizes = tiledSizes;
939+ for (const auto &iteratorType : llvm::enumerate (
940+ cast<RankedTensorType>(destBbArgs[destNum].getType ())
941+ .getShape ())) {
942+ sizes[iteratorType.index ()] =
943+ getAsIndexOpFoldResult (b.getContext (), iteratorType.value ());
944+ if (llvm::find (constantNewParallelDims, iteratorType.index ()) !=
945+ constantNewParallelDims.end ()) {
946+ sizes[iteratorType.index ()] = b.getIndexAttr (1 );
947+ }
917948 }
918949
919950 auto nonZeroDimIdx = 0 ;
920- for (auto dim = 0UL ; dim < numThreads.size (); dim++) {
921- if (!isConstantIntValue (numThreads[dim], 0 )) {
922- if (llvm::find (redDims, dim) != redDims.end ())
923- outOffsets[dim] = forallOp.getInductionVars ()[nonZeroDimIdx];
951+ auto currentReductionIdx = 0 ;
952+ for (const auto &iteratorType : llvm::enumerate (numThreads)) {
953+ if (!isConstantIntValue (iteratorType.value (), 0 )) {
954+ if (llvm::find (redDims, iteratorType.index ()) != redDims.end ()) {
955+ outOffsets[constantNewParallelDims[currentReductionIdx++]] =
956+ forallOp.getInductionVars ()[nonZeroDimIdx];
957+ }
924958 nonZeroDimIdx++;
925959 }
926960 }
@@ -929,7 +963,10 @@ tileAllUsingForall(RewriterBase &b, PartialReductionOpInterface op,
929963 loc, cast<RankedTensorType>(initOperand.getType ()),
930964 destBbArgs[destNum], outOffsets, sizes, strides));
931965 } else {
932- tiledDpsInitOperands.push_back (initOperand);
966+ auto *it = llvm::find (dest, initOperand);
967+ assert (it != dest.end () && " dest operand not found in dest" );
968+ unsigned destNum = std::distance (dest.begin (), it);
969+ tiledDpsInitOperands.push_back (destBbArgs[destNum]);
933970 }
934971 }
935972
@@ -944,19 +981,35 @@ tileAllUsingForall(RewriterBase &b, PartialReductionOpInterface op,
944981 initOperandPtr.set (tiledInitValue);
945982 }
946983 });
947-
948984 // 5. Tile the cloned op and delete the clone.
949- FailureOr<TilingResult> tilingResult =
950- cast<TilingInterface>(clonedOp).getTiledImplementation (b, tiledOffsets,
951- tiledSizes);
952- if (failed (tilingResult))
953- return clonedOp->emitError (" Failed to tile op: " );
954- if (tilingResult->tiledOps .size () != 1 ) {
955- return clonedOp->emitError (" expected a single produced tiled op, got " )
956- << tilingResult->tiledOps .size ();
985+ if (tileSizes.empty () || threadNums.empty ()) {
986+ FailureOr<TilingResult> tilingResult =
987+ cast<TilingInterface>(clonedOp).getTiledImplementation (
988+ b, tiledOffsets, tiledSizes);
989+ if (failed (tilingResult))
990+ return clonedOp->emitError (" Failed to tile op: " );
991+ if (tilingResult->tiledOps .size () != 1 ) {
992+ return clonedOp->emitError (" expected a single produced tiled op, got " )
993+ << tilingResult->tiledOps .size ();
994+ }
995+ tiledOp = tilingResult->tiledOps .front ();
996+ tilingResults = tilingResult->tiledValues ;
997+ } else {
998+ LinalgTilingOptions options;
999+ FailureOr<TiledLinalgOp> maybeTiled = tileLinalgOpImpl<scf::ForOp>(
1000+ b, cast<LinalgOp>(clonedOp), tileSizes, options);
1001+ if (failed (maybeTiled))
1002+ return b.notifyMatchFailure (op, " failed tileLinalgOpImpl" );
1003+
1004+ SmallVector<Value> ids = forallOp.getInductionVars ();
1005+ mapLoopToProcessorIds (cast<scf::ForOp>(maybeTiled->loops .back ()), ids,
1006+ materializedNonZeroNumThreads);
1007+ if (maybeTiled->loops .size () != 1 ) {
1008+ return clonedOp->emitError (" expected a single produced loop" );
1009+ }
1010+ tiledOp = maybeTiled->op ;
1011+ tilingResults = maybeTiled->loops .front ()->getResults ();
9571012 }
958- tiledOp = tilingResult->tiledOps .front ();
959- tilingResults = tilingResult->tiledValues ;
9601013
9611014 b.eraseOp (clonedOp);
9621015 }
@@ -974,23 +1027,33 @@ tileAllUsingForall(RewriterBase &b, PartialReductionOpInterface op,
9741027 return op->emitOpError (" output offsets couldn't be calculated" );
9751028 SmallVector<OpFoldResult> resultOffsetsRank, resultSizesRank;
9761029 int64_t offIdx = 0 ;
977- int64_t sizeIdx = 0 ;
9781030 int64_t nonZeroDimIdx = 0 ;
1031+ SmallVector<Value> reductionInductionVars;
9791032 for (auto i = 0UL ; i < numThreads.size (); ++i) {
980- if (llvm::find (redDims, i) != redDims.end ()) {
1033+ if (llvm::find (constantNewParallelDims, i) !=
1034+ constantNewParallelDims.end ()) {
9811035 if (hasReductionThreads) {
982- resultOffsetsRank.push_back (
983- forallOp.getInductionVars ()[nonZeroDimIdx]);
1036+ resultOffsetsRank.push_back (b.getIndexAttr (1 ));
9841037 resultSizesRank.push_back (b.getIndexAttr (1 ));
9851038 }
986- nonZeroDimIdx++;
987- continue ;
1039+ } else {
1040+ resultOffsetsRank.push_back (resultOffsets[offIdx]);
1041+ resultSizesRank.push_back (resultSizes[offIdx++]);
1042+ }
1043+ if (llvm::find (redDims, i) != redDims.end ()) {
1044+ reductionInductionVars.push_back (
1045+ forallOp.getInductionVars ()[nonZeroDimIdx]);
9881046 }
9891047 if (!isConstantIntValue (numThreads[i], 0 )) {
9901048 nonZeroDimIdx++;
9911049 }
992- resultOffsetsRank.push_back (resultOffsets[offIdx++]);
993- resultSizesRank.push_back (resultSizes[sizeIdx++]);
1050+ }
1051+ if (hasReductionThreads) {
1052+ for (auto [parallelDims, redVar] :
1053+ llvm::zip (constantNewParallelDims, reductionInductionVars)) {
1054+ resultOffsetsRank[parallelDims] = redVar;
1055+ resultSizesRank[parallelDims] = b.getIndexAttr (1 );
1056+ }
9941057 }
9951058 SmallVector<OpFoldResult> strides (resultSizesRank.size (),
9961059 b.getIndexAttr (1 ));
@@ -1001,18 +1064,16 @@ tileAllUsingForall(RewriterBase &b, PartialReductionOpInterface op,
10011064 b.create <tensor::ParallelInsertSliceOp>(
10021065 loc, result, bbArg, resultOffsetsRank, resultSizesRank, strides);
10031066 }
1004-
10051067 // 7. Merge the partial reductions.
10061068 Operation *mergeOp = nullptr ;
10071069 b.setInsertionPointAfter (forallOp);
10081070 if (hasReductionThreads) {
1009- Operation *mergeOp =
1010- op. mergeReductions (b, loc, forallOp-> getResults (), redDims );
1071+ Operation *mergeOp = op. mergeReductions (b, loc, forallOp-> getResults (),
1072+ constantNewParallelDims );
10111073 b.replaceOp (op, mergeOp->getResults ());
10121074 } else {
10131075 b.replaceOp (op, forallOp->getResults ());
10141076 }
1015-
10161077 // 8. Return.
10171078 ForallReductionTilingResult results;
10181079 results.initialOp = *identityTensor;
0 commit comments