Skip to content

Support dynamically quantized 2D convolutions #10248

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

keyprocedure
Copy link
Contributor

@keyprocedure keyprocedure commented Apr 16, 2025

Summary

Add initial support for dynamically quantized Conv2d in XNNPACK:

  • Add conv to DYNAMIC_OPS for annotation
  • Update partitioner to support dynamically quantized Conv2d
  • Add checks to ensure only 2D, non-depthwise dynamically quantized convs are partitioned and annotated
  • Update NHWC permute node insertion to trace back to original input for dynamically quantized inputs
  • Compute num_nonbatch_dims based on whether the node feeds into a conv
  • Remove the num_nonbatch_dims check from XNNCompiler
  • Add unit tests for channels-last permute and single, sequential, and parallel dynamically quantized 2D convs

Fixes #9021

Test plan

python -m unittest backends.xnnpack.test.ops.test_conv2d.TestConv2d.test_dq_conv2d
python -m unittest backends.xnnpack.test.ops.test_conv2d.TestConv2d.test_dq_conv2d_seq
python -m unittest backends.xnnpack.test.ops.test_conv2d.TestConv2d.test_dq_conv2d_parallel
python -m unittest backends.xnnpack.test.passes.test_channels_last_tagged_reshape.TestChannelsLastTaggedReshapePass.test_dq_conv2d_channels_last_tagged_reshape_pass

Copy link

pytorch-bot bot commented Apr 16, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/10248

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 9693061 with merge base b7eee0c (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 16, 2025
@keyprocedure
Copy link
Contributor Author

@pytorchbot label "topic: not user facing"

Copy link
Contributor

@mcr229 mcr229 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks really good. Just a few minor comments.

@@ -283,14 +284,26 @@ def input_to_nhwc(
]
else:
# Need to create NHWC node
# Check if input uses dynamic quantization
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice! the change made here looks simplier than i expected. Though i suspect it was still pretty hard to navigate and figure out.

Do you mind adding a new test here as well:
https://github.com/pytorch/executorch/blob/d7f74bd4adf10950224d2d975a6e23e92e7be6f3/backends/xnnpack/test/passes/test_channels_last_tagged_reshape.py

Essentially we just check if there is a dynamically quantized convolution, then we place one permute before the dynamic_q chain and one permute after the conv.

Copy link
Contributor Author

@keyprocedure keyprocedure Apr 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was definitely one of those changes where the time-to-LOC ratio was off the charts haha

Test added in test_channels_last_tagged_reshape.py

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea, this is a good datapoint, I want to help improve the contributability flow for the XNNPACK Backend, and some of that definitely means improving/refactoring passes like this which are way too complex.

q_input_val = q_input.meta.get("val", None)
q_input_shape = getattr(q_input_val, "shape", None)
if q_input_shape is not None:
num_nonbatch_dims = max(len(q_input_shape) - 1, 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this isn't necessarily the case for linear layers, the input to these will always have 1 nonbatch dimension.:
(x, y ,z, input_channels). The rank of the tensor can be arbitrary, and we always interpret the last dimension as input channels, and every other dimension as a batch dimension.

for now let's just add a check if convolution then 3 if linear then 1. This issue stems from the fact that we are injecting per_tensor quant nodes, but our intent is to do per_batch. We are adding in affine quantization which should tell us how many batch and how many non-batch dimensions there are in the quant node, so later on we will fix it to use that, but for now, it might just be hard code the conv and linear case :(

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good to know. I added a check to determine if the node feeds into a conv and set the non-batch dimensions to 3 if it does

weight_shape = getattr(weight_val, "shape", None)

# Skip if not a 4D weight tensor (i.e. not conv2d)
if weight_shape is not None and len(weight_shape) != 4:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice, let's also add a skip if the convolution is depthwise, since XNNPACK can't handle that case yet.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way to get the group number for the depthwise check here? I'm currently defaulting to 1 group

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i do believe the default is 1 if it is not an arg

@@ -1172,7 +1167,7 @@ Error defineStaticTransposeNode(
ET_CHECK_OR_RETURN_ERROR(
status == xnn_status_success,
Internal,
"Failed to create sigmoid node %i with code: %s",
"Failed to create static transpose node %i with code: %s",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks :)

Comment on lines 180 to 181
self.conv.weight.requires_grad = False
self.conv.bias.requires_grad = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this necessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, removed

@@ -169,6 +173,20 @@ def get_inputs(self):
return (torch.randn(2, 2, 4, 4),)


class Conv2dDynamicQuant(torch.nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you just add two more, where we have:

  1. two convolutions in sequence: inp --> conv --> conv --> out
  2. two convolutions running in parallel:
inp1 --> conv --> out1
   \
     --> conv2 --> out2

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added the two unit tests

@@ -358,6 +358,11 @@ def check_constraints(self, node: torch.fx.Node, ep: ExportedProgram) -> bool:
why(node, "Only support 1D + 2D Conv")
return False # Only support 1D + 2D Conv

precision = self._detect_precision(node)
if precision == ConfigPrecisionType.DYNAMIC_QUANT and len(conv_stride) != 2:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you also add the depthwise check here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added the depthwise check. If it looks good, I can move the check to a helper function since it's being used in op_conv2d.py, gemm_configs.py, and xnnpack_quantizer_utils.py

Comment on lines 366 to 378
is_transpose = node.args[6]

if is_transpose:
group_input_channels = int(kernel_shape[0] / groups)
group_output_channels = kernel_shape[1]
else:
group_input_channels = kernel_shape[1]
group_output_channels = int(kernel_shape[0] / groups)

is_depthwise = (
group_input_channels == 1
and group_output_channels % group_input_channels == 0
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we can move this check into xnnpack/utils/utils.py

@mcr229
Copy link
Contributor

mcr229 commented Apr 21, 2025

looks good, let's rebase and just let the CI run thank you!

Copy link
Contributor

@digantdesai digantdesai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!

@@ -169,6 +173,55 @@ def get_inputs(self):
return (torch.randn(2, 2, 4, 4),)


class Conv2dDQ(torch.nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit

Suggested change
class Conv2dDQ(torch.nn.Module):
class Conv2d(torch.nn.Module):


DynamicallyQuantizedPartitioner = XnnpackPartitioner(
config_precisions=ConfigPrecisionType.DYNAMIC_QUANT,
per_op_mode=True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: also check without this?


def test_dq_conv2d_seq(self) -> None:
model = Conv2dDQSeq()
self._test_dq(model, conv_count=2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: get conv count from the model rather than hardcoding

pssrawat and others added 2 commits April 21, 2025 16:55
Differential Revision: D72503552

Pull Request resolved: pytorch#9923
…h#9926)

The output verification sometimes fails for the mm tests on U85. Add
pytest.mark.flaky decorators to the tests to prevent sporadic failures.


Co-authored-by: Martin Lindström <[email protected]>
kirklandsign and others added 10 commits April 21, 2025 16:55
Differential Revision: D73292616

Pull Request resolved: pytorch#10312
Currently, we generate every combination of inputs for each module with the export_delegate_program script:
- extract_segments=True, False
- delegate_alignment=None,1024

Planning to add another flag, 'external_constants', which will move constants into a separate file to test program-data separation for delegated programs.

This test only requires pte, ptd, with default settings. So refactoring the export script to only generate based on the args, and update genrule to generate what the test requires.

Differential Revision: [D73278562](https://our.internmc.facebook.com/intern/diff/D73278562/)
…ch#10340)

When using attention bias dont override seq length for causal attention

Differential Revision: [D73222733](https://our.internmc.facebook.com/intern/diff/D73222733/)
… attention mask (pytorch#10341)

Previously we assumed that the custom sdpa always does causal attention.
This diff adds option to this module swap pass to make custom sdpa leverage
attention mask instead of causal.

Differential Revision: [D73222736](https://our.internmc.facebook.com/intern/diff/D73222736/)
Skip those with spaces that aren't actually xrefs
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: not user facing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support Dynamically Quantized Convolutions