Skip to content

Commit 505067c

Browse files
Andrew Grebenisanfacebook-github-bot
authored andcommitted
Add uint8/int8 specializations for conv per tensor (#14033)
Summary: Continued support of adding custom Cadence python references Reviewed By: hsharma35 Differential Revision: D81720359
1 parent 900a8fe commit 505067c

File tree

3 files changed

+285
-40
lines changed

3 files changed

+285
-40
lines changed

backends/cadence/aot/ops_registrations.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -873,6 +873,11 @@ def quantized_conv_nchw_asym8sxsym8s_asym8s_per_tensor_meta(
873873
out_multiplier: int,
874874
out_shift: int,
875875
) -> torch.Tensor:
876+
assert (
877+
input.dtype == torch.int8
878+
and weight.dtype == torch.int8
879+
and bias.dtype == torch.int32
880+
)
876881
out_channels, _, *kernel_size = weight.shape
877882

878883
in_size = input.shape
@@ -917,6 +922,11 @@ def quantized_conv_nchw_asym8uxsym8u_asym8u_per_tensor_meta(
917922
out_multiplier: int,
918923
out_shift: int,
919924
) -> torch.Tensor:
925+
assert (
926+
input.dtype == torch.uint8
927+
and weight.dtype == torch.uint8
928+
and bias.dtype == torch.int32
929+
)
920930
out_channels, _, *kernel_size = weight.shape
921931

922932
in_size = input.shape
@@ -961,6 +971,11 @@ def quantized_conv_nhwc_asym8sxsym8s_asym8s_per_tensor_meta(
961971
out_multiplier: int,
962972
out_shift: int,
963973
) -> torch.Tensor:
974+
assert (
975+
input.dtype == torch.int8
976+
and weight.dtype == torch.int8
977+
and bias.dtype == torch.int32
978+
)
964979
out_channels, *kernel_size, _ = weight.shape
965980

966981
in_size = input.shape
@@ -1005,6 +1020,11 @@ def quantized_conv_nhwc_asym8uxsym8u_asym8u_per_tensor_meta(
10051020
out_multiplier: int,
10061021
out_shift: int,
10071022
) -> torch.Tensor:
1023+
assert (
1024+
input.dtype == torch.uint8
1025+
and weight.dtype == torch.uint8
1026+
and bias.dtype == torch.int32
1027+
)
10081028
out_channels, *kernel_size, _ = weight.shape
10091029

10101030
in_size = input.shape
@@ -1049,6 +1069,11 @@ def quantized_conv_nchw_dilated_asym8sxsym8s_asym8s_per_tensor_meta(
10491069
out_multiplier: int,
10501070
out_shift: int,
10511071
) -> torch.Tensor:
1072+
assert (
1073+
input.dtype == torch.int8
1074+
and weight.dtype == torch.int8
1075+
and bias.dtype == torch.int32
1076+
)
10521077
out_channels, _, *kernel_size = weight.shape
10531078

10541079
in_size = input.shape
@@ -1093,6 +1118,11 @@ def quantized_conv_nchw_dilated_asym8uxsym8u_asym8u_per_tensor_meta(
10931118
out_multiplier: int,
10941119
out_shift: int,
10951120
) -> torch.Tensor:
1121+
assert (
1122+
input.dtype == torch.uint8
1123+
and weight.dtype == torch.uint8
1124+
and bias.dtype == torch.int32
1125+
)
10961126
out_channels, _, *kernel_size = weight.shape
10971127

10981128
in_size = input.shape
@@ -1137,6 +1167,11 @@ def quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s_per_tensor_meta(
11371167
out_multiplier: int,
11381168
out_shift: int,
11391169
) -> torch.Tensor:
1170+
assert (
1171+
input.dtype == torch.int8
1172+
and weight.dtype == torch.int8
1173+
and bias.dtype == torch.int32
1174+
)
11401175
out_channels, *kernel_size, _ = weight.shape
11411176

11421177
in_size = input.shape
@@ -1181,6 +1216,11 @@ def quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor_meta(
11811216
out_multiplier: int,
11821217
out_shift: int,
11831218
) -> torch.Tensor:
1219+
assert (
1220+
input.dtype == torch.uint8
1221+
and weight.dtype == torch.uint8
1222+
and bias.dtype == torch.int32
1223+
)
11841224
out_channels, *kernel_size, _ = weight.shape
11851225

11861226
in_size = input.shape
@@ -1225,6 +1265,11 @@ def quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor_meta(
12251265
out_multiplier: int,
12261266
out_shift: int,
12271267
) -> torch.Tensor:
1268+
assert (
1269+
input.dtype == torch.int8
1270+
and weight.dtype == torch.int8
1271+
and bias.dtype == torch.int32
1272+
)
12281273
out_channels, _, *kernel_size = weight.shape
12291274

12301275
in_size = input.shape
@@ -1269,6 +1314,11 @@ def quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor_meta(
12691314
out_multiplier: int,
12701315
out_shift: int,
12711316
) -> torch.Tensor:
1317+
assert (
1318+
input.dtype == torch.uint8
1319+
and weight.dtype == torch.uint8
1320+
and bias.dtype == torch.int32
1321+
)
12721322
out_channels, _, *kernel_size = weight.shape
12731323

12741324
in_size = input.shape
@@ -1313,6 +1363,11 @@ def quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor_meta(
13131363
out_multiplier: int,
13141364
out_shift: int,
13151365
) -> torch.Tensor:
1366+
assert (
1367+
input.dtype == torch.int8
1368+
and weight.dtype == torch.int8
1369+
and bias.dtype == torch.int32
1370+
)
13161371
out_channels, *kernel_size, _ = weight.shape
13171372

13181373
in_size = input.shape
@@ -1357,6 +1412,11 @@ def quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor_meta(
13571412
out_multiplier: int,
13581413
out_shift: int,
13591414
) -> torch.Tensor:
1415+
assert (
1416+
input.dtype == torch.uint8
1417+
and weight.dtype == torch.uint8
1418+
and bias.dtype == torch.int32
1419+
)
13601420
out_channels, *kernel_size, _ = weight.shape
13611421

13621422
in_size = input.shape

backends/cadence/aot/ref_implementations.py

Lines changed: 140 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
# pyre-strict
88

99

10-
from typing import Optional
10+
from typing import Callable, Optional
1111

1212
import torch
1313

@@ -479,6 +479,145 @@ def quantized_conv_nhwc_per_tensor(
479479
)
480480

481481

482+
def quantized_conv_variant(
483+
layout: str,
484+
input_dtype: torch.dtype,
485+
weight_dtype: torch.dtype,
486+
) -> Callable[[Callable[..., torch.Tensor]], Callable[..., torch.Tensor]]:
487+
"""Create a quantized conv variant with type checking."""
488+
489+
def decorator(_: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]:
490+
def variant(
491+
input_tensor: torch.Tensor,
492+
weight: torch.Tensor,
493+
bias: torch.Tensor,
494+
stride: tuple[int, int],
495+
padding: tuple[int, int],
496+
dilation: tuple[int, int],
497+
groups: int,
498+
in_zero_point: int,
499+
weight_zero_point: int,
500+
bias_scale: float,
501+
output_scale: float,
502+
output_zero_point: int,
503+
out_multiplier: int,
504+
out_shift: int,
505+
) -> torch.Tensor:
506+
assert (
507+
input_tensor.dtype == input_dtype
508+
), f"Expected input dtype {input_dtype}, got {input_tensor.dtype}"
509+
assert (
510+
weight.dtype == weight_dtype
511+
), f"Expected weight dtype {weight_dtype}, got {weight.dtype}"
512+
513+
assert (
514+
bias.dtype == torch.int32
515+
), f"Expected bias dtype int32, got {bias.dtype}"
516+
517+
# Call the appropriate base function
518+
match layout:
519+
case "nchw":
520+
return quantized_conv_nchw_per_tensor(
521+
input_tensor,
522+
weight,
523+
bias,
524+
stride,
525+
padding,
526+
dilation,
527+
groups,
528+
in_zero_point,
529+
weight_zero_point,
530+
bias_scale,
531+
output_scale,
532+
output_zero_point,
533+
out_multiplier,
534+
out_shift,
535+
)
536+
case "nhwc":
537+
return quantized_conv_nhwc_per_tensor(
538+
input_tensor,
539+
weight,
540+
bias,
541+
stride,
542+
padding,
543+
dilation,
544+
groups,
545+
in_zero_point,
546+
weight_zero_point,
547+
bias_scale,
548+
output_scale,
549+
output_zero_point,
550+
out_multiplier,
551+
out_shift,
552+
)
553+
case _:
554+
raise ValueError(f"Unknown layout {layout}")
555+
556+
return variant
557+
558+
return decorator
559+
560+
561+
@impl(m, "quantized_conv_nchw_asym8sxsym8s_asym8s_per_tensor")
562+
@quantized_conv_variant("nchw", torch.int8, torch.int8)
563+
def quantized_conv_nchw_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ...
564+
565+
566+
@impl(m, "quantized_conv_nchw_asym8uxsym8u_asym8u_per_tensor")
567+
@quantized_conv_variant("nchw", torch.uint8, torch.uint8)
568+
def quantized_conv_nchw_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ...
569+
570+
571+
@impl(m, "quantized_conv_nhwc_asym8sxsym8s_asym8s_per_tensor")
572+
@quantized_conv_variant("nhwc", torch.int8, torch.int8)
573+
def quantized_conv_nhwc_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ...
574+
575+
576+
@impl(m, "quantized_conv_nhwc_asym8uxsym8u_asym8u_per_tensor")
577+
@quantized_conv_variant("nhwc", torch.uint8, torch.uint8)
578+
def quantized_conv_nhwc_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ...
579+
580+
581+
@impl(m, "quantized_conv_nchw_dilated_asym8sxsym8s_asym8s_per_tensor")
582+
@quantized_conv_variant("nchw", torch.int8, torch.int8)
583+
def quantized_conv_nchw_dilated_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ...
584+
585+
586+
@impl(m, "quantized_conv_nchw_dilated_asym8uxsym8u_asym8u_per_tensor")
587+
@quantized_conv_variant("nchw", torch.uint8, torch.uint8)
588+
def quantized_conv_nchw_dilated_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ...
589+
590+
591+
@impl(m, "quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s_per_tensor")
592+
@quantized_conv_variant("nhwc", torch.int8, torch.int8)
593+
def quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ...
594+
595+
596+
@impl(m, "quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor")
597+
@quantized_conv_variant("nhwc", torch.uint8, torch.uint8)
598+
def quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ...
599+
600+
601+
@impl(m, "quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor")
602+
@quantized_conv_variant("nchw", torch.int8, torch.int8)
603+
def quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ...
604+
605+
606+
@impl(m, "quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor")
607+
@quantized_conv_variant("nchw", torch.uint8, torch.uint8)
608+
def quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ...
609+
610+
611+
@impl(m, "quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor")
612+
@quantized_conv_variant("nhwc", torch.int8, torch.int8)
613+
def quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ...
614+
615+
616+
@impl(m, "quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor")
617+
@quantized_conv_variant("nhwc", torch.uint8, torch.uint8)
618+
def quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ...
619+
620+
482621
@impl(m, "quantized_relu")
483622
def quantized_relu(
484623
X: torch.Tensor,

0 commit comments

Comments
 (0)