Skip to content

Commit 6cadff8

Browse files
authoredAug 6, 2024··
Centralize target description query through DLTI and add verifier pass (#210)
* Add target description query and verifier pass * replace the sysDesc in fusion pass
1 parent 3be8dec commit 6cadff8

15 files changed

+421
-82
lines changed
 
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
//===-- TargetDescriptionAnalysis.h - target description class --*- C++ -*-===//
2+
//
3+
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_ANALYSIS_TARGETDESCRIPTIONANALYSIS_H
10+
#define MLIR_ANALYSIS_TARGETDESCRIPTIONANALYSIS_H
11+
12+
#include "gc/Dialect/Linalgx/LinalgxOps.h"
13+
#include "mlir/Dialect/DLTI/DLTI.h"
14+
#include "mlir/Dialect/Linalg/IR/Linalg.h"
15+
#include "mlir/Interfaces/DataLayoutInterfaces.h"
16+
#include "llvm/ADT/StringRef.h"
17+
18+
namespace mlir {
19+
namespace gc {
20+
21+
using namespace mlir;
22+
23+
enum DeviceType { CPU = 0 };
24+
25+
class TargetDescriptionAnalysisBase {
26+
public:
27+
TargetDescriptionAnalysisBase(Operation *op, DeviceType device)
28+
: ctx(op->getContext()), device(device),
29+
layout(isa<ModuleOp>(op) ? dyn_cast<ModuleOp>(op)
30+
: op->getParentOfType<ModuleOp>()),
31+
loc(op->getLoc()) {}
32+
33+
// get the device ID
34+
DeviceType getDevice() { return device; }
35+
36+
// get the MLIR context
37+
MLIRContext *getContext() { return ctx; }
38+
39+
// get the data layout
40+
DataLayout getLayout() { return layout; }
41+
42+
// get the property value by key
43+
std::optional<Attribute> getPropertyValue(StringRef key);
44+
45+
// get the location
46+
Location getLocation() { return loc; }
47+
48+
// check if the property exists
49+
bool hasProperty(StringRef key) { return getPropertyValue(key).has_value(); }
50+
51+
// emit warning if the property is not found
52+
template <typename T>
53+
void emitNotFoundWarning(Location loc, StringRef key, T value);
54+
55+
// the map from device type to device string
56+
static llvm::DenseMap<DeviceType, std::string> DeviceKeyMap;
57+
58+
private:
59+
MLIRContext *ctx;
60+
DeviceType device;
61+
DataLayout layout;
62+
Location loc;
63+
};
64+
65+
class CPUTargetDescriptionAnalysis : public TargetDescriptionAnalysisBase {
66+
public:
67+
static constexpr StringLiteral kL1CacheSize = "L1_cache_size_in_bytes";
68+
static constexpr StringLiteral kL2CacheSize = "L2_cache_size_in_bytes";
69+
static constexpr StringLiteral kL3CacheSize = "L3_cache_size_in_bytes";
70+
static constexpr StringLiteral kMaxVectorWidth = "max_vector_width";
71+
static constexpr StringLiteral kNumThreads = "num_threads";
72+
73+
// get runtime OMP_NUM_THREADS
74+
unsigned getNumThreads();
75+
76+
// get cache size by cacheLevel
77+
unsigned getCacheSize(uint8_t cacheLevel);
78+
79+
// get the maximum vector length in bits
80+
unsigned getMaxVectorWidth();
81+
82+
CPUTargetDescriptionAnalysis(Operation *op)
83+
: TargetDescriptionAnalysisBase(op, DeviceType::CPU) {}
84+
};
85+
86+
} // namespace gc
87+
} // namespace mlir
88+
89+
#endif

‎include/gc/Transforms/Passes.td

+26-25
Original file line numberDiff line numberDiff line change
@@ -19,42 +19,30 @@ def TileLinalgNamed : Pass<"tile-named-linalg", "func::FuncOp"> {
1919

2020
#ifdef GC_HAS_ONEDNN_DIALECT
2121
def ConvertOneDNNGraphToLinalg : Pass<"convert-onednn-graph-to-linalg"> {
22-
let summary = "Lower the operations from the oneDNN Graph dialect into Linalg";
23-
let description = [{
24-
Lowers the `onednn_graph` ops to `linalg` ops.
25-
}];
22+
let summary =
23+
"Lower the operations from the oneDNN Graph dialect into Linalg";
24+
let description = [{Lowers the `onednn_graph` ops to `linalg` ops.}];
2625
let dependentDialects = [
27-
"func::FuncDialect",
28-
"math::MathDialect",
29-
"arith::ArithDialect",
30-
"tensor::TensorDialect",
31-
"linalg::LinalgDialect",
32-
"linalgx::LinalgxDialect"
26+
"func::FuncDialect", "math::MathDialect", "arith::ArithDialect",
27+
"tensor::TensorDialect", "linalg::LinalgDialect", "linalgx::LinalgxDialect"
3328
];
3429
}
3530
#endif
3631

3732
#ifdef GC_USE_IMEX
3833
def LinalgToXeGPU : Pass<"linalg-to-xegpu", "func::FuncOp"> {
3934
let summary = "Convert linalg dialect to XeGPU dialect.";
40-
let description = [{
41-
Lower linalg ops to XeGPU dialect.
42-
}];
43-
let dependentDialects = ["linalg::LinalgDialect",
44-
"gpu::GPUDialect",
45-
"xegpu::XeGPUDialect",
46-
"scf::SCFDialect",
47-
"memref::MemRefDialect",
48-
"arith::ArithDialect",
49-
"math::MathDialect",
50-
"vector::VectorDialect"];
35+
let description = [{Lower linalg ops to XeGPU dialect.}];
36+
let dependentDialects = [
37+
"linalg::LinalgDialect", "gpu::GPUDialect", "xegpu::XeGPUDialect",
38+
"scf::SCFDialect", "memref::MemRefDialect", "arith::ArithDialect",
39+
"math::MathDialect", "vector::VectorDialect"
40+
];
5141
let options = [
5242
Option<"kTile", "k-tile", "int64_t",
53-
/*default=*/"32",
54-
"GEMM tile size for reduction dimension.">,
43+
/*default=*/"32", "GEMM tile size for reduction dimension.">,
5544
Option<"stages", "stages", "int64_t",
56-
/*default=*/"1",
57-
"Number of cooperative prefetch stages.">,
45+
/*default=*/"1", "Number of cooperative prefetch stages.">,
5846
ListOption<"dpasTile", "dpas-tile", "int64_t",
5947
"DPAS register block sizes MxNxK">,
6048
];
@@ -93,4 +81,17 @@ def IterativeTilingAndFusion : Pass<"iterative-tiling-and-fusion",
9381
];
9482
}
9583

84+
def VerifyTargetDescription : Pass<"verify-target-description", "ModuleOp"> {
85+
let summary = "Verify the target description from ModuleOp DLTI attribute.";
86+
let description = [{
87+
Verify the target description from ModuleOp DLTI attribute. Raise error for unexpected input(such as a negative number of num_threads), and raise warn for missing fields, and provide a default value(such as 32K for L1_cache_size).
88+
}];
89+
let dependentDialects = ["DLTIDialect"];
90+
let options = [
91+
Option<"device", "device", "std::string",
92+
/*default=*/"\"CPU\"",
93+
"The device to verify. Supported device: CPU, ">,
94+
];
95+
}
96+
9697
#endif // GC_DIALECT_GC_PASSES

‎lib/gc/Analysis/CMakeLists.txt

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
gc_set_mlir_link_components(MLIR_LINK_COMPONENTS
2+
MLIRIR
3+
MLIRSupport)
4+
5+
gc_add_mlir_library(GcAnalysis
6+
TargetDescriptionAnalysis.cpp
7+
8+
DEPENDS
9+
GraphCompilerPassIncGen
10+
11+
LINK_LIBS PUBLIC
12+
${mlir_dialect_libs}
13+
${MLIR_LINK_COMPONENTS}
14+
GcInterface
15+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
//===-- TargetDescriptionAnalysis.cpp - target description impl -*- C++ -*-===//
2+
//
3+
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "gc/Analysis/TargetDescriptionAnalysis.h"
10+
#include <limits>
11+
#include <llvm/Support/Debug.h>
12+
#include <regex>
13+
14+
namespace mlir {
15+
namespace gc {
16+
17+
#define DEBUG_TYPE "target-description-analysis"
18+
19+
llvm::DenseMap<DeviceType, std::string>
20+
TargetDescriptionAnalysisBase::DeviceKeyMap = {
21+
{CPU, "CPU"},
22+
};
23+
24+
template <typename T>
25+
void TargetDescriptionAnalysisBase::emitNotFoundWarning(Location loc,
26+
StringRef key,
27+
T value) {
28+
mlir::emitWarning(loc) << key << " not found, using default value " << value;
29+
}
30+
31+
static bool isIntegerNumber(const std::string &token) {
32+
return std::regex_match(token, std::regex(("(\\+|-)?[[:digit:]]+")));
33+
}
34+
35+
static int64_t getIntFromAttribute(Attribute attr) {
36+
if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
37+
if (intAttr.getType().isSignedInteger())
38+
return intAttr.getSInt();
39+
else if (intAttr.getType().isUnsignedInteger())
40+
return intAttr.getUInt();
41+
else
42+
return intAttr.getInt();
43+
} else if (auto strAttr = dyn_cast<StringAttr>(attr)) {
44+
std::string str = strAttr.getValue().str();
45+
if (isIntegerNumber(str))
46+
return std::stoll(str);
47+
}
48+
llvm_unreachable("Not an integer attribute or integer like string attribute");
49+
}
50+
51+
std::optional<Attribute>
52+
TargetDescriptionAnalysisBase::getPropertyValue(StringRef key) {
53+
return layout.getDevicePropertyValue(
54+
Builder(getContext())
55+
.getStringAttr(DeviceKeyMap[getDevice()] /* device ID*/),
56+
Builder(getContext()).getStringAttr(key));
57+
}
58+
59+
unsigned CPUTargetDescriptionAnalysis::getNumThreads() {
60+
static const unsigned defaultNumThreads = 1;
61+
std::optional<Attribute> numThreads = getPropertyValue(kNumThreads);
62+
63+
if (numThreads)
64+
return getIntFromAttribute(*numThreads);
65+
emitNotFoundWarning(getLocation(), kNumThreads, defaultNumThreads);
66+
return defaultNumThreads;
67+
}
68+
69+
unsigned CPUTargetDescriptionAnalysis::getCacheSize(uint8_t cacheLevel) {
70+
assert(cacheLevel > 0 && cacheLevel < 4 && "Invalid cache level");
71+
llvm::DenseMap<StringRef, unsigned> CPUTargetCacheSizeValueMap = {
72+
{CPUTargetDescriptionAnalysis::kL1CacheSize, 32 * 1024},
73+
{CPUTargetDescriptionAnalysis::kL2CacheSize, 1024 * 1024},
74+
{CPUTargetDescriptionAnalysis::kL3CacheSize, 32 * 1024 * 1024},
75+
};
76+
StringLiteral key = "";
77+
if (cacheLevel == 1)
78+
key = kL1CacheSize;
79+
else if (cacheLevel == 2)
80+
key = kL2CacheSize;
81+
else if (cacheLevel == 3)
82+
key = kL3CacheSize;
83+
84+
std::optional<Attribute> cacheSize = getPropertyValue(key);
85+
if (cacheSize)
86+
return getIntFromAttribute(*cacheSize);
87+
88+
emitNotFoundWarning(getLocation(), key, CPUTargetCacheSizeValueMap[key]);
89+
return CPUTargetCacheSizeValueMap[key];
90+
}
91+
92+
unsigned CPUTargetDescriptionAnalysis::getMaxVectorWidth() {
93+
static const unsigned defaultMaxVectorWidth = 512;
94+
std::optional<Attribute> maxVectorWidth = getPropertyValue(kMaxVectorWidth);
95+
if (maxVectorWidth)
96+
return getIntFromAttribute(*maxVectorWidth);
97+
emitNotFoundWarning(getLocation(), kMaxVectorWidth, defaultMaxVectorWidth);
98+
return defaultMaxVectorWidth;
99+
}
100+
101+
} // namespace gc
102+
} // namespace mlir

‎lib/gc/CAPI/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
set(GC_ALL_LIBS
22
${GC_ONEDNN_DIALECT_LIB_NAME}
33
GcPasses
4+
GcAnalysis
45
MLIRCPURuntimeTransforms)
56

67
if(GC_ENABLE_IMEX)

‎lib/gc/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
add_subdirectory(Analysis)
12
add_subdirectory(CAPI)
23
add_subdirectory(Dialect)
34
add_subdirectory(Transforms)

‎lib/gc/ExecutionEngine/Driver/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -39,4 +39,5 @@ gc_add_mlir_library(GcJitWrapper
3939
${dialect_libs}
4040
${conversion_libs}
4141
${GC_PASSES}
42+
GcAnalysis
4243
)

‎lib/gc/Transforms/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ gc_add_mlir_library(GcPasses
1515
TileNamed.cpp
1616
IterativeTilingAndFusion.cpp
1717
TilingUsingInterfaceX.cpp
18+
VerifyTargetDescription.cpp
1819

1920
DEPENDS
2021
GraphCompilerPassIncGen

‎lib/gc/Transforms/IterativeTilingAndFusion.cpp

+3-57
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9+
#include "gc/Analysis/TargetDescriptionAnalysis.h"
910
#include "gc/Transforms/Passes.h"
1011
#include "mlir/Analysis/TopologicalSortUtils.h"
1112
#include "mlir/Dialect/DLTI/Traits.h"
@@ -579,62 +580,6 @@ static LogicalResult isSelfTiledOp(Operation *targetOp) {
579580
return success(walkResult.wasInterrupted());
580581
}
581582

582-
struct SystemDesc {
583-
// get runtime OMP_NUM_THREADS
584-
uint32_t getNumThreads() {
585-
std::optional<Attribute> numThreads = layout.getDevicePropertyValue(
586-
Builder(ctx).getStringAttr("CPU" /* device ID*/),
587-
Builder(ctx).getStringAttr("num_threads"));
588-
if (numThreads && isa<IntegerAttr>(*numThreads)) {
589-
return dyn_cast<IntegerAttr>(*numThreads).getInt();
590-
}
591-
return 1;
592-
}
593-
// get cache size by cacheLevel
594-
size_t getCacheSize(uint8_t cacheLevel) {
595-
if (cacheLevel == 1) {
596-
std::optional<Attribute> cacheSize = layout.getDevicePropertyValue(
597-
Builder(ctx).getStringAttr("CPU" /* device ID*/),
598-
Builder(ctx).getStringAttr("L1_cache_size_in_bytes"));
599-
if (cacheSize && isa<IntegerAttr>(*cacheSize)) {
600-
return dyn_cast<IntegerAttr>(*cacheSize).getInt();
601-
}
602-
} else if (cacheLevel == 2) {
603-
std::optional<Attribute> cacheSize = layout.getDevicePropertyValue(
604-
Builder(ctx).getStringAttr("CPU" /* device ID*/),
605-
Builder(ctx).getStringAttr("L2_cache_size_in_bytes"));
606-
if (cacheSize && isa<IntegerAttr>(*cacheSize)) {
607-
return dyn_cast<IntegerAttr>(*cacheSize).getInt();
608-
}
609-
} else if (cacheLevel == 3) {
610-
std::optional<Attribute> cacheSize = layout.getDevicePropertyValue(
611-
Builder(ctx).getStringAttr("CPU" /* device ID*/),
612-
Builder(ctx).getStringAttr("L3_cache_size_in_bytes"));
613-
if (cacheSize && isa<IntegerAttr>(*cacheSize)) {
614-
return dyn_cast<IntegerAttr>(*cacheSize).getInt();
615-
}
616-
}
617-
return 0;
618-
}
619-
620-
// get the maximum vector length in bits
621-
size_t getMaxVectorLength() {
622-
std::optional<Attribute> maxVectorLength = layout.getDevicePropertyValue(
623-
Builder(ctx).getStringAttr("CPU" /* device ID*/),
624-
Builder(ctx).getStringAttr("max_vector_width"));
625-
if (maxVectorLength && isa<IntegerAttr>(*maxVectorLength)) {
626-
return dyn_cast<IntegerAttr>(*maxVectorLength).getInt();
627-
}
628-
return 512;
629-
}
630-
631-
SystemDesc(ModuleOp m) : layout(m), ctx(m->getContext()) {}
632-
633-
private:
634-
DataLayout layout;
635-
MLIRContext *ctx;
636-
};
637-
638583
using OpTileSizeMap = std::unordered_map<std::string, SmallVector<int64_t>>;
639584

640585
template <typename OpTy>
@@ -806,7 +751,8 @@ struct IterativeTilingAndFusion
806751
// Get funcOp
807752
func::FuncOp func = getOperation();
808753
// Get system descriptor
809-
SystemDesc sysDesc(func->getParentOfType<ModuleOp>());
754+
CPUTargetDescriptionAnalysis sysDesc =
755+
getAnalysis<CPUTargetDescriptionAnalysis>();
810756
// Flexible options to control which candidate slice would be selected from
811757
// the view of both validity and performance.
812758
CandidateSliceOptions sliceOptions;

‎lib/gc/Transforms/Pipeline.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,8 @@ void populateLLVMPasses(mlir::OpPassManager &pm) {
133133
}
134134

135135
void populateCPUPipeline(mlir::OpPassManager &pm) {
136+
// verify the target description attribute
137+
pm.addNestedPass<func::FuncOp>(createVerifyTargetDescription());
136138
// front-end, oneDNN graph dialect
137139
populateFrontendPasses(pm);
138140
// middle-end, LinalgX/Linalg/tensor dialects
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
//===-- VerifyTargetDescription.cpp - Verity target desc --------*- C++ -*-===//
2+
//
3+
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "gc/Analysis/TargetDescriptionAnalysis.h"
10+
#include "gc/Transforms/Passes.h"
11+
#include "mlir/Dialect/DLTI/DLTI.h"
12+
#include "mlir/IR/BuiltinOps.h"
13+
14+
#include "mlir/Pass/Pass.h"
15+
using namespace mlir;
16+
17+
namespace mlir {
18+
namespace gc {
19+
#define GEN_PASS_DEF_VERIFYTARGETDESCRIPTION
20+
#include "gc/Transforms/Passes.h.inc"
21+
22+
namespace {
23+
24+
static LogicalResult verifyCPUTargetDescription(RewriterBase &rewriter,
25+
Operation *op) {
26+
CPUTargetDescriptionAnalysis cpuTargetDesc(op);
27+
Location loc = op->getLoc();
28+
29+
// Check if the num_threads is existed and greater than 0
30+
if (cpuTargetDesc.getNumThreads() < 1) {
31+
mlir::emitError(loc)
32+
<< "num_threads must be a greater than 0 integer, but get "
33+
<< cpuTargetDesc.getNumThreads();
34+
return failure();
35+
}
36+
37+
// Check if the L1 cache size is existed and greater than 0
38+
if (cpuTargetDesc.getCacheSize(1) < 1) {
39+
mlir::emitError(loc)
40+
<< "L1_cache_size_in_bytes must be a greater than 0 integer, but get "
41+
<< cpuTargetDesc.getCacheSize(1);
42+
return failure();
43+
}
44+
45+
// Check if the L2 cache size is existed and greater than 0
46+
if (cpuTargetDesc.getCacheSize(2) < 1) {
47+
mlir::emitError(loc)
48+
<< "L2_cache_size_in_bytes must be a greater than 0 integer, but get "
49+
<< cpuTargetDesc.getCacheSize(2);
50+
return failure();
51+
}
52+
53+
// Check if the L3 cache size is existed and greater than 0
54+
if (cpuTargetDesc.getCacheSize(3) < 1) {
55+
mlir::emitError(loc)
56+
<< "L3_cache_size_in_bytes must be a greater than 0 integer, but get "
57+
<< cpuTargetDesc.getCacheSize(3);
58+
return failure();
59+
}
60+
61+
// Check if the max_vector_width is existed and greater than 0
62+
if (cpuTargetDesc.getMaxVectorWidth() < 1) {
63+
mlir::emitError(loc)
64+
<< "max_vector_width must be a greater than 0 integer, but get "
65+
<< cpuTargetDesc.getMaxVectorWidth();
66+
return failure();
67+
}
68+
return success();
69+
}
70+
71+
class VerifyTargetDescription
72+
: public impl::VerifyTargetDescriptionBase<VerifyTargetDescription> {
73+
using Base::Base;
74+
void runOnOperation() override {
75+
Operation *module = getOperation();
76+
MLIRContext *ctx = &getContext();
77+
IRRewriter rewriter(ctx);
78+
if (device == "CPU") {
79+
if (failed(verifyCPUTargetDescription(rewriter, module))) {
80+
mlir::emitError(module->getLoc())
81+
<< "Failed to verify the target description";
82+
signalPassFailure();
83+
}
84+
}
85+
}
86+
};
87+
88+
} // namespace
89+
} // namespace gc
90+
} // namespace mlir

‎src/gc-opt/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ target_link_libraries(gc-opt PRIVATE
4444
${conversion_libs}
4545
${MLIR_LINK_COMPONENTS}
4646
GcPasses
47+
GcAnalysis
4748
)
4849

4950
if(GC_ENABLE_IMEX)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
add_mlir_unittest(GCAnalysisTests
2+
TargetDescriptionAnalysisTest.cpp
3+
)
4+
target_link_libraries(GCAnalysisTests
5+
PRIVATE
6+
GcAnalysis
7+
GcJitWrapper)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
//===-- TargetDescriptionAnalysisTest.cpp -----------------------*- C++ -*-===//
2+
//
3+
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
#include "gc/Analysis/TargetDescriptionAnalysis.h"
9+
#include "gc/ExecutionEngine/Driver/Driver.h"
10+
#include "mlir/AsmParser/AsmParser.h"
11+
#include "mlir/ExecutionEngine/MemRefUtils.h"
12+
#include "mlir/IR/AsmState.h"
13+
#include "mlir/IR/BuiltinOps.h"
14+
#include "mlir/IR/MLIRContext.h"
15+
#include "mlir/Parser/Parser.h"
16+
#include "mlir/Pass/PassManager.h"
17+
#include "llvm/Support/ErrorOr.h"
18+
#include "llvm/Support/MemoryBuffer.h"
19+
#include "llvm/Support/SourceMgr.h"
20+
#include "llvm/Support/raw_ostream.h"
21+
#include "gtest/gtest.h"
22+
#include <memory>
23+
24+
using namespace mlir;
25+
26+
static const char code1[] = R"mlir(
27+
module attributes {
28+
dlti.target_system_spec = #dlti.target_system_spec<
29+
"CPU": #dlti.target_device_spec<
30+
#dlti.dl_entry<"L1_cache_size_in_bytes", 49152 : ui32>,
31+
#dlti.dl_entry<"L2_cache_size_in_bytes", 2097152 : ui64>,
32+
#dlti.dl_entry<"L3_cache_size_in_bytes", "110100480">,
33+
#dlti.dl_entry<"num_threads", 56 : i32>,
34+
#dlti.dl_entry<"max_vector_width", 512 : i64>>
35+
>} {}
36+
)mlir";
37+
38+
TEST(TargetDescriptionAnalysis, CPUNormal) {
39+
MLIRContext ctx{gc::initCompilerAndGetDialects()};
40+
std::unique_ptr<llvm::MemoryBuffer> ir_buffer =
41+
llvm::MemoryBuffer::getMemBuffer(code1);
42+
// Parse the input mlir.
43+
llvm::SourceMgr sourceMgr;
44+
sourceMgr.AddNewSourceBuffer(std::move(ir_buffer), llvm::SMLoc());
45+
mlir::OwningOpRef<mlir::ModuleOp> module =
46+
mlir::parseSourceFile<mlir::ModuleOp>(sourceMgr, &ctx);
47+
ASSERT_TRUE(module);
48+
auto CPUTagetDesc = gc::CPUTargetDescriptionAnalysis(module.get());
49+
ASSERT_EQ(CPUTagetDesc.getNumThreads(), 56);
50+
ASSERT_EQ(CPUTagetDesc.getCacheSize(1), 49152);
51+
ASSERT_EQ(CPUTagetDesc.getCacheSize(2), 2097152);
52+
ASSERT_EQ(CPUTagetDesc.getCacheSize(3), 110100480);
53+
ASSERT_EQ(CPUTagetDesc.getMaxVectorWidth(), 512);
54+
}
55+
56+
static const char code2[] = R"mlir(
57+
module attributes {
58+
dlti.target_system_spec = #dlti.target_system_spec<
59+
"CPU": #dlti.target_device_spec<
60+
#dlti.dl_entry<"L1_cache_size_in_bytes", 49152 : ui32>,
61+
#dlti.dl_entry<"L2_cache_size_in_bytes", 2097152 : ui32>>
62+
>} {}
63+
)mlir";
64+
65+
TEST(TargetDescriptionAnalysis, CPUMissingValue) {
66+
MLIRContext ctx{gc::initCompilerAndGetDialects()};
67+
std::unique_ptr<llvm::MemoryBuffer> ir_buffer =
68+
llvm::MemoryBuffer::getMemBuffer(code2);
69+
// Parse the input mlir.
70+
llvm::SourceMgr sourceMgr;
71+
sourceMgr.AddNewSourceBuffer(std::move(ir_buffer), llvm::SMLoc());
72+
mlir::OwningOpRef<mlir::ModuleOp> module =
73+
mlir::parseSourceFile<mlir::ModuleOp>(sourceMgr, &ctx);
74+
ASSERT_TRUE(module);
75+
auto CPUTagetDesc = gc::CPUTargetDescriptionAnalysis(module.get());
76+
ASSERT_EQ(CPUTagetDesc.getNumThreads(), 1);
77+
ASSERT_EQ(CPUTagetDesc.getCacheSize(1), 49152);
78+
ASSERT_EQ(CPUTagetDesc.getCacheSize(2), 2097152);
79+
ASSERT_EQ(CPUTagetDesc.getCacheSize(3), 1048576);
80+
ASSERT_EQ(CPUTagetDesc.getMaxVectorWidth(), 512);
81+
}

‎test/mlir/unittests/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ function(add_mlir_unittest test_dirname)
99
add_unittest(GCUnitTests ${test_dirname} ${ARGN})
1010
endfunction()
1111

12+
add_subdirectory(Analysis)
1213
add_subdirectory(Example)
1314
add_subdirectory(ExecutionEngine)
1415

0 commit comments

Comments
 (0)
Please sign in to comment.