Skip to content

Commit 8a1875e

Browse files
Andrew Grebenisanfacebook-github-bot
authored andcommitted
Update channels last python reference to not use memory_format=channels_last (#14035)
Summary: Our implementation is actually supposed to assume input shapes come in channels last, not relying on torch channels last memory format. Same thing with output shapes. Differential Revision: D81842686
1 parent c3b842f commit 8a1875e

File tree

2 files changed

+36
-11
lines changed

2 files changed

+36
-11
lines changed

backends/cadence/aot/ref_implementations.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -458,10 +458,21 @@ def quantized_conv_nhwc_per_tensor(
458458
- out_shift (int): Unused
459459
"""
460460

461-
if not input_tensor.is_contiguous(memory_format=torch.channels_last):
462-
raise ValueError("Input tensor must be in NHWC format")
461+
# Convert to NCHW format to reuse the existing implementation
462+
conv_is_1d = False
463+
if len(input_tensor.shape) == 3:
464+
conv_is_1d = True
465+
input_tensor = input_tensor.movedim(-1, 1).contiguous()
466+
if len(weight.shape) != 3:
467+
raise ValueError("Weight tensor must be 3D if input is 3D")
468+
weight = weight.movedim(-1, 1).contiguous()
469+
else:
470+
input_tensor = input_tensor.movedim(-1, -3)
471+
if len(weight.shape) != 4:
472+
raise ValueError("Weight tensor must be 4D if input is nd > 3")
473+
weight = torch.permute(weight, (0, -1, 1, 2)).contiguous()
463474

464-
return quantized_conv_per_tensor(
475+
nchw_out = quantized_conv_per_tensor(
465476
input_tensor,
466477
weight,
467478
bias,
@@ -478,6 +489,11 @@ def quantized_conv_nhwc_per_tensor(
478489
out_shift,
479490
)
480491

492+
if conv_is_1d:
493+
return nchw_out.movedim(1, -1).contiguous()
494+
else:
495+
return nchw_out.movedim(-3, -1).contiguous()
496+
481497

482498
def quantized_conv_variant(
483499
layout: str,

backends/cadence/aot/tests/test_ref_implementations.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -449,7 +449,7 @@ def test_quantized_layer_norm_per_tensor(
449449
), # expected_output: [1+2, 2+3, 3+4] / 0.5 = [6, 10, 14]
450450
memory_format,
451451
)
452-
for memory_format in [torch.contiguous_format]
452+
for memory_format in [torch.contiguous_format, torch.channels_last]
453453
],
454454
# Test case 5: Multiple output channels
455455
*[
@@ -686,10 +686,13 @@ def test_quantized_conv_per_tensor(
686686
) -> None:
687687
assert memory_format in [torch.contiguous_format, torch.channels_last]
688688

689-
if len(input_tensor.shape) == 3 and memory_format == torch.channels_last:
690-
self.fail("Channels last format is not supported for 3D input tensors")
691-
692-
input_tensor = input_tensor.to(memory_format=memory_format)
689+
if memory_format == torch.channels_last:
690+
if input_tensor.ndim == 3:
691+
input_tensor = input_tensor.movedim(1, -1)
692+
weight = weight.movedim(1, -1)
693+
else:
694+
input_tensor = input_tensor.movedim(-3, -1)
695+
weight = weight.movedim(-3, -1)
693696

694697
convs = [
695698
(
@@ -701,7 +704,7 @@ def test_quantized_conv_per_tensor(
701704

702705
optimized_convs = []
703706
if input_tensor.dtype == torch.int8 and weight.dtype == torch.int8:
704-
if input_tensor.is_contiguous(memory_format=torch.contiguous_format):
707+
if memory_format == torch.contiguous_format:
705708
optimized_convs = [
706709
torch.ops.cadence.quantized_conv_nchw_asym8sxsym8s_asym8s.per_tensor,
707710
torch.ops.cadence.quantized_conv_nchw_dilated_asym8sxsym8s_asym8s.per_tensor,
@@ -715,7 +718,7 @@ def test_quantized_conv_per_tensor(
715718
torch.ops.cadence.quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s.per_tensor,
716719
]
717720
elif input_tensor.dtype == torch.uint8 and weight.dtype == torch.uint8:
718-
if input_tensor.is_contiguous(memory_format=torch.contiguous_format):
721+
if memory_format == torch.contiguous_format:
719722
optimized_convs = [
720723
torch.ops.cadence.quantized_conv_nchw_asym8uxsym8u_asym8u.per_tensor,
721724
torch.ops.cadence.quantized_conv_nchw_dilated_asym8uxsym8u_asym8u.per_tensor,
@@ -746,7 +749,13 @@ def test_quantized_conv_per_tensor(
746749
output_zero_point,
747750
out_multiplier,
748751
out_shift,
749-
).to(memory_format=torch.contiguous_format)
752+
)
753+
754+
if memory_format == torch.channels_last:
755+
if input_tensor.ndim == 3:
756+
output = output.movedim(-1, 1)
757+
else:
758+
output = output.movedim(-1, -3)
750759

751760
# Verify output properties
752761
self.assertEqual(output.dtype, dtype, f"Output dtype should be {dtype}")

0 commit comments

Comments
 (0)