Skip to content

Commit 1242a74

Browse files
committed
fix comments
1 parent d8e968f commit 1242a74

File tree

5 files changed

+215
-39
lines changed

5 files changed

+215
-39
lines changed
+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
//===-- NumericUtils.h - numeric utilities ----------------------*- C++ -*-===//
2+
//
3+
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef GC_TRANSFORMS_UTILS_NEMURICUTILS_H
10+
#define GC_TRANSFORMS_UTILS_NEMURICUTILS_H
11+
#include "mlir/IR/TypeUtilities.h"
12+
#include "mlir/IR/Types.h"
13+
#include <limits>
14+
#include <stdint.h>
15+
#include <variant>
16+
17+
namespace mlir {
18+
namespace gc {
19+
20+
union Float32Bits {
21+
uint32_t u;
22+
float f;
23+
};
24+
uint16_t float2half(float floatValue);
25+
float half2float(uint16_t halfValue);
26+
uint16_t float2bfloat(float floatValue);
27+
float bfloat2float(uint16_t bfloatBits);
28+
std::variant<float, int64_t> numeric_limits_minimum(Type type);
29+
std::variant<float, int64_t> numericLimitsMaximum(Type type);
30+
31+
} // namespace gc
32+
} // namespace mlir
33+
34+
#endif

include/gc/Transforms/Utils/VectorUtils.h

+2-11
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//===-- VectorUtils.h - vector fusion analysis ------------------*- C++ -*-===//
1+
//===-- VectorUtils.h - vector utilities ------------------------*- C++ -*-===//
22
//
33
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
@@ -8,6 +8,7 @@
88

99
#ifndef GC_TRANSFORMS_UTILS_VECTORUTILS_H
1010
#define GC_TRANSFORMS_UTILS_VECTORUTILS_H
11+
#include "gc/Transforms/Utils/NumericUtils.h"
1112
#include "mlir/Dialect/Affine/IR/AffineOps.h"
1213
#include "mlir/Dialect/Func/IR/FuncOps.h"
1314
#include "mlir/Dialect/Tensor/IR/Tensor.h"
@@ -96,16 +97,6 @@ int getNearestVectorStep(const int step);
9697
/// prev-op, may need to use result vectortype
9798
/// default will return the opeation result type
9899
mlir::FailureOr<VectorType> getOperationMaxVectorType(Operation *op);
99-
union Float32Bits {
100-
uint32_t u;
101-
float f;
102-
};
103-
uint16_t float2half(float floatValue);
104-
float half2float(uint16_t halfValue);
105-
uint16_t float2bfloat(float floatValue);
106-
float bfloat2float(uint16_t bfloatBits);
107-
std::variant<float, int64_t> numeric_limits_minimum(Type type);
108-
std::variant<float, int64_t> numericLimitsMaximum(Type type);
109100

110101
template <typename T = float>
111102
T getInitValForReduce(vector::CombiningKind kind, Type t) {

lib/gc/Transforms/Utils/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ gc_add_mlir_library(GcUtilsIR
33
StructuredOpMatcher.cpp
44
ValueUtils.cpp
55
VectorUtils.cpp
6+
NumericUtils.cpp
67

78
DEPENDS
89
MLIRLinalgDialect
+164
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
//===- NumericUtils.cpp - numeric utilities ---------------------*- C++ -*-===//
2+
//
3+
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
#include "gc/Transforms/Utils/NumericUtils.h"
9+
10+
namespace mlir {
11+
namespace gc {
12+
13+
const uint32_t kF32MantiBits = 23;
14+
const uint32_t kF32HalfMantiBitDiff = 13;
15+
const uint32_t kF32HalfBitDiff = 16;
16+
const Float32Bits kF32Magic = {113 << kF32MantiBits};
17+
const uint32_t kF32HalfExpAdjust = (127 - 15) << kF32MantiBits;
18+
const uint32_t kF32BfMantiBitDiff = 16;
19+
20+
/// Constructs the 16 bit representation for a half precision value from a float
21+
/// value. This implementation is adapted from Eigen.
22+
uint16_t float2half(float floatValue) {
23+
const Float32Bits inf = {255 << kF32MantiBits};
24+
const Float32Bits f16max = {(127 + 16) << kF32MantiBits};
25+
const Float32Bits denormMagic = {((127 - 15) + (kF32MantiBits - 10) + 1)
26+
<< kF32MantiBits};
27+
uint32_t signMask = 0x80000000u;
28+
uint16_t halfValue = static_cast<uint16_t>(0x0u);
29+
Float32Bits f;
30+
f.f = floatValue;
31+
uint32_t sign = f.u & signMask;
32+
f.u ^= sign;
33+
34+
if (f.u >= f16max.u) {
35+
const uint32_t halfQnan = 0x7e00;
36+
const uint32_t halfInf = 0x7c00;
37+
// Inf or NaN (all exponent bits set).
38+
halfValue = (f.u > inf.u) ? halfQnan : halfInf; // NaN->qNaN and Inf->Inf
39+
} else {
40+
// (De)normalized number or zero.
41+
if (f.u < kF32Magic.u) {
42+
// The resulting FP16 is subnormal or zero.
43+
//
44+
// Use a magic value to align our 10 mantissa bits at the bottom of the
45+
// float. As long as FP addition is round-to-nearest-even this works.
46+
f.f += denormMagic.f;
47+
48+
halfValue = static_cast<uint16_t>(f.u - denormMagic.u);
49+
} else {
50+
uint32_t mantOdd =
51+
(f.u >> kF32HalfMantiBitDiff) & 1; // Resulting mantissa is odd.
52+
53+
// Update exponent, rounding bias part 1. The following expressions are
54+
// equivalent to `f.u += ((unsigned int)(15 - 127) << kF32MantiBits) +
55+
// 0xfff`, but without arithmetic overflow.
56+
f.u += 0xc8000fffU;
57+
// Rounding bias part 2.
58+
f.u += mantOdd;
59+
halfValue = static_cast<uint16_t>(f.u >> kF32HalfMantiBitDiff);
60+
}
61+
}
62+
63+
halfValue |= static_cast<uint16_t>(sign >> kF32HalfBitDiff);
64+
return halfValue;
65+
}
66+
67+
/// Converts the 16 bit representation of a half precision value to a float
68+
/// value. This implementation is adapted from Eigen.
69+
float half2float(uint16_t halfValue) {
70+
const uint32_t shiftedExp =
71+
0x7c00 << kF32HalfMantiBitDiff; // Exponent mask after shift.
72+
73+
// Initialize the float representation with the exponent/mantissa bits.
74+
Float32Bits f = {
75+
static_cast<uint32_t>((halfValue & 0x7fff) << kF32HalfMantiBitDiff)};
76+
const uint32_t exp = shiftedExp & f.u;
77+
f.u += kF32HalfExpAdjust; // Adjust the exponent
78+
79+
// Handle exponent special cases.
80+
if (exp == shiftedExp) {
81+
// Inf/NaN
82+
f.u += kF32HalfExpAdjust;
83+
} else if (exp == 0) {
84+
// Zero/Denormal?
85+
f.u += 1 << kF32MantiBits;
86+
f.f -= kF32Magic.f;
87+
}
88+
89+
f.u |= (halfValue & 0x8000) << kF32HalfBitDiff; // Sign bit.
90+
return f.f;
91+
}
92+
93+
// Constructs the 16 bit representation for a bfloat value from a float value.
94+
// This implementation is adapted from Eigen.
95+
uint16_t float2bfloat(float floatValue) {
96+
if (std::isnan(floatValue))
97+
return std::signbit(floatValue) ? 0xFFC0 : 0x7FC0;
98+
99+
Float32Bits floatBits;
100+
floatBits.f = floatValue;
101+
uint16_t bfloatBits;
102+
103+
// Least significant bit of resulting bfloat.
104+
uint32_t lsb = (floatBits.u >> kF32BfMantiBitDiff) & 1;
105+
uint32_t roundingBias = 0x7fff + lsb;
106+
floatBits.u += roundingBias;
107+
bfloatBits = static_cast<uint16_t>(floatBits.u >> kF32BfMantiBitDiff);
108+
return bfloatBits;
109+
}
110+
111+
// Converts the 16 bit representation of a bfloat value to a float value. This
112+
// implementation is adapted from Eigen.
113+
float bfloat2float(uint16_t bfloatBits) {
114+
Float32Bits floatBits;
115+
floatBits.u = static_cast<uint32_t>(bfloatBits) << kF32BfMantiBitDiff;
116+
return floatBits.f;
117+
}
118+
119+
std::variant<float, int64_t> numeric_limits_minimum(Type type) {
120+
Type t1 = getElementTypeOrSelf(type);
121+
if (t1.isF32()) {
122+
return -std::numeric_limits<float>::infinity();
123+
} else if (t1.isBF16()) {
124+
return bfloat2float(float2bfloat(-std::numeric_limits<float>::infinity()));
125+
} else if (t1.isF16()) {
126+
return (float)half2float(
127+
float2half(-std::numeric_limits<float>::infinity()));
128+
} else if (t1.isSignedInteger(8)) {
129+
return int64_t(-128);
130+
} else if (t1.isSignedInteger(32)) {
131+
return int64_t(std::numeric_limits<int32_t>::min());
132+
} else if (t1.isSignlessInteger(8) or t1.isSignlessInteger(32)) {
133+
return int64_t(0);
134+
} else {
135+
llvm_unreachable("unsupported data type");
136+
return (int64_t)0;
137+
}
138+
}
139+
140+
std::variant<float, int64_t> numericLimitsMaximum(Type type) {
141+
Type t1 = getElementTypeOrSelf(type);
142+
if (t1.isF32()) {
143+
return std::numeric_limits<float>::infinity();
144+
} else if (t1.isBF16()) {
145+
return bfloat2float(float2bfloat(std::numeric_limits<float>::infinity()));
146+
} else if (t1.isF16()) {
147+
return (float)half2float(
148+
float2half(std::numeric_limits<float>::infinity()));
149+
} else if (t1.isSignedInteger(8)) {
150+
return int64_t(127);
151+
} else if (t1.isSignedInteger(32)) {
152+
return int64_t(std::numeric_limits<int32_t>::max());
153+
} else if (t1.isSignlessInteger(8)) {
154+
return int64_t(255);
155+
} else if (t1.isSignedInteger(32)) {
156+
return int64_t(std::numeric_limits<uint32_t>::max());
157+
} else {
158+
llvm_unreachable("unsupported data type");
159+
return (int64_t)0;
160+
}
161+
}
162+
163+
} // namespace gc
164+
} // namespace mlir

lib/gc/Transforms/Utils/VectorUtils.cpp

+14-28
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
1-
//===- VectorUtils.cpp - analysis vector ops --------------------*- C++ -*-===//
1+
//===- VectorUtils.cpp - vector utilities -----------------------*- C++ -*-===//
22
//
33
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66
//
77
//===----------------------------------------------------------------------===//
88
#include "gc/Transforms/Utils/VectorUtils.h"
9+
#include "mlir/IR/Value.h"
910
#include "mlir/Support/LLVM.h"
10-
1111
namespace mlir {
1212
namespace gc {
1313

@@ -37,13 +37,10 @@ OPPRIORITY operator++(OPPRIORITY &c) {
3737
LogicalResult moveFront(Operation *op, IRRewriter &rewriter) {
3838
Operation *backOperation = nullptr;
3939
// check all the operand is block argument
40-
bool allBlockArgs = true;
41-
for (auto operand : op->getOperands()) {
42-
if (!isa<BlockArgument>(operand)) {
43-
allBlockArgs = false;
44-
break;
45-
}
46-
}
40+
bool allBlockArgs = llvm::all_of(op->getOperands(), [](Value operand) {
41+
return isa<BlockArgument>(operand);
42+
});
43+
4744
if (allBlockArgs) {
4845
moveOpBeginingOfBlock(op, rewriter);
4946
return success();
@@ -153,31 +150,20 @@ void getOperationPriority(
153150
// get the position of each operation
154151
func->walk<WalkOrder::PreOrder>([&](Operation *op) {
155152
TypeSwitch<Operation *, void>(op)
156-
.Case<affine::AffineApplyOp>([&](affine::AffineApplyOp affineOp) {
153+
.Case<affine::AffineApplyOp>([&](auto op) {
157154
candidateOps.push(std::make_pair(op, OPPRIORITY::FIRST));
158155
return;
159156
})
160-
.Case<tensor::ExtractSliceOp>([&](tensor::ExtractSliceOp extractOp) {
161-
candidateOps.push(std::make_pair(op, OPPRIORITY::SECOND));
162-
return;
163-
})
164-
.Case<tensor::EmptyOp>([&](tensor::EmptyOp emptyOp) {
165-
candidateOps.push(std::make_pair(op, OPPRIORITY::SECOND));
166-
return;
167-
})
168-
.Case<tensor::InsertSliceOp>([&](tensor::InsertSliceOp insertOp) {
169-
candidateOps.push(std::make_pair(op, OPPRIORITY::SECOND));
170-
return;
171-
})
172-
.Case<vector::TransferReadOp>([&](vector::TransferReadOp readOp) {
173-
candidateOps.push(std::make_pair(op, OPPRIORITY::LAST));
174-
return;
175-
})
176-
.Case<vector::TransferWriteOp>([&](vector::TransferWriteOp writeOp) {
157+
.Case<tensor::EmptyOp, tensor::InsertSliceOp, tensor::ExtractSliceOp>(
158+
[&](auto op) {
159+
candidateOps.push(std::make_pair(op, OPPRIORITY::SECOND));
160+
return;
161+
})
162+
.Case<vector::TransferWriteOp, vector::TransferReadOp>([&](auto op) {
177163
candidateOps.push(std::make_pair(op, OPPRIORITY::LAST));
178164
return;
179165
})
180-
.Case<vector::BroadcastOp>([&](vector::BroadcastOp bcOp) {
166+
.Case<vector::BroadcastOp>([&](auto op) {
181167
candidateOps.push(std::make_pair(op, OPPRIORITY::THIRD));
182168
return;
183169
})

0 commit comments

Comments
 (0)