1
+ // ===- NumericUtils.cpp - numeric utilities ---------------------*- C++ -*-===//
2
+ //
3
+ // This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4
+ // See https://llvm.org/LICENSE.txt for license information.
5
+ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
+ //
7
+ // ===----------------------------------------------------------------------===//
8
+ #include " gc/Transforms/Utils/NumericUtils.h"
9
+
10
+ namespace mlir {
11
+ namespace gc {
12
+
13
+ const uint32_t kF32MantiBits = 23 ;
14
+ const uint32_t kF32HalfMantiBitDiff = 13 ;
15
+ const uint32_t kF32HalfBitDiff = 16 ;
16
+ const Float32Bits kF32Magic = {113 << kF32MantiBits };
17
+ const uint32_t kF32HalfExpAdjust = (127 - 15 ) << kF32MantiBits ;
18
+ const uint32_t kF32BfMantiBitDiff = 16 ;
19
+
20
+ // / Constructs the 16 bit representation for a half precision value from a float
21
+ // / value. This implementation is adapted from Eigen.
22
+ uint16_t float2half (float floatValue) {
23
+ const Float32Bits inf = {255 << kF32MantiBits };
24
+ const Float32Bits f16max = {(127 + 16 ) << kF32MantiBits };
25
+ const Float32Bits denormMagic = {((127 - 15 ) + (kF32MantiBits - 10 ) + 1 )
26
+ << kF32MantiBits };
27
+ uint32_t signMask = 0x80000000u ;
28
+ uint16_t halfValue = static_cast <uint16_t >(0x0u );
29
+ Float32Bits f;
30
+ f.f = floatValue;
31
+ uint32_t sign = f.u & signMask;
32
+ f.u ^= sign;
33
+
34
+ if (f.u >= f16max.u ) {
35
+ const uint32_t halfQnan = 0x7e00 ;
36
+ const uint32_t halfInf = 0x7c00 ;
37
+ // Inf or NaN (all exponent bits set).
38
+ halfValue = (f.u > inf.u ) ? halfQnan : halfInf; // NaN->qNaN and Inf->Inf
39
+ } else {
40
+ // (De)normalized number or zero.
41
+ if (f.u < kF32Magic .u ) {
42
+ // The resulting FP16 is subnormal or zero.
43
+ //
44
+ // Use a magic value to align our 10 mantissa bits at the bottom of the
45
+ // float. As long as FP addition is round-to-nearest-even this works.
46
+ f.f += denormMagic.f ;
47
+
48
+ halfValue = static_cast <uint16_t >(f.u - denormMagic.u );
49
+ } else {
50
+ uint32_t mantOdd =
51
+ (f.u >> kF32HalfMantiBitDiff ) & 1 ; // Resulting mantissa is odd.
52
+
53
+ // Update exponent, rounding bias part 1. The following expressions are
54
+ // equivalent to `f.u += ((unsigned int)(15 - 127) << kF32MantiBits) +
55
+ // 0xfff`, but without arithmetic overflow.
56
+ f.u += 0xc8000fffU ;
57
+ // Rounding bias part 2.
58
+ f.u += mantOdd;
59
+ halfValue = static_cast <uint16_t >(f.u >> kF32HalfMantiBitDiff );
60
+ }
61
+ }
62
+
63
+ halfValue |= static_cast <uint16_t >(sign >> kF32HalfBitDiff );
64
+ return halfValue;
65
+ }
66
+
67
+ // / Converts the 16 bit representation of a half precision value to a float
68
+ // / value. This implementation is adapted from Eigen.
69
+ float half2float (uint16_t halfValue) {
70
+ const uint32_t shiftedExp =
71
+ 0x7c00 << kF32HalfMantiBitDiff ; // Exponent mask after shift.
72
+
73
+ // Initialize the float representation with the exponent/mantissa bits.
74
+ Float32Bits f = {
75
+ static_cast <uint32_t >((halfValue & 0x7fff ) << kF32HalfMantiBitDiff )};
76
+ const uint32_t exp = shiftedExp & f.u ;
77
+ f.u += kF32HalfExpAdjust ; // Adjust the exponent
78
+
79
+ // Handle exponent special cases.
80
+ if (exp == shiftedExp) {
81
+ // Inf/NaN
82
+ f.u += kF32HalfExpAdjust ;
83
+ } else if (exp == 0 ) {
84
+ // Zero/Denormal?
85
+ f.u += 1 << kF32MantiBits ;
86
+ f.f -= kF32Magic .f ;
87
+ }
88
+
89
+ f.u |= (halfValue & 0x8000 ) << kF32HalfBitDiff ; // Sign bit.
90
+ return f.f ;
91
+ }
92
+
93
+ // Constructs the 16 bit representation for a bfloat value from a float value.
94
+ // This implementation is adapted from Eigen.
95
+ uint16_t float2bfloat (float floatValue) {
96
+ if (std::isnan (floatValue))
97
+ return std::signbit (floatValue) ? 0xFFC0 : 0x7FC0 ;
98
+
99
+ Float32Bits floatBits;
100
+ floatBits.f = floatValue;
101
+ uint16_t bfloatBits;
102
+
103
+ // Least significant bit of resulting bfloat.
104
+ uint32_t lsb = (floatBits.u >> kF32BfMantiBitDiff ) & 1 ;
105
+ uint32_t roundingBias = 0x7fff + lsb;
106
+ floatBits.u += roundingBias;
107
+ bfloatBits = static_cast <uint16_t >(floatBits.u >> kF32BfMantiBitDiff );
108
+ return bfloatBits;
109
+ }
110
+
111
+ // Converts the 16 bit representation of a bfloat value to a float value. This
112
+ // implementation is adapted from Eigen.
113
+ float bfloat2float (uint16_t bfloatBits) {
114
+ Float32Bits floatBits;
115
+ floatBits.u = static_cast <uint32_t >(bfloatBits) << kF32BfMantiBitDiff ;
116
+ return floatBits.f ;
117
+ }
118
+
119
+ std::variant<float , int64_t > numeric_limits_minimum (Type type) {
120
+ Type t1 = getElementTypeOrSelf (type);
121
+ if (t1.isF32 ()) {
122
+ return -std::numeric_limits<float >::infinity ();
123
+ } else if (t1.isBF16 ()) {
124
+ return bfloat2float (float2bfloat (-std::numeric_limits<float >::infinity ()));
125
+ } else if (t1.isF16 ()) {
126
+ return (float )half2float (
127
+ float2half (-std::numeric_limits<float >::infinity ()));
128
+ } else if (t1.isSignedInteger (8 )) {
129
+ return int64_t (-128 );
130
+ } else if (t1.isSignedInteger (32 )) {
131
+ return int64_t (std::numeric_limits<int32_t >::min ());
132
+ } else if (t1.isSignlessInteger (8 ) or t1.isSignlessInteger (32 )) {
133
+ return int64_t (0 );
134
+ } else {
135
+ llvm_unreachable (" unsupported data type" );
136
+ return (int64_t )0 ;
137
+ }
138
+ }
139
+
140
+ std::variant<float , int64_t > numericLimitsMaximum (Type type) {
141
+ Type t1 = getElementTypeOrSelf (type);
142
+ if (t1.isF32 ()) {
143
+ return std::numeric_limits<float >::infinity ();
144
+ } else if (t1.isBF16 ()) {
145
+ return bfloat2float (float2bfloat (std::numeric_limits<float >::infinity ()));
146
+ } else if (t1.isF16 ()) {
147
+ return (float )half2float (
148
+ float2half (std::numeric_limits<float >::infinity ()));
149
+ } else if (t1.isSignedInteger (8 )) {
150
+ return int64_t (127 );
151
+ } else if (t1.isSignedInteger (32 )) {
152
+ return int64_t (std::numeric_limits<int32_t >::max ());
153
+ } else if (t1.isSignlessInteger (8 )) {
154
+ return int64_t (255 );
155
+ } else if (t1.isSignedInteger (32 )) {
156
+ return int64_t (std::numeric_limits<uint32_t >::max ());
157
+ } else {
158
+ llvm_unreachable (" unsupported data type" );
159
+ return (int64_t )0 ;
160
+ }
161
+ }
162
+
163
+ } // namespace gc
164
+ } // namespace mlir
0 commit comments