diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index e44d9d037..8715896ad 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -950,27 +950,11 @@ jobs: run: | export TORCHCHAT_ROOT=${PWD} echo "et-git-hash=$(cat ${TORCHCHAT_ROOT}/install/.pins/et-pin.txt)" >> "$GITHUB_ENV" - - name: Load or install ET - id: install-et - uses: actions/cache@v4 - with: - path: | - ./et-build - ./torchchat/utils/scripts - key: et-build-${{runner.os}}-${{runner.arch}}-${{env.et-git-hash}}-${{ hashFiles('**/install_et.sh', '**/build_native.sh') }} - - if: ${{ steps.install-et.outputs.cache-hit != 'true' }} - continue-on-error: true + - name: Install ExecuTorch run: | echo "Installing ExecuTorch" + export TORCHCHAT_ROOT=${PWD} bash torchchat/utils/scripts/install_et.sh - - name: Install ExecuTorch python - run: | - echo "Install ExecuTorch python" - export TORCHCHAT_ROOT=$PWD - export ET_BUILD_DIR="et-build" - ENABLE_ET_PYBIND="${1:-true}" - source "torchchat/utils/scripts/install_utils.sh" - install_executorch_python_libs $ENABLE_ET_PYBIND - name: Install runner run: | echo "Installing runner" diff --git a/install/.pins/torchao-pin.txt b/install/.pins/torchao-pin.txt index c1b84754c..3e1fb8927 100644 --- a/install/.pins/torchao-pin.txt +++ b/install/.pins/torchao-pin.txt @@ -1 +1 @@ -711fa0809f06fc97febd0c3fe72563c3fe227e51 +7513042f39515af4c643bc1f9399952ad7f4f904 diff --git a/torchchat/utils/quantize.py b/torchchat/utils/quantize.py index aca81ef5c..f4339c0c3 100644 --- a/torchchat/utils/quantize.py +++ b/torchchat/utils/quantize.py @@ -34,10 +34,15 @@ # AttributeError: '_OpNamespace' 'quantized_decomposed' object has no attribute 'quantize_per_channel_group' from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa +from torchao.dtypes import PackedLinearInt8DynamicActivationIntxWeightLayout, QDQLayout +from torchao.experimental.quant_api import EmbeddingQuantizer +from torchao.quantization.granularity import PerAxis, PerGroup from torchao.quantization.quant_api import ( int4_weight_only, Int4WeightOnlyQuantizer, Int8DynActInt4WeightQuantizer, + Int8DynamicActivationIntxWeightConfig, + MappingType, quantize_, ) from torchao.utils import unwrap_tensor_subclass @@ -50,18 +55,6 @@ state_dict_device, use_et_backend, ) -from torchao.experimental.packed_linear_int8_dynamic_activation_intx_weight_layout import ( - PackedLinearInt8DynamicActivationIntxWeightLayout, -) -from torchao.experimental.quant_api import ( - int8_dynamic_activation_intx_weight, - IntxWeightEmbeddingQuantizer, -) -from torchao.quantization.granularity import ( - PerGroup, - PerRow, -) -from torchao.dtypes import PlainLayout # Flag for whether the a8wxdq quantizer is available. @@ -87,7 +80,7 @@ def get_named_parameters(func: Callable) -> List[str]: return named_params def validate_args(named_params: List[str], q_kwargs: Dict[str, Any], quantizer: Optional[str] = None) -> Dict[str, Any]: - for key in q_kwargs.keys(): + for key in list(q_kwargs.keys()): if key not in named_params: print(f"Specification for quantizer {quantizer} has extraneous key {key}. Ignoring.") del q_kwargs[key] @@ -137,29 +130,34 @@ def quantize_model( group_size = q_kwargs["groupsize"] bit_width = q_kwargs["bitwidth"] has_weight_zeros = q_kwargs["has_weight_zeros"] - granularity = PerRow() if group_size == -1 else PerGroup(group_size) + granularity = PerAxis() if group_size == -1 else PerGroup(group_size) weight_dtype = getattr(torch, f"int{bit_width}") + weight_mapping_type = ( + MappingType.ASYMMETRIC + if has_weight_zeros + else MappingType.SYMMETRIC + ) try: quantize_( - model, - int8_dynamic_activation_intx_weight( + model, + Int8DynamicActivationIntxWeightConfig( weight_dtype=weight_dtype, - granularity=granularity, - has_weight_zeros=has_weight_zeros, + weight_granularity=granularity, + weight_mapping_type=weight_mapping_type, layout=PackedLinearInt8DynamicActivationIntxWeightLayout(), ), ) except Exception as e: print("Encountered error during quantization: {e}") - print("Trying with PlainLayout") + print("Trying with QDQLayout") quantize_( - model, - int8_dynamic_activation_intx_weight( + model, + Int8DynamicActivationIntxWeightConfig( weight_dtype=weight_dtype, - granularity=granularity, - has_weight_zeros=has_weight_zeros, - layout=PlainLayout(), + weight_granularity=granularity, + weight_mapping_type=weight_mapping_type, + layout=QDQLayout(), ), ) @@ -174,6 +172,22 @@ def quantize_model( print(f"Quantizer {quantizer} requires float32 inputs, but received {get_precision()}. Changing dtype to float32. Note that after quantization, the weights will be lowbit integers, not float32.") set_precision(torch.float32) + group_size = q_kwargs["groupsize"] + bit_width = q_kwargs["bitwidth"] + has_weight_zeros = q_kwargs.get("has_weight_zeros", True) + q_kwargs["granularity"] = ( + PerAxis() if group_size == -1 else PerGroup(group_size) + ) + q_kwargs["weight_dtype"] = getattr(torch, f"int{bit_width}") + q_kwargs["mapping_type"] = ( + MappingType.ASYMMETRIC + if has_weight_zeros + else MappingType.SYMMETRIC + ) + q_kwargs["use_fallback"] = False + del q_kwargs["groupsize"] + del q_kwargs["bitwidth"] + if quantizer == "linear:afpwx" and device != "mps": raise RuntimeError("linear:afpwx quantization can only run on mps device!") @@ -188,7 +202,10 @@ def quantize_model( # Handle tokenizer for scenarios where the quantizer needs to tokenizer sample inputs if "tokenizer" in named_params: q_kwargs["tokenizer"] = tokenizer - quant_handler = q(device=device, precision=precision, **q_kwargs) + if quantizer == "embedding:wx": + quant_handler = q(**q_kwargs) + else: + quant_handler = q(device=device, precision=precision, **q_kwargs) # quantize model model = quant_handler.quantize(model) @@ -939,7 +956,7 @@ def quantized_model(self) -> nn.Module: # class references quantizer_class_dict = { "embedding": EmbeddingOnlyQuantHandler, - "embedding:wx": IntxWeightEmbeddingQuantizer, + "embedding:wx": EmbeddingQuantizer, "linear:int8": WeightOnlyInt8QuantHandler, "precision": PrecisionHandler, "executor": ExecutorHandler, @@ -979,5 +996,19 @@ def quantized_model(self) -> nn.Module: except Exception as e: print("Unable to load torchao mps ops library.") + torchao_experimental_mps_op_lib_spec = importlib.util.spec_from_file_location( + "torchao_experimental_mps_op_lib", + f"{torchao_build_path}/src/ao/torchao/experimental/ops/mps/mps_op_lib.py", + ) + torchao_experimental_mps_op_lib = importlib.util.module_from_spec( + torchao_experimental_mps_op_lib_spec + ) + sys.modules["torchao_experimental_mps_op_lib"] = torchao_experimental_mps_op_lib + torchao_experimental_mps_op_lib_spec.loader.exec_module( + torchao_experimental_mps_op_lib + ) + from torchao_experimental_mps_op_lib import * + + except Exception as e: print("Unable to import torchao experimental quant_api with error: ", e)