Skip to content

Commit 7c4efe2

Browse files
Andrew Grebenisanfacebook-github-bot
authored andcommitted
Update channels last python reference to not use memory_format=channels_last (#14035)
Summary: The default overload of custom channels last assumes that inputs and weights are permuted and contiguous in memory. Differential Revision: D81842686
1 parent cb943e6 commit 7c4efe2

File tree

2 files changed

+11
-6
lines changed

2 files changed

+11
-6
lines changed

backends/cadence/aot/ref_implementations.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -457,9 +457,12 @@ def quantized_conv_nhwc_per_tensor(
457457
- out_multiplier (int): Unused
458458
- out_shift (int): Unused
459459
"""
460-
461-
if not input_tensor.is_contiguous(memory_format=torch.channels_last):
462-
raise ValueError("Input tensor must be in NHWC format")
460+
assert input_tensor.is_contiguous(memory_format=torch.contiguous_format)
461+
assert weight.is_contiguous(memory_format=torch.contiguous_format)
462+
input_tensor = torch.permute(input_tensor, (0, -1, 1, 2)).to(
463+
memory_format=torch.channels_last
464+
)
465+
weight = torch.permute(weight, (0, -1, 1, 2))
463466

464467
return quantized_conv_per_tensor(
465468
input_tensor,

backends/cadence/aot/tests/test_ref_implementations.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -689,7 +689,9 @@ def test_quantized_conv_per_tensor(
689689
if len(input_tensor.shape) == 3 and memory_format == torch.channels_last:
690690
self.fail("Channels last format is not supported for 3D input tensors")
691691

692-
input_tensor = input_tensor.to(memory_format=memory_format)
692+
if memory_format == torch.channels_last:
693+
input_tensor = torch.permute(input_tensor, (0, 2, 3, 1)).contiguous()
694+
weight = torch.permute(weight, (0, 2, 3, 1)).contiguous()
693695

694696
convs = [
695697
(
@@ -701,7 +703,7 @@ def test_quantized_conv_per_tensor(
701703

702704
optimized_convs = []
703705
if input_tensor.dtype == torch.int8 and weight.dtype == torch.int8:
704-
if input_tensor.is_contiguous(memory_format=torch.contiguous_format):
706+
if memory_format == torch.contiguous_format:
705707
optimized_convs = [
706708
torch.ops.cadence.quantized_conv_nchw_asym8sxsym8s_asym8s.per_tensor,
707709
torch.ops.cadence.quantized_conv_nchw_dilated_asym8sxsym8s_asym8s.per_tensor,
@@ -715,7 +717,7 @@ def test_quantized_conv_per_tensor(
715717
torch.ops.cadence.quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s.per_tensor,
716718
]
717719
elif input_tensor.dtype == torch.uint8 and weight.dtype == torch.uint8:
718-
if input_tensor.is_contiguous(memory_format=torch.contiguous_format):
720+
if memory_format == torch.contiguous_format:
719721
optimized_convs = [
720722
torch.ops.cadence.quantized_conv_nchw_asym8uxsym8u_asym8u.per_tensor,
721723
torch.ops.cadence.quantized_conv_nchw_dilated_asym8uxsym8u_asym8u.per_tensor,

0 commit comments

Comments
 (0)