@@ -64,9 +64,9 @@ def quantize_per_tensor(
6464 f"Unsupported dtype to quantize to. Supported dtypes must be one of { supported_quant_types } "
6565 )
6666
67- dequantized = torch .round (input_tensor * scale + zero_point ).to (dtype )
67+ quantized = torch .round (input_tensor * scale + zero_point ).to (dtype )
6868 return torch .max (
69- torch .min (dequantized , torch .tensor (quant_max )),
69+ torch .min (quantized , torch .tensor (quant_max )),
7070 torch .tensor (quant_min ),
7171 )
7272
@@ -247,12 +247,12 @@ def quantized_linear(
247247 ).reshape (* leading_dims , N )
248248
249249
250- @impl (m , "quantized_layer_norm_per_tensor " )
250+ @impl (m , "quantized_layer_norm.per_tensor " )
251251def quantized_layer_norm_per_tensor (
252252 input_tensor : torch .Tensor ,
253253 X_scale : float ,
254254 X_zero_point : int ,
255- normalized_shape : int ,
255+ normalized_shape : list [ int ] ,
256256 weight : torch .Tensor ,
257257 bias : torch .Tensor ,
258258 eps : float ,
@@ -283,7 +283,7 @@ def quantized_layer_norm_per_tensor(
283283 input_tensor , X_scale , X_zero_point , - 128 , 127 , torch .float32
284284 )
285285 out = torch .nn .functional .layer_norm (
286- float_input_tensor , ( normalized_shape ,) , weight , bias , eps = eps
286+ float_input_tensor , normalized_shape , weight , bias , eps = eps
287287 )
288288
289289 return quantize_per_tensor (
@@ -365,7 +365,7 @@ def quantized_conv_per_tensor(
365365 )
366366
367367
368- @impl (m , "quantized_conv_nchw_per_tensor " )
368+ @impl (m , "quantized_conv_nchw.per_tensor " )
369369def quantized_conv_nchw_per_tensor (
370370 input_tensor : torch .Tensor ,
371371 weight : torch .Tensor ,
@@ -421,7 +421,7 @@ def quantized_conv_nchw_per_tensor(
421421 )
422422
423423
424- @impl (m , "quantized_conv_nhwc_per_tensor " )
424+ @impl (m , "quantized_conv_nhwc.per_tensor " )
425425def quantized_conv_nhwc_per_tensor (
426426 input_tensor : torch .Tensor ,
427427 weight : torch .Tensor ,
@@ -558,62 +558,62 @@ def variant(
558558 return decorator
559559
560560
561- @impl (m , "quantized_conv_nchw_asym8sxsym8s_asym8s_per_tensor " )
561+ @impl (m , "quantized_conv_nchw_asym8sxsym8s_asym8s.per_tensor " )
562562@quantized_conv_variant ("nchw" , torch .int8 , torch .int8 )
563563def quantized_conv_nchw_asym8sxsym8s_asym8s_per_tensor () -> torch .Tensor : ...
564564
565565
566- @impl (m , "quantized_conv_nchw_asym8uxsym8u_asym8u_per_tensor " )
566+ @impl (m , "quantized_conv_nchw_asym8uxsym8u_asym8u.per_tensor " )
567567@quantized_conv_variant ("nchw" , torch .uint8 , torch .uint8 )
568568def quantized_conv_nchw_asym8uxsym8u_asym8u_per_tensor () -> torch .Tensor : ...
569569
570570
571- @impl (m , "quantized_conv_nhwc_asym8sxsym8s_asym8s_per_tensor " )
571+ @impl (m , "quantized_conv_nhwc_asym8sxsym8s_asym8s.per_tensor " )
572572@quantized_conv_variant ("nhwc" , torch .int8 , torch .int8 )
573573def quantized_conv_nhwc_asym8sxsym8s_asym8s_per_tensor () -> torch .Tensor : ...
574574
575575
576- @impl (m , "quantized_conv_nhwc_asym8uxsym8u_asym8u_per_tensor " )
576+ @impl (m , "quantized_conv_nhwc_asym8uxsym8u_asym8u.per_tensor " )
577577@quantized_conv_variant ("nhwc" , torch .uint8 , torch .uint8 )
578578def quantized_conv_nhwc_asym8uxsym8u_asym8u_per_tensor () -> torch .Tensor : ...
579579
580580
581- @impl (m , "quantized_conv_nchw_dilated_asym8sxsym8s_asym8s_per_tensor " )
581+ @impl (m , "quantized_conv_nchw_dilated_asym8sxsym8s_asym8s.per_tensor " )
582582@quantized_conv_variant ("nchw" , torch .int8 , torch .int8 )
583583def quantized_conv_nchw_dilated_asym8sxsym8s_asym8s_per_tensor () -> torch .Tensor : ...
584584
585585
586- @impl (m , "quantized_conv_nchw_dilated_asym8uxsym8u_asym8u_per_tensor " )
586+ @impl (m , "quantized_conv_nchw_dilated_asym8uxsym8u_asym8u.per_tensor " )
587587@quantized_conv_variant ("nchw" , torch .uint8 , torch .uint8 )
588588def quantized_conv_nchw_dilated_asym8uxsym8u_asym8u_per_tensor () -> torch .Tensor : ...
589589
590590
591- @impl (m , "quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s_per_tensor " )
591+ @impl (m , "quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s.per_tensor " )
592592@quantized_conv_variant ("nhwc" , torch .int8 , torch .int8 )
593593def quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s_per_tensor () -> torch .Tensor : ...
594594
595595
596- @impl (m , "quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor " )
596+ @impl (m , "quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u.per_tensor " )
597597@quantized_conv_variant ("nhwc" , torch .uint8 , torch .uint8 )
598598def quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor () -> torch .Tensor : ...
599599
600600
601- @impl (m , "quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor " )
601+ @impl (m , "quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s.per_tensor " )
602602@quantized_conv_variant ("nchw" , torch .int8 , torch .int8 )
603603def quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor () -> torch .Tensor : ...
604604
605605
606- @impl (m , "quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor " )
606+ @impl (m , "quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u.per_tensor " )
607607@quantized_conv_variant ("nchw" , torch .uint8 , torch .uint8 )
608608def quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor () -> torch .Tensor : ...
609609
610610
611- @impl (m , "quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor " )
611+ @impl (m , "quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s.per_tensor " )
612612@quantized_conv_variant ("nhwc" , torch .int8 , torch .int8 )
613613def quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor () -> torch .Tensor : ...
614614
615615
616- @impl (m , "quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor " )
616+ @impl (m , "quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u.per_tensor " )
617617@quantized_conv_variant ("nhwc" , torch .uint8 , torch .uint8 )
618618def quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor () -> torch .Tensor : ...
619619
0 commit comments