|
7 | 7 | # pyre-strict
|
8 | 8 |
|
9 | 9 |
|
10 |
| -from typing import Optional |
| 10 | +from typing import Callable, Optional |
11 | 11 |
|
12 | 12 | import torch
|
13 | 13 |
|
@@ -481,6 +481,145 @@ def quantized_conv_nhwc_per_tensor(
|
481 | 481 | )
|
482 | 482 |
|
483 | 483 |
|
| 484 | +def quantized_conv_variant( |
| 485 | + layout: str, |
| 486 | + input_dtype: torch.dtype, |
| 487 | + weight_dtype: torch.dtype, |
| 488 | +) -> Callable[[Callable[..., torch.Tensor]], Callable[..., torch.Tensor]]: |
| 489 | + """Create a quantized conv variant with type checking.""" |
| 490 | + |
| 491 | + def decorator(_: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]: |
| 492 | + def variant( |
| 493 | + input_tensor: torch.Tensor, |
| 494 | + weight: torch.Tensor, |
| 495 | + bias: torch.Tensor, |
| 496 | + stride: tuple[int, int], |
| 497 | + padding: tuple[int, int], |
| 498 | + dilation: tuple[int, int], |
| 499 | + groups: int, |
| 500 | + in_zero_point: int, |
| 501 | + weight_zero_point: int, |
| 502 | + bias_scale: float, |
| 503 | + output_scale: float, |
| 504 | + output_zero_point: int, |
| 505 | + out_multiplier: int, |
| 506 | + out_shift: int, |
| 507 | + ) -> torch.Tensor: |
| 508 | + assert ( |
| 509 | + input_tensor.dtype == input_dtype |
| 510 | + ), f"Expected input dtype {input_dtype}, got {input_tensor.dtype}" |
| 511 | + assert ( |
| 512 | + weight.dtype == weight_dtype |
| 513 | + ), f"Expected weight dtype {weight_dtype}, got {weight.dtype}" |
| 514 | + |
| 515 | + assert ( |
| 516 | + bias.dtype == torch.int32 |
| 517 | + ), f"Expected bias dtype int32, got {bias.dtype}" |
| 518 | + |
| 519 | + # Call the appropriate base function |
| 520 | + match layout: |
| 521 | + case "nchw": |
| 522 | + return quantized_conv_nchw_per_tensor( |
| 523 | + input_tensor, |
| 524 | + weight, |
| 525 | + bias, |
| 526 | + stride, |
| 527 | + padding, |
| 528 | + dilation, |
| 529 | + groups, |
| 530 | + in_zero_point, |
| 531 | + weight_zero_point, |
| 532 | + bias_scale, |
| 533 | + output_scale, |
| 534 | + output_zero_point, |
| 535 | + out_multiplier, |
| 536 | + out_shift, |
| 537 | + ) |
| 538 | + case "nhwc": |
| 539 | + return quantized_conv_nhwc_per_tensor( |
| 540 | + input_tensor, |
| 541 | + weight, |
| 542 | + bias, |
| 543 | + stride, |
| 544 | + padding, |
| 545 | + dilation, |
| 546 | + groups, |
| 547 | + in_zero_point, |
| 548 | + weight_zero_point, |
| 549 | + bias_scale, |
| 550 | + output_scale, |
| 551 | + output_zero_point, |
| 552 | + out_multiplier, |
| 553 | + out_shift, |
| 554 | + ) |
| 555 | + case _: |
| 556 | + raise ValueError(f"Unknown layout {layout}") |
| 557 | + |
| 558 | + return variant |
| 559 | + |
| 560 | + return decorator |
| 561 | + |
| 562 | + |
| 563 | +@impl(m, "quantized_conv_nchw_asym8sxsym8s_asym8s_per_tensor") |
| 564 | +@quantized_conv_variant("nchw", torch.int8, torch.int8) |
| 565 | +def quantized_conv_nchw_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ... |
| 566 | + |
| 567 | + |
| 568 | +@impl(m, "quantized_conv_nchw_asym8uxsym8u_asym8u_per_tensor") |
| 569 | +@quantized_conv_variant("nchw", torch.uint8, torch.uint8) |
| 570 | +def quantized_conv_nchw_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ... |
| 571 | + |
| 572 | + |
| 573 | +@impl(m, "quantized_conv_nhwc_asym8sxsym8s_asym8s_per_tensor") |
| 574 | +@quantized_conv_variant("nhwc", torch.int8, torch.int8) |
| 575 | +def quantized_conv_nhwc_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ... |
| 576 | + |
| 577 | + |
| 578 | +@impl(m, "quantized_conv_nhwc_asym8uxsym8u_asym8u_per_tensor") |
| 579 | +@quantized_conv_variant("nhwc", torch.uint8, torch.uint8) |
| 580 | +def quantized_conv_nhwc_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ... |
| 581 | + |
| 582 | + |
| 583 | +@impl(m, "quantized_conv_nchw_dilated_asym8sxsym8s_asym8s_per_tensor") |
| 584 | +@quantized_conv_variant("nchw", torch.int8, torch.int8) |
| 585 | +def quantized_conv_nchw_dilated_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ... |
| 586 | + |
| 587 | + |
| 588 | +@impl(m, "quantized_conv_nchw_dilated_asym8uxsym8u_asym8u_per_tensor") |
| 589 | +@quantized_conv_variant("nchw", torch.uint8, torch.uint8) |
| 590 | +def quantized_conv_nchw_dilated_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ... |
| 591 | + |
| 592 | + |
| 593 | +@impl(m, "quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s_per_tensor") |
| 594 | +@quantized_conv_variant("nhwc", torch.int8, torch.int8) |
| 595 | +def quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ... |
| 596 | + |
| 597 | + |
| 598 | +@impl(m, "quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor") |
| 599 | +@quantized_conv_variant("nhwc", torch.uint8, torch.uint8) |
| 600 | +def quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ... |
| 601 | + |
| 602 | + |
| 603 | +@impl(m, "quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor") |
| 604 | +@quantized_conv_variant("nchw", torch.int8, torch.int8) |
| 605 | +def quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ... |
| 606 | + |
| 607 | + |
| 608 | +@impl(m, "quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor") |
| 609 | +@quantized_conv_variant("nchw", torch.uint8, torch.uint8) |
| 610 | +def quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ... |
| 611 | + |
| 612 | + |
| 613 | +@impl(m, "quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor") |
| 614 | +@quantized_conv_variant("nhwc", torch.int8, torch.int8) |
| 615 | +def quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ... |
| 616 | + |
| 617 | + |
| 618 | +@impl(m, "quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor") |
| 619 | +@quantized_conv_variant("nhwc", torch.uint8, torch.uint8) |
| 620 | +def quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ... |
| 621 | + |
| 622 | + |
484 | 623 | @impl(m, "quantized_relu")
|
485 | 624 | def quantized_relu(
|
486 | 625 | X: torch.Tensor,
|
|
0 commit comments