Skip to content

Commit 9ba8a86

Browse files
committed
vectorize cast function; type annotation changes; add to exclusion list
Signed-off-by: Yuan Yao <[email protected]>
1 parent 26c05d4 commit 9ba8a86

File tree

16 files changed

+63
-31
lines changed

16 files changed

+63
-31
lines changed

docs/Operators.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -3718,7 +3718,7 @@ for from_type, to_type in test_cases:
37183718
"-INF",
37193719
"-4",
37203720
"0.01",
3721-
"-1000000",
3721+
"-0.0",
37223722
],
37233723
dtype=np.float32,
37243724
)
@@ -3746,7 +3746,7 @@ for from_type, to_type in test_cases:
37463746
raise ValueError(
37473747
f"Conversion from {from_type} to {to_type} is not tested."
37483748
)
3749-
expected = evaluate_float4e2m1_from_bits(
3749+
expected = unpacked_float4e2m1_to_float32(
37503750
subbyte.float32_to_float4e2m1_unpacked(np_fp32)
37513751
)
37523752
output = make_tensor(

docs/TestCoverage.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -2587,7 +2587,7 @@ for from_type, to_type in test_cases:
25872587
"-INF",
25882588
"-4",
25892589
"0.01",
2590-
"-1000000",
2590+
"-0.0",
25912591
],
25922592
dtype=np.float32,
25932593
)
@@ -2615,7 +2615,7 @@ for from_type, to_type in test_cases:
26152615
raise ValueError(
26162616
f"Conversion from {from_type} to {to_type} is not tested."
26172617
)
2618-
expected = evaluate_float4e2m1_from_bits(
2618+
expected = unpacked_float4e2m1_to_float32(
26192619
subbyte.float32_to_float4e2m1_unpacked(np_fp32)
26202620
)
26212621
output = make_tensor(

onnx/backend/test/case/node/cast.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@
1919
tensor_dtype_to_field,
2020
)
2121
from onnx.numpy_helper import (
22-
evaluate_float4e2m1_from_bits,
2322
float8e4m3_to_float32,
2423
float8e5m2_to_float32,
24+
unpacked_float4e2m1_to_float32,
2525
)
2626

2727

@@ -303,7 +303,7 @@ def export() -> None:
303303
"-INF",
304304
"-4",
305305
"0.01",
306-
"-1000000",
306+
"-0.0",
307307
],
308308
dtype=np.float32,
309309
)
@@ -331,7 +331,7 @@ def export() -> None:
331331
raise ValueError(
332332
f"Conversion from {from_type} to {to_type} is not tested."
333333
)
334-
expected = evaluate_float4e2m1_from_bits(
334+
expected = unpacked_float4e2m1_to_float32(
335335
subbyte.float32_to_float4e2m1_unpacked(np_fp32)
336336
)
337337
output = make_tensor(
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11

2-
*'�o�h�x�������������������B��Bx
2+
*'�o�h�x�������������������B��Bx
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
*
2-
�w�By
2+
�w�By
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
*
2-
�w�Bx
2+
�w�Bx
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
*
2-
�w�Bx
2+
�w�Bx
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
*
2-
�w�By
2+
�w�By

onnx/numpy_helper.py

+15-11
Original file line numberDiff line numberDiff line change
@@ -221,22 +221,26 @@ def unpack_int4(
221221
return res
222222

223223

224-
def evaluate_float4e2m1_from_bits(x: np.ndarray[np.uint8]) -> np.ndarray[np.float32]:
225-
"""Evaluate the numerical value of a single float4e2m1 element represented as uint8
224+
def unpacked_float4e2m1_to_float32(x: np.ndarray) -> np.ndarray:
225+
"""Evaluate the numerical value of an array of unpacked float4e2m1 values (as uint8)
226226
See :ref:`onnx-detail-int4` for technical details.
227227
228228
Args:
229-
x: a uint8 element representing a float4e2m1 (using the 4 LSB)
229+
x: an array of uint8 elements representing a float4e2m1 (using the 4 LSB)
230230
231231
Returns:
232-
A float32 element representing the value of the float4e2m1 input.
232+
An array of float32 elements representing the values of the float4e2m1 input.
233233
"""
234234
# x is stored in 4 LSB of int
235-
S = np.where(np.bitwise_and(x, 0x08), -1, 1)
236-
M = x & 0x01
237-
E = (x & 0x06) >> 1
238-
239-
val = np.where(E==0, S*(M/2.0), S*(1.0+M/2.0) *2.0 **(E-1)) # denormalized, normalized
235+
sign = np.where(np.bitwise_and(x, 0x08), -1, 1)
236+
mantissa = x & 0x01
237+
exponent = (x & 0x06) >> 1
238+
239+
val = np.where(
240+
exponent == 0,
241+
sign * (mantissa / 2.0),
242+
sign * (1.0 + mantissa / 2.0) * 2.0 ** (exponent - 1),
243+
) # denormalized, normalized
240244
return val
241245

242246

@@ -258,8 +262,8 @@ def unpack_float4e2m1(
258262
res_high, res_low = subbyte.unpack_single_4bitx2(data.ravel(), False)
259263
res = np.empty((res_high.size + res_low.size,), dtype=np.float32)
260264

261-
res[0::2] = evaluate_float4e2m1_from_bits(res_high)
262-
res[1::2] = evaluate_float4e2m1_from_bits(res_low)
265+
res[0::2] = unpacked_float4e2m1_to_float32(res_high)
266+
res[1::2] = unpacked_float4e2m1_to_float32(res_low)
263267

264268
if (
265269
res.size == np.prod(dims) + 1

onnx/reference/ops/op_cast.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@
2424
)
2525
from onnx.numpy_helper import (
2626
bfloat16_to_float32,
27-
evaluate_float4e2m1_from_bits,
2827
float8e4m3_to_float32,
2928
float8e5m2_to_float32,
29+
unpacked_float4e2m1_to_float32,
3030
)
3131
from onnx.onnx_pb import TensorProto
3232
from onnx.reference.op_run import OpRun
@@ -131,7 +131,7 @@ def cast_to(x, to, saturate): # noqa: PLR0911
131131
if x.dtype == float4e2m1 and x.dtype.descr[0][0] == "float4e2m1":
132132
if to == TensorProto.FLOAT4E2M1:
133133
return x
134-
res = evaluate_float4e2m1_from_bits(x)
134+
res = unpacked_float4e2m1_to_float32(x)
135135
if to == TensorProto.FLOAT:
136136
return res.astype(np.float32)
137137
elif to == TensorProto.FLOAT16:

onnx/reference/ops/op_dequantize_linear.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@
1717
)
1818
from onnx.helper import np_dtype_to_tensor_dtype
1919
from onnx.numpy_helper import (
20-
evaluate_float4e2m1_from_bits,
2120
float8e4m3_to_float32,
2221
float8e5m2_to_float32,
22+
unpacked_float4e2m1_to_float32,
2323
)
2424
from onnx.reference.op_run import OpRun
2525
from onnx.reference.ops.op_quantize_linear import reshape_input
@@ -93,7 +93,7 @@ def _run(
9393
elif x_type == TensorProto.FLOAT8E5M2FNUZ:
9494
dx = float8e5m2_to_float32(x, fn=True, uz=True)
9595
elif x_type == TensorProto.FLOAT4E2M1:
96-
dx = evaluate_float4e2m1_from_bits(x)
96+
dx = unpacked_float4e2m1_to_float32(x)
9797
else:
9898
dx = x.astype(np.float32)
9999
y = dx * reshape_input(x_scale, x.shape, axis, block_size)

onnx/subbyte.py

+28-4
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def unpack_single_4bitx2(
7272
return (x_low.astype(dtype), x_high.astype(dtype))
7373

7474

75-
def float32_to_float4e2m1_unpacked(x: np.ndarray | np.dtype) -> np.ndarray:
75+
def float32_to_float4e2m1_unpacked_slow(x: np.ndarray | np.dtype) -> np.ndarray:
7676
"""Cast float32 to float4e2m1 (without packing).
7777
7878
Args:
@@ -85,7 +85,7 @@ def float32_to_float4e2m1_unpacked(x: np.ndarray | np.dtype) -> np.ndarray:
8585
def float32_to_float4e2m1(value):
8686
if np.isnan(value):
8787
return 0x7
88-
s = 0x0 if value >= 0 else 0x8
88+
s = 0x8 if np.signbit(value) else 0x0
8989
magnitude = np.abs(value)
9090
if np.isinf(magnitude):
9191
ret = 0x7
@@ -116,14 +116,38 @@ def float32_to_float4e2m1(value):
116116
return y.astype(np.uint8) # type: ignore[no-any-return]
117117

118118

119-
def float32x2_to_float4e2m1x2(val_low: np.dtype, val_high: np.dtype) -> np.ndarray:
119+
def float32_to_float4e2m1_unpacked(values: np.ndarray) -> np.ndarray:
120+
"""Cast float32 to float4e2m1 (without packing).
121+
122+
Args:
123+
values: element or array to be converted
124+
125+
Returns:
126+
An ndarray with unpacked float4e2m1 elements (as uint8)
127+
"""
128+
sign = np.where(np.signbit(values), 0x8, 0x0).astype(np.uint8)
129+
magnitude = np.abs(values)
130+
res = np.zeros(values.shape, dtype=np.uint8)
131+
res[(magnitude > 0.25) & (magnitude < 0.75)] = 0x1
132+
res[(magnitude >= 0.75) & (magnitude <= 1.25)] = 0x2
133+
res[(magnitude > 1.25) & (magnitude < 1.75)] = 0x3
134+
res[(magnitude >= 1.75) & (magnitude <= 2.5)] = 0x4
135+
res[(magnitude > 2.5) & (magnitude < 3.5)] = 0x5
136+
res[(magnitude >= 3.5) & (magnitude <= 5.0)] = 0x6
137+
res[magnitude > 5.0] = 0x7
138+
res |= sign
139+
res[np.isnan(values)] = 0x7
140+
return res
141+
142+
143+
def float32x2_to_float4e2m1x2(val_low: np.ndarray, val_high: np.ndarray) -> np.ndarray:
120144
"""Cast two elements to float4e2m1 and pack to a single byte
121145
Args:
122146
val_low: element to be packed in the 4 LSB
123147
val_high: element to be packed in the 4 MSB
124148
125149
Returns:
126-
An ndarray with a single uint8 element, containing both float4e2m1 elements
150+
An ndarray with uint8 elements, containing both float4e2m1 elements
127151
"""
128152
i8_high = float32_to_float4e2m1_unpacked(val_high)
129153
i8_low = float32_to_float4e2m1_unpacked(val_low)

onnx/test/test_backend_onnxruntime.py

+4
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,10 @@ def run_node(cls, node, inputs, device=None, outputs_info=None, **kwargs):
164164
"|test_cast_UINT4_to_FLOAT16" # No corresponding Numpy type for Tensor Type.
165165
"|test_cast_INT4_to_FLOAT16" # No corresponding Numpy type for Tensor Type.
166166
"|test_maxpool_2d_ceil_output_size_reduce_by_one" # TODO: remove after https://github.com/microsoft/onnxruntime/pull/18377 in Ort release.
167+
"|test_quantizeLinear_float4e2m1" # No corresponding Numpy type for Tensor Type.
168+
"|test_dequantizeLinear_float4e2m1" # No corresponding Numpy type for Tensor Type.
169+
"|cast_float4e2m1" # No corresponding Numpy type for Tensor Type.
170+
"|to_float4e2m1" # No corresponding Numpy type for Tensor Type.
167171
")"
168172
)
169173

0 commit comments

Comments
 (0)