Skip to content

Commit 46dc2f7

Browse files
committed
update cast, castlike, Q/DQ
Signed-off-by: Yuan Yao <[email protected]>
1 parent c057d17 commit 46dc2f7

File tree

139 files changed

+1057
-48
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

139 files changed

+1057
-48
lines changed

docs/Changelog.md

+271-5

docs/Operators.md

+144-18

docs/TestCoverage.md

+124-2

onnx/backend/test/case/node/cast.py

+59-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,11 @@
1818
make_tensor,
1919
tensor_dtype_to_field,
2020
)
21-
from onnx.numpy_helper import float8e4m3_to_float32, float8e5m2_to_float32
21+
from onnx.numpy_helper import (
22+
float8e4m3_to_float32,
23+
float8e5m2_to_float32,
24+
unpacked_float4e2m1_to_float32,
25+
)
2226

2327

2428
class Cast(Base):
@@ -62,6 +66,10 @@ def export() -> None:
6266
("INT4", "FLOAT"),
6367
("INT4", "FLOAT16"),
6468
("INT4", "INT8"),
69+
("FLOAT4E2M1", "FLOAT"),
70+
("FLOAT4E2M1", "FLOAT16"),
71+
("FLOAT", "FLOAT4E2M1"),
72+
("FLOAT16", "FLOAT4E2M1"),
6573
]
6674

6775
vect_float32_to_float8e4m3 = np.vectorize(float32_to_float8e4m3)
@@ -278,7 +286,57 @@ def export() -> None:
278286
output_type_proto = onnx.helper.make_tensor_type_proto(
279287
getattr(TensorProto, to_type), input_shape
280288
)
289+
elif from_type == "FLOAT4E2M1" or to_type == "FLOAT4E2M1":
290+
np_fp32 = np.array(
291+
[
292+
"0.48",
293+
"0.25",
294+
"1.05",
295+
"-3.5",
296+
"-8",
297+
"9",
298+
"1000000",
299+
"1e-7",
300+
"NaN",
301+
"INF",
302+
"+INF",
303+
"-INF",
304+
"-4",
305+
"0.01",
306+
"-0.0",
307+
],
308+
dtype=np.float32,
309+
)
310+
input_shape = (3, 5)
311+
if from_type == "FLOAT":
312+
input_values = np_fp32
313+
input = make_tensor(
314+
"x", TensorProto.FLOAT, input_shape, input_values.tolist()
315+
)
316+
elif from_type == "FLOAT16":
317+
input_values = np_fp32.astype(np.float16).astype(np.float32)
318+
input = make_tensor(
319+
"x", TensorProto.FLOAT16, input_shape, input_values.tolist()
320+
)
321+
elif from_type == "FLOAT4E2M1":
322+
input = make_tensor(
323+
"x", TensorProto.FLOAT4E2M1, input_shape, np_fp32.tolist()
324+
)
325+
else:
326+
raise ValueError(
327+
f"Conversion from {from_type} to {to_type} is not tested."
328+
)
281329

330+
if to_type not in ("FLOAT", "FLOAT16", "FLOAT4E2M1"):
331+
raise ValueError(
332+
f"Conversion from {from_type} to {to_type} is not tested."
333+
)
334+
expected = unpacked_float4e2m1_to_float32(
335+
subbyte.float32_to_float4e2m1_unpacked(np_fp32)
336+
)
337+
output = make_tensor(
338+
"y", getattr(TensorProto, to_type), input_shape, expected.tolist()
339+
)
282340
elif from_type != "STRING":
283341
input = np.random.random_sample(shape).astype(
284342
helper.tensor_dtype_to_np_dtype(getattr(TensorProto, from_type))

onnx/backend/test/case/node/dequantizelinear.py

+22
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,28 @@ def export_int4() -> None:
235235
name="test_dequantizelinear_int4",
236236
)
237237

238+
@staticmethod
239+
def export_float4e2m1() -> None:
240+
node = onnx.helper.make_node(
241+
"DequantizeLinear",
242+
inputs=["x", "x_scale", "x_zero_point"],
243+
outputs=["y"],
244+
axis=0,
245+
)
246+
247+
# scalar zero point and scale
248+
x = make_tensor("x", TensorProto.FLOAT4E2M1, [5], [0, 1, -1, 1.5, -4])
249+
x_scale = np.float32(2)
250+
x_zero_point = make_tensor("x_zero_point", TensorProto.FLOAT4E2M1, (1,), [0])
251+
y = np.array([0, 2, -2, 3, -8], dtype=np.float32)
252+
253+
expect(
254+
node,
255+
inputs=[x, x_scale, x_zero_point],
256+
outputs=[y],
257+
name="test_dequantizelinear_float4e2m1",
258+
)
259+
238260
@staticmethod
239261
def export_blocked() -> None:
240262
node = onnx.helper.make_node(

onnx/backend/test/case/node/quantizelinear.py

+38
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,44 @@ def export_int4() -> None:
276276
name="test_quantizelinear_int4",
277277
)
278278

279+
@staticmethod
280+
def export_float4e2m1() -> None:
281+
node = onnx.helper.make_node(
282+
"QuantizeLinear",
283+
inputs=["x", "y_scale", "y_zero_point"],
284+
outputs=["y"],
285+
axis=0,
286+
)
287+
288+
x = np.array(
289+
[
290+
[0.0, 2.5, 4.8, 8.6],
291+
[-30, -20, 6, 9],
292+
[-0.0, -2.5, -4.8, -8.6],
293+
]
294+
).astype(np.float32)
295+
296+
y_scale = np.asarray([2.0, 3.0, 4.0], dtype=np.float32)
297+
y_zero_point = make_tensor(
298+
"y_zero_point",
299+
TensorProto.FLOAT4E2M1,
300+
y_scale.shape,
301+
np.zeros_like(y_scale),
302+
)
303+
y = make_tensor(
304+
"y",
305+
TensorProto.FLOAT4E2M1,
306+
x.shape,
307+
[0, 1, 2, 4, -6, -6, 2, 3, 0, -0.5, -1, -2],
308+
)
309+
310+
expect(
311+
node,
312+
inputs=[x, y_scale, y_zero_point],
313+
outputs=[y],
314+
name="test_quantizelinear_float4e2m1",
315+
)
316+
279317
@staticmethod
280318
def export_blocked_asymmetric() -> None:
281319
node = onnx.helper.make_node(
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+

2+
*'�o�h�x�������������������B��Bx
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
*
2+
�w�By
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
*
2+
�w�Bx
Binary file not shown.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
*
2+
�w�Bx
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
*
2+
�w�By
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
* :Bx
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
* d�T��By
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

onnx/defs/operator_sets.h

+9-2
Original file line numberDiff line numberDiff line change
@@ -1291,11 +1291,18 @@ class OpSet_Onnx_ver22 {
12911291
};
12921292

12931293
// Iterate over schema from ai.onnx version 23
1294+
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 23, Cast);
1295+
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 23, CastLike);
1296+
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 23, DequantizeLinear);
1297+
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 23, QuantizeLinear);
1298+
12941299
class OpSet_Onnx_ver23 {
12951300
public:
12961301
static void ForEachSchema(std::function<void(OpSchema&&)> fn) {
1297-
// TODO: Remove after introducing the first schema to opset 23
1298-
(void)fn;
1302+
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 23, Cast)>());
1303+
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 23, CastLike)>());
1304+
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 23, DequantizeLinear)>());
1305+
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 23, QuantizeLinear)>());
12991306
}
13001307
};
13011308

0 commit comments

Comments
 (0)