diff --git a/src/compressed_tensors/quantization/quant_args.py b/src/compressed_tensors/quantization/quant_args.py index d9a92d0d..2675d9b8 100644 --- a/src/compressed_tensors/quantization/quant_args.py +++ b/src/compressed_tensors/quantization/quant_args.py @@ -12,13 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -import warnings from enum import Enum from typing import Any, Dict, List, Optional, Union import torch from compressed_tensors.utils import Aliasable -from compressed_tensors.utils.helpers import deprecated from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator @@ -263,8 +261,6 @@ def validate_model_after(model: "QuantizationArgs") -> "QuantizationArgs": block_structure = model.block_structure actorder = model.actorder dynamic = model.dynamic - observer = model.observer - dynamic = model.dynamic # infer strategy if strategy is None: @@ -316,45 +312,8 @@ def validate_model_after(model: "QuantizationArgs") -> "QuantizationArgs": "activation ordering" ) - # infer observer w.r.t. dynamic - if dynamic: - supported_strategies = ( - QuantizationStrategy.TOKEN, - QuantizationStrategy.TENSOR, - QuantizationStrategy.TENSOR_GROUP, - QuantizationStrategy.GROUP, - ) - if strategy not in supported_strategies: - raise ValueError( - f"One of {supported_strategies} must be used for dynamic quant." - ) - - if ( - dynamic == DynamicType.LOCAL - and strategy != QuantizationStrategy.TENSOR_GROUP - ): - raise ValueError("local is only supported for strategy tensor_group") - - if observer is not None: - if dynamic is True: # checking if dynamic is True, not "local" - if ( - observer != "memoryless" - ): # avoid annoying users with old configs - warnings.warn( - "No observer is used for dynamic quant., setting to None" - ) - observer = None - else: - if dynamic == DynamicType.LOCAL: - observer = "minmax" - - elif observer is None: - # default to minmax for non-dynamic cases - observer = "minmax" - # write back modified values model.strategy = strategy - model.observer = observer return model def pytorch_dtype(self) -> torch.dtype: @@ -373,10 +332,6 @@ def pytorch_dtype(self) -> torch.dtype: else: raise ValueError(f"Invalid quantization type {self.type}") - @deprecated("QuantizationArgs.observer") - def get_observer(self) -> str: - return self.observer - model_config = ConfigDict(extra="forbid") diff --git a/src/compressed_tensors/quantization/quant_scheme.py b/src/compressed_tensors/quantization/quant_scheme.py index 79db8d28..09f5901f 100644 --- a/src/compressed_tensors/quantization/quant_scheme.py +++ b/src/compressed_tensors/quantization/quant_scheme.py @@ -59,6 +59,7 @@ def validate_model_after(model: "QuantizationScheme") -> "QuantizationScheme": weights = model.weights format = model.format + # validate input args if inputs is not None: if inputs.strategy not in ( QuantizationStrategy.TOKEN, @@ -84,15 +85,18 @@ def validate_model_after(model: "QuantizationScheme") -> "QuantizationScheme": if inputs.actorder is not None: raise ValueError("Cannot apply actorder to input activations") + # validate output args if outputs is not None: if outputs.actorder is not None: raise ValueError("Cannot apply actorder to output activations") + # validate format if format == CompressionFormat.mixed_precision.value: raise ValueError( "mixed-precision cannot be set as a format for a QuantizationScheme" ) + # validate matching group sizes if ( inputs and weights @@ -110,8 +114,35 @@ def validate_model_after(model: "QuantizationScheme") -> "QuantizationScheme": stacklevel=2, ) + # set observer defaults + model._validate_observers() + return model + def _validate_observers(self): + inputs = self.input_activations + weights = self.weights + outputs = self.output_activations + + if inputs is not None and inputs.observer is None: + if inputs.dynamic is True: + inputs.observer = "memoryless_minmax" + else: + inputs.observer = "static_minmax" + + if weights is not None and weights.observer is None: + weights.observer = "memoryless_minmax" + + if outputs is not None and outputs.observer is None: + if outputs.dynamic is True: + outputs.observer = "memoryless_minmax" + else: + outputs.observer = "static_minmax" + + self.input_activations = inputs + self.weights = weights + self.output_activations = outputs + model_config = ConfigDict(extra="forbid") @@ -172,7 +203,6 @@ def is_preset_scheme(name: str) -> bool: symmetric=True, dynamic=False, group_size=16, - observer="static_minmax", ), input_activations=QuantizationArgs( num_bits=4, @@ -181,7 +211,6 @@ def is_preset_scheme(name: str) -> bool: symmetric=True, dynamic=DynamicType.LOCAL, group_size=16, - observer="static_minmax", ), )