|
7 | 7 | # pyre-strict
|
8 | 8 |
|
9 | 9 |
|
10 |
| -from typing import Optional |
| 10 | +from typing import Callable, Optional |
11 | 11 |
|
12 | 12 | import torch
|
13 | 13 |
|
@@ -479,6 +479,145 @@ def quantized_conv_nhwc_per_tensor(
|
479 | 479 | )
|
480 | 480 |
|
481 | 481 |
|
| 482 | +def quantized_conv_variant( |
| 483 | + layout: str, |
| 484 | + input_dtype: torch.dtype, |
| 485 | + weight_dtype: torch.dtype, |
| 486 | +) -> Callable[[Callable[..., torch.Tensor]], Callable[..., torch.Tensor]]: |
| 487 | + """Create a quantized conv variant with type checking.""" |
| 488 | + |
| 489 | + def decorator(_: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]: |
| 490 | + def variant( |
| 491 | + input_tensor: torch.Tensor, |
| 492 | + weight: torch.Tensor, |
| 493 | + bias: torch.Tensor, |
| 494 | + stride: tuple[int, int], |
| 495 | + padding: tuple[int, int], |
| 496 | + dilation: tuple[int, int], |
| 497 | + groups: int, |
| 498 | + in_zero_point: int, |
| 499 | + weight_zero_point: int, |
| 500 | + bias_scale: float, |
| 501 | + output_scale: float, |
| 502 | + output_zero_point: int, |
| 503 | + out_multiplier: int, |
| 504 | + out_shift: int, |
| 505 | + ) -> torch.Tensor: |
| 506 | + assert ( |
| 507 | + input_tensor.dtype == input_dtype |
| 508 | + ), f"Expected input dtype {input_dtype}, got {input_tensor.dtype}" |
| 509 | + assert ( |
| 510 | + weight.dtype == weight_dtype |
| 511 | + ), f"Expected weight dtype {weight_dtype}, got {weight.dtype}" |
| 512 | + |
| 513 | + assert ( |
| 514 | + bias.dtype == torch.int32 |
| 515 | + ), f"Expected bias dtype int32, got {bias.dtype}" |
| 516 | + |
| 517 | + # Call the appropriate base function |
| 518 | + match layout: |
| 519 | + case "nchw": |
| 520 | + return quantized_conv_nchw_per_tensor( |
| 521 | + input_tensor, |
| 522 | + weight, |
| 523 | + bias, |
| 524 | + stride, |
| 525 | + padding, |
| 526 | + dilation, |
| 527 | + groups, |
| 528 | + in_zero_point, |
| 529 | + weight_zero_point, |
| 530 | + bias_scale, |
| 531 | + output_scale, |
| 532 | + output_zero_point, |
| 533 | + out_multiplier, |
| 534 | + out_shift, |
| 535 | + ) |
| 536 | + case "nhwc": |
| 537 | + return quantized_conv_nhwc_per_tensor( |
| 538 | + input_tensor, |
| 539 | + weight, |
| 540 | + bias, |
| 541 | + stride, |
| 542 | + padding, |
| 543 | + dilation, |
| 544 | + groups, |
| 545 | + in_zero_point, |
| 546 | + weight_zero_point, |
| 547 | + bias_scale, |
| 548 | + output_scale, |
| 549 | + output_zero_point, |
| 550 | + out_multiplier, |
| 551 | + out_shift, |
| 552 | + ) |
| 553 | + case _: |
| 554 | + raise ValueError(f"Unknown layout {layout}") |
| 555 | + |
| 556 | + return variant |
| 557 | + |
| 558 | + return decorator |
| 559 | + |
| 560 | + |
| 561 | +@impl(m, "quantized_conv_nchw_asym8sxsym8s_asym8s_per_tensor") |
| 562 | +@quantized_conv_variant("nchw", torch.int8, torch.int8) |
| 563 | +def quantized_conv_nchw_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ... |
| 564 | + |
| 565 | + |
| 566 | +@impl(m, "quantized_conv_nchw_asym8uxsym8u_asym8u_per_tensor") |
| 567 | +@quantized_conv_variant("nchw", torch.uint8, torch.uint8) |
| 568 | +def quantized_conv_nchw_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ... |
| 569 | + |
| 570 | + |
| 571 | +@impl(m, "quantized_conv_nhwc_asym8sxsym8s_asym8s_per_tensor") |
| 572 | +@quantized_conv_variant("nhwc", torch.int8, torch.int8) |
| 573 | +def quantized_conv_nhwc_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ... |
| 574 | + |
| 575 | + |
| 576 | +@impl(m, "quantized_conv_nhwc_asym8uxsym8u_asym8u_per_tensor") |
| 577 | +@quantized_conv_variant("nhwc", torch.uint8, torch.uint8) |
| 578 | +def quantized_conv_nhwc_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ... |
| 579 | + |
| 580 | + |
| 581 | +@impl(m, "quantized_conv_nchw_dilated_asym8sxsym8s_asym8s_per_tensor") |
| 582 | +@quantized_conv_variant("nchw", torch.int8, torch.int8) |
| 583 | +def quantized_conv_nchw_dilated_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ... |
| 584 | + |
| 585 | + |
| 586 | +@impl(m, "quantized_conv_nchw_dilated_asym8uxsym8u_asym8u_per_tensor") |
| 587 | +@quantized_conv_variant("nchw", torch.uint8, torch.uint8) |
| 588 | +def quantized_conv_nchw_dilated_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ... |
| 589 | + |
| 590 | + |
| 591 | +@impl(m, "quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s_per_tensor") |
| 592 | +@quantized_conv_variant("nhwc", torch.int8, torch.int8) |
| 593 | +def quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ... |
| 594 | + |
| 595 | + |
| 596 | +@impl(m, "quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor") |
| 597 | +@quantized_conv_variant("nhwc", torch.uint8, torch.uint8) |
| 598 | +def quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ... |
| 599 | + |
| 600 | + |
| 601 | +@impl(m, "quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor") |
| 602 | +@quantized_conv_variant("nchw", torch.int8, torch.int8) |
| 603 | +def quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ... |
| 604 | + |
| 605 | + |
| 606 | +@impl(m, "quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor") |
| 607 | +@quantized_conv_variant("nchw", torch.uint8, torch.uint8) |
| 608 | +def quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ... |
| 609 | + |
| 610 | + |
| 611 | +@impl(m, "quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor") |
| 612 | +@quantized_conv_variant("nhwc", torch.int8, torch.int8) |
| 613 | +def quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ... |
| 614 | + |
| 615 | + |
| 616 | +@impl(m, "quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor") |
| 617 | +@quantized_conv_variant("nhwc", torch.uint8, torch.uint8) |
| 618 | +def quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ... |
| 619 | + |
| 620 | + |
482 | 621 | @impl(m, "quantized_relu")
|
483 | 622 | def quantized_relu(
|
484 | 623 | X: torch.Tensor,
|
|
0 commit comments