Skip to content

Commit 01e31b7

Browse files
authored
fix(compress.py): use single value table for per-tensor quantized tensors (#3025)
Compress using a single value table when a tensor is per-tensor quantized, as indicated by the presence of only one quantization scale and zero point. Update unit tests accordingly and augment `test_models` to accommodate additional quantization fields. Abandon the logic that a tensor should be compressed along the NHWC channel dimension if the quantization parameters do not specify an axis. Instead, fail with an error if the compression axis cannot be inferred from the quantization parameters. The interpreter already expects a single value table when a tensor is per-tensor quantized. BUG=part of #2636
1 parent e3ac890 commit 01e31b7

File tree

3 files changed

+164
-36
lines changed

3 files changed

+164
-36
lines changed

tensorflow/lite/micro/compression/compress.py

+62-27
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import bitarray.util
2121
from dataclasses import dataclass, field
2222
import sys
23-
from typing import ByteString, Iterable
23+
from typing import ByteString, Iterable, Optional
2424

2525
import absl.app
2626
import absl.flags
@@ -107,7 +107,7 @@ def _add_subgraph(self):
107107

108108
@dataclass
109109
class _LutCompressedArray:
110-
compression_axis: int = 0
110+
compression_axis: Optional[int] = None
111111
lookup_tables: list[np.ndarray] = field(default_factory=list)
112112
indices: np.ndarray = field(default_factory=lambda: np.array([]))
113113

@@ -121,27 +121,46 @@ def index_bitwidth(self) -> int:
121121
return max_index.bit_length() or 1
122122

123123

124-
def _lut_compress_array(tensor: np.ndarray, axis: int) -> _LutCompressedArray:
125-
"""Compresses using a lookup table per subarray along the given axis.
124+
def _lut_compress_array(tensor: np.ndarray,
125+
axis: Optional[int]) -> _LutCompressedArray:
126+
"""Compresses the given tensor using lookup tables.
126127
127-
Compressing a tensor with a lookup table per subarray along a particular axis
128-
is analogous to quantizing a tensor with different quantization parameters
129-
per subarray along a particular axis (dimension).
128+
Args:
129+
tensor (np.ndarray): The tensor to be compressed.
130+
131+
axis (Optional[int]): The axis along which to compress the tensor. If an
132+
axis is given, a lookup table is created for each slice along the
133+
axis. If axis is None, a single lookup table is used for the entire
134+
tensor.
135+
136+
Compressing a tensor with a lookup table per slice along a
137+
particular axis is analogous to quantizing a tensor with different
138+
quantization parameters per slice along a particular axis (dimension).
139+
140+
Returns:
141+
_LutCompressedArray: An object containing the compressed tensor data,
142+
including the lookup tables and indices.
130143
"""
131144
compressed = _LutCompressedArray()
132145
compressed.compression_axis = axis
133146

134-
# Iterate over subarrays along the compression axis
135-
subarray_indices = []
136-
for subarray in np.moveaxis(tensor, axis, 0):
137-
values, indices = np.unique(subarray, return_inverse=True)
147+
if axis is None:
148+
# Compute unique values and indices for the entire tensor
149+
values, indices = np.unique(tensor, return_inverse=True)
138150
compressed.lookup_tables.append(values)
139-
indices = indices.reshape(subarray.shape)
140-
subarray_indices.append(indices)
141-
142-
# Reconstruct a tensor of indices from the subarrays
143-
stacked = np.stack(subarray_indices, axis=0)
144-
compressed.indices = np.moveaxis(stacked, 0, axis)
151+
compressed.indices = indices.reshape(tensor.shape)
152+
else:
153+
# Iterate over slices along the compression axis
154+
slice_indices = []
155+
for slice in np.moveaxis(tensor, axis, 0):
156+
values, indices = np.unique(slice, return_inverse=True)
157+
compressed.lookup_tables.append(values)
158+
indices = indices.reshape(slice.shape)
159+
slice_indices.append(indices)
160+
161+
# Reconstruct a tensor of indices from the slices
162+
stacked = np.stack(slice_indices, axis=0)
163+
compressed.indices = np.moveaxis(stacked, 0, axis)
145164

146165
return compressed
147166

@@ -155,18 +174,34 @@ def _check_lut_compression(compression) -> spec.LookUpTableCompression:
155174
return compression[0]
156175

157176

158-
def _identify_compression_axis(tensor: model_facade._Tensor) -> int:
159-
"""Finds the axis along which to compress.
177+
def _identify_compression_axis(tensor: model_facade._Tensor) -> Optional[int]:
178+
"""Determines the axis along which to compress.
179+
180+
The axis along which to compress is inferred from the tensor's quantization
181+
parameters.
182+
183+
Returns:
184+
The axis along which to compress, or None to indicate one value table for
185+
the entire tensor.
160186
161-
Use the quantization axis, else the NWHC channel dimension. If necessary,
162-
an user-specified override could be added to the compression spec schema.
187+
Raises:
188+
CompressionError: If the axis cannot be determined.
163189
"""
164-
if tensor.quantization is not None:
165-
axis = tensor.quantization.quantizedDimension
166-
else:
167-
axis = tensor.array.ndim - 1
190+
q = tensor.quantization
191+
if q is not None \
192+
and q.scale is not None \
193+
and q.quantizedDimension < len(tensor.shape):
194+
quantization_channels = len(q.scale)
195+
if quantization_channels == 1:
196+
# Use one value table for the entire tensor
197+
return None
198+
199+
if quantization_channels == tensor.shape[q.quantizedDimension]:
200+
return q.quantizedDimension
168201

169-
return axis
202+
raise CompressionError(
203+
f"Invalid or no quanitzation parameters from which to "
204+
f"infer the axis along which tensor should be compressed.")
170205

171206

172207
def _check_bitwidth(compressed: int, specified: int, spec: spec.Tensor):
@@ -204,7 +239,7 @@ def _pack_lookup_tables(tables: list[np.ndarray], table_len: int) -> bytearray:
204239
205240
Pack the value tables of a LutCompressedArray into a bytes object in the
206241
format writable to a value_table buffer in the .tflite flatbuffer. The
207-
tables, one per subarray, are concatinated.
242+
tables are concatinated.
208243
"""
209244
buffer = bytearray()
210245
for t in tables:

tensorflow/lite/micro/compression/compress_test.py

+95-4
Original file line numberDiff line numberDiff line change
@@ -200,40 +200,87 @@ def test_multiple_tables_with_padding(self):
200200
"shape": (16, 1),
201201
"type": tflite.TensorType.UINT8,
202202
"buffer": 1,
203+
"quantization": {
204+
"quantized_dimension": 1,
205+
"scale": (1,),
206+
"zero_point": (0,),
207+
},
203208
},
204209
1: {
205210
"shape": (16, 1),
206211
"type": tflite.TensorType.INT8,
207212
"buffer": 2,
213+
"quantization": {
214+
"quantized_dimension": 1,
215+
"scale": (1,),
216+
"zero_point": (0,),
217+
},
208218
},
209219
2: {
210220
"shape": (16, 1),
211221
"type": tflite.TensorType.INT16,
212222
"buffer": 3,
223+
"quantization": {
224+
"quantized_dimension": 1,
225+
"scale": (1,),
226+
"zero_point": (0,),
227+
},
213228
},
214229
3: {
215230
"shape": (16, 1),
216231
"type": tflite.TensorType.INT32,
217232
"buffer": 4,
233+
"quantization": {
234+
"quantized_dimension": 1,
235+
"scale": (1,),
236+
"zero_point": (0,),
237+
},
218238
},
219239
4: {
220240
"shape": (16, 1),
221241
"type": tflite.TensorType.INT32,
222242
"buffer": 5,
243+
"quantization": {
244+
"quantized_dimension": 1,
245+
"scale": (1,),
246+
"zero_point": (0,),
247+
},
223248
},
224249
5: {
225250
"shape": (4, 5),
226251
"type": tflite.TensorType.INT16,
227252
"buffer": 6,
253+
"quantization": {
254+
"quantized_dimension": 1,
255+
"scale": (1, 1, 1, 1, 1),
256+
"zero_point": (0, 0, 0, 0, 0),
257+
},
228258
},
229259
6: {
230260
"shape": (5, 4),
231261
"type": tflite.TensorType.INT16,
232262
"buffer": 7,
233263
"quantization": {
234264
"quantized_dimension": 0,
265+
"scale": (1, 1, 1, 1, 1),
266+
"zero_point": (0, 0, 0, 0, 0),
235267
},
236268
},
269+
7: {
270+
"shape": (5, 4),
271+
"type": tflite.TensorType.INT16,
272+
"buffer": 8,
273+
"quantization": {
274+
"quantized_dimension": 0,
275+
"scale": (1,),
276+
"zero_point": (0,),
277+
},
278+
},
279+
8: {
280+
"shape": (16, 1),
281+
"type": tflite.TensorType.UINT8,
282+
"buffer": 9,
283+
},
237284
},
238285
},
239286
},
@@ -260,6 +307,14 @@ def test_multiple_tables_with_padding(self):
260307
(9, 10, 11, 12),
261308
(13, 14, 15, 16),
262309
(17, 18, 19, 20)), dtype=np.dtype("<i2")),
310+
311+
8: np.array(((1, 2, 3, 4),
312+
(1, 2, 3, 4),
313+
(1, 2, 3, 4),
314+
(1, 2, 3, 4),
315+
(1, 2, 3, 4)), dtype=np.dtype("<i2")),
316+
317+
9: np.array(range(16), dtype=np.dtype("<u1")),
263318
},
264319
}
265320

@@ -297,6 +352,11 @@ def test_multiple_tables_with_padding(self):
297352
tensor=6,
298353
compression=[spec.LookUpTableCompression(index_bitwidth=2)],
299354
),
355+
spec.Tensor( # spec 6
356+
subgraph=0,
357+
tensor=7,
358+
compression=[spec.LookUpTableCompression(index_bitwidth=2)],
359+
),
300360
]
301361
# yapf: enable
302362

@@ -362,6 +422,18 @@ def test_invalid_tensor_spec(self):
362422
self.assertRaises(compress.CompressionError,
363423
lambda: compress.compress(self.flatbuffer, specs))
364424

425+
def test_no_axis(self):
426+
"""Raises if no quantization from which to infer compression axis."""
427+
specs = [
428+
spec.Tensor(
429+
subgraph=0,
430+
tensor=8,
431+
compression=[spec.LookUpTableCompression(index_bitwidth=4)],
432+
),
433+
]
434+
self.assertRaises(compress.CompressionError,
435+
lambda: compress.compress(self.flatbuffer, specs))
436+
365437

366438
class TestLutCompressedArray(tf.test.TestCase):
367439

@@ -519,8 +591,8 @@ def test_compressed_int32(self):
519591
expected_values = np.array(range(-160_016, -160_000), dtype="<i4")
520592
self.assertAllEqual(values, expected_values)
521593

522-
def test_channel_axis(self):
523-
"""Compression along the NWHC channel axis when no quanitzation axis."""
594+
def test_axis_1(self):
595+
"""Compression along quanitzation_dimension == 1."""
524596
bitwidth, indices, values = self._get_compressed(subgraph=0, tensor=5)
525597
self.assertEqual(bitwidth, 2)
526598

@@ -537,8 +609,8 @@ def test_channel_axis(self):
537609
expected_values = np.array(range(1, 21), dtype=np.dtype("<i2"))
538610
self.assertAllEqual(values, expected_values)
539611

540-
def test_quantization_axis(self):
541-
"""Compression along the quanitzation axis."""
612+
def test_axis_0(self):
613+
"""Compression along quanitzation_dimension == 0."""
542614
bitwidth, indices, values = self._get_compressed(subgraph=0, tensor=6)
543615
self.assertEqual(bitwidth, 2)
544616

@@ -556,6 +628,25 @@ def test_quantization_axis(self):
556628
expected_values = np.array(range(1, 21), dtype=np.dtype("<i2"))
557629
self.assertAllEqual(values, expected_values)
558630

631+
def test_per_tensor(self):
632+
"""Compression with one value table per tensor."""
633+
bitwidth, indices, values = self._get_compressed(subgraph=0, tensor=7)
634+
self.assertEqual(bitwidth, 2)
635+
636+
# yapf: disable
637+
expected_indices = self._make_indices("""
638+
00 01 10 11
639+
00 01 10 11
640+
00 01 10 11
641+
00 01 10 11
642+
00 01 10 11
643+
""")
644+
# yapf: enable
645+
self.assertEqual(indices, expected_indices)
646+
647+
expected_values = np.array(range(1, 5), dtype=np.dtype("<i2"))
648+
self.assertAllEqual(values, expected_values)
649+
559650

560651
if __name__ == "__main__":
561652
tf.test.main()

tensorflow/lite/micro/compression/test_models.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -157,12 +157,14 @@ def build(model_definition: dict) -> bytearray:
157157
tensor_t.type = tensor["type"]
158158
tensor_t.buffer = tensor["buffer"]
159159

160-
try:
161-
d = tensor["quantization"]["quantized_dimension"]
160+
if "quantization" in tensor:
162161
tensor_t.quantization = tflite.QuantizationParametersT()
163-
tensor_t.quantization.quantizedDimension = d
164-
except KeyError:
165-
tensor_t.quantization = None
162+
tensor_t.quantization.quantizedDimension = \
163+
tensor["quantization"].get("quantized_dimension", None)
164+
tensor_t.quantization.scale = \
165+
tensor["quantization"].get("scale", None)
166+
tensor_t.quantization.zeroPoint = \
167+
tensor["quantization"].get("zero_point", None)
166168

167169
subgraph_t.tensors.append(tensor_t)
168170

0 commit comments

Comments
 (0)