From 00873282792070ae076d36816ba529c72174e38f Mon Sep 17 00:00:00 2001 From: "Zhang, Yifei" Date: Tue, 16 Jul 2024 20:15:53 -0700 Subject: [PATCH] fix broadcast axis and binary target input --- lib/gc/Analysis/GlobalAnalysis.cpp | 38 +++++++++++---- lib/gc/Transforms/PropagateLayout.cpp | 66 +++++++++++++++------------ 2 files changed, 66 insertions(+), 38 deletions(-) diff --git a/lib/gc/Analysis/GlobalAnalysis.cpp b/lib/gc/Analysis/GlobalAnalysis.cpp index d1886bacc..0cd217f23 100644 --- a/lib/gc/Analysis/GlobalAnalysis.cpp +++ b/lib/gc/Analysis/GlobalAnalysis.cpp @@ -154,6 +154,15 @@ inferTargetLayout(TensorLayout layoutBase, return TensorLayout(targetOuterAxis, targetInnerAxis, targetTileSizes); } +static size_t getTargetInputIdx(ArrayRef curInputLayouts) { + for (auto i = 0; i < curInputLayouts.size(); ++i) { + if (!curInputLayouts[i].isPlainLayout()) { + return i; + } + } + return 0; +} + GlobalAnalysis::GlobalAnalysis(Operation *root) { root->walk([&](Operation *op) { if (auto linalgOp = dyn_cast(op)) { @@ -201,23 +210,32 @@ GlobalAnalysis::GlobalAnalysis(Operation *root) { layoutCache[linalgOp] = suggestedLayout; } else { SmallVector inputLayouts, outputLayouts; - inputLayouts.push_back(curInputLayouts[0]); + size_t targetIdx = getTargetInputIdx(curInputLayouts); // TODO(yifei): wisely choose the input format basis // Let's only refer to input[0] for now - for (size_t i = 1; i < curInputs.size(); ++i) { + for (size_t i = 0; i < curInputs.size(); ++i) { + std::cout << "inferring indexing map relation" << std::endl; // getMatchingIndexingMap - auto res = inferIndexingMapRelation( - linalgOp.getMatchingIndexingMap(curInputs[0]), - linalgOp.getMatchingIndexingMap(curInputs[i])); - TensorLayout inputLayout = - *inferTargetLayout(curInputLayouts[0], *res); - inputLayouts.push_back(inputLayout); + if (i != targetIdx) { + auto res = inferIndexingMapRelation( + linalgOp.getMatchingIndexingMap(curInputs[targetIdx]), + linalgOp.getMatchingIndexingMap(curInputs[i])); + for (auto tp : *res) { + std::cout << "target index: " << tp.first + << " maps to base index: " << tp.second << std::endl; + } + TensorLayout inputLayout = + *inferTargetLayout(curInputLayouts[targetIdx], *res); + inputLayouts.push_back(inputLayout); + } else { + inputLayouts.push_back(curInputLayouts[targetIdx]); + } } auto res_out = inferIndexingMapRelation( - linalgOp.getMatchingIndexingMap(curInputs[0]), + linalgOp.getMatchingIndexingMap(curInputs[targetIdx]), linalgOp.getIndexingMapMatchingResult(curResults[0])); TensorLayout outputLayout = - *inferTargetLayout(curInputLayouts[0], *res_out); + *inferTargetLayout(curInputLayouts[targetIdx], *res_out); outputLayouts.push_back(outputLayout); OperatorLayout suggestedLayout(inputLayouts, outputLayouts); layoutCache[linalgOp] = suggestedLayout; diff --git a/lib/gc/Transforms/PropagateLayout.cpp b/lib/gc/Transforms/PropagateLayout.cpp index 0580bc8e6..44dbce9c4 100644 --- a/lib/gc/Transforms/PropagateLayout.cpp +++ b/lib/gc/Transforms/PropagateLayout.cpp @@ -38,6 +38,12 @@ using namespace mlir::tensor; static SmallVector getPackedAxes(ArrayRef dimensions, TensorLayout targetLayout) { SmallVector result(dimensions); + // permuting on outer axis + auto outerPerm = targetLayout.getOuterAxis(); + for (size_t i = 0; i < dimensions.size(); ++i) { + result[i] = outerPerm[dimensions[i]]; + } + // inserting inner axis auto innerPos = targetLayout.getInnerAxis(); for (size_t i = 0; i < dimensions.size(); ++i) { if (std::find(innerPos.begin(), innerPos.end(), dimensions[i]) != @@ -153,8 +159,10 @@ FailureOr packNamedOp(RewriterBase &rewriter, loc, inits.getTypes(), inputs, inits, packedAxes); packedLinalgOp->getRegion(0).takeBody(linalgOp->getRegion(0)); } else if (auto broadcastOp = dyn_cast(&linalgOp)) { - packedLinalgOp = rewriter.create( - loc, inputs[0], inits[0], broadcastOp->getDimensions()); + SmallVector packedAxes = + getPackedAxes(broadcastOp->getDimensions(), initLayouts[0]); + packedLinalgOp = rewriter.create(loc, inputs[0], + inits[0], packedAxes); } else if (auto transposeOp = dyn_cast(&linalgOp)) { SmallVector packedPermAxes = getPackedPermAxes( transposeOp->getPermutation(), inputLayouts[0], initLayouts[0]); @@ -237,32 +245,33 @@ LogicalResult namedOpLayoutPropagation(MLIRContext *ctx, mlir::Operation *graph, return WalkResult::skip(); } } else if (auto expandShapeOp = dyn_cast(op)) { - Location loc = expandShapeOp->getLoc(); - auto inputLayout = opLayout->getSupportedInputLayouts()[0]; - auto outputLayout = opLayout->getSupportedOutputLayouts()[0]; - Value dest = tensor::PackOp::createDestinationTensor( - rewriter, loc, expandShapeOp.getSrc(), inputLayout.getTileSizes(), - inputLayout.getInnerAxis(), inputLayout.getOuterAxis()); - Value packedSource = rewriter.create( - loc, expandShapeOp.getSrc(), dest, inputLayout.getInnerAxis(), - inputLayout.getTileSizes(), std::nullopt, - inputLayout.getOuterAxis()); - auto resultType = RankedTensorType::get( - expandShapeOp.getStaticOutputShape(), - expandShapeOp.getSrcType().getElementType()); - RankedTensorType resultPackType = tensor::PackOp::inferPackedType( - resultType, vector::getAsIntegers(outputLayout.getTileSizes()), - outputLayout.getInnerAxis(), outputLayout.getOuterAxis()); - auto reassocExpand = getReassociationIndicesForReshape( - cast(dest.getType()), resultPackType); - auto packedExpandShape = rewriter.create( - loc, expandShapeOp.getSrcType().getElementType(), packedSource, - *reassocExpand); - Value result = rewriter.create( - packedExpandShape->getLoc(), packedExpandShape, packedExpandShape, - outputLayout.getInnerAxis(), outputLayout.getTileSizes(), - outputLayout.getOuterAxis()); - rewriter.replaceOp(expandShapeOp, result); + // Location loc = expandShapeOp->getLoc(); + // auto inputLayout = opLayout->getSupportedInputLayouts()[0]; + // auto outputLayout = opLayout->getSupportedOutputLayouts()[0]; + // Value dest = tensor::PackOp::createDestinationTensor( + // rewriter, loc, expandShapeOp.getSrc(), + // inputLayout.getTileSizes(), inputLayout.getInnerAxis(), + // inputLayout.getOuterAxis()); + // Value packedSource = rewriter.create( + // loc, expandShapeOp.getSrc(), dest, inputLayout.getInnerAxis(), + // inputLayout.getTileSizes(), std::nullopt, + // inputLayout.getOuterAxis()); + // auto resultType = RankedTensorType::get( + // expandShapeOp.getStaticOutputShape(), + // expandShapeOp.getSrcType().getElementType()); + // RankedTensorType resultPackType = tensor::PackOp::inferPackedType( + // resultType, vector::getAsIntegers(outputLayout.getTileSizes()), + // outputLayout.getInnerAxis(), outputLayout.getOuterAxis()); + // auto reassocExpand = getReassociationIndicesForReshape( + // cast(dest.getType()), resultPackType); + // auto packedExpandShape = rewriter.create( + // loc, expandShapeOp.getSrcType().getElementType(), packedSource, + // *reassocExpand); + // Value result = rewriter.create( + // packedExpandShape->getLoc(), packedExpandShape, + // packedExpandShape, outputLayout.getInnerAxis(), + // outputLayout.getTileSizes(), outputLayout.getOuterAxis()); + // rewriter.replaceOp(expandShapeOp, result); } } } @@ -270,6 +279,7 @@ LogicalResult namedOpLayoutPropagation(MLIRContext *ctx, mlir::Operation *graph, }); if (walk.wasSkipped()) return failure(); + graph->dump(); return success(); }