@@ -72,7 +72,7 @@ def unpack_single_4bitx2(
72
72
return (x_low .astype (dtype ), x_high .astype (dtype ))
73
73
74
74
75
- def float32_to_float4e2m1_unpacked (x : np .ndarray | np .dtype ) -> np .ndarray :
75
+ def float32_to_float4e2m1_unpacked_slow (x : np .ndarray | np .dtype ) -> np .ndarray :
76
76
"""Cast float32 to float4e2m1 (without packing).
77
77
78
78
Args:
@@ -85,7 +85,7 @@ def float32_to_float4e2m1_unpacked(x: np.ndarray | np.dtype) -> np.ndarray:
85
85
def float32_to_float4e2m1 (value ):
86
86
if np .isnan (value ):
87
87
return 0x7
88
- s = 0x0 if value >= 0 else 0x8
88
+ s = 0x8 if np . signbit ( value ) else 0x0
89
89
magnitude = np .abs (value )
90
90
if np .isinf (magnitude ):
91
91
ret = 0x7
@@ -116,14 +116,38 @@ def float32_to_float4e2m1(value):
116
116
return y .astype (np .uint8 ) # type: ignore[no-any-return]
117
117
118
118
119
- def float32x2_to_float4e2m1x2 (val_low : np .dtype , val_high : np .dtype ) -> np .ndarray :
119
+ def float32_to_float4e2m1_unpacked (values : np .ndarray ) -> np .ndarray :
120
+ """Cast float32 to float4e2m1 (without packing).
121
+
122
+ Args:
123
+ values: element or array to be converted
124
+
125
+ Returns:
126
+ An ndarray with unpacked float4e2m1 elements (as uint8)
127
+ """
128
+ sign = np .where (np .signbit (values ), 0x8 , 0x0 ).astype (np .uint8 )
129
+ magnitude = np .abs (values )
130
+ res = np .zeros (values .shape , dtype = np .uint8 )
131
+ res [(magnitude > 0.25 ) & (magnitude < 0.75 )] = 0x1
132
+ res [(magnitude >= 0.75 ) & (magnitude <= 1.25 )] = 0x2
133
+ res [(magnitude > 1.25 ) & (magnitude < 1.75 )] = 0x3
134
+ res [(magnitude >= 1.75 ) & (magnitude <= 2.5 )] = 0x4
135
+ res [(magnitude > 2.5 ) & (magnitude < 3.5 )] = 0x5
136
+ res [(magnitude >= 3.5 ) & (magnitude <= 5.0 )] = 0x6
137
+ res [magnitude > 5.0 ] = 0x7
138
+ res |= sign
139
+ res [np .isnan (values )] = 0x7
140
+ return res
141
+
142
+
143
+ def float32x2_to_float4e2m1x2 (val_low : np .ndarray , val_high : np .ndarray ) -> np .ndarray :
120
144
"""Cast two elements to float4e2m1 and pack to a single byte
121
145
Args:
122
146
val_low: element to be packed in the 4 LSB
123
147
val_high: element to be packed in the 4 MSB
124
148
125
149
Returns:
126
- An ndarray with a single uint8 element , containing both float4e2m1 elements
150
+ An ndarray with uint8 elements , containing both float4e2m1 elements
127
151
"""
128
152
i8_high = float32_to_float4e2m1_unpacked (val_high )
129
153
i8_low = float32_to_float4e2m1_unpacked (val_low )
0 commit comments