Skip to content

Commit 0e7794c

Browse files
committed
use CRTP and type trait to avoid virtual function to improve compile performance
1 parent 919dd11 commit 0e7794c

File tree

7 files changed

+106
-47
lines changed

7 files changed

+106
-47
lines changed

include/gc/Analysis/VectorBasedFusionAnalysis.h

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#define MLIR_ANALYSIS_VECTORBASEDFUSIONANALYSIS_H
1111

1212
#include "gc/Dialect/Linalgx/LinalgxOps.h"
13+
#include "gc/Dialect/Linalgx/Utils.h"
1314
#include "gc/Dialect/Microkernel/MicrokernelOps.h"
1415
#include "gc/Transforms/Passes.h"
1516
#include "gc/Transforms/Utils/VectorUtils.h"
@@ -28,8 +29,7 @@ namespace gc {
2829

2930
/// record hardware information
3031
struct HardWareInfo {
31-
bool favx512f = true;
32-
bool favx2 = true;
32+
size_t vectorWidth = 0;
3333
};
3434

3535
/// Vector type conversion helper class
@@ -66,6 +66,7 @@ enum class ReturnTypeKind {
6666
RT_InGroup,
6767
};
6868

69+
/// Base class of vector-based fusion.
6970
class VectorFusionBase {
7071

7172
private:
@@ -257,16 +258,19 @@ Operation *GroupOperationFusion::getNextTargetOperationInCurrentGroup(
257258

258259
while (!tmpOpQueue.empty()) {
259260
auto frontOp = tmpOpQueue.front();
260-
if (isa<Target>(frontOp)) {
261-
for (auto x : frontOp->getOperands())
262-
if (x.getDefiningOp() == curOp)
263-
return frontOp;
264-
}
265261
tmpOpQueue.pop();
262+
if (not isa<Target>(frontOp))
263+
continue;
264+
for (auto x : frontOp->getOperands())
265+
if (x.getDefiningOp() == curOp)
266+
return frontOp;
266267
}
267268
return nullptr;
268269
}
269270

271+
/// Analysis each operation group class.
272+
/// Currently it will run vector-base fusion, analysis empty group and each
273+
/// operation group's max vectorized step.
270274
class GroupOperationAnalysis {
271275
private:
272276
/// vector-based fusion related data
@@ -282,7 +286,7 @@ class GroupOperationAnalysis {
282286
void analysisGroupMaxSteps();
283287
/// get fusion strategy
284288
GroupOperationFusion &getGroupOperationFusion() { return fusionStrategy; }
285-
289+
/// running the vector-based fusion
286290
void run() { fusionStrategy.run(); }
287291
};
288292
} // namespace gc

lib/gc/Transforms/TilingVector.hpp renamed to include/gc/Transforms/TilingVector.h

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "mlir/Dialect/Math/IR/Math.h"
1919
#include "mlir/Dialect/SCF/IR/SCF.h"
2020
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
21+
#include "mlir/Dialect/Vector/IR/VectorOps.h"
2122
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
2223
#include "mlir/Dialect/Vector/Transforms/Passes.h"
2324
#include "mlir/IR/Visitors.h"
@@ -120,9 +121,30 @@ struct GenerateLoopHelper {
120121
//===----------------------------------------------------------------------===//
121122
// vectorize operation class
122123
//===----------------------------------------------------------------------===//
124+
class MultiReductionCanonicalizer;
125+
class BroadcastCanonicalizer;
126+
class TransposeCanonicalizer;
127+
class ShapeCastCanonicalizer;
128+
129+
// fixed extraction trait
130+
template <typename T> struct SpecialOpTraits;
131+
template <> struct SpecialOpTraits<vector::MultiDimReductionOp> {
132+
using DerivedSpecialT = MultiReductionCanonicalizer;
133+
};
134+
template <> struct SpecialOpTraits<vector::BroadcastOp> {
135+
using DerivedSpecialT = BroadcastCanonicalizer;
136+
};
137+
template <> struct SpecialOpTraits<vector::TransposeOp> {
138+
using DerivedSpecialT = TransposeCanonicalizer;
139+
};
140+
template <> struct SpecialOpTraits<vector::ShapeCastOp> {
141+
using DerivedSpecialT = ShapeCastCanonicalizer;
142+
};
123143

124144
/// base class of special operation
125145
template <class T> class SpecialOperationCanonicalizer {
146+
using DerivedT = typename SpecialOpTraits<T>::DerivedSpecialT;
147+
126148
private:
127149
/// store current special operation
128150
SmallVector<T, 4> candidateRdOps;
@@ -148,9 +170,12 @@ template <class T> class SpecialOperationCanonicalizer {
148170
SpecialOperationCanonicalizer(const SmallVector<T, 4> &candidateRdOps,
149171
SpecialOperationKind kind, size_t step)
150172
: candidateRdOps(candidateRdOps), vectorStep(step), kind(kind) {}
151-
llvm::SmallVector<T, 4> &getCandidateOps();
173+
SmallVector<T, 4> &getCandidateOps();
152174
virtual ~SpecialOperationCanonicalizer() {}
153-
virtual void prepareSpecialOperationInfo() = 0;
175+
/// call derived speical operation init information methods
176+
void prepareSpecialOperationInfo() {
177+
static_cast<DerivedT *>(this)->prepareSpecialInfo();
178+
}
154179
/// get kind of speical operation
155180
SpecialOperationKind getKind() noexcept { return kind; }
156181
/// set current operation group vectorize step
@@ -241,7 +266,7 @@ class MultiReductionCanonicalizer
241266

242267
/// initalize parallel, reduction axis, reduction operation type and whether
243268
/// last dim is reduction axis
244-
void prepareSpecialOperationInfo() override;
269+
void prepareSpecialInfo();
245270

246271
static bool classof(SpecialOperationCanonicalizer *canonicalizer) {
247272
return canonicalizer->getKind() ==
@@ -259,7 +284,7 @@ class BroadcastCanonicalizer
259284
: SpecialOperationCanonicalizer<vector::BroadcastOp>(
260285
candidateBcOps, SpecialOperationKind::OP_Broadcast, steps){};
261286
virtual ~BroadcastCanonicalizer() noexcept {}
262-
void prepareSpecialOperationInfo() override {}
287+
void prepareSpecialInfo(){};
263288
static bool classof(SpecialOperationCanonicalizer *canonicalizer) {
264289
return canonicalizer->getKind() == SpecialOperationKind::OP_Broadcast;
265290
}
@@ -278,7 +303,7 @@ class TransposeCanonicalizer
278303
: SpecialOperationCanonicalizer<vector::TransposeOp>(
279304
candidateTpOps, SpecialOperationKind::OP_Transpose, steps){};
280305
virtual ~TransposeCanonicalizer() noexcept {}
281-
void prepareSpecialOperationInfo() override{};
306+
void prepareSpecialInfo(){};
282307
static bool classof(SpecialOperationCanonicalizer *canonicalizer) {
283308
return canonicalizer->getKind() == SpecialOperationKind::OP_Transpose;
284309
}
@@ -306,7 +331,7 @@ class ShapeCastCanonicalizer
306331
: SpecialOperationCanonicalizer<vector::ShapeCastOp>(
307332
candidateScOps, SpecialOperationKind::OP_ShapeCast, steps){};
308333
virtual ~ShapeCastCanonicalizer() {}
309-
void prepareSpecialOperationInfo() override {}
334+
void prepareSpecialInfo() {}
310335
static bool classof(SpecialOperationCanonicalizer *canonicalizer) {
311336
return canonicalizer->getKind() == SpecialOperationKind::OP_ShapeCast;
312337
}

include/gc/Transforms/Utils/VectorUtils.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,8 @@ namespace gc {
5454
/// block.
5555
/// insert_slice just move them to the privious of the first operation which
5656
/// use it.
57-
void moveSomeInterferenceOperation(
58-
func::FuncOp *func, MLIRContext *ctx,
59-
std::function<bool(Operation *)> &conditionalFunc);
57+
void moveOpsFrontOrBack(func::FuncOp *func, MLIRContext *ctx,
58+
std::function<bool(Operation *)> &conditionalFunc);
6059

6160
/// build a constant operation of index type
6261
Value makeIndexArithConstantOp(OpBuilder &opBuilder, const Location &loc,

lib/gc/Analysis/VectorBasedFusionAnalysis.cpp

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
//
77
//===----------------------------------------------------------------------===//
88
#include "gc/Analysis/VectorBasedFusionAnalysis.h"
9-
#include "gc/Dialect/Linalgx/Utils.h"
109

1110
namespace mlir {
1211
namespace gc {
@@ -397,17 +396,7 @@ int TypeHelper::generateValidSteps(int steps, VectorType type) {
397396
// Get the maximum number of current data types that a register can hold
398397
[[nodiscard]] int TypeHelper::getDataTypeMAXSIMDLength(VectorType type) {
399398
auto typebits = type.getElementTypeBitWidth();
400-
const int favx512bits = 512;
401-
const int favx2bits = 256;
402-
if (info.favx512f)
403-
return favx512bits / typebits;
404-
405-
if (info.favx2)
406-
return favx2bits / typebits;
407-
408-
// invalid hardware
409-
llvm_unreachable("Invalid hardware.");
410-
return -1;
399+
return info.vectorWidth / typebits;
411400
}
412401

413402
/// Get a appropriate for loop step for current vector type

lib/gc/Transforms/CPUPhysicalRegisterPass.cpp

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66
//
77
//===----------------------------------------------------------------------===//
8-
#include "TilingVector.hpp"
8+
#include "gc/Transforms/TilingVector.h"
99

1010
namespace mlir {
1111
namespace gc {
@@ -1802,7 +1802,7 @@ bool MultiReductionCanonicalizer::hasLastDimReduction() {
18021802
return res;
18031803
}
18041804

1805-
void MultiReductionCanonicalizer::prepareSpecialOperationInfo() {
1805+
void MultiReductionCanonicalizer::prepareSpecialInfo() {
18061806
if (getCandidateOps().empty())
18071807
return;
18081808

@@ -2110,9 +2110,8 @@ void ForLoopGenerator::setOperationCorrectOperand(
21102110
loopHelperParam
21112111
.loopIterArgs[loopHelperParam.currentLoopStateIdxMap.at(loopArg)]);
21122112
}
2113-
int offset = isa<vector::TransferWriteOp>(op) ? 2 : 1;
2114-
if (dyn_cast<vector::TransferWriteOp>(op) ||
2115-
dyn_cast<vector::TransferReadOp>(op)) {
2113+
int operandOffset = isa<vector::TransferWriteOp>(op) ? 2 : 1;
2114+
if (isReadOrWriteOperation(op)) {
21162115
if (not opPermuationMap.contains(op))
21172116
llvm_unreachable("Map must contains operation.");
21182117

@@ -2133,7 +2132,7 @@ void ForLoopGenerator::setOperationCorrectOperand(
21332132
}
21342133

21352134
ShapedType tensorType =
2136-
cast<ShapedType>(op->getOperandTypes()[offset - 1]);
2135+
cast<ShapedType>(op->getOperandTypes()[operandOffset - 1]);
21372136
int64_t varIdx = dim;
21382137
if (tensorType.getRank() >
21392138
(int64_t)loopHelperParam.inductionVars.size()) {
@@ -2146,11 +2145,12 @@ void ForLoopGenerator::setOperationCorrectOperand(
21462145
}
21472146
if (loopHelperParam.indiceLoopMap.contains(op))
21482147
op->setOperand(
2149-
dim + offset,
2148+
dim + operandOffset,
21502149
loopHelperParam
21512150
.inductionVars[loopHelperParam.indiceLoopMap[op][varIdx]]);
21522151
else
2153-
op->setOperand(dim + offset, loopHelperParam.inductionVars[varIdx]);
2152+
op->setOperand(dim + operandOffset,
2153+
loopHelperParam.inductionVars[varIdx]);
21542154
}
21552155
if (auto readOp = dyn_cast<vector::TransferReadOp>(op)) {
21562156
size_t grpIdx = getVectorBasedFusion().getOpGroupIndexMap()[op];
@@ -2780,8 +2780,8 @@ void GroupOperationFusionImpl::broadcastFromElements(Operation *op,
27802780
op->getLoc(), newOperandType, op->getOperands()[0]);
27812781
removeOpInCurrentGroups(grpIdx, op, bcastOp);
27822782
std::function<bool(Operation *)> candidateFunc = isBroadcastOp;
2783-
moveSomeInterferenceOperation(&getGroupOperationFusion().getFunction(),
2784-
op->getContext(), candidateFunc);
2783+
moveOpsFrontOrBack(&getGroupOperationFusion().getFunction(),
2784+
op->getContext(), candidateFunc);
27852785
}
27862786
}
27872787
}
@@ -2946,22 +2946,21 @@ struct CPUPhysicalRegisterPass
29462946
}
29472947
// affineApply operation is always used by other operations.
29482948
std::function<bool(Operation *)> candidateFunc = isProducerOp;
2949-
moveSomeInterferenceOperation(&func, ctx, candidateFunc);
2949+
moveOpsFrontOrBack(&func, ctx, candidateFunc);
29502950
candidateFunc = isCandidateMoveOperations;
2951-
moveSomeInterferenceOperation(&func, ctx, candidateFunc);
2951+
moveOpsFrontOrBack(&func, ctx, candidateFunc);
29522952
// canonicalize vector operation, default use vector-based fusion
29532953
// strategy.
29542954
HardWareInfo hwInfo;
29552955
CPUTargetDescriptionAnalysis sysDesc =
29562956
getAnalysis<CPUTargetDescriptionAnalysis>();
2957-
hwInfo.favx512f = sysDesc.getMaxVectorWidth() >= 512;
2958-
hwInfo.favx2 = sysDesc.getMaxVectorWidth() >= 256;
2957+
hwInfo.vectorWidth = sysDesc.getMaxVectorWidth();
29592958
VectorOperationCanonicalizer canonicalizer(
29602959
func, hwInfo, CanonicalizerKind::GroupOperations);
29612960
canonicalizer.run();
29622961

29632962
candidateFunc = isReadOrWriteOperation;
2964-
moveSomeInterferenceOperation(&func, ctx, candidateFunc);
2963+
moveOpsFrontOrBack(&func, ctx, candidateFunc);
29652964

29662965
// transpose kernel
29672966
vector::VectorTransformsOptions transposeOptions =

lib/gc/Transforms/Utils/VectorUtils.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -154,9 +154,8 @@ void moveCandidateOperation(
154154
// block.
155155
// insert_slice just move them to the privious of the first operation which
156156
// use it.
157-
void moveSomeInterferenceOperation(
158-
func::FuncOp *func, MLIRContext *ctx,
159-
std::function<bool(Operation *)> &conditionalFunc) {
157+
void moveOpsFrontOrBack(func::FuncOp *func, MLIRContext *ctx,
158+
std::function<bool(Operation *)> &conditionalFunc) {
160159
// Pre-order traversal of each op
161160
// Record each operation position. Inorder to we can kown current operation
162161
// should move after which operation.

test/mlir/test/gc/Transforms/cpu-phyaical-register.mlir renamed to test/mlir/test/gc/Transforms/cpu-physical-register.mlir

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -664,3 +664,47 @@ func.func @add_small_tensor_test14(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -
664664
%2 = linalg.max ins(%1, %cst : tensor<2xf32>, tensor<2xf32>) outs(%0: tensor<2xf32>) -> tensor<2xf32>
665665
return %2 : tensor<2xf32>
666666
}
667+
668+
// CHECK-LABEL: func @broadcast_add_test15
669+
// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
670+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
671+
// CHECK: %[[C1:.*]] = arith.constant 1 : index
672+
// CHECK: %[[C64:.*]] = arith.constant 64 : index
673+
// CHECK: %[[C16:.*]] = arith.constant 16 : index
674+
// CHECK: scf.for %[[arg2:.*]] = %[[C0]] to %[[C64]] step %[[C1]] iter_args(%[[arg3:.*]] = {{.*}}) -> (tensor<64x64xf32>)
675+
// CHECK: scf.for %[[arg4:.*]] = %[[C0]] to %[[C64]] step %[[C16]] iter_args(%[[arg5:.*]] = %[[arg3]]) -> (tensor<64x64xf32>)
676+
// CHECK: %[[READ0:.*]] = vector.transfer_read %[[arg5]][%[[arg2]], %[[arg4]]], %[[CST]] {in_bounds = [true]} : tensor<64x64xf32>, vector<16xf32>
677+
// CHECK: %[[READ1:.*]] = vector.transfer_read %arg0[%[[arg4]]], %[[CST]] {in_bounds = [true]} : tensor<64xf32>, vector<16xf32>
678+
// CHECK: %[[ADD0:.*]] = arith.addf %[[READ1]], %[[READ0]] : vector<16xf32>
679+
// CHECK: %[[WRITE:.*]] = vector.transfer_write %[[ADD0]], %[[arg5]][%[[arg2]], %[[arg4]]] {in_bounds = [true]} : vector<16xf32>, tensor<64x64xf32>
680+
func.func @broadcast_add_test15(%arg0: tensor<64xf32>, %arg1: tensor<64x64xf32>) -> tensor<64x64xf32> {
681+
%0 = tensor.empty() : tensor<64x64xf32>
682+
%bcast = linalg.broadcast
683+
ins(%arg0:tensor<64xf32>)
684+
outs(%0:tensor<64x64xf32>)
685+
dimensions = [0]
686+
%out3 = linalg.add ins(%bcast, %arg1: tensor<64x64xf32>, tensor<64x64xf32>)
687+
outs(%arg1: tensor<64x64xf32>) -> tensor<64x64xf32>
688+
return %out3: tensor<64x64xf32>
689+
}
690+
691+
// CHECK-LABEL: func @broadcast_single_test16
692+
// CHECK: %[[C16:.*]] = arith.constant 16 : index
693+
// CHECK: %[[C64:.*]] = arith.constant 64 : index
694+
// CHECK: %[[C1:.*]] = arith.constant 1 : index
695+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
696+
// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
697+
// CHECK: %[[EMPTY0:.*]] = tensor.empty() : tensor<64x64xf32>
698+
// CHECK: scf.for %[[arg1:.*]] = %[[C0]] to %[[C64]] step %[[C1]] iter_args(%[[arg2:.*]] = %[[EMPTY0]]) -> (tensor<64x64xf32>)
699+
// CHECK: scf.for %[[arg3:.*]] = %[[C0]] to %[[C64]] step %[[C16]] iter_args(%[[arg4:.*]] = %[[arg2]]) -> (tensor<64x64xf32>)
700+
// CHECK: %[[READ0:.*]] = vector.transfer_read %arg0[%[[arg3]]], %[[CST]] {in_bounds = [true]} : tensor<64xf32>, vector<16xf32>
701+
// CHECK: %[[WRITE0:.*]] = vector.transfer_write %[[READ0]], %[[arg4]][%[[arg1]], %[[arg3]]] {in_bounds = [true]} : vector<16xf32>, tensor<64x64xf32>
702+
func.func @broadcast_single_test16(%arg0: tensor<64xf32>) -> tensor<64x64xf32> {
703+
%0 = tensor.empty() : tensor<64x64xf32>
704+
%bcast = linalg.broadcast
705+
ins(%arg0: tensor<64xf32>)
706+
outs(%0:tensor<64x64xf32>)
707+
dimensions = [0]
708+
return %bcast: tensor<64x64xf32>
709+
}
710+

0 commit comments

Comments
 (0)