Skip to content

Conversation

brian-dellabetta
Copy link
Collaborator

@brian-dellabetta brian-dellabetta commented Aug 21, 2025

SUMMARY:
Prerequisites:

This allows for multi-modifier support by scoping the application of quantization config/status to only the modules in the model that match the given targets/ignore configuration, rather than all modules. Initialization of observers is moved to on_start (instead of on_initialize) to match their removal on_end (and not on_finalize). This prevents collision during the multi-modifier lifecycle

  • Update AWQ
  • Update QuantizationModifier
  • Update QuantizationMixin
  • Update GPTQ
  • Any others?
  • Should we enable/disable quantization for the entire model or only matching modules? See TODO here

TEST PLAN:

  • Tests were added to [Multi-Modifier] Scoped apply quantization config neuralmagic/compressed-tensors#432 to confirm correct application of multiple modifiers.
  • Added an example in this PR to show how AWQ and GPTQ can be applied heterogeneously to a model, along with a small README. Logs show alternating AWQ and GPTQ messages for "sequential", and correct behavior for "independent" pipelines. Model checkpoint for the sequential pipeline shows correct application of W8A8 to self_attn layers and W4A16 to mlp layers. config.json and safetensors weights all look as expected

Copy link

👋 Hi! Thank you for contributing to llm-compressor. Please add the ready label when the PR is ready for review.

Note: This is required to complete the testing suite, please only add the label once the PR is code complete and local testing has been performed.

@brian-dellabetta brian-dellabetta force-pushed the bdellabe/scoped-quant-status branch 2 times, most recently from 5fec983 to 2f93072 Compare August 28, 2025 16:51
@brian-dellabetta brian-dellabetta changed the title [Multi-modifier] Support scoped appliation of quantization config/status [Multi-modifier] Support scoped application of quantization config/status Sep 2, 2025
Signed-off-by: Brian Dellabetta <[email protected]>
Signed-off-by: Brian Dellabetta <[email protected]>
Signed-off-by: Brian Dellabetta <[email protected]>
@brian-dellabetta brian-dellabetta force-pushed the bdellabe/scoped-quant-status branch from 2f93072 to f99db2f Compare September 11, 2025 16:43
@@ -178,7 +182,7 @@ def on_start(self, state: State, event: Event, **kwargs):

# register gptq hooks
added_hook = False
for module in state.model.modules():
for _, module in match_named_modules(state.model, self.targets, self.ignore):
if getattr_chain(module, "quantization_scheme.weights", None) is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be changed into an assert rather than an if?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah i think that makes sense, I can update

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm on second thought, it doesn't look like there's anything in the validation layer confirming each quantization args instance has a weights field. so if a user sets an invalid config where weight quantization isn't configured, it would error out here. Is that what we want?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer explicit error rather than silent skip here

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we do this in the validation layer though? I can add a check to model validate, and switch to assert?

Signed-off-by: Brian Dellabetta <[email protected]>
Signed-off-by: Brian Dellabetta <[email protected]>
@brian-dellabetta brian-dellabetta marked this pull request as ready for review September 15, 2025 20:38
@brian-dellabetta brian-dellabetta added the ready When a PR is ready for review label Sep 15, 2025
Signed-off-by: Brian Dellabetta <[email protected]>
@brian-dellabetta brian-dellabetta removed the ready When a PR is ready for review label Sep 15, 2025
Copy link
Collaborator

@kylesayrs kylesayrs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding some basic tests/ common use cases, otherwise looks good!

@@ -0,0 +1,101 @@
from datasets import load_dataset
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we maybe generalize this folder name to mixed-precision so that people associate this with enabling mixed precision workloads?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even though multi-modifier recipes don't necessarily need to be mixed precision? I just think mixed-precision is a stronger message than multi-modifier

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can rename to examples/mixed_precision if we find that to be more appropriate. Let's discuss in standup

max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
# Option 1) run both modifiers in a single calibrated run
pipeline="sequential",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the pipeline not already sequential by default?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it actually defaults to independent. Without it set, it infers to use SequentialPipeline for GPTQ, running just GPTQ independently. It then infers to use SequentialPipeline for AWQ, running just AWQ independently. Not sure what we want for default behavior though, this just makes it explicit

print("==========================================\n\n")

# Save to disk compressed.
SAVE_DIR = model_id.rstrip("/").split("/")[-1] + "-W4A16-G128"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add GPTQ and AWQ to the names?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants