Skip to content

Commit

Permalink
format code
Browse files Browse the repository at this point in the history
  • Loading branch information
zhczhong committed Aug 6, 2024
1 parent 8e5d071 commit 21b4dfe
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 125 deletions.
7 changes: 3 additions & 4 deletions lib/gc/Analysis/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ gc_set_mlir_link_components(MLIR_LINK_COMPONENTS
MLIRIR
MLIRSupport)

add_mlir_library(GCAnalysis
gc_add_mlir_library(GCAnalysis
MatmulConfigAnalysis.cpp

ADDITIONAL_HEADER_DIRS
Expand All @@ -14,6 +14,5 @@ add_mlir_library(GCAnalysis
LINK_LIBS PUBLIC
${mlir_dialect_libs}
${MLIR_LINK_COMPONENTS}
)

set_property(GLOBAL APPEND PROPERTY GC_PASS_LIBS GCAnalysis)
GcInterface
)
82 changes: 33 additions & 49 deletions lib/gc/Analysis/MatmulConfigAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,10 @@ getCandidate(uint32_t num, uint32_t floor,
// factor
std::vector<uint32_t> candidates;
uint32_t upperbound = std::min(num, ceil);
for (uint32_t i = floor; i <= upperbound; i++) {
if (num % i == 0) {
for (uint32_t i = floor; i <= upperbound; i++)
if (num % i == 0)
candidates.push_back(i);
}
}

// the pow of 2
uint32_t candidate = 1U;
while (candidate < floor)
Expand All @@ -68,9 +67,8 @@ getCandidate(uint32_t num, uint32_t floor,
bool validateThreads(ArrayRef<uint32_t> threads, SystemDesc &sysDesc) {
uint32_t numThreads = sysDesc.getNumThreads();
uint32_t actualThreads = 1U;
for (uint32_t t : threads) {
for (uint32_t t : threads)
actualThreads *= t;
}
return actualThreads == numThreads;
}

Expand Down Expand Up @@ -154,9 +152,8 @@ double computationIntensityOnL2Cache(linalg::LinalgOp &linalgOp,
config.NBlock * config.KBlock +
config.MBlock * config.KBlock;
double computationIntensity = FLOPS / memoryConsumption;
if (memoryConsumption * dtypeSize > L2Cache * fullLoadRatio) {
if (memoryConsumption * dtypeSize > L2Cache * fullLoadRatio)
computationIntensity /= outOfCachePenalty;
}
return 1 / computationIntensity;
}

Expand All @@ -183,19 +180,17 @@ filterConfigByCostModel(ArrayRef<MatmulConfig> configs,
double thresholdCost = costs[idx[(size_t)(preserveRatio * configs.size())]];
thresholdCost =
threshold < thresholdCost && threshold > 0 ? threshold : thresholdCost;
for (const auto &i : idx) {
if (costs[i] <= thresholdCost) {
for (const auto &i : idx)
if (costs[i] <= thresholdCost)
result.push_back(configs[i]);
}
}

LLVM_DEBUG(llvm::dbgs() << "thresholdCost is: " << thresholdCost
<< "\nbest with cost: " << costs[idx[0]] << "\n"
<< configs[idx[0]] << "\n worst with cost: "
<< costs[idx[configs.size() - 1]] << "\n"
<< configs[idx[configs.size() - 1]] << "\n");
if (result.empty()) {
if (result.empty())
result = configs;
}
return result;
}

Expand Down Expand Up @@ -248,27 +243,23 @@ prepareConfigCandidates(Operation *root, SystemDesc &sysDesc,
for (uint32_t MThreads : MThreadsCandidates) {
for (uint32_t NThreads : NThreadsCandidates) {
for (uint32_t KThreads : KThreadsCandidates) {
if (!validateThreads({MThreads, NThreads, KThreads}, sysDesc)) {
if (!validateThreads({MThreads, NThreads, KThreads}, sysDesc))
continue;
}
for (uint32_t MBlock : MBlockCandidates) {
for (uint32_t innerMostMBlock : innerMostMBlockCandidates) {
if (MBlock % innerMostMBlock != 0 ||
shape[0] % innerMostMBlock != 0) {
shape[0] % innerMostMBlock != 0)
continue;
}
for (uint32_t NBlock : NBlockCandidates) {
for (uint32_t innerMostNBlock : innerMostNBlockCandidates) {
if (NBlock % innerMostNBlock != 0 ||
shape[1] % innerMostNBlock != 0) {
shape[1] % innerMostNBlock != 0)
continue;
}
for (uint32_t KBlock : KBlockCandidates) {
for (uint32_t innerMostKBlock : innerMostKBlockCandidates) {
if (KBlock % innerMostKBlock != 0 ||
shape[2] % innerMostKBlock != 0) {
shape[2] % innerMostKBlock != 0)
continue;
}
MatmulConfig config{
MThreads, NThreads, KThreads,
MBlock, NBlock, KBlock,
Expand All @@ -293,14 +284,12 @@ bool validateConfig(const MatmulConfig &cfg) {
if (cfg.MThreads <= 0 || cfg.NThreads <= 0 || cfg.KThreads <= 0 ||
cfg.MBlock <= 0 || cfg.NBlock <= 0 || cfg.KBlock <= 0 ||
cfg.innerMostMBlock <= 0 || cfg.innerMostNBlock <= 0 ||
cfg.innerMostKBlock <= 0) {
cfg.innerMostKBlock <= 0)
return false;
}
if (cfg.MBlock % cfg.innerMostMBlock != 0 ||
cfg.NBlock % cfg.innerMostNBlock != 0 ||
cfg.KBlock % cfg.innerMostKBlock != 0) {
cfg.KBlock % cfg.innerMostKBlock != 0)
return false;
}
return true;
}

Expand Down Expand Up @@ -371,19 +360,16 @@ MatmulConfigAnalysis::MatmulConfigAnalysis(Operation *root) {
uint32_t M = 1U, N = 1U, K = 1U;
for (auto &&[s, dimType] :
llvm::zip(linalgOp.getShape(linalgOp.getDpsInputOperand(0)),
oprandDimType[0])) {
if (dimType == DimType::M) {
oprandDimType[0]))
if (dimType == DimType::M)
M *= s;
}
}
for (auto &&[s, dimType] :
llvm::zip(linalgOp.getShape(linalgOp.getDpsInputOperand(1)),
oprandDimType[1])) {
if (dimType == DimType::N) {
if (dimType == DimType::N)
N *= s;
} else if (dimType == DimType::K) {
else if (dimType == DimType::K)
K *= s;
}
}

// innermost Block, if the layout is blockied layout, the innermost block
Expand All @@ -395,30 +381,30 @@ MatmulConfigAnalysis::MatmulConfigAnalysis(Operation *root) {
SmallVector<uint32_t> givenInnermostBlock;
if (MDimTypeIdx.size() > 1) {
config.innerMostMBlock = 1;
for (size_t i = 1UL; i < MDimTypeIdx.size(); i++) {
config.innerMostMBlock *=
linalgOp.getShape(linalgOp.getDpsInputOperand(0))[MDimTypeIdx[i]];
}
for (auto &&[i, d] : llvm::enumerate(MDimTypeIdx))
if (i != 0)
config.innerMostMBlock *=
linalgOp.getShape(linalgOp.getDpsInputOperand(0))[d];
givenInnermostBlock.push_back(config.innerMostMBlock);
} else {
givenInnermostBlock.push_back(0);
}
if (NDimTypeIdx.size() > 1) {
config.innerMostNBlock = 1;
for (size_t i = 1UL; i < NDimTypeIdx.size(); i++) {
config.innerMostNBlock *=
linalgOp.getShape(linalgOp.getDpsInputOperand(1))[NDimTypeIdx[i]];
}
for (auto &&[i, d] : llvm::enumerate(NDimTypeIdx))
if (i != 0)
config.innerMostNBlock *=
linalgOp.getShape(linalgOp.getDpsInputOperand(1))[d];
givenInnermostBlock.push_back(config.innerMostNBlock);
} else {
givenInnermostBlock.push_back(0);
}
if (KDimTypeIdx.size() > 1) {
config.innerMostKBlock = 1;
for (size_t i = 1UL; i < KDimTypeIdx.size(); i++) {
config.innerMostKBlock *=
linalgOp.getShape(linalgOp.getDpsInputOperand(1))[KDimTypeIdx[i]];
}
for (auto &&[i, d] : llvm::enumerate(KDimTypeIdx))
if (i != 0)
config.innerMostKBlock *=
linalgOp.getShape(linalgOp.getDpsInputOperand(1))[d];
givenInnermostBlock.push_back(config.innerMostKBlock);
} else {
givenInnermostBlock.push_back(0);
Expand All @@ -444,13 +430,11 @@ MatmulConfigAnalysis::MatmulConfigAnalysis(Operation *root) {
SmallVector<uint32_t> shape = {M, N, K};
std::vector<MatmulConfig> configCandidates =
prepareConfigCandidates(root, sysDesc, shape, givenInnermostBlock);
for (auto &&[fn, name, threshold] : costModelList) {
for (auto &&[fn, name, threshold] : costModelList)
configCandidates = filterConfigByCostModel(
configCandidates, linalgOp, shape, sysDesc, fn, 0.5, threshold);
}
if (!configCandidates.empty()) {
if (!configCandidates.empty())
config = configCandidates[0];
}
}

LLVM_DEBUG(llvm::dbgs()
Expand Down
Loading

0 comments on commit 21b4dfe

Please sign in to comment.