diff --git a/fms_mo/aiu_addons/__init__.py b/fms_mo/aiu_addons/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/fms_mo/aiu_addons/gptq/gptq_aiu_adapter.py b/fms_mo/aiu_addons/gptq/gptq_aiu_adapter.py index 4ba43df4..bceb5415 100644 --- a/fms_mo/aiu_addons/gptq/gptq_aiu_adapter.py +++ b/fms_mo/aiu_addons/gptq/gptq_aiu_adapter.py @@ -23,6 +23,7 @@ def _gptq_qweights_transpose_aiu( input_sd: Mapping[str, torch.Tensor], + **kwargs, # pylint: disable=unused-argument ) -> Mapping[str, torch.Tensor]: new_sd = {} for name, param in input_sd.items(): @@ -41,6 +42,9 @@ def _gptq_qweights_transpose_aiu( serialization.register_adapter_step( "gpt_bigcode", "gptq_qweights_transpose_aiu", _gptq_qweights_transpose_aiu ) +serialization.register_adapter_step( + "granite", "gptq_qweights_transpose_aiu", _gptq_qweights_transpose_aiu +) serialization.register_adapter( "llama", "hf_gptq_aiu", @@ -57,3 +61,14 @@ def _gptq_qweights_transpose_aiu( "hf_gptq_aiu", ["hf_to_fms_names", "weight_fusion", "gptq_qweights_transpose_aiu"], ) +serialization.register_adapter( + "granite", + "hf_gptq_aiu", + [ + "hf_to_fms_names", + "hf_to_fms_rope", + "hf_gptq_fusion_check", + "weight_fusion", + "gptq_qweights_transpose_aiu", + ], +) diff --git a/fms_mo/aiu_addons/gptq/gptq_aiu_linear.py b/fms_mo/aiu_addons/gptq/gptq_aiu_linear.py index 9aa30db3..ef9e5977 100644 --- a/fms_mo/aiu_addons/gptq/gptq_aiu_linear.py +++ b/fms_mo/aiu_addons/gptq/gptq_aiu_linear.py @@ -28,7 +28,6 @@ from fms.modules.tp import ShardType, TPModule from fms.utils.gptq import GPTQLinearConfig import torch -import torch.nn as nn # Local from fms_mo.aiu_addons.gptq.gptq_aiu_op import register_aiu_gptq_op @@ -36,7 +35,11 @@ register_aiu_gptq_op() -class GPTQLinearAIU(nn.Module): +class GPTQLinearAIU(torch.nn.Module): + """Simplified QLinear that wraps GPTQ W4A16 custom operation. + gptq_gemm.i4f16_fxinputs_aiu must have been pre-registered to use this class. + """ + def __init__( self, in_features: int, @@ -112,6 +115,8 @@ def __init__( self.aiu_op = torch.ops.gptq_gemm.i4f16_fxinputs_aiu def forward(self, x): + """Call pre-registered custom GPTQ operation""" + x = self.aiu_op( x.half(), self.qweight, @@ -137,7 +142,9 @@ def get_gptq_aiu_linear( out_features: int, bias: bool, linear_config: Optional[Mapping[str, Any]] = None, -): +) -> torch.nn.Module: + """Retrieve a GPTQ W4A16 Linear module""" + gptq_config = GPTQLinearConfig(**linear_config) if gptq_config.desc_act: raise NotImplementedError( diff --git a/fms_mo/aiu_addons/gptq/gptq_aiu_op.py b/fms_mo/aiu_addons/gptq/gptq_aiu_op.py index 7c15c591..d958f38e 100644 --- a/fms_mo/aiu_addons/gptq/gptq_aiu_op.py +++ b/fms_mo/aiu_addons/gptq/gptq_aiu_op.py @@ -19,6 +19,9 @@ # Third Party import torch +# pylint: disable=unused-argument +# gptq op must be registered with specific I/O, even if not in use by the op function + logger = logging.getLogger(__name__) diff --git a/fms_mo/aiu_addons/i8i8/i8i8_aiu_adapter.py b/fms_mo/aiu_addons/i8i8/i8i8_aiu_adapter.py index 7dbb9377..ded1c7cd 100644 --- a/fms_mo/aiu_addons/i8i8/i8i8_aiu_adapter.py +++ b/fms_mo/aiu_addons/i8i8/i8i8_aiu_adapter.py @@ -23,6 +23,7 @@ def _int8_qparams_aiu( input_sd: Mapping[str, torch.Tensor], + **kwargs, # pylint: disable=unused-argument ) -> Mapping[str, torch.Tensor]: new_sd = {} modules_seen = set() @@ -94,7 +95,6 @@ def _add_defaults_and_concat( sq_scale.to(torch.float32), ) ) - return # registration of new adapter steps for each architecture diff --git a/fms_mo/aiu_addons/i8i8/i8i8_aiu_linear.py b/fms_mo/aiu_addons/i8i8/i8i8_aiu_linear.py index a4ebbf88..13844a92 100644 --- a/fms_mo/aiu_addons/i8i8/i8i8_aiu_linear.py +++ b/fms_mo/aiu_addons/i8i8/i8i8_aiu_linear.py @@ -28,7 +28,6 @@ from fms.modules.tp import ShardType, TPModule from fms.utils.config import ModelConfig import torch -import torch.nn as nn # Local from fms_mo.aiu_addons.i8i8.i8i8_aiu_op import register_aiu_i8i8_op @@ -38,6 +37,8 @@ @dataclass class W8A8LinearConfig(ModelConfig): + """Configuration for W8A8 Linear module""" + linear_type: str = "int8" bits: int = 8 weight_per_channel: bool = False @@ -46,8 +47,10 @@ class W8A8LinearConfig(ModelConfig): smoothquant_layers: Optional[list] = None -class W8A8LinearAIU(nn.Module): - """Simplified QLinear that wraps quantize/dequantize operation""" +class W8A8LinearAIU(torch.nn.Module): + """Simplified QLinear that wraps quantize/dequantize operation. + fms_mo.i8i8_aiu must have been pre-registered to use this class. + """ def __init__( self, @@ -199,7 +202,9 @@ def get_int8_aiu_linear( bias: bool, linear_config: Optional[Mapping[str, Any]] = None, use_smoothquant: bool = True, -): +) -> torch.nn.Module: + """Retrieve a W8A8 Linear module""" + int8_config = W8A8LinearConfig(**linear_config) linear = W8A8LinearAIU( in_features=in_features, @@ -216,8 +221,7 @@ def shard_int8_aiu_linear( tp_module: TPModule, module_sharding_info: dict[str, LinearModuleShardingInfo], ) -> Optional[set]: - """ - Set up INT8 (W8A8) quantization parameters to be sharded onto + """Set up INT8 (W8A8) quantization parameters to be sharded onto AIU-compliant linear modules | GPU | @@ -273,8 +277,7 @@ def shard_int8_aiu_linear( ) raise NotImplementedError("TP not yet supported for INT8. Work in progress") - - return unused_keys + # return unused_keys register_linear_type_to_module_map("int8_aiu", get_int8_aiu_linear) diff --git a/fms_mo/aiu_addons/i8i8/i8i8_aiu_op.py b/fms_mo/aiu_addons/i8i8/i8i8_aiu_op.py index b872b328..f2aa7a1b 100644 --- a/fms_mo/aiu_addons/i8i8/i8i8_aiu_op.py +++ b/fms_mo/aiu_addons/i8i8/i8i8_aiu_op.py @@ -22,6 +22,13 @@ logger = logging.getLogger(__name__) +# pylint: disable=unused-argument +# i8i8 op must be registered with specific I/O, even if not in use by the op function + +# pylint: disable=not-callable +# torch.nn.functional.linear not recognized as callable +# open issue in PyLint: https://github.com/pytorch/pytorch/issues/119482 + def register_aiu_i8i8_op(): """Register AIU-specific op to enable torch compile without graph break. @@ -64,7 +71,8 @@ def i8i8_aiu( dtype = x.dtype out_feat, in_feat = weight.size() - w_cv, w_cvn, a_cv, a_cvn, zshift, sq = extract_qdata( + # unused returns are w_cvn and zero_shift + w_cv, _, a_cv, a_cvn, _, sq = extract_qdata( qdata, weight_quant_type, activ_quant_type, @@ -88,6 +96,8 @@ def i8i8_aiu_abstract( activ_quant_type, smoothquant, ): + """OP template of I/O sizes""" + outshape = x.size()[:-1] + (weight.size(0),) return torch.empty( outshape, dtype=x.dtype, device=x.device, requires_grad=False @@ -153,18 +163,19 @@ def dequant_weights( w_cv: torch.Tensor, sq: torch.Tensor, weight_quant_type: str, -): +) -> torch.Tensor: + """Dequantize integer weights based on quantizer type""" + if weight_quant_type == "per_tensor": # assume 8-bit symmetric W quantization # w size: (out_feat, in_feat) # sq size: (in_feat) or (1), no need to unsqueeze return (weight * w_cv / 127) / sq - elif weight_quant_type == "per_channel": + if weight_quant_type == "per_channel": # w_cv is (out_feat), need to unsqueeze to broadcast mul to weight return (weight * w_cv.unsqueeze(dim=1) / 127) / sq - else: - raise NotImplementedError( - f"weight quantizantion type {weight_quant_type} is not supported" - ) + raise NotImplementedError( + f"weight quantizantion type {weight_quant_type} is not supported" + ) def quant_dequant_activ( @@ -173,8 +184,10 @@ def quant_dequant_activ( a_cvn: torch.Tensor, sq: torch.Tensor, activ_quant_type: str, -): +) -> torch.Tensor: """ + Quantize and dequantize activations based on quantizer type + x size (*, hid_dim) sq size (hid_dim) or (1) => no need to unsqueeze to perform x / sq @@ -183,18 +196,17 @@ def quant_dequant_activ( scale_x = 127 / a_cv x_int = torch.round(x / sq * scale_x).clamp(-127, 127) return x_int / scale_x * sq - elif activ_quant_type == "per_tensor_asymm": + if activ_quant_type == "per_tensor_asymm": scale_x = 255 / (a_cv - a_cvn) zp_x = a_cvn * scale_x x_int = torch.round(x / sq * scale_x - zp_x).clamp(0, 255) return (x_int + zp_x) / scale_x * sq - elif activ_quant_type == "per_token": + if activ_quant_type == "per_token": x_sq = x / sq a_cv_per_token = x_sq.abs().max(dim=-1, keepdim=True)[0] scale_x = 127 / a_cv_per_token x_int = torch.round(x_sq * scale_x).clamp(-127, 127) return x_int / scale_x * sq - else: - raise NotImplementedError( - f"activation quantizantion type {activ_quant_type} is not supported" - ) + raise NotImplementedError( + f"activation quantizantion type {activ_quant_type} is not supported" + ) diff --git a/pyproject.toml b/pyproject.toml index d8eeeaf2..20677f4a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,7 @@ dependencies = [ "huggingface_hub", "pandas", "safetensors", +"ibm-fms>=0.0.8" ] [project.optional-dependencies]