@@ -449,7 +449,7 @@ def test_quantized_layer_norm_per_tensor(
449
449
), # expected_output: [1+2, 2+3, 3+4] / 0.5 = [6, 10, 14]
450
450
memory_format ,
451
451
)
452
- for memory_format in [torch .contiguous_format ]
452
+ for memory_format in [torch .contiguous_format , torch . channels_last ]
453
453
],
454
454
# Test case 5: Multiple output channels
455
455
* [
@@ -686,10 +686,13 @@ def test_quantized_conv_per_tensor(
686
686
) -> None :
687
687
assert memory_format in [torch .contiguous_format , torch .channels_last ]
688
688
689
- if len (input_tensor .shape ) == 3 and memory_format == torch .channels_last :
690
- self .fail ("Channels last format is not supported for 3D input tensors" )
691
-
692
- input_tensor = input_tensor .to (memory_format = memory_format )
689
+ if memory_format == torch .channels_last :
690
+ if input_tensor .ndim == 3 :
691
+ input_tensor = input_tensor .movedim (1 , - 1 )
692
+ weight = weight .movedim (1 , - 1 )
693
+ else :
694
+ input_tensor = input_tensor .movedim (- 3 , - 1 )
695
+ weight = weight .movedim (- 3 , - 1 )
693
696
694
697
convs = [
695
698
(
@@ -701,7 +704,7 @@ def test_quantized_conv_per_tensor(
701
704
702
705
optimized_convs = []
703
706
if input_tensor .dtype == torch .int8 and weight .dtype == torch .int8 :
704
- if input_tensor . is_contiguous ( memory_format = torch .contiguous_format ) :
707
+ if memory_format == torch .contiguous_format :
705
708
optimized_convs = [
706
709
torch .ops .cadence .quantized_conv_nchw_asym8sxsym8s_asym8s .per_tensor ,
707
710
torch .ops .cadence .quantized_conv_nchw_dilated_asym8sxsym8s_asym8s .per_tensor ,
@@ -715,7 +718,7 @@ def test_quantized_conv_per_tensor(
715
718
torch .ops .cadence .quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s .per_tensor ,
716
719
]
717
720
elif input_tensor .dtype == torch .uint8 and weight .dtype == torch .uint8 :
718
- if input_tensor . is_contiguous ( memory_format = torch .contiguous_format ) :
721
+ if memory_format == torch .contiguous_format :
719
722
optimized_convs = [
720
723
torch .ops .cadence .quantized_conv_nchw_asym8uxsym8u_asym8u .per_tensor ,
721
724
torch .ops .cadence .quantized_conv_nchw_dilated_asym8uxsym8u_asym8u .per_tensor ,
@@ -746,7 +749,13 @@ def test_quantized_conv_per_tensor(
746
749
output_zero_point ,
747
750
out_multiplier ,
748
751
out_shift ,
749
- ).to (memory_format = torch .contiguous_format )
752
+ )
753
+
754
+ if memory_format == torch .channels_last :
755
+ if input_tensor .ndim == 3 :
756
+ output = output .movedim (- 1 , 1 )
757
+ else :
758
+ output = output .movedim (- 1 , - 3 )
750
759
751
760
# Verify output properties
752
761
self .assertEqual (output .dtype , dtype , f"Output dtype should be { dtype } " )
0 commit comments