Skip to content

Commit 900a8fe

Browse files
Andrew Grebenisanfacebook-github-bot
authored andcommitted
Remove quantized_conv and leave just per tensor variants (#13958)
Summary: As discussed offline, there is no need for the non-per-tensor-variants of quantized conv channels first/last. Only per-tensor variants remain. Reviewed By: hsharma35 Differential Revision: D81649180
1 parent cf31f18 commit 900a8fe

File tree

2 files changed

+58
-72
lines changed

2 files changed

+58
-72
lines changed

backends/cadence/aot/ref_implementations.py

Lines changed: 31 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,7 @@ def quantized_layer_norm_per_tensor(
296296
)
297297

298298

299-
def quantized_conv(
299+
def quantized_conv_per_tensor(
300300
input_tensor: torch.Tensor,
301301
weight: torch.Tensor,
302302
bias: torch.Tensor,
@@ -305,12 +305,12 @@ def quantized_conv(
305305
dilation: tuple[int, int],
306306
groups: int,
307307
in_zero_point: int,
308-
weight_zero_point: torch.Tensor,
309-
bias_scale: torch.Tensor,
308+
weight_zero_point: int,
309+
bias_scale: float,
310310
output_scale: float,
311311
output_zero_point: int,
312-
out_multiplier: torch.Tensor,
313-
out_shift: torch.Tensor,
312+
out_multiplier: int,
313+
out_shift: int,
314314
) -> torch.Tensor:
315315
"""
316316
Quantized convolution operation.
@@ -324,19 +324,13 @@ def quantized_conv(
324324
- dilation (Tuple[int]): The dilation of the convolution
325325
- groups (int): The number of groups
326326
- in_zero_point (int): The quantized mapping of zero for the input
327-
- weight_zero_point (Tensor): The quantized mapping of zero for the weight
328-
- bias_scale (Tensor): The quantized bias scale
327+
- weight_zero_point (int): The quantized mapping of zero for the weight
328+
- bias_scale (float): The quantized bias scale
329329
- output_scale (float): The scale of the output
330330
- output_zero_point (int): The zero point of the output
331-
- out_multiplier (Tensor): Unused
332-
- out_shift (Tensor): Unused
331+
- out_multiplier (int): Unused
332+
- out_shift (int): Unused
333333
"""
334-
if weight_zero_point.view(-1).shape != (1,):
335-
raise ValueError("Weight zero point must be a scalar")
336-
337-
if bias_scale.view(-1).shape != (1,):
338-
raise ValueError("Bias scale must be a scalar")
339-
340334
if len(input_tensor.shape) == 3:
341335
float_out = torch.nn.functional.conv1d(
342336
(input_tensor - in_zero_point).float(),
@@ -371,8 +365,8 @@ def quantized_conv(
371365
)
372366

373367

374-
@impl(m, "quantized_conv_nchw")
375-
def quantized_conv_nchw(
368+
@impl(m, "quantized_conv_nchw_per_tensor")
369+
def quantized_conv_nchw_per_tensor(
376370
input_tensor: torch.Tensor,
377371
weight: torch.Tensor,
378372
bias: torch.Tensor,
@@ -381,12 +375,12 @@ def quantized_conv_nchw(
381375
dilation: tuple[int, int],
382376
groups: int,
383377
in_zero_point: int,
384-
weight_zero_point: torch.Tensor,
385-
bias_scale: torch.Tensor,
378+
weight_zero_point: int,
379+
bias_scale: float,
386380
output_scale: float,
387381
output_zero_point: int,
388-
out_multiplier: torch.Tensor,
389-
out_shift: torch.Tensor,
382+
out_multiplier: int,
383+
out_shift: int,
390384
) -> torch.Tensor:
391385
"""
392386
Quantized convolution operation.
@@ -400,16 +394,16 @@ def quantized_conv_nchw(
400394
- dilation (Tuple[int]): The dilation of the convolution
401395
- groups (int): The number of groups
402396
- in_zero_point (int): The quantized mapping of zero for the input
403-
- weight_zero_point (Tensor): The quantized mapping of zero for the weight
404-
- bias_scale (Tensor): The quantized bias scale
397+
- weight_zero_point (int): The quantized mapping of zero for the weight
398+
- bias_scale (float): The quantized bias scale
405399
- output_scale (float): The scale of the output
406400
- output_zero_point (int): The zero point of the output
407-
- out_multiplier (Tensor): Unused
408-
- out_shift (Tensor): Unused
401+
- out_multiplier (int): Unused
402+
- out_shift (int): Unused
409403
"""
410404
if not input_tensor.is_contiguous(memory_format=torch.contiguous_format):
411405
raise ValueError("Input tensor must be in NCHW format")
412-
return quantized_conv(
406+
return quantized_conv_per_tensor(
413407
input_tensor,
414408
weight,
415409
bias,
@@ -427,8 +421,8 @@ def quantized_conv_nchw(
427421
)
428422

429423

430-
@impl(m, "quantized_conv_nhwc")
431-
def quantized_conv_nhwc(
424+
@impl(m, "quantized_conv_nhwc_per_tensor")
425+
def quantized_conv_nhwc_per_tensor(
432426
input_tensor: torch.Tensor,
433427
weight: torch.Tensor,
434428
bias: torch.Tensor,
@@ -437,12 +431,12 @@ def quantized_conv_nhwc(
437431
dilation: tuple[int, int],
438432
groups: int,
439433
in_zero_point: int,
440-
weight_zero_point: torch.Tensor,
441-
bias_scale: torch.Tensor,
434+
weight_zero_point: int,
435+
bias_scale: float,
442436
output_scale: float,
443437
output_zero_point: int,
444-
out_multiplier: torch.Tensor,
445-
out_shift: torch.Tensor,
438+
out_multiplier: int,
439+
out_shift: int,
446440
) -> torch.Tensor:
447441
"""
448442
Quantized convolution operation.
@@ -456,18 +450,18 @@ def quantized_conv_nhwc(
456450
- dilation (Tuple[int]): The dilation of the convolution
457451
- groups (int): The number of groups
458452
- in_zero_point (int): The quantized mapping of zero for the input
459-
- weight_zero_point (Tensor): The quantized mapping of zero for the weight
460-
- bias_scale (Tensor): The quantized bias scale
453+
- weight_zero_point (int): The quantized mapping of zero for the weight
454+
- bias_scale (float): The quantized bias scale
461455
- output_scale (float): The scale of the output
462456
- output_zero_point (int): The zero point of the output
463-
- out_multiplier (Tensor): Unused
464-
- out_shift (Tensor): Unused
457+
- out_multiplier (int): Unused
458+
- out_shift (int): Unused
465459
"""
466460

467461
if not input_tensor.is_contiguous(memory_format=torch.channels_last):
468462
raise ValueError("Input tensor must be in NHWC format")
469463

470-
return quantized_conv(
464+
return quantized_conv_per_tensor(
471465
input_tensor,
472466
weight,
473467
bias,

backends/cadence/aot/tests/test_ref_implementations.py

Lines changed: 27 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
dequantize_per_tensor,
1616
quantize_per_tensor,
1717
quantized_add,
18-
quantized_conv_nchw,
19-
quantized_conv_nhwc,
18+
quantized_conv_nchw_per_tensor,
19+
quantized_conv_nhwc_per_tensor,
2020
quantized_layer_norm_per_tensor,
2121
quantized_linear,
2222
quantized_relu,
@@ -356,8 +356,8 @@ def test_quantized_layer_norm_per_tensor(
356356
(1, 1), # dilation
357357
1, # groups
358358
0, # in_zero_point
359-
torch.tensor([0], dtype=torch.int8), # weight_zero_point
360-
torch.tensor([1.0], dtype=torch.float32), # bias_scale
359+
0, # weight_zero_point
360+
1.0, # bias_scale
361361
0.1, # output_scale
362362
0, # output_zero_point
363363
torch.tensor(
@@ -387,8 +387,8 @@ def test_quantized_layer_norm_per_tensor(
387387
(1, 1), # dilation
388388
1, # groups
389389
0, # in_zero_point
390-
torch.tensor([0], dtype=torch.int8), # weight_zero_point
391-
torch.tensor([1.0], dtype=torch.float32), # bias_scale
390+
0, # weight_zero_point
391+
1.0, # bias_scale
392392
0.25, # output_scale
393393
0, # output_zero_point
394394
typing.cast(None, torch.Tensor),
@@ -416,8 +416,8 @@ def test_quantized_layer_norm_per_tensor(
416416
(1, 1), # dilation
417417
1, # groups
418418
128, # in_zero_point
419-
torch.tensor([128], dtype=torch.uint8), # weight_zero_point
420-
torch.tensor([0.1], dtype=torch.float32), # bias_scale
419+
128, # weight_zero_point
420+
0.1, # bias_scale
421421
0.1, # output_scale
422422
128, # output_zero_point
423423
typing.cast(None, torch.Tensor),
@@ -447,8 +447,8 @@ def test_quantized_layer_norm_per_tensor(
447447
(1, 1), # dilation (padding for 2D, actual dilation is dilation[1])
448448
1, # groups
449449
0, # in_zero_point
450-
torch.tensor([0], dtype=torch.int8), # weight_zero_point
451-
torch.tensor([1.0], dtype=torch.float32), # bias_scale
450+
0, # weight_zero_point
451+
1.0, # bias_scale
452452
0.5, # output_scale
453453
0, # output_zero_point
454454
typing.cast(None, torch.Tensor),
@@ -482,8 +482,8 @@ def test_quantized_layer_norm_per_tensor(
482482
(1, 1), # dilation
483483
1, # groups
484484
0, # in_zero_point
485-
torch.tensor([0], dtype=torch.int8), # weight_zero_point
486-
torch.tensor([1.0], dtype=torch.float32), # bias_scale
485+
0, # weight_zero_point
486+
1.0, # bias_scale
487487
0.2, # output_scale
488488
0, # output_zero_point
489489
typing.cast(None, torch.Tensor),
@@ -523,8 +523,8 @@ def test_quantized_layer_norm_per_tensor(
523523
(1, 1), # dilation
524524
1, # groups
525525
0, # in_zero_point
526-
torch.tensor([0], dtype=torch.int16), # weight_zero_point
527-
torch.tensor([1.0], dtype=torch.float32), # bias_scale
526+
0, # weight_zero_point
527+
1.0, # bias_scale
528528
0.1, # output_scale
529529
0, # output_zero_point
530530
typing.cast(None, torch.Tensor),
@@ -576,12 +576,8 @@ def test_quantized_layer_norm_per_tensor(
576576
(1, 1), # dilation
577577
1, # groups
578578
0, # in_zero_point
579-
torch.tensor(
580-
[0], dtype=torch.int16
581-
), # weight_zero_point for each output channel
582-
torch.tensor(
583-
[1.0], dtype=torch.float32
584-
), # bias_scale for each channel
579+
0, # weight_zero_point
580+
1.0, # bias_scale
585581
0.05, # output_scale
586582
0, # output_zero_point
587583
typing.cast(None, torch.Tensor),
@@ -623,12 +619,8 @@ def test_quantized_layer_norm_per_tensor(
623619
(1, 1), # dilation
624620
2, # groups (grouped convolution)
625621
0, # in_zero_point
626-
torch.tensor(
627-
[0], dtype=torch.int8
628-
), # weight_zero_point for each output channel
629-
torch.tensor(
630-
[1.0], dtype=torch.float32
631-
), # bias_scale for each channel
622+
0, # weight_zero_point
623+
1.0, # bias_scale
632624
0.2, # output_scale
633625
0, # output_zero_point
634626
typing.cast(None, torch.Tensor),
@@ -666,8 +658,8 @@ def test_quantized_layer_norm_per_tensor(
666658
(1, 1), # dilation
667659
1, # groups
668660
0, # in_zero_point
669-
torch.tensor([0], dtype=torch.int8), # weight_zero_point
670-
torch.tensor([1.0], dtype=torch.float32), # bias_scale
661+
0, # weight_zero_point
662+
1.0, # bias_scale
671663
0.5, # output_scale
672664
0, # output_zero_point
673665
typing.cast(None, torch.Tensor),
@@ -682,7 +674,7 @@ def test_quantized_layer_norm_per_tensor(
682674
],
683675
]
684676
)
685-
def test_quantized_conv(
677+
def test_quantized_conv_per_tensor(
686678
self,
687679
input_tensor: torch.Tensor,
688680
weight: torch.Tensor,
@@ -692,12 +684,12 @@ def test_quantized_conv(
692684
dilation: tuple[int, int],
693685
groups: int,
694686
in_zero_point: int,
695-
weight_zero_point: torch.Tensor,
696-
bias_scale: torch.Tensor,
687+
weight_zero_point: int,
688+
bias_scale: float,
697689
output_scale: float,
698690
output_zero_point: int,
699-
out_multiplier: torch.Tensor,
700-
out_shift: torch.Tensor,
691+
out_multiplier: int,
692+
out_shift: int,
701693
dtype: torch.dtype,
702694
expected_output: torch.Tensor,
703695
memory_format: torch.memory_format,
@@ -710,9 +702,9 @@ def test_quantized_conv(
710702
input_tensor = input_tensor.to(memory_format=memory_format)
711703

712704
conv = (
713-
quantized_conv_nchw
705+
quantized_conv_nchw_per_tensor
714706
if memory_format == torch.contiguous_format
715-
else quantized_conv_nhwc
707+
else quantized_conv_nhwc_per_tensor
716708
)
717709

718710
output = conv(

0 commit comments

Comments
 (0)