Skip to content

Commit 7598f1e

Browse files
committed
Add target description query and verifier pass
1 parent f5bde39 commit 7598f1e

File tree

13 files changed

+449
-25
lines changed

13 files changed

+449
-25
lines changed
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
//===-- TargetDescriptionAnalysis.h - target description class --*- C++ -*-===//
2+
//
3+
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_ANALYSIS_TARGETDESCRIPTIONANALYSIS_H
10+
#define MLIR_ANALYSIS_TARGETDESCRIPTIONANALYSIS_H
11+
12+
#include "gc/Dialect/Linalgx/LinalgxOps.h"
13+
#include "mlir/Dialect/DLTI/DLTI.h"
14+
#include "mlir/Dialect/Linalg/IR/Linalg.h"
15+
#include "mlir/Interfaces/DataLayoutInterfaces.h"
16+
#include "llvm/ADT/StringRef.h"
17+
18+
namespace mlir {
19+
namespace gc {
20+
21+
using namespace mlir;
22+
23+
enum DeviceType { CPU = 0 };
24+
25+
class TargetDescriptionAnalysisBase {
26+
public:
27+
TargetDescriptionAnalysisBase(Operation *op, DeviceType device)
28+
: ctx(op->getContext()), device(device),
29+
layout(isa<ModuleOp>(op) ? dyn_cast<ModuleOp>(op)
30+
: op->getParentOfType<ModuleOp>()),
31+
loc(op->getLoc()) {}
32+
33+
// get the device ID
34+
DeviceType getDevice() { return device; }
35+
36+
// get the MLIR context
37+
MLIRContext *getContext() { return ctx; }
38+
39+
// get the data layout
40+
DataLayout getLayout() { return layout; }
41+
42+
// get the property value by key
43+
std::optional<Attribute> getPropertyValue(StringRef key);
44+
45+
// get the location
46+
Location getLocation() { return loc; }
47+
48+
// check if the property exists
49+
bool hasProperty(StringRef key) { return getPropertyValue(key).has_value(); }
50+
51+
// emit warning if the property is not found
52+
template <typename T>
53+
void emitNotFoundWarning(Location loc, StringRef key, T value);
54+
55+
// the map from device type to device string
56+
static llvm::DenseMap<DeviceType, std::string> DeviceKeyMap;
57+
58+
private:
59+
MLIRContext *ctx;
60+
DeviceType device;
61+
DataLayout layout;
62+
Location loc;
63+
};
64+
65+
class CPUTargetDescriptionAnalysis : public TargetDescriptionAnalysisBase {
66+
public:
67+
static constexpr StringLiteral kL1CacheSize = "L1_cache_size_in_bytes";
68+
static constexpr StringLiteral kL2CacheSize = "L2_cache_size_in_bytes";
69+
static constexpr StringLiteral kL3CacheSize = "L3_cache_size_in_bytes";
70+
static constexpr StringLiteral kMaxVectorWidth = "max_vector_width";
71+
static constexpr StringLiteral kNumThreads = "num_threads";
72+
73+
// get runtime OMP_NUM_THREADS
74+
size_t getNumThreads();
75+
76+
// get cache size by cacheLevel
77+
size_t getCacheSize(uint8_t cacheLevel);
78+
79+
// get the maximum vector length in bits
80+
size_t getMaxVectorWidth();
81+
82+
// get the default value map(attr key, default value)
83+
static llvm::DenseMap<StringRef, int64_t> CPUTargetDeafultValueMap;
84+
85+
CPUTargetDescriptionAnalysis(Operation *op)
86+
: TargetDescriptionAnalysisBase(op, DeviceType::CPU) {}
87+
};
88+
89+
} // namespace gc
90+
} // namespace mlir
91+
92+
#endif

include/gc/Transforms/Passes.td

Lines changed: 26 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -19,46 +19,47 @@ def TileLinalgNamed : Pass<"tile-named-linalg", "func::FuncOp"> {
1919

2020
#ifdef GC_HAS_ONEDNN_DIALECT
2121
def ConvertOneDNNGraphToLinalg : Pass<"convert-onednn-graph-to-linalg"> {
22-
let summary = "Lower the operations from the oneDNN Graph dialect into Linalg";
23-
let description = [{
24-
Lowers the `onednn_graph` ops to `linalg` ops.
25-
}];
22+
let summary =
23+
"Lower the operations from the oneDNN Graph dialect into Linalg";
24+
let description = [{Lowers the `onednn_graph` ops to `linalg` ops.}];
2625
let dependentDialects = [
27-
"func::FuncDialect",
28-
"math::MathDialect",
29-
"arith::ArithDialect",
30-
"tensor::TensorDialect",
31-
"linalg::LinalgDialect",
32-
"linalgx::LinalgxDialect"
26+
"func::FuncDialect", "math::MathDialect", "arith::ArithDialect",
27+
"tensor::TensorDialect", "linalg::LinalgDialect", "linalgx::LinalgxDialect"
3328
];
3429
}
3530
#endif
3631

3732
#ifdef GC_USE_IMEX
3833
def LinalgToXeGPU : Pass<"linalg-to-xegpu", "func::FuncOp"> {
3934
let summary = "Convert linalg dialect to XeGPU dialect.";
40-
let description = [{
41-
Lower linalg ops to XeGPU dialect.
42-
}];
43-
let dependentDialects = ["linalg::LinalgDialect",
44-
"gpu::GPUDialect",
45-
"xegpu::XeGPUDialect",
46-
"scf::SCFDialect",
47-
"memref::MemRefDialect",
48-
"arith::ArithDialect",
49-
"math::MathDialect",
50-
"vector::VectorDialect"];
35+
let description = [{Lower linalg ops to XeGPU dialect.}];
36+
let dependentDialects = [
37+
"linalg::LinalgDialect", "gpu::GPUDialect", "xegpu::XeGPUDialect",
38+
"scf::SCFDialect", "memref::MemRefDialect", "arith::ArithDialect",
39+
"math::MathDialect", "vector::VectorDialect"
40+
];
5141
let options = [
5242
Option<"kTile", "k-tile", "int64_t",
53-
/*default=*/"32",
54-
"GEMM tile size for reduction dimension.">,
43+
/*default=*/"32", "GEMM tile size for reduction dimension.">,
5544
Option<"stages", "stages", "int64_t",
56-
/*default=*/"1",
57-
"Number of cooperative prefetch stages.">,
45+
/*default=*/"1", "Number of cooperative prefetch stages.">,
5846
ListOption<"dpasTile", "dpas-tile", "int64_t",
5947
"DPAS register block sizes MxNxK">,
6048
];
6149
}
6250
#endif
6351

52+
def VerifyTargetDescription : Pass<"verify-target-description", "ModuleOp"> {
53+
let summary = "Verify the target description from ModuleOp DLTI attribute.";
54+
let description = [{
55+
Verify the target description from ModuleOp DLTI attribute. Raise error for unexpected input(such as a negative number of num_threads), and raise warn for missing fields, and provide a default value(such as 32K for L1_cache_size).
56+
}];
57+
let dependentDialects = ["DLTIDialect"];
58+
let options = [
59+
Option<"device", "device", "std::string",
60+
/*default=*/"\"CPU\"",
61+
"The device to verify. Supported device: CPU, ">,
62+
];
63+
}
64+
6465
#endif // GC_DIALECT_GC_PASSES

lib/gc/Analysis/CMakeLists.txt

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
gc_set_mlir_link_components(MLIR_LINK_COMPONENTS
2+
MLIRIR
3+
MLIRSupport)
4+
5+
gc_add_mlir_library(GcAnalysis
6+
TargetDescriptionAnalysis.cpp
7+
8+
DEPENDS
9+
GraphCompilerPassIncGen
10+
11+
LINK_LIBS PUBLIC
12+
${mlir_dialect_libs}
13+
${MLIR_LINK_COMPONENTS}
14+
GcInterface
15+
)
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
//===-- TargetDescriptionAnalysis.cpp - target description impl -*- C++ -*-===//
2+
//
3+
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "gc/Analysis/TargetDescriptionAnalysis.h"
10+
#include <limits>
11+
#include <llvm/Support/Debug.h>
12+
#include <regex>
13+
14+
namespace mlir {
15+
namespace gc {
16+
17+
#define DEBUG_TYPE "target-description-analysis"
18+
19+
llvm::DenseMap<DeviceType, std::string>
20+
TargetDescriptionAnalysisBase::DeviceKeyMap = {
21+
{CPU, "CPU"},
22+
};
23+
24+
// default values for properties
25+
llvm::DenseMap<StringRef, int64_t>
26+
CPUTargetDescriptionAnalysis::CPUTargetDeafultValueMap = {
27+
{CPUTargetDescriptionAnalysis::kNumThreads, 1},
28+
{CPUTargetDescriptionAnalysis::kL1CacheSize, 32 * 1024},
29+
{CPUTargetDescriptionAnalysis::kL2CacheSize, 1024 * 1024},
30+
{CPUTargetDescriptionAnalysis::kL3CacheSize, 32 * 1024 * 1024},
31+
{CPUTargetDescriptionAnalysis::kMaxVectorWidth, 512},
32+
};
33+
34+
template <typename T>
35+
void TargetDescriptionAnalysisBase::emitNotFoundWarning(Location loc,
36+
StringRef key,
37+
T value) {
38+
mlir::emitWarning(loc) << key << " not found, using default value " << value;
39+
}
40+
41+
static bool isIntegerNumber(const std::string &token) {
42+
return std::regex_match(token, std::regex(("(\\+|-)?[[:digit:]]+")));
43+
}
44+
45+
static int64_t getIntFromAttribute(Attribute attr) {
46+
if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
47+
if (intAttr.getType().isSignedInteger())
48+
return intAttr.getSInt();
49+
else if (intAttr.getType().isUnsignedInteger())
50+
return intAttr.getUInt();
51+
else
52+
return intAttr.getInt();
53+
} else if (auto strAttr = dyn_cast<StringAttr>(attr)) {
54+
std::string str = strAttr.getValue().str();
55+
if (isIntegerNumber(str))
56+
return std::stoll(str);
57+
}
58+
llvm_unreachable("Not an integer attribute or integer like string attribute");
59+
}
60+
61+
std::optional<Attribute>
62+
TargetDescriptionAnalysisBase::getPropertyValue(StringRef key) {
63+
return layout.getDevicePropertyValue(
64+
Builder(getContext())
65+
.getStringAttr(DeviceKeyMap[getDevice()] /* device ID*/),
66+
Builder(getContext()).getStringAttr(key));
67+
}
68+
69+
size_t CPUTargetDescriptionAnalysis::getNumThreads() {
70+
std::optional<Attribute> numThreads = getPropertyValue(kNumThreads);
71+
72+
if (numThreads)
73+
return getIntFromAttribute(*numThreads);
74+
emitNotFoundWarning(getLocation(), kNumThreads,
75+
CPUTargetDeafultValueMap[kNumThreads]);
76+
return CPUTargetDeafultValueMap[kNumThreads];
77+
}
78+
79+
size_t CPUTargetDescriptionAnalysis::getCacheSize(uint8_t cacheLevel) {
80+
assert(cacheLevel > 0 && cacheLevel < 4 && "Invalid cache level");
81+
StringLiteral key = "";
82+
if (cacheLevel == 1)
83+
key = kL1CacheSize;
84+
else if (cacheLevel == 2)
85+
key = kL2CacheSize;
86+
else if (cacheLevel == 3)
87+
key = kL3CacheSize;
88+
89+
std::optional<Attribute> cacheSize = getPropertyValue(key);
90+
if (cacheSize)
91+
return getIntFromAttribute(*cacheSize);
92+
93+
emitNotFoundWarning(getLocation(), key, CPUTargetDeafultValueMap[key]);
94+
return CPUTargetDeafultValueMap[key];
95+
}
96+
97+
size_t CPUTargetDescriptionAnalysis::getMaxVectorWidth() {
98+
std::optional<Attribute> maxVectorWidth = getPropertyValue(kMaxVectorWidth);
99+
if (maxVectorWidth)
100+
return getIntFromAttribute(*maxVectorWidth);
101+
emitNotFoundWarning(getLocation(), kMaxVectorWidth,
102+
CPUTargetDeafultValueMap[kMaxVectorWidth]);
103+
return CPUTargetDeafultValueMap[kMaxVectorWidth];
104+
}
105+
106+
} // namespace gc
107+
} // namespace mlir

lib/gc/CAPI/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
set(GC_ALL_LIBS
22
${GC_ONEDNN_DIALECT_LIB_NAME}
33
GcPasses
4+
GcAnalysis
45
MLIRCPURuntimeTransforms)
56

67
if(GC_ENABLE_IMEX)

lib/gc/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
add_subdirectory(Analysis)
12
add_subdirectory(CAPI)
23
add_subdirectory(Dialect)
34
add_subdirectory(Transforms)

lib/gc/ExecutionEngine/Driver/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,4 +39,5 @@ gc_add_mlir_library(GcJitWrapper
3939
${dialect_libs}
4040
${conversion_libs}
4141
${GC_PASSES}
42+
GcAnalysis
4243
)

lib/gc/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ gc_add_mlir_library(GcPasses
1313
OneDNNGraphToLinalg.cpp
1414
Pipeline.cpp
1515
TileNamed.cpp
16+
VerifyTargetDescription.cpp
1617

1718
DEPENDS
1819
GraphCompilerPassIncGen

0 commit comments

Comments
 (0)