Skip to content

Commit 53a2b15

Browse files
Andrew Grebenisanfacebook-github-bot
authored andcommitted
Ensure we can call custom ops from torch cadence lib (pytorch#14034)
Summary: Fixes mismatches between op registration names and implementation names, fixes some type issues in tests where unexpected types are passed in given the op definition. Also fixes an incorrect layernorm meta op (normalized_shape should be list, not int). Tests corrected as well. Tests now use the torch cadence custom op library. Reviewed By: hsharma35 Differential Revision: D81738196
1 parent d3c84b7 commit 53a2b15

File tree

4 files changed

+67
-90
lines changed

4 files changed

+67
-90
lines changed

backends/cadence/aot/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -615,6 +615,7 @@ python_unittest(
615615
typing = True,
616616
deps = [
617617
":typing_stubs",
618+
"//executorch/backends/cadence/aot:ops_registrations",
618619
"//executorch/backends/cadence/aot:ref_implementations",
619620
"//caffe2:torch",
620621
]

backends/cadence/aot/ops_registrations.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1449,7 +1449,7 @@ def quantized_layer_norm_meta(
14491449
input: torch.Tensor,
14501450
X_scale: torch.Tensor,
14511451
X_zero_point: torch.Tensor,
1452-
normalized_shape: int,
1452+
normalized_shape: list[int],
14531453
weight: torch.Tensor,
14541454
bias: torch.Tensor,
14551455
eps: float,
@@ -1464,7 +1464,7 @@ def quantized_layer_norm_per_tensor_meta(
14641464
input: torch.Tensor,
14651465
X_scale: float,
14661466
X_zero_point: int,
1467-
normalized_shape: int,
1467+
normalized_shape: list[int],
14681468
weight: torch.Tensor,
14691469
bias: torch.Tensor,
14701470
eps: float,

backends/cadence/aot/ref_implementations.py

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,6 @@
2323
ScalarType.QINT32: torch.qint32,
2424
}
2525

26-
_Number = bool | int | float
27-
2826

2927
@impl(m, "quantize_per_tensor")
3028
def quantize_per_tensor(
@@ -66,9 +64,9 @@ def quantize_per_tensor(
6664
f"Unsupported dtype to quantize to. Supported dtypes must be one of {supported_quant_types}"
6765
)
6866

69-
dequantized = torch.round(input_tensor * scale + zero_point).to(dtype)
67+
quantized = torch.round(input_tensor * scale + zero_point).to(dtype)
7068
return torch.max(
71-
torch.min(dequantized, torch.tensor(quant_max)),
69+
torch.min(quantized, torch.tensor(quant_max)),
7270
torch.tensor(quant_min),
7371
)
7472

@@ -249,12 +247,12 @@ def quantized_linear(
249247
).reshape(*leading_dims, N)
250248

251249

252-
@impl(m, "quantized_layer_norm_per_tensor")
250+
@impl(m, "quantized_layer_norm.per_tensor")
253251
def quantized_layer_norm_per_tensor(
254252
input_tensor: torch.Tensor,
255253
X_scale: float,
256254
X_zero_point: int,
257-
normalized_shape: int,
255+
normalized_shape: list[int],
258256
weight: torch.Tensor,
259257
bias: torch.Tensor,
260258
eps: float,
@@ -285,7 +283,7 @@ def quantized_layer_norm_per_tensor(
285283
input_tensor, X_scale, X_zero_point, -128, 127, torch.float32
286284
)
287285
out = torch.nn.functional.layer_norm(
288-
float_input_tensor, (normalized_shape,), weight, bias, eps=eps
286+
float_input_tensor, normalized_shape, weight, bias, eps=eps
289287
)
290288

291289
return quantize_per_tensor(
@@ -367,7 +365,7 @@ def quantized_conv_per_tensor(
367365
)
368366

369367

370-
@impl(m, "quantized_conv_nchw_per_tensor")
368+
@impl(m, "quantized_conv_nchw.per_tensor")
371369
def quantized_conv_nchw_per_tensor(
372370
input_tensor: torch.Tensor,
373371
weight: torch.Tensor,
@@ -423,7 +421,7 @@ def quantized_conv_nchw_per_tensor(
423421
)
424422

425423

426-
@impl(m, "quantized_conv_nhwc_per_tensor")
424+
@impl(m, "quantized_conv_nhwc.per_tensor")
427425
def quantized_conv_nhwc_per_tensor(
428426
input_tensor: torch.Tensor,
429427
weight: torch.Tensor,
@@ -560,62 +558,62 @@ def variant(
560558
return decorator
561559

562560

563-
@impl(m, "quantized_conv_nchw_asym8sxsym8s_asym8s_per_tensor")
561+
@impl(m, "quantized_conv_nchw_asym8sxsym8s_asym8s.per_tensor")
564562
@quantized_conv_variant("nchw", torch.int8, torch.int8)
565563
def quantized_conv_nchw_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ...
566564

567565

568-
@impl(m, "quantized_conv_nchw_asym8uxsym8u_asym8u_per_tensor")
566+
@impl(m, "quantized_conv_nchw_asym8uxsym8u_asym8u.per_tensor")
569567
@quantized_conv_variant("nchw", torch.uint8, torch.uint8)
570568
def quantized_conv_nchw_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ...
571569

572570

573-
@impl(m, "quantized_conv_nhwc_asym8sxsym8s_asym8s_per_tensor")
571+
@impl(m, "quantized_conv_nhwc_asym8sxsym8s_asym8s.per_tensor")
574572
@quantized_conv_variant("nhwc", torch.int8, torch.int8)
575573
def quantized_conv_nhwc_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ...
576574

577575

578-
@impl(m, "quantized_conv_nhwc_asym8uxsym8u_asym8u_per_tensor")
576+
@impl(m, "quantized_conv_nhwc_asym8uxsym8u_asym8u.per_tensor")
579577
@quantized_conv_variant("nhwc", torch.uint8, torch.uint8)
580578
def quantized_conv_nhwc_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ...
581579

582580

583-
@impl(m, "quantized_conv_nchw_dilated_asym8sxsym8s_asym8s_per_tensor")
581+
@impl(m, "quantized_conv_nchw_dilated_asym8sxsym8s_asym8s.per_tensor")
584582
@quantized_conv_variant("nchw", torch.int8, torch.int8)
585583
def quantized_conv_nchw_dilated_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ...
586584

587585

588-
@impl(m, "quantized_conv_nchw_dilated_asym8uxsym8u_asym8u_per_tensor")
586+
@impl(m, "quantized_conv_nchw_dilated_asym8uxsym8u_asym8u.per_tensor")
589587
@quantized_conv_variant("nchw", torch.uint8, torch.uint8)
590588
def quantized_conv_nchw_dilated_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ...
591589

592590

593-
@impl(m, "quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s_per_tensor")
591+
@impl(m, "quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s.per_tensor")
594592
@quantized_conv_variant("nhwc", torch.int8, torch.int8)
595593
def quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ...
596594

597595

598-
@impl(m, "quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor")
596+
@impl(m, "quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u.per_tensor")
599597
@quantized_conv_variant("nhwc", torch.uint8, torch.uint8)
600598
def quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ...
601599

602600

603-
@impl(m, "quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor")
601+
@impl(m, "quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s.per_tensor")
604602
@quantized_conv_variant("nchw", torch.int8, torch.int8)
605603
def quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ...
606604

607605

608-
@impl(m, "quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor")
606+
@impl(m, "quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u.per_tensor")
609607
@quantized_conv_variant("nchw", torch.uint8, torch.uint8)
610608
def quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ...
611609

612610

613-
@impl(m, "quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor")
611+
@impl(m, "quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s.per_tensor")
614612
@quantized_conv_variant("nhwc", torch.int8, torch.int8)
615613
def quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ...
616614

617615

618-
@impl(m, "quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor")
616+
@impl(m, "quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u.per_tensor")
619617
@quantized_conv_variant("nhwc", torch.uint8, torch.uint8)
620618
def quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ...
621619

0 commit comments

Comments
 (0)