Skip to content

Commit 144f679

Browse files
Andrew Grebenisanfacebook-github-bot
authored andcommitted
Ensure we can call custom ops from torch cadence lib (#14034)
Summary: Fixes mismatches between op registration names and implementation names, fixes some type issues in tests where unexpected types are passed in given the op definition. Also fixes an incorrect layernorm meta op (normalized_shape should be list, not int). Tests corrected as well. Tests now use the torch cadence custom op library. Reviewed By: hsharma35 Differential Revision: D81738196
1 parent 57173d9 commit 144f679

File tree

4 files changed

+67
-88
lines changed

4 files changed

+67
-88
lines changed

backends/cadence/aot/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -614,6 +614,7 @@ python_unittest(
614614
typing = True,
615615
deps = [
616616
":typing_stubs",
617+
"//executorch/backends/cadence/aot:ops_registrations",
617618
"//executorch/backends/cadence/aot:ref_implementations",
618619
"//caffe2:torch",
619620
]

backends/cadence/aot/ops_registrations.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1449,7 +1449,7 @@ def quantized_layer_norm_meta(
14491449
input: torch.Tensor,
14501450
X_scale: torch.Tensor,
14511451
X_zero_point: torch.Tensor,
1452-
normalized_shape: int,
1452+
normalized_shape: list[int],
14531453
weight: torch.Tensor,
14541454
bias: torch.Tensor,
14551455
eps: float,
@@ -1464,7 +1464,7 @@ def quantized_layer_norm_per_tensor_meta(
14641464
input: torch.Tensor,
14651465
X_scale: float,
14661466
X_zero_point: int,
1467-
normalized_shape: int,
1467+
normalized_shape: list[int],
14681468
weight: torch.Tensor,
14691469
bias: torch.Tensor,
14701470
eps: float,

backends/cadence/aot/ref_implementations.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,9 @@ def quantize_per_tensor(
6464
f"Unsupported dtype to quantize to. Supported dtypes must be one of {supported_quant_types}"
6565
)
6666

67-
dequantized = torch.round(input_tensor * scale + zero_point).to(dtype)
67+
quantized = torch.round(input_tensor * scale + zero_point).to(dtype)
6868
return torch.max(
69-
torch.min(dequantized, torch.tensor(quant_max)),
69+
torch.min(quantized, torch.tensor(quant_max)),
7070
torch.tensor(quant_min),
7171
)
7272

@@ -247,12 +247,12 @@ def quantized_linear(
247247
).reshape(*leading_dims, N)
248248

249249

250-
@impl(m, "quantized_layer_norm_per_tensor")
250+
@impl(m, "quantized_layer_norm.per_tensor")
251251
def quantized_layer_norm_per_tensor(
252252
input_tensor: torch.Tensor,
253253
X_scale: float,
254254
X_zero_point: int,
255-
normalized_shape: int,
255+
normalized_shape: list[int],
256256
weight: torch.Tensor,
257257
bias: torch.Tensor,
258258
eps: float,
@@ -283,7 +283,7 @@ def quantized_layer_norm_per_tensor(
283283
input_tensor, X_scale, X_zero_point, -128, 127, torch.float32
284284
)
285285
out = torch.nn.functional.layer_norm(
286-
float_input_tensor, (normalized_shape,), weight, bias, eps=eps
286+
float_input_tensor, normalized_shape, weight, bias, eps=eps
287287
)
288288

289289
return quantize_per_tensor(
@@ -365,7 +365,7 @@ def quantized_conv_per_tensor(
365365
)
366366

367367

368-
@impl(m, "quantized_conv_nchw_per_tensor")
368+
@impl(m, "quantized_conv_nchw.per_tensor")
369369
def quantized_conv_nchw_per_tensor(
370370
input_tensor: torch.Tensor,
371371
weight: torch.Tensor,
@@ -421,7 +421,7 @@ def quantized_conv_nchw_per_tensor(
421421
)
422422

423423

424-
@impl(m, "quantized_conv_nhwc_per_tensor")
424+
@impl(m, "quantized_conv_nhwc.per_tensor")
425425
def quantized_conv_nhwc_per_tensor(
426426
input_tensor: torch.Tensor,
427427
weight: torch.Tensor,
@@ -558,62 +558,62 @@ def variant(
558558
return decorator
559559

560560

561-
@impl(m, "quantized_conv_nchw_asym8sxsym8s_asym8s_per_tensor")
561+
@impl(m, "quantized_conv_nchw_asym8sxsym8s_asym8s.per_tensor")
562562
@quantized_conv_variant("nchw", torch.int8, torch.int8)
563563
def quantized_conv_nchw_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ...
564564

565565

566-
@impl(m, "quantized_conv_nchw_asym8uxsym8u_asym8u_per_tensor")
566+
@impl(m, "quantized_conv_nchw_asym8uxsym8u_asym8u.per_tensor")
567567
@quantized_conv_variant("nchw", torch.uint8, torch.uint8)
568568
def quantized_conv_nchw_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ...
569569

570570

571-
@impl(m, "quantized_conv_nhwc_asym8sxsym8s_asym8s_per_tensor")
571+
@impl(m, "quantized_conv_nhwc_asym8sxsym8s_asym8s.per_tensor")
572572
@quantized_conv_variant("nhwc", torch.int8, torch.int8)
573573
def quantized_conv_nhwc_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ...
574574

575575

576-
@impl(m, "quantized_conv_nhwc_asym8uxsym8u_asym8u_per_tensor")
576+
@impl(m, "quantized_conv_nhwc_asym8uxsym8u_asym8u.per_tensor")
577577
@quantized_conv_variant("nhwc", torch.uint8, torch.uint8)
578578
def quantized_conv_nhwc_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ...
579579

580580

581-
@impl(m, "quantized_conv_nchw_dilated_asym8sxsym8s_asym8s_per_tensor")
581+
@impl(m, "quantized_conv_nchw_dilated_asym8sxsym8s_asym8s.per_tensor")
582582
@quantized_conv_variant("nchw", torch.int8, torch.int8)
583583
def quantized_conv_nchw_dilated_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ...
584584

585585

586-
@impl(m, "quantized_conv_nchw_dilated_asym8uxsym8u_asym8u_per_tensor")
586+
@impl(m, "quantized_conv_nchw_dilated_asym8uxsym8u_asym8u.per_tensor")
587587
@quantized_conv_variant("nchw", torch.uint8, torch.uint8)
588588
def quantized_conv_nchw_dilated_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ...
589589

590590

591-
@impl(m, "quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s_per_tensor")
591+
@impl(m, "quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s.per_tensor")
592592
@quantized_conv_variant("nhwc", torch.int8, torch.int8)
593593
def quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ...
594594

595595

596-
@impl(m, "quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor")
596+
@impl(m, "quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u.per_tensor")
597597
@quantized_conv_variant("nhwc", torch.uint8, torch.uint8)
598598
def quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ...
599599

600600

601-
@impl(m, "quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor")
601+
@impl(m, "quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s.per_tensor")
602602
@quantized_conv_variant("nchw", torch.int8, torch.int8)
603603
def quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ...
604604

605605

606-
@impl(m, "quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor")
606+
@impl(m, "quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u.per_tensor")
607607
@quantized_conv_variant("nchw", torch.uint8, torch.uint8)
608608
def quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ...
609609

610610

611-
@impl(m, "quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor")
611+
@impl(m, "quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s.per_tensor")
612612
@quantized_conv_variant("nhwc", torch.int8, torch.int8)
613613
def quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ...
614614

615615

616-
@impl(m, "quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor")
616+
@impl(m, "quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u.per_tensor")
617617
@quantized_conv_variant("nhwc", torch.uint8, torch.uint8)
618618
def quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ...
619619

0 commit comments

Comments
 (0)