Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added fms_mo/aiu_addons/__init__.py
Empty file.
15 changes: 15 additions & 0 deletions fms_mo/aiu_addons/gptq/gptq_aiu_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

def _gptq_qweights_transpose_aiu(
input_sd: Mapping[str, torch.Tensor],
**kwargs, # pylint: disable=unused-argument
) -> Mapping[str, torch.Tensor]:
new_sd = {}
for name, param in input_sd.items():
Expand All @@ -41,6 +42,9 @@ def _gptq_qweights_transpose_aiu(
serialization.register_adapter_step(
"gpt_bigcode", "gptq_qweights_transpose_aiu", _gptq_qweights_transpose_aiu
)
serialization.register_adapter_step(
"granite", "gptq_qweights_transpose_aiu", _gptq_qweights_transpose_aiu
)
serialization.register_adapter(
"llama",
"hf_gptq_aiu",
Expand All @@ -57,3 +61,14 @@ def _gptq_qweights_transpose_aiu(
"hf_gptq_aiu",
["hf_to_fms_names", "weight_fusion", "gptq_qweights_transpose_aiu"],
)
serialization.register_adapter(
"granite",
"hf_gptq_aiu",
[
"hf_to_fms_names",
"hf_to_fms_rope",
"hf_gptq_fusion_check",
"weight_fusion",
"gptq_qweights_transpose_aiu",
],
)
13 changes: 10 additions & 3 deletions fms_mo/aiu_addons/gptq/gptq_aiu_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,18 @@
from fms.modules.tp import ShardType, TPModule
from fms.utils.gptq import GPTQLinearConfig
import torch
import torch.nn as nn

# Local
from fms_mo.aiu_addons.gptq.gptq_aiu_op import register_aiu_gptq_op

register_aiu_gptq_op()


class GPTQLinearAIU(nn.Module):
class GPTQLinearAIU(torch.nn.Module):
"""Simplified QLinear that wraps GPTQ W4A16 custom operation.
gptq_gemm.i4f16_fxinputs_aiu must have been pre-registered to use this class.
"""

def __init__(
self,
in_features: int,
Expand Down Expand Up @@ -112,6 +115,8 @@ def __init__(
self.aiu_op = torch.ops.gptq_gemm.i4f16_fxinputs_aiu

def forward(self, x):
"""Call pre-registered custom GPTQ operation"""

x = self.aiu_op(
x.half(),
self.qweight,
Expand All @@ -137,7 +142,9 @@ def get_gptq_aiu_linear(
out_features: int,
bias: bool,
linear_config: Optional[Mapping[str, Any]] = None,
):
) -> torch.nn.Module:
"""Retrieve a GPTQ W4A16 Linear module"""

gptq_config = GPTQLinearConfig(**linear_config)
if gptq_config.desc_act:
raise NotImplementedError(
Expand Down
3 changes: 3 additions & 0 deletions fms_mo/aiu_addons/gptq/gptq_aiu_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
# Third Party
import torch

# pylint: disable=unused-argument
# gptq op must be registered with specific I/O, even if not in use by the op function

logger = logging.getLogger(__name__)


Expand Down
2 changes: 1 addition & 1 deletion fms_mo/aiu_addons/i8i8/i8i8_aiu_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

def _int8_qparams_aiu(
input_sd: Mapping[str, torch.Tensor],
**kwargs, # pylint: disable=unused-argument
) -> Mapping[str, torch.Tensor]:
new_sd = {}
modules_seen = set()
Expand Down Expand Up @@ -94,7 +95,6 @@ def _add_defaults_and_concat(
sq_scale.to(torch.float32),
)
)
return


# registration of new adapter steps for each architecture
Expand Down
19 changes: 11 additions & 8 deletions fms_mo/aiu_addons/i8i8/i8i8_aiu_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from fms.modules.tp import ShardType, TPModule
from fms.utils.config import ModelConfig
import torch
import torch.nn as nn

# Local
from fms_mo.aiu_addons.i8i8.i8i8_aiu_op import register_aiu_i8i8_op
Expand All @@ -38,6 +37,8 @@

@dataclass
class W8A8LinearConfig(ModelConfig):
"""Configuration for W8A8 Linear module"""

linear_type: str = "int8"
bits: int = 8
weight_per_channel: bool = False
Expand All @@ -46,8 +47,10 @@ class W8A8LinearConfig(ModelConfig):
smoothquant_layers: Optional[list] = None


class W8A8LinearAIU(nn.Module):
"""Simplified QLinear that wraps quantize/dequantize operation"""
class W8A8LinearAIU(torch.nn.Module):
"""Simplified QLinear that wraps quantize/dequantize operation.
fms_mo.i8i8_aiu must have been pre-registered to use this class.
"""

def __init__(
self,
Expand Down Expand Up @@ -199,7 +202,9 @@ def get_int8_aiu_linear(
bias: bool,
linear_config: Optional[Mapping[str, Any]] = None,
use_smoothquant: bool = True,
):
) -> torch.nn.Module:
"""Retrieve a W8A8 Linear module"""

int8_config = W8A8LinearConfig(**linear_config)
linear = W8A8LinearAIU(
in_features=in_features,
Expand All @@ -216,8 +221,7 @@ def shard_int8_aiu_linear(
tp_module: TPModule,
module_sharding_info: dict[str, LinearModuleShardingInfo],
) -> Optional[set]:
"""
Set up INT8 (W8A8) quantization parameters to be sharded onto
"""Set up INT8 (W8A8) quantization parameters to be sharded onto
AIU-compliant linear modules

| GPU |
Expand Down Expand Up @@ -273,8 +277,7 @@ def shard_int8_aiu_linear(
)

raise NotImplementedError("TP not yet supported for INT8. Work in progress")

return unused_keys
# return unused_keys


register_linear_type_to_module_map("int8_aiu", get_int8_aiu_linear)
Expand Down
40 changes: 26 additions & 14 deletions fms_mo/aiu_addons/i8i8/i8i8_aiu_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@

logger = logging.getLogger(__name__)

# pylint: disable=unused-argument
# i8i8 op must be registered with specific I/O, even if not in use by the op function

# pylint: disable=not-callable
# torch.nn.functional.linear not recognized as callable
# open issue in PyLint: https://github.com/pytorch/pytorch/issues/119482


def register_aiu_i8i8_op():
"""Register AIU-specific op to enable torch compile without graph break.
Expand Down Expand Up @@ -64,7 +71,8 @@ def i8i8_aiu(
dtype = x.dtype
out_feat, in_feat = weight.size()

w_cv, w_cvn, a_cv, a_cvn, zshift, sq = extract_qdata(
# unused returns are w_cvn and zero_shift
w_cv, _, a_cv, a_cvn, _, sq = extract_qdata(
qdata,
weight_quant_type,
activ_quant_type,
Expand All @@ -88,6 +96,8 @@ def i8i8_aiu_abstract(
activ_quant_type,
smoothquant,
):
"""OP template of I/O sizes"""

outshape = x.size()[:-1] + (weight.size(0),)
return torch.empty(
outshape, dtype=x.dtype, device=x.device, requires_grad=False
Expand Down Expand Up @@ -153,18 +163,19 @@ def dequant_weights(
w_cv: torch.Tensor,
sq: torch.Tensor,
weight_quant_type: str,
):
) -> torch.Tensor:
"""Dequantize integer weights based on quantizer type"""

if weight_quant_type == "per_tensor": # assume 8-bit symmetric W quantization
# w size: (out_feat, in_feat)
# sq size: (in_feat) or (1), no need to unsqueeze
return (weight * w_cv / 127) / sq
elif weight_quant_type == "per_channel":
if weight_quant_type == "per_channel":
# w_cv is (out_feat), need to unsqueeze to broadcast mul to weight
return (weight * w_cv.unsqueeze(dim=1) / 127) / sq
else:
raise NotImplementedError(
f"weight quantizantion type {weight_quant_type} is not supported"
)
raise NotImplementedError(
f"weight quantizantion type {weight_quant_type} is not supported"
)


def quant_dequant_activ(
Expand All @@ -173,8 +184,10 @@ def quant_dequant_activ(
a_cvn: torch.Tensor,
sq: torch.Tensor,
activ_quant_type: str,
):
) -> torch.Tensor:
"""
Quantize and dequantize activations based on quantizer type

x size (*, hid_dim)
sq size (hid_dim) or (1)
=> no need to unsqueeze to perform x / sq
Expand All @@ -183,18 +196,17 @@ def quant_dequant_activ(
scale_x = 127 / a_cv
x_int = torch.round(x / sq * scale_x).clamp(-127, 127)
return x_int / scale_x * sq
elif activ_quant_type == "per_tensor_asymm":
if activ_quant_type == "per_tensor_asymm":
scale_x = 255 / (a_cv - a_cvn)
zp_x = a_cvn * scale_x
x_int = torch.round(x / sq * scale_x - zp_x).clamp(0, 255)
return (x_int + zp_x) / scale_x * sq
elif activ_quant_type == "per_token":
if activ_quant_type == "per_token":
x_sq = x / sq
a_cv_per_token = x_sq.abs().max(dim=-1, keepdim=True)[0]
scale_x = 127 / a_cv_per_token
x_int = torch.round(x_sq * scale_x).clamp(-127, 127)
return x_int / scale_x * sq
else:
raise NotImplementedError(
f"activation quantizantion type {activ_quant_type} is not supported"
)
raise NotImplementedError(
f"activation quantizantion type {activ_quant_type} is not supported"
)
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ dependencies = [
"huggingface_hub",
"pandas",
"safetensors",
"ninja",
"ibm-fms>=0.0.8"
]

[project.optional-dependencies]
Expand Down