Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 0 additions & 45 deletions src/compressed_tensors/quantization/quant_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
from enum import Enum
from typing import Any, Dict, List, Optional, Union

import torch
from compressed_tensors.utils import Aliasable
from compressed_tensors.utils.helpers import deprecated
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator


Expand Down Expand Up @@ -263,8 +261,6 @@ def validate_model_after(model: "QuantizationArgs") -> "QuantizationArgs":
block_structure = model.block_structure
actorder = model.actorder
dynamic = model.dynamic
observer = model.observer
dynamic = model.dynamic

# infer strategy
if strategy is None:
Expand Down Expand Up @@ -316,45 +312,8 @@ def validate_model_after(model: "QuantizationArgs") -> "QuantizationArgs":
"activation ordering"
)

# infer observer w.r.t. dynamic
if dynamic:
supported_strategies = (
QuantizationStrategy.TOKEN,
QuantizationStrategy.TENSOR,
QuantizationStrategy.TENSOR_GROUP,
QuantizationStrategy.GROUP,
)
if strategy not in supported_strategies:
raise ValueError(
f"One of {supported_strategies} must be used for dynamic quant."
)

if (
dynamic == DynamicType.LOCAL
and strategy != QuantizationStrategy.TENSOR_GROUP
):
raise ValueError("local is only supported for strategy tensor_group")

if observer is not None:
if dynamic is True: # checking if dynamic is True, not "local"
if (
observer != "memoryless"
): # avoid annoying users with old configs
warnings.warn(
"No observer is used for dynamic quant., setting to None"
)
observer = None
else:
if dynamic == DynamicType.LOCAL:
observer = "minmax"

elif observer is None:
# default to minmax for non-dynamic cases
observer = "minmax"

# write back modified values
model.strategy = strategy
model.observer = observer
return model

def pytorch_dtype(self) -> torch.dtype:
Expand All @@ -373,10 +332,6 @@ def pytorch_dtype(self) -> torch.dtype:
else:
raise ValueError(f"Invalid quantization type {self.type}")

@deprecated("QuantizationArgs.observer")
def get_observer(self) -> str:
return self.observer

model_config = ConfigDict(extra="forbid")


Expand Down
33 changes: 31 additions & 2 deletions src/compressed_tensors/quantization/quant_scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def validate_model_after(model: "QuantizationScheme") -> "QuantizationScheme":
weights = model.weights
format = model.format

# validate input args
if inputs is not None:
if inputs.strategy not in (
QuantizationStrategy.TOKEN,
Expand All @@ -84,15 +85,18 @@ def validate_model_after(model: "QuantizationScheme") -> "QuantizationScheme":
if inputs.actorder is not None:
raise ValueError("Cannot apply actorder to input activations")

# validate output args
if outputs is not None:
if outputs.actorder is not None:
raise ValueError("Cannot apply actorder to output activations")

# validate format
if format == CompressionFormat.mixed_precision.value:
raise ValueError(
"mixed-precision cannot be set as a format for a QuantizationScheme"
)

# validate matching group sizes
if (
inputs
and weights
Expand All @@ -110,8 +114,35 @@ def validate_model_after(model: "QuantizationScheme") -> "QuantizationScheme":
stacklevel=2,
)

# set observer defaults
model._validate_observers()

return model

def _validate_observers(self):
inputs = self.input_activations
weights = self.weights
outputs = self.output_activations

if inputs is not None and inputs.observer is None:
if inputs.dynamic is True:
inputs.observer = "memoryless_minmax"
else:
inputs.observer = "static_minmax"

if weights is not None and weights.observer is None:
weights.observer = "memoryless_minmax"

if outputs is not None and outputs.observer is None:
if outputs.dynamic is True:
outputs.observer = "memoryless_minmax"
else:
outputs.observer = "static_minmax"

self.input_activations = inputs
self.weights = weights
self.output_activations = outputs

model_config = ConfigDict(extra="forbid")


Expand Down Expand Up @@ -172,7 +203,6 @@ def is_preset_scheme(name: str) -> bool:
symmetric=True,
dynamic=False,
group_size=16,
observer="static_minmax",
),
input_activations=QuantizationArgs(
num_bits=4,
Expand All @@ -181,7 +211,6 @@ def is_preset_scheme(name: str) -> bool:
symmetric=True,
dynamic=DynamicType.LOCAL,
group_size=16,
observer="static_minmax",
),
)

Expand Down