@@ -64,9 +64,9 @@ def quantize_per_tensor(
64
64
f"Unsupported dtype to quantize to. Supported dtypes must be one of { supported_quant_types } "
65
65
)
66
66
67
- dequantized = torch .round (input_tensor * scale + zero_point ).to (dtype )
67
+ quantized = torch .round (input_tensor * scale + zero_point ).to (dtype )
68
68
return torch .max (
69
- torch .min (dequantized , torch .tensor (quant_max )),
69
+ torch .min (quantized , torch .tensor (quant_max )),
70
70
torch .tensor (quant_min ),
71
71
)
72
72
@@ -247,12 +247,12 @@ def quantized_linear(
247
247
).reshape (* leading_dims , N )
248
248
249
249
250
- @impl (m , "quantized_layer_norm_per_tensor " )
250
+ @impl (m , "quantized_layer_norm.per_tensor " )
251
251
def quantized_layer_norm_per_tensor (
252
252
input_tensor : torch .Tensor ,
253
253
X_scale : float ,
254
254
X_zero_point : int ,
255
- normalized_shape : int ,
255
+ normalized_shape : list [ int ] ,
256
256
weight : torch .Tensor ,
257
257
bias : torch .Tensor ,
258
258
eps : float ,
@@ -283,7 +283,7 @@ def quantized_layer_norm_per_tensor(
283
283
input_tensor , X_scale , X_zero_point , - 128 , 127 , torch .float32
284
284
)
285
285
out = torch .nn .functional .layer_norm (
286
- float_input_tensor , ( normalized_shape ,) , weight , bias , eps = eps
286
+ float_input_tensor , normalized_shape , weight , bias , eps = eps
287
287
)
288
288
289
289
return quantize_per_tensor (
@@ -365,7 +365,7 @@ def quantized_conv_per_tensor(
365
365
)
366
366
367
367
368
- @impl (m , "quantized_conv_nchw_per_tensor " )
368
+ @impl (m , "quantized_conv_nchw.per_tensor " )
369
369
def quantized_conv_nchw_per_tensor (
370
370
input_tensor : torch .Tensor ,
371
371
weight : torch .Tensor ,
@@ -421,7 +421,7 @@ def quantized_conv_nchw_per_tensor(
421
421
)
422
422
423
423
424
- @impl (m , "quantized_conv_nhwc_per_tensor " )
424
+ @impl (m , "quantized_conv_nhwc.per_tensor " )
425
425
def quantized_conv_nhwc_per_tensor (
426
426
input_tensor : torch .Tensor ,
427
427
weight : torch .Tensor ,
@@ -558,62 +558,62 @@ def variant(
558
558
return decorator
559
559
560
560
561
- @impl (m , "quantized_conv_nchw_asym8sxsym8s_asym8s_per_tensor " )
561
+ @impl (m , "quantized_conv_nchw_asym8sxsym8s_asym8s.per_tensor " )
562
562
@quantized_conv_variant ("nchw" , torch .int8 , torch .int8 )
563
563
def quantized_conv_nchw_asym8sxsym8s_asym8s_per_tensor () -> torch .Tensor : ...
564
564
565
565
566
- @impl (m , "quantized_conv_nchw_asym8uxsym8u_asym8u_per_tensor " )
566
+ @impl (m , "quantized_conv_nchw_asym8uxsym8u_asym8u.per_tensor " )
567
567
@quantized_conv_variant ("nchw" , torch .uint8 , torch .uint8 )
568
568
def quantized_conv_nchw_asym8uxsym8u_asym8u_per_tensor () -> torch .Tensor : ...
569
569
570
570
571
- @impl (m , "quantized_conv_nhwc_asym8sxsym8s_asym8s_per_tensor " )
571
+ @impl (m , "quantized_conv_nhwc_asym8sxsym8s_asym8s.per_tensor " )
572
572
@quantized_conv_variant ("nhwc" , torch .int8 , torch .int8 )
573
573
def quantized_conv_nhwc_asym8sxsym8s_asym8s_per_tensor () -> torch .Tensor : ...
574
574
575
575
576
- @impl (m , "quantized_conv_nhwc_asym8uxsym8u_asym8u_per_tensor " )
576
+ @impl (m , "quantized_conv_nhwc_asym8uxsym8u_asym8u.per_tensor " )
577
577
@quantized_conv_variant ("nhwc" , torch .uint8 , torch .uint8 )
578
578
def quantized_conv_nhwc_asym8uxsym8u_asym8u_per_tensor () -> torch .Tensor : ...
579
579
580
580
581
- @impl (m , "quantized_conv_nchw_dilated_asym8sxsym8s_asym8s_per_tensor " )
581
+ @impl (m , "quantized_conv_nchw_dilated_asym8sxsym8s_asym8s.per_tensor " )
582
582
@quantized_conv_variant ("nchw" , torch .int8 , torch .int8 )
583
583
def quantized_conv_nchw_dilated_asym8sxsym8s_asym8s_per_tensor () -> torch .Tensor : ...
584
584
585
585
586
- @impl (m , "quantized_conv_nchw_dilated_asym8uxsym8u_asym8u_per_tensor " )
586
+ @impl (m , "quantized_conv_nchw_dilated_asym8uxsym8u_asym8u.per_tensor " )
587
587
@quantized_conv_variant ("nchw" , torch .uint8 , torch .uint8 )
588
588
def quantized_conv_nchw_dilated_asym8uxsym8u_asym8u_per_tensor () -> torch .Tensor : ...
589
589
590
590
591
- @impl (m , "quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s_per_tensor " )
591
+ @impl (m , "quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s.per_tensor " )
592
592
@quantized_conv_variant ("nhwc" , torch .int8 , torch .int8 )
593
593
def quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s_per_tensor () -> torch .Tensor : ...
594
594
595
595
596
- @impl (m , "quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor " )
596
+ @impl (m , "quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u.per_tensor " )
597
597
@quantized_conv_variant ("nhwc" , torch .uint8 , torch .uint8 )
598
598
def quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor () -> torch .Tensor : ...
599
599
600
600
601
- @impl (m , "quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor " )
601
+ @impl (m , "quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s.per_tensor " )
602
602
@quantized_conv_variant ("nchw" , torch .int8 , torch .int8 )
603
603
def quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor () -> torch .Tensor : ...
604
604
605
605
606
- @impl (m , "quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor " )
606
+ @impl (m , "quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u.per_tensor " )
607
607
@quantized_conv_variant ("nchw" , torch .uint8 , torch .uint8 )
608
608
def quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor () -> torch .Tensor : ...
609
609
610
610
611
- @impl (m , "quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor " )
611
+ @impl (m , "quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s.per_tensor " )
612
612
@quantized_conv_variant ("nhwc" , torch .int8 , torch .int8 )
613
613
def quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor () -> torch .Tensor : ...
614
614
615
615
616
- @impl (m , "quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor " )
616
+ @impl (m , "quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u.per_tensor " )
617
617
@quantized_conv_variant ("nhwc" , torch .uint8 , torch .uint8 )
618
618
def quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor () -> torch .Tensor : ...
619
619
0 commit comments