|
23 | 23 | ScalarType.QINT32: torch.qint32,
|
24 | 24 | }
|
25 | 25 |
|
26 |
| -_Number = bool | int | float |
27 |
| - |
28 | 26 |
|
29 | 27 | @impl(m, "quantize_per_tensor")
|
30 | 28 | def quantize_per_tensor(
|
@@ -66,9 +64,9 @@ def quantize_per_tensor(
|
66 | 64 | f"Unsupported dtype to quantize to. Supported dtypes must be one of {supported_quant_types}"
|
67 | 65 | )
|
68 | 66 |
|
69 |
| - dequantized = torch.round(input_tensor * scale + zero_point).to(dtype) |
| 67 | + quantized = torch.round(input_tensor * scale + zero_point).to(dtype) |
70 | 68 | return torch.max(
|
71 |
| - torch.min(dequantized, torch.tensor(quant_max)), |
| 69 | + torch.min(quantized, torch.tensor(quant_max)), |
72 | 70 | torch.tensor(quant_min),
|
73 | 71 | )
|
74 | 72 |
|
@@ -249,12 +247,12 @@ def quantized_linear(
|
249 | 247 | ).reshape(*leading_dims, N)
|
250 | 248 |
|
251 | 249 |
|
252 |
| -@impl(m, "quantized_layer_norm_per_tensor") |
| 250 | +@impl(m, "quantized_layer_norm.per_tensor") |
253 | 251 | def quantized_layer_norm_per_tensor(
|
254 | 252 | input_tensor: torch.Tensor,
|
255 | 253 | X_scale: float,
|
256 | 254 | X_zero_point: int,
|
257 |
| - normalized_shape: int, |
| 255 | + normalized_shape: list[int], |
258 | 256 | weight: torch.Tensor,
|
259 | 257 | bias: torch.Tensor,
|
260 | 258 | eps: float,
|
@@ -285,7 +283,7 @@ def quantized_layer_norm_per_tensor(
|
285 | 283 | input_tensor, X_scale, X_zero_point, -128, 127, torch.float32
|
286 | 284 | )
|
287 | 285 | out = torch.nn.functional.layer_norm(
|
288 |
| - float_input_tensor, (normalized_shape,), weight, bias, eps=eps |
| 286 | + float_input_tensor, normalized_shape, weight, bias, eps=eps |
289 | 287 | )
|
290 | 288 |
|
291 | 289 | return quantize_per_tensor(
|
@@ -367,7 +365,7 @@ def quantized_conv_per_tensor(
|
367 | 365 | )
|
368 | 366 |
|
369 | 367 |
|
370 |
| -@impl(m, "quantized_conv_nchw_per_tensor") |
| 368 | +@impl(m, "quantized_conv_nchw.per_tensor") |
371 | 369 | def quantized_conv_nchw_per_tensor(
|
372 | 370 | input_tensor: torch.Tensor,
|
373 | 371 | weight: torch.Tensor,
|
@@ -423,7 +421,7 @@ def quantized_conv_nchw_per_tensor(
|
423 | 421 | )
|
424 | 422 |
|
425 | 423 |
|
426 |
| -@impl(m, "quantized_conv_nhwc_per_tensor") |
| 424 | +@impl(m, "quantized_conv_nhwc.per_tensor") |
427 | 425 | def quantized_conv_nhwc_per_tensor(
|
428 | 426 | input_tensor: torch.Tensor,
|
429 | 427 | weight: torch.Tensor,
|
@@ -560,62 +558,62 @@ def variant(
|
560 | 558 | return decorator
|
561 | 559 |
|
562 | 560 |
|
563 |
| -@impl(m, "quantized_conv_nchw_asym8sxsym8s_asym8s_per_tensor") |
| 561 | +@impl(m, "quantized_conv_nchw_asym8sxsym8s_asym8s.per_tensor") |
564 | 562 | @quantized_conv_variant("nchw", torch.int8, torch.int8)
|
565 | 563 | def quantized_conv_nchw_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ...
|
566 | 564 |
|
567 | 565 |
|
568 |
| -@impl(m, "quantized_conv_nchw_asym8uxsym8u_asym8u_per_tensor") |
| 566 | +@impl(m, "quantized_conv_nchw_asym8uxsym8u_asym8u.per_tensor") |
569 | 567 | @quantized_conv_variant("nchw", torch.uint8, torch.uint8)
|
570 | 568 | def quantized_conv_nchw_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ...
|
571 | 569 |
|
572 | 570 |
|
573 |
| -@impl(m, "quantized_conv_nhwc_asym8sxsym8s_asym8s_per_tensor") |
| 571 | +@impl(m, "quantized_conv_nhwc_asym8sxsym8s_asym8s.per_tensor") |
574 | 572 | @quantized_conv_variant("nhwc", torch.int8, torch.int8)
|
575 | 573 | def quantized_conv_nhwc_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ...
|
576 | 574 |
|
577 | 575 |
|
578 |
| -@impl(m, "quantized_conv_nhwc_asym8uxsym8u_asym8u_per_tensor") |
| 576 | +@impl(m, "quantized_conv_nhwc_asym8uxsym8u_asym8u.per_tensor") |
579 | 577 | @quantized_conv_variant("nhwc", torch.uint8, torch.uint8)
|
580 | 578 | def quantized_conv_nhwc_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ...
|
581 | 579 |
|
582 | 580 |
|
583 |
| -@impl(m, "quantized_conv_nchw_dilated_asym8sxsym8s_asym8s_per_tensor") |
| 581 | +@impl(m, "quantized_conv_nchw_dilated_asym8sxsym8s_asym8s.per_tensor") |
584 | 582 | @quantized_conv_variant("nchw", torch.int8, torch.int8)
|
585 | 583 | def quantized_conv_nchw_dilated_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ...
|
586 | 584 |
|
587 | 585 |
|
588 |
| -@impl(m, "quantized_conv_nchw_dilated_asym8uxsym8u_asym8u_per_tensor") |
| 586 | +@impl(m, "quantized_conv_nchw_dilated_asym8uxsym8u_asym8u.per_tensor") |
589 | 587 | @quantized_conv_variant("nchw", torch.uint8, torch.uint8)
|
590 | 588 | def quantized_conv_nchw_dilated_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ...
|
591 | 589 |
|
592 | 590 |
|
593 |
| -@impl(m, "quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s_per_tensor") |
| 591 | +@impl(m, "quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s.per_tensor") |
594 | 592 | @quantized_conv_variant("nhwc", torch.int8, torch.int8)
|
595 | 593 | def quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ...
|
596 | 594 |
|
597 | 595 |
|
598 |
| -@impl(m, "quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor") |
| 596 | +@impl(m, "quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u.per_tensor") |
599 | 597 | @quantized_conv_variant("nhwc", torch.uint8, torch.uint8)
|
600 | 598 | def quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ...
|
601 | 599 |
|
602 | 600 |
|
603 |
| -@impl(m, "quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor") |
| 601 | +@impl(m, "quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s.per_tensor") |
604 | 602 | @quantized_conv_variant("nchw", torch.int8, torch.int8)
|
605 | 603 | def quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ...
|
606 | 604 |
|
607 | 605 |
|
608 |
| -@impl(m, "quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor") |
| 606 | +@impl(m, "quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u.per_tensor") |
609 | 607 | @quantized_conv_variant("nchw", torch.uint8, torch.uint8)
|
610 | 608 | def quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ...
|
611 | 609 |
|
612 | 610 |
|
613 |
| -@impl(m, "quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor") |
| 611 | +@impl(m, "quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s.per_tensor") |
614 | 612 | @quantized_conv_variant("nhwc", torch.int8, torch.int8)
|
615 | 613 | def quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ...
|
616 | 614 |
|
617 | 615 |
|
618 |
| -@impl(m, "quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor") |
| 616 | +@impl(m, "quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u.per_tensor") |
619 | 617 | @quantized_conv_variant("nhwc", torch.uint8, torch.uint8)
|
620 | 618 | def quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ...
|
621 | 619 |
|
|
0 commit comments