diff --git a/torchchat/utils/quantize.py b/torchchat/utils/quantize.py index f4339c0c3..1da8e7f39 100644 --- a/torchchat/utils/quantize.py +++ b/torchchat/utils/quantize.py @@ -23,6 +23,7 @@ from __future__ import annotations import json +import types # from functools import reduce # from math import gcd @@ -32,6 +33,8 @@ import torch.nn as nn import torch.nn.functional as F +from torchao.quantization.quant_api import _linear_extra_repr + # AttributeError: '_OpNamespace' 'quantized_decomposed' object has no attribute 'quantize_per_channel_group' from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa from torchao.dtypes import PackedLinearInt8DynamicActivationIntxWeightLayout, QDQLayout @@ -111,6 +114,16 @@ def quantize_model( if isinstance(quantize_options, str): quantize_options = json.loads(quantize_options) + def _attach_extra_repr(module): + for name, child in module.named_children(): + if isinstance(child, nn.Linear): + if not hasattr(child, 'extra_repr'): + child.extra_repr = types.MethodType(_linear_extra_repr, child) + else: + _attach_extra_repr(child) + + _attach_extra_repr(model) + for quantizer, q_kwargs in quantize_options.items(): if quantizer not in quantizer_class_dict: raise RuntimeError(f"unknown quantizer {quantizer} specified")