Skip to content

Commit d3c84b7

Browse files
Andrew Grebenisanfacebook-github-bot
authored andcommitted
Add uint8/int8 specializations for conv per tensor (pytorch#14033)
Summary: Continued support of adding custom Cadence python references Reviewed By: hsharma35 Differential Revision: D81720359
1 parent c615770 commit d3c84b7

File tree

3 files changed

+285
-40
lines changed

3 files changed

+285
-40
lines changed

backends/cadence/aot/ops_registrations.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -873,6 +873,11 @@ def quantized_conv_nchw_asym8sxsym8s_asym8s_per_tensor_meta(
873873
out_multiplier: int,
874874
out_shift: int,
875875
) -> torch.Tensor:
876+
assert (
877+
input.dtype == torch.int8
878+
and weight.dtype == torch.int8
879+
and bias.dtype == torch.int32
880+
)
876881
out_channels, _, *kernel_size = weight.shape
877882

878883
in_size = input.shape
@@ -917,6 +922,11 @@ def quantized_conv_nchw_asym8uxsym8u_asym8u_per_tensor_meta(
917922
out_multiplier: int,
918923
out_shift: int,
919924
) -> torch.Tensor:
925+
assert (
926+
input.dtype == torch.uint8
927+
and weight.dtype == torch.uint8
928+
and bias.dtype == torch.int32
929+
)
920930
out_channels, _, *kernel_size = weight.shape
921931

922932
in_size = input.shape
@@ -961,6 +971,11 @@ def quantized_conv_nhwc_asym8sxsym8s_asym8s_per_tensor_meta(
961971
out_multiplier: int,
962972
out_shift: int,
963973
) -> torch.Tensor:
974+
assert (
975+
input.dtype == torch.int8
976+
and weight.dtype == torch.int8
977+
and bias.dtype == torch.int32
978+
)
964979
out_channels, *kernel_size, _ = weight.shape
965980

966981
in_size = input.shape
@@ -1005,6 +1020,11 @@ def quantized_conv_nhwc_asym8uxsym8u_asym8u_per_tensor_meta(
10051020
out_multiplier: int,
10061021
out_shift: int,
10071022
) -> torch.Tensor:
1023+
assert (
1024+
input.dtype == torch.uint8
1025+
and weight.dtype == torch.uint8
1026+
and bias.dtype == torch.int32
1027+
)
10081028
out_channels, *kernel_size, _ = weight.shape
10091029

10101030
in_size = input.shape
@@ -1049,6 +1069,11 @@ def quantized_conv_nchw_dilated_asym8sxsym8s_asym8s_per_tensor_meta(
10491069
out_multiplier: int,
10501070
out_shift: int,
10511071
) -> torch.Tensor:
1072+
assert (
1073+
input.dtype == torch.int8
1074+
and weight.dtype == torch.int8
1075+
and bias.dtype == torch.int32
1076+
)
10521077
out_channels, _, *kernel_size = weight.shape
10531078

10541079
in_size = input.shape
@@ -1093,6 +1118,11 @@ def quantized_conv_nchw_dilated_asym8uxsym8u_asym8u_per_tensor_meta(
10931118
out_multiplier: int,
10941119
out_shift: int,
10951120
) -> torch.Tensor:
1121+
assert (
1122+
input.dtype == torch.uint8
1123+
and weight.dtype == torch.uint8
1124+
and bias.dtype == torch.int32
1125+
)
10961126
out_channels, _, *kernel_size = weight.shape
10971127

10981128
in_size = input.shape
@@ -1137,6 +1167,11 @@ def quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s_per_tensor_meta(
11371167
out_multiplier: int,
11381168
out_shift: int,
11391169
) -> torch.Tensor:
1170+
assert (
1171+
input.dtype == torch.int8
1172+
and weight.dtype == torch.int8
1173+
and bias.dtype == torch.int32
1174+
)
11401175
out_channels, *kernel_size, _ = weight.shape
11411176

11421177
in_size = input.shape
@@ -1181,6 +1216,11 @@ def quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor_meta(
11811216
out_multiplier: int,
11821217
out_shift: int,
11831218
) -> torch.Tensor:
1219+
assert (
1220+
input.dtype == torch.uint8
1221+
and weight.dtype == torch.uint8
1222+
and bias.dtype == torch.int32
1223+
)
11841224
out_channels, *kernel_size, _ = weight.shape
11851225

11861226
in_size = input.shape
@@ -1225,6 +1265,11 @@ def quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor_meta(
12251265
out_multiplier: int,
12261266
out_shift: int,
12271267
) -> torch.Tensor:
1268+
assert (
1269+
input.dtype == torch.int8
1270+
and weight.dtype == torch.int8
1271+
and bias.dtype == torch.int32
1272+
)
12281273
out_channels, _, *kernel_size = weight.shape
12291274

12301275
in_size = input.shape
@@ -1269,6 +1314,11 @@ def quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor_meta(
12691314
out_multiplier: int,
12701315
out_shift: int,
12711316
) -> torch.Tensor:
1317+
assert (
1318+
input.dtype == torch.uint8
1319+
and weight.dtype == torch.uint8
1320+
and bias.dtype == torch.int32
1321+
)
12721322
out_channels, _, *kernel_size = weight.shape
12731323

12741324
in_size = input.shape
@@ -1313,6 +1363,11 @@ def quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor_meta(
13131363
out_multiplier: int,
13141364
out_shift: int,
13151365
) -> torch.Tensor:
1366+
assert (
1367+
input.dtype == torch.int8
1368+
and weight.dtype == torch.int8
1369+
and bias.dtype == torch.int32
1370+
)
13161371
out_channels, *kernel_size, _ = weight.shape
13171372

13181373
in_size = input.shape
@@ -1357,6 +1412,11 @@ def quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor_meta(
13571412
out_multiplier: int,
13581413
out_shift: int,
13591414
) -> torch.Tensor:
1415+
assert (
1416+
input.dtype == torch.uint8
1417+
and weight.dtype == torch.uint8
1418+
and bias.dtype == torch.int32
1419+
)
13601420
out_channels, *kernel_size, _ = weight.shape
13611421

13621422
in_size = input.shape

backends/cadence/aot/ref_implementations.py

Lines changed: 140 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
# pyre-strict
88

99

10-
from typing import Optional
10+
from typing import Callable, Optional
1111

1212
import torch
1313

@@ -481,6 +481,145 @@ def quantized_conv_nhwc_per_tensor(
481481
)
482482

483483

484+
def quantized_conv_variant(
485+
layout: str,
486+
input_dtype: torch.dtype,
487+
weight_dtype: torch.dtype,
488+
) -> Callable[[Callable[..., torch.Tensor]], Callable[..., torch.Tensor]]:
489+
"""Create a quantized conv variant with type checking."""
490+
491+
def decorator(_: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]:
492+
def variant(
493+
input_tensor: torch.Tensor,
494+
weight: torch.Tensor,
495+
bias: torch.Tensor,
496+
stride: tuple[int, int],
497+
padding: tuple[int, int],
498+
dilation: tuple[int, int],
499+
groups: int,
500+
in_zero_point: int,
501+
weight_zero_point: int,
502+
bias_scale: float,
503+
output_scale: float,
504+
output_zero_point: int,
505+
out_multiplier: int,
506+
out_shift: int,
507+
) -> torch.Tensor:
508+
assert (
509+
input_tensor.dtype == input_dtype
510+
), f"Expected input dtype {input_dtype}, got {input_tensor.dtype}"
511+
assert (
512+
weight.dtype == weight_dtype
513+
), f"Expected weight dtype {weight_dtype}, got {weight.dtype}"
514+
515+
assert (
516+
bias.dtype == torch.int32
517+
), f"Expected bias dtype int32, got {bias.dtype}"
518+
519+
# Call the appropriate base function
520+
match layout:
521+
case "nchw":
522+
return quantized_conv_nchw_per_tensor(
523+
input_tensor,
524+
weight,
525+
bias,
526+
stride,
527+
padding,
528+
dilation,
529+
groups,
530+
in_zero_point,
531+
weight_zero_point,
532+
bias_scale,
533+
output_scale,
534+
output_zero_point,
535+
out_multiplier,
536+
out_shift,
537+
)
538+
case "nhwc":
539+
return quantized_conv_nhwc_per_tensor(
540+
input_tensor,
541+
weight,
542+
bias,
543+
stride,
544+
padding,
545+
dilation,
546+
groups,
547+
in_zero_point,
548+
weight_zero_point,
549+
bias_scale,
550+
output_scale,
551+
output_zero_point,
552+
out_multiplier,
553+
out_shift,
554+
)
555+
case _:
556+
raise ValueError(f"Unknown layout {layout}")
557+
558+
return variant
559+
560+
return decorator
561+
562+
563+
@impl(m, "quantized_conv_nchw_asym8sxsym8s_asym8s_per_tensor")
564+
@quantized_conv_variant("nchw", torch.int8, torch.int8)
565+
def quantized_conv_nchw_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ...
566+
567+
568+
@impl(m, "quantized_conv_nchw_asym8uxsym8u_asym8u_per_tensor")
569+
@quantized_conv_variant("nchw", torch.uint8, torch.uint8)
570+
def quantized_conv_nchw_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ...
571+
572+
573+
@impl(m, "quantized_conv_nhwc_asym8sxsym8s_asym8s_per_tensor")
574+
@quantized_conv_variant("nhwc", torch.int8, torch.int8)
575+
def quantized_conv_nhwc_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ...
576+
577+
578+
@impl(m, "quantized_conv_nhwc_asym8uxsym8u_asym8u_per_tensor")
579+
@quantized_conv_variant("nhwc", torch.uint8, torch.uint8)
580+
def quantized_conv_nhwc_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ...
581+
582+
583+
@impl(m, "quantized_conv_nchw_dilated_asym8sxsym8s_asym8s_per_tensor")
584+
@quantized_conv_variant("nchw", torch.int8, torch.int8)
585+
def quantized_conv_nchw_dilated_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ...
586+
587+
588+
@impl(m, "quantized_conv_nchw_dilated_asym8uxsym8u_asym8u_per_tensor")
589+
@quantized_conv_variant("nchw", torch.uint8, torch.uint8)
590+
def quantized_conv_nchw_dilated_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ...
591+
592+
593+
@impl(m, "quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s_per_tensor")
594+
@quantized_conv_variant("nhwc", torch.int8, torch.int8)
595+
def quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ...
596+
597+
598+
@impl(m, "quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor")
599+
@quantized_conv_variant("nhwc", torch.uint8, torch.uint8)
600+
def quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ...
601+
602+
603+
@impl(m, "quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor")
604+
@quantized_conv_variant("nchw", torch.int8, torch.int8)
605+
def quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ...
606+
607+
608+
@impl(m, "quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor")
609+
@quantized_conv_variant("nchw", torch.uint8, torch.uint8)
610+
def quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ...
611+
612+
613+
@impl(m, "quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor")
614+
@quantized_conv_variant("nhwc", torch.int8, torch.int8)
615+
def quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ...
616+
617+
618+
@impl(m, "quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor")
619+
@quantized_conv_variant("nhwc", torch.uint8, torch.uint8)
620+
def quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ...
621+
622+
484623
@impl(m, "quantized_relu")
485624
def quantized_relu(
486625
X: torch.Tensor,

0 commit comments

Comments
 (0)