@@ -200,40 +200,87 @@ def test_multiple_tables_with_padding(self):
200
200
"shape" : (16 , 1 ),
201
201
"type" : tflite .TensorType .UINT8 ,
202
202
"buffer" : 1 ,
203
+ "quantization" : {
204
+ "quantized_dimension" : 1 ,
205
+ "scale" : (1 ,),
206
+ "zero_point" : (0 ,),
207
+ },
203
208
},
204
209
1 : {
205
210
"shape" : (16 , 1 ),
206
211
"type" : tflite .TensorType .INT8 ,
207
212
"buffer" : 2 ,
213
+ "quantization" : {
214
+ "quantized_dimension" : 1 ,
215
+ "scale" : (1 ,),
216
+ "zero_point" : (0 ,),
217
+ },
208
218
},
209
219
2 : {
210
220
"shape" : (16 , 1 ),
211
221
"type" : tflite .TensorType .INT16 ,
212
222
"buffer" : 3 ,
223
+ "quantization" : {
224
+ "quantized_dimension" : 1 ,
225
+ "scale" : (1 ,),
226
+ "zero_point" : (0 ,),
227
+ },
213
228
},
214
229
3 : {
215
230
"shape" : (16 , 1 ),
216
231
"type" : tflite .TensorType .INT32 ,
217
232
"buffer" : 4 ,
233
+ "quantization" : {
234
+ "quantized_dimension" : 1 ,
235
+ "scale" : (1 ,),
236
+ "zero_point" : (0 ,),
237
+ },
218
238
},
219
239
4 : {
220
240
"shape" : (16 , 1 ),
221
241
"type" : tflite .TensorType .INT32 ,
222
242
"buffer" : 5 ,
243
+ "quantization" : {
244
+ "quantized_dimension" : 1 ,
245
+ "scale" : (1 ,),
246
+ "zero_point" : (0 ,),
247
+ },
223
248
},
224
249
5 : {
225
250
"shape" : (4 , 5 ),
226
251
"type" : tflite .TensorType .INT16 ,
227
252
"buffer" : 6 ,
253
+ "quantization" : {
254
+ "quantized_dimension" : 1 ,
255
+ "scale" : (1 , 1 , 1 , 1 , 1 ),
256
+ "zero_point" : (0 , 0 , 0 , 0 , 0 ),
257
+ },
228
258
},
229
259
6 : {
230
260
"shape" : (5 , 4 ),
231
261
"type" : tflite .TensorType .INT16 ,
232
262
"buffer" : 7 ,
233
263
"quantization" : {
234
264
"quantized_dimension" : 0 ,
265
+ "scale" : (1 , 1 , 1 , 1 , 1 ),
266
+ "zero_point" : (0 , 0 , 0 , 0 , 0 ),
235
267
},
236
268
},
269
+ 7 : {
270
+ "shape" : (5 , 4 ),
271
+ "type" : tflite .TensorType .INT16 ,
272
+ "buffer" : 8 ,
273
+ "quantization" : {
274
+ "quantized_dimension" : 0 ,
275
+ "scale" : (1 ,),
276
+ "zero_point" : (0 ,),
277
+ },
278
+ },
279
+ 8 : {
280
+ "shape" : (16 , 1 ),
281
+ "type" : tflite .TensorType .UINT8 ,
282
+ "buffer" : 9 ,
283
+ },
237
284
},
238
285
},
239
286
},
@@ -260,6 +307,14 @@ def test_multiple_tables_with_padding(self):
260
307
(9 , 10 , 11 , 12 ),
261
308
(13 , 14 , 15 , 16 ),
262
309
(17 , 18 , 19 , 20 )), dtype = np .dtype ("<i2" )),
310
+
311
+ 8 : np .array (((1 , 2 , 3 , 4 ),
312
+ (1 , 2 , 3 , 4 ),
313
+ (1 , 2 , 3 , 4 ),
314
+ (1 , 2 , 3 , 4 ),
315
+ (1 , 2 , 3 , 4 )), dtype = np .dtype ("<i2" )),
316
+
317
+ 9 : np .array (range (16 ), dtype = np .dtype ("<u1" )),
263
318
},
264
319
}
265
320
@@ -297,6 +352,11 @@ def test_multiple_tables_with_padding(self):
297
352
tensor = 6 ,
298
353
compression = [spec .LookUpTableCompression (index_bitwidth = 2 )],
299
354
),
355
+ spec .Tensor ( # spec 6
356
+ subgraph = 0 ,
357
+ tensor = 7 ,
358
+ compression = [spec .LookUpTableCompression (index_bitwidth = 2 )],
359
+ ),
300
360
]
301
361
# yapf: enable
302
362
@@ -362,6 +422,18 @@ def test_invalid_tensor_spec(self):
362
422
self .assertRaises (compress .CompressionError ,
363
423
lambda : compress .compress (self .flatbuffer , specs ))
364
424
425
+ def test_no_axis (self ):
426
+ """Raises if no quantization from which to infer compression axis."""
427
+ specs = [
428
+ spec .Tensor (
429
+ subgraph = 0 ,
430
+ tensor = 8 ,
431
+ compression = [spec .LookUpTableCompression (index_bitwidth = 4 )],
432
+ ),
433
+ ]
434
+ self .assertRaises (compress .CompressionError ,
435
+ lambda : compress .compress (self .flatbuffer , specs ))
436
+
365
437
366
438
class TestLutCompressedArray (tf .test .TestCase ):
367
439
@@ -519,8 +591,8 @@ def test_compressed_int32(self):
519
591
expected_values = np .array (range (- 160_016 , - 160_000 ), dtype = "<i4" )
520
592
self .assertAllEqual (values , expected_values )
521
593
522
- def test_channel_axis (self ):
523
- """Compression along the NWHC channel axis when no quanitzation axis ."""
594
+ def test_axis_1 (self ):
595
+ """Compression along quanitzation_dimension == 1 ."""
524
596
bitwidth , indices , values = self ._get_compressed (subgraph = 0 , tensor = 5 )
525
597
self .assertEqual (bitwidth , 2 )
526
598
@@ -537,8 +609,8 @@ def test_channel_axis(self):
537
609
expected_values = np .array (range (1 , 21 ), dtype = np .dtype ("<i2" ))
538
610
self .assertAllEqual (values , expected_values )
539
611
540
- def test_quantization_axis (self ):
541
- """Compression along the quanitzation axis ."""
612
+ def test_axis_0 (self ):
613
+ """Compression along quanitzation_dimension == 0 ."""
542
614
bitwidth , indices , values = self ._get_compressed (subgraph = 0 , tensor = 6 )
543
615
self .assertEqual (bitwidth , 2 )
544
616
@@ -556,6 +628,25 @@ def test_quantization_axis(self):
556
628
expected_values = np .array (range (1 , 21 ), dtype = np .dtype ("<i2" ))
557
629
self .assertAllEqual (values , expected_values )
558
630
631
+ def test_per_tensor (self ):
632
+ """Compression with one value table per tensor."""
633
+ bitwidth , indices , values = self ._get_compressed (subgraph = 0 , tensor = 7 )
634
+ self .assertEqual (bitwidth , 2 )
635
+
636
+ # yapf: disable
637
+ expected_indices = self ._make_indices ("""
638
+ 00 01 10 11
639
+ 00 01 10 11
640
+ 00 01 10 11
641
+ 00 01 10 11
642
+ 00 01 10 11
643
+ """ )
644
+ # yapf: enable
645
+ self .assertEqual (indices , expected_indices )
646
+
647
+ expected_values = np .array (range (1 , 5 ), dtype = np .dtype ("<i2" ))
648
+ self .assertAllEqual (values , expected_values )
649
+
559
650
560
651
if __name__ == "__main__" :
561
652
tf .test .main ()
0 commit comments