@@ -22,27 +22,15 @@ llvm::raw_ostream &operator<<(llvm::raw_ostream &ss,
2222 SmallVector<int64_t > innerAxis = layoutCache.getInnerAxis ();
2323 SmallVector<OpFoldResult> tileSizes = layoutCache.getTileSizes ();
2424 ss << " [" ;
25- for (size_t i = 0 ; i < outerAxis.size (); ++i) {
26- if (i != 0 ) {
27- ss << " , " ;
28- }
29- ss << outerAxis[i];
30- }
31- for (size_t i = 0 ; i < innerAxis.size (); ++i) {
32- ss << (i == 0 ? " ; " : " , " );
33- ss << innerAxis[i];
25+ llvm::interleaveComma (outerAxis, ss);
26+ if (!innerAxis.empty ()) {
27+ ss << " ; " ;
28+ llvm::interleaveComma (innerAxis, ss);
3429 }
3530 ss << " ]" ;
3631 if (!tileSizes.empty ()) {
3732 ss << " ; {" ;
38- for (size_t i = 0 ; i < tileSizes.size (); ++i) {
39- if (i != 0 ) {
40- ss << " , " ;
41- }
42- if (getConstantIntValue (tileSizes[i]).has_value ()) {
43- ss << *getConstantIntValue (tileSizes[i]);
44- }
45- }
33+ llvm::interleaveComma (tileSizes, ss);
4634 ss << " }" ;
4735 }
4836 return ss;
@@ -58,11 +46,11 @@ llvm::raw_ostream &operator<<(llvm::raw_ostream &ss,
5846 const OperatorLayout &opLayout) {
5947 for (auto &&[idx, layoutCache] :
6048 llvm::enumerate (opLayout.getSupportedInputLayouts ())) {
61- ss << " input " << idx << " 's layoutCache : " << layoutCache << " \n " ;
49+ ss << " input " << idx << " 's layout : " << layoutCache << " \n " ;
6250 }
6351 for (auto &&[idx, layoutCache] :
6452 llvm::enumerate (opLayout.getSupportedOutputLayouts ())) {
65- ss << " output " << idx << " 's layoutCache : " << layoutCache << " \n " ;
53+ ss << " output " << idx << " 's layout : " << layoutCache << " \n " ;
6654 }
6755 return ss;
6856}
@@ -156,15 +144,15 @@ inferTargetLayout(TensorLayout layoutBase,
156144}
157145
158146static size_t getTargetInputIdx (ArrayRef<TensorLayout> curInputLayouts) {
159- for (auto i = 0 ; i < curInputLayouts.size (); ++i) {
147+ for (size_t i = 0 ; i < curInputLayouts.size (); ++i) {
160148 if (!curInputLayouts[i].isPlainLayout ()) {
161149 return i;
162150 }
163151 }
164152 return 0 ;
165153}
166154
167- static bool supportedContractionOpList (linalg::LinalgOp &linalgOp) {
155+ static bool supportedContractionNamedOpList (linalg::LinalgOp &linalgOp) {
168156 if (isa<linalg::MatmulOp, linalg::MatmulTransposeAOp,
169157 linalg::MatmulTransposeBOp, linalg::BatchMatmulOp,
170158 linalg::BatchMatmulTransposeAOp, linalg::BatchMatmulTransposeBOp>(
@@ -211,7 +199,7 @@ GlobalAnalysis::GlobalAnalysis(Operation *root) {
211199 // ------ Get Current Op's Suggested Layout & Do Propagation ------
212200 IRRewriter rewriter (linalgOp);
213201 // TODO: extend to packed/vnni matmul ops
214- if (supportedContractionOpList (linalgOp)) {
202+ if (supportedContractionNamedOpList (linalgOp)) {
215203 // get input and output rank
216204 auto ARank = cast<ShapedType>(linalgOp.getDpsInputs ()[0 ].getType ())
217205 .getShape ()
@@ -253,7 +241,8 @@ GlobalAnalysis::GlobalAnalysis(Operation *root) {
253241 rewriter.getIndexAttr (iin)});
254242 OperatorLayout suggestedLayout ({ALayout, BLayout}, {CLayout});
255243 layoutCache[linalgOp] = suggestedLayout;
256- } else if (!mlir::linalg::isaContractionOpInterface (linalgOp)) {
244+ } else if (!mlir::linalg::isaContractionOpInterface (linalgOp) &&
245+ !supportedContractionNamedOpList (linalgOp)) {
257246 SmallVector<TensorLayout> inputLayouts, outputLayouts;
258247 size_t targetIdx = getTargetInputIdx (curInputLayouts);
259248 // TODO(yifei): wisely choose the input format basis
@@ -345,11 +334,12 @@ GlobalAnalysis::GlobalAnalysis(Operation *root) {
345334
346335namespace utils {
347336bool isPackableNamedOp (Operation *op) {
348- if ((isa<linalg::LinalgOp>(op) &&
349- !mlir::linalg::isaContractionOpInterface (
350- dyn_cast<linalg::LinalgOp>(op)) &&
351- !isa<linalgx::Mm4DVnniOp>(op)) ||
352- isa<tensor::ExpandShapeOp>(op) || isa<tensor::PadOp>(op))
337+ if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
338+ if (!supportedContractionNamedOpList (linalgOp)) {
339+ return true ;
340+ }
341+ } else if (isa<tensor::ExpandShapeOp, tensor::CollapseShapeOp, tensor::PadOp>(
342+ op))
353343 return true ;
354344 return false ;
355345}
0 commit comments