Skip to content

Commit ba82f37

Browse files
committed
more cleanup
Signed-off-by: Yuan Yao <[email protected]>
1 parent f4ca510 commit ba82f37

File tree

10 files changed

+16
-6
lines changed

10 files changed

+16
-6
lines changed

docs/IR.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -421,7 +421,7 @@ It is common to represent a tensor as a nested list. This generally works fine,
421421

422422
|Group|Types|Description|
423423
|---|---|---|
424-
Floating Point Types|float16, float32, float64, bfloat16, float8e4m3fn, float8e5m2, float8e4m3fnuz, float8e5m2fnuz|Values adhering to the IEEE 754-2008 standard representation of floating-point data or defined in papers [FP8 Formats for Deep Learning](https://arxiv.org/abs/2209.05433) and [8-bit Numerical Formats for Deep Neural Networks](https://arxiv.org/abs/2206.02915)
424+
Floating Point Types|float16, float32, float64, bfloat16, float8e4m3fn, float8e5m2, float8e4m3fnuz, float8e5m2fnuz, float4e2m1|Values adhering to the IEEE 754-2008 standard representation of floating-point data or defined in papers [FP8 Formats for Deep Learning](https://arxiv.org/abs/2209.05433) and [8-bit Numerical Formats for Deep Neural Networks](https://arxiv.org/abs/2206.02915)
425425
Signed Integer Types|int4, int8, int16, int32, int64|Signed integers are supported for 4-64 bit widths.
426426
Unsigned Integer Types|uint4, uint8, uint16, uint32, uint64|Unsigned integers are supported for 4-64 bit widths.
427427
Complex Types|complex64, complex128|A complex number with either 32- or 64-bit real and imaginary parts.

docs/docsgen/source/api/numpy_helper.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
.. autofunction:: onnx.numpy_helper.to_array
3434
```
3535

36-
As numpy does not support all the types defined in ONNX (float 8 types, blofat16, int4, uint4),
36+
As numpy does not support all the types defined in ONNX (float 8 types, blofat16, int4, uint4, float4e2m1),
3737
these two functions use a custom dtype defined in :mod:`onnx._custom_element_types`.
3838

3939
## sequence

onnx/common/ir_pb_converter.cc

+2-1
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@ Tensor tensorProtoToTensor(const ONNX_NAMESPACE::TensorProto& tp) {
5151
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN:
5252
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FNUZ:
5353
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2:
54-
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2FNUZ: {
54+
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2FNUZ:
55+
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT4E2M1: {
5556
ret.int32s().reserve(tp.int32_data_size());
5657
for (int i = 0; i < tp.int32_data_size(); i++) {
5758
ret.int32s().push_back(tp.int32_data(i));

onnx/defs/parser.cc

+1
Original file line numberDiff line numberDiff line change
@@ -473,6 +473,7 @@ Status OnnxParser::Parse(TensorProto& tensorProto, const TypeProto& tensorTypePr
473473
case TensorProto::DataType::TensorProto_DataType_FLOAT8E5M2:
474474
case TensorProto::DataType::TensorProto_DataType_FLOAT8E5M2FNUZ:
475475
case TensorProto::DataType::TensorProto_DataType_BOOL:
476+
case TensorProto::DataType::TensorProto_DataType_FLOAT4E2M1:
476477
PARSE_TOKEN(intval);
477478
// TODO: check values are in the correct range.
478479
tensorProto.add_int32_data(intval);

onnx/defs/parser.h

+1
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ class PrimitiveTypeNameMap : public StringIntMap<PrimitiveTypeNameMap> {
9696
map_["float8e5m2fnuz"] = TensorProto_DataType_FLOAT8E5M2FNUZ;
9797
map_["uint4"] = TensorProto_DataType_UINT4;
9898
map_["int4"] = TensorProto_DataType_INT4;
99+
map_["float4e2m1"] = TensorProto_DataType_FLOAT4E2M1;
99100
}
100101

101102
static bool IsTypeName(const std::string& dtype) {

onnx/numpy_helper.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def evaluate_float4e2m1_from_bits(x: np.uint8) -> np.float32:
229229
x: a uint8 element representing a float4e2m1 (using the 4 LSB)
230230
231231
Returns:
232-
Packed array with size `ceil(farray.size/2)` (single dimension).
232+
A float32 element representing the value of the float4e2m1 input.
233233
"""
234234
# x is stored in 4 LSB of int
235235
S = -1 if bool(x & 0x08) else 1
@@ -619,10 +619,13 @@ def from_array(tensor: np.ndarray, name: str | None = None) -> TensorProto:
619619
elif dt == custom_np_types.uint4 and dt.descr[0][0] == "uint4":
620620
to = TensorProto.UINT4
621621
dt_to = np.uint8 # type: ignore[assignment]
622+
elif dt == custom_np_types.float4e2m1 and dt.descr[0][0] == "float4e2m1":
623+
to = TensorProto.FLOAT4E2M1
624+
dt_to = np.uint8
622625
else:
623626
return _from_array(tensor, name)
624627

625-
if to in (TensorProto.UINT4, TensorProto.INT4):
628+
if to in (TensorProto.UINT4, TensorProto.INT4, TensorProto.FLOAT4E2M1):
626629
value = tensor.astype(dt_to).ravel()
627630
if value.size % 2 == 1:
628631
raise ValueError(

onnx/reference/custom_element_types.py

+2
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from onnx._custom_element_types import (
1414
bfloat16,
15+
float4e2m1,
1516
float8e4m3fn,
1617
float8e4m3fnuz,
1718
float8e5m2,
@@ -22,6 +23,7 @@
2223

2324
_supported_types = [
2425
(bfloat16, "bfloat16", "bfloat16"),
26+
(float4e2m1, "float4e2m1", "float4_e2m1"),
2527
(float8e4m3fn, "e4m3fn", "float8_e4m3fn"),
2628
(float8e4m3fnuz, "e4m3fnuz", "float8_e4m3fnuz"),
2729
(float8e5m2, "e5m2", "float8_e5m2"),

onnx/reference/ops/op_cast.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def cast_to(x, to, saturate): # noqa: PLR0911
140140

141141
if to == TensorProto.FLOAT4E2M1:
142142
xf = x.astype(np.float32)
143-
y = subbyte.float32_to_float4e2m1_unpacked(xf)
143+
y = subbyte.float32_to_float4e2m1_unpacked(xf).astype(float4e2m1)
144144
return y.reshape(x.shape)
145145

146146
if to == TensorProto.STRING:

onnx/test/numpy_helper_test.py

+1
Original file line numberDiff line numberDiff line change
@@ -648,6 +648,7 @@ def test_to_array_from_array(self, att):
648648
def test_to_array_from_array_subtype(self):
649649
self._to_array_from_array(onnx.TensorProto.INT4)
650650
self._to_array_from_array(onnx.TensorProto.UINT4)
651+
self._to_array_from_array(onnx.TensorProto.FLOAT4E2M1)
651652

652653
def test_to_array_from_array_string(self):
653654
self._to_array_from_array(onnx.TensorProto.STRING, False)

onnx/test/parser_test.py

+1
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,7 @@ def test_parse_various_float_values(self, test_literal, expect_exception):
292292
("uint16", TensorProto.UINT16),
293293
("uint32", TensorProto.UINT32),
294294
("uint64", TensorProto.UINT64),
295+
("float4e2m1", TensorProto.FLOAT4E2M1),
295296
]
296297
)
297298
def test_parse_graph_types(self, name, itype) -> None:

0 commit comments

Comments
 (0)