Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion fms_mo/aiu_addons/gptq/gptq_aiu_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def get_gptq_aiu_linear(
in_features: int,
out_features: int,
bias: bool,
linear_config: Optional[Mapping[str, Any]] = None,
linear_config: Mapping[str, Any],
) -> torch.nn.Module:
"""Retrieve a GPTQ W4A16 Linear module"""

Expand Down
4 changes: 2 additions & 2 deletions fms_mo/aiu_addons/i8i8/i8i8_aiu_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ def _int8_qparams_aiu(


def _add_defaults_and_concat(
new_sd: Mapping[str, torch.Tensor],
modules_seen: set,
new_sd: dict[str, torch.Tensor],
modules_seen: set[str],
) -> None:
"""
Add default activation clip values, zero_shift, and smoothquant_scale (if not
Expand Down
2 changes: 1 addition & 1 deletion fms_mo/aiu_addons/i8i8/i8i8_aiu_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def get_int8_aiu_linear(
in_features: int,
out_features: int,
bias: bool,
linear_config: Optional[Mapping[str, Any]] = None,
linear_config: Mapping[str, Any],
use_smoothquant: bool = True,
) -> torch.nn.Module:
"""Retrieve a W8A8 Linear module"""
Expand Down
12 changes: 6 additions & 6 deletions fms_mo/aiu_addons/i8i8/i8i8_aiu_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def i8i8_aiu(
x_dq = quant_dequant_activ(x, a_cv, a_cvn, sq, activ_quant_type)
w_dq = dequant_weights(weight, w_cv, sq, weight_quant_type)

return F.linear(x_dq.to(dtype), w_dq.to(dtype), bias)
return F.linear(x_dq.to(dtype), w_dq.to(dtype), bias.to(dtype))

@torch.library.impl_abstract(op_namespace_id)
def i8i8_aiu_abstract(
Expand Down Expand Up @@ -114,7 +114,7 @@ def extract_qdata(
w_in_feat: int,
w_out_feat: int,
smoothquant: bool,
) -> tuple[torch.Tensor]:
) -> tuple[torch.Tensor, ...]:
"""6 tensors are to be de-concatenated from qdata:
w_clip_val [ : idx1]
w_clip_valn [idx1: idx2]
Expand Down Expand Up @@ -194,19 +194,19 @@ def quant_dequant_activ(
"""
if activ_quant_type == "per_tensor_symm":
scale_x = 127 / a_cv
x_int = torch.round(x / sq * scale_x).clamp(-127, 127)
return x_int / scale_x * sq
x_int = torch.round(x / sq * scale_x).clamp(-127, 127).to(torch.int8)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this type casting really necessary? seems like the next line will apply a division with a float, which seems to automatically upcast again?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not needed. I added the casting during debugging, as I was observing numerical discrepancies between this function output and a reference output. I will remove it

return x_int.div(scale_x).mul(sq)
if activ_quant_type == "per_tensor_asymm":
scale_x = 255 / (a_cv - a_cvn)
zp_x = a_cvn * scale_x
x_int = torch.round(x / sq * scale_x - zp_x).clamp(0, 255)
return (x_int + zp_x) / scale_x * sq
return x_int.add(zp_x).div(scale_x).mul(sq)
if activ_quant_type == "per_token":
x_sq = x / sq
a_cv_per_token = x_sq.abs().max(dim=-1, keepdim=True)[0]
scale_x = 127 / a_cv_per_token
x_int = torch.round(x_sq * scale_x).clamp(-127, 127)
return x_int / scale_x * sq
return x_int.div(scale_x).mul(sq)
raise NotImplementedError(
f"activation quantizantion type {activ_quant_type} is not supported"
)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ dependencies = [
"accelerate>=0.20.3,!=0.34,<1.4",
"transformers>=4.45,<4.49",
"torch>=2.2.0,<2.5",
"triton>=3.0,<3.2",
"triton>=3.0,<3.2",
"tqdm>=4.66.2,<5.0",
"datasets>=3.0.0,<4.0",
"ninja>=1.11.1.1,<2.0",
Expand Down
90 changes: 43 additions & 47 deletions tests/aiu_addons/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
# limitations under the License.
"""Pytest configuration file with fixtures for add-ons functionality testing"""

# Standard
from pathlib import Path

# Third Party
import pytest
import torch
Expand Down Expand Up @@ -67,65 +70,58 @@ def get_gptq_gemm_inputs(request) -> tuple[torch.Tensor, ...]:

i8i8_metadata = [
{
"bs": 4,
"seq_len": 7,
"hid_dim": 256,
"out_feat": 512,
"dtype": torch.float16,
"wtype": "per_tensor", # per_channel
"atype": "per_tensor_symm", # per_tensor_asymm, per_token
"smoothquant": False,
}
},
# {
# "wtype": "per_channel", # per_channel
# "atype": "per_tensor_symm", # per_tensor_asymm, per_token
# "smoothquant": False,
# },
]


@pytest.fixture(scope="session", params=i8i8_metadata)
def get_i8i8_gemm_inputs(
request,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, str, str, bool]:
) -> tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
str,
str,
bool,
torch.Tensor,
]:
"""pytest fixture returning test inputs for INT8xINT8 op"""

data = request.param
x = torch.randn(
(data["bs"], data["seq_len"], data["hid_dim"]),
dtype=data["dtype"],
).clamp(-1, 1)
w_int = torch.randint(
low=-8,
high=8,
size=(data["out_feat"], data["hid_dim"]),
dtype=torch.int8,

filename = (
f"ref_w-{data['wtype']}_"
f"a-{data['atype']}_"
f"sq-{'Y' if data['smoothquant'] else 'N'}.pt"
)
b = torch.zeros(data["out_feat"], dtype=data["dtype"])
qdata = create_qdata(
data["wtype"],
data["atype"],
data["hid_dim"],
data["out_feat"],
data["smoothquant"],
data["dtype"],
addon_references = Path("tests/artifacts/aiu_addons")
i8i8_data = torch.load(addon_references / filename, weights_only=True)

assert isinstance(i8i8_data, dict)
assert data["wtype"] == i8i8_data["weight_quant_type"]
assert data["atype"] == i8i8_data["activ_quant_type"]
assert data["smoothquant"] == i8i8_data["smoothquant"]
assert all(
item in i8i8_data for item in ["x", "w_int", "bias", "qdata", "reference_out"]
)

return (x, w_int, b, qdata, data["wtype"], data["atype"], data["smoothquant"])


def create_qdata(
wtype: str,
atype: str,
in_feat: int,
out_feat: int,
smoothquant: bool,
dtype: torch.dtype,
) -> torch.Tensor:
"""Generate dummy qdata tensor based on the provided quantization configuration"""

qdata_len = 2 if wtype == "per_tensor" else 2 * out_feat # weight clips
qdata_len += 2 # activation clips
qdata_len += out_feat if atype == "per_tensor_asymm" else 1 # zero shift
qdata_len += in_feat if smoothquant else 1 # smoothquant scales

# TODO: improve dummy generation
qdata = torch.ones(qdata_len, dtype=dtype)
qdata[1] = -qdata[0] # !!! temporary solution to enforce clip symmetry
qdata[3] = -qdata[2]
return qdata
return (
i8i8_data["x"],
i8i8_data["w_int"],
i8i8_data["bias"],
i8i8_data["qdata"],
i8i8_data["weight_quant_type"],
i8i8_data["activ_quant_type"],
i8i8_data["smoothquant"],
i8i8_data["reference_out"],
)
21 changes: 18 additions & 3 deletions tests/aiu_addons/test_int8_addon.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,17 @@ def test_i8i8_registration() -> None:

def test_i8i8_op(
get_i8i8_gemm_inputs: tuple[
torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, str, str, bool
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
str,
str,
bool,
torch.Tensor,
],
) -> None:
"""Validate output shapes of INT8xINT8 matmul.
"""Validate output shapes and content of INT8xINT8 matmul.
Computations are simulated, using quantized/dequantized tensors.
"""

Expand All @@ -48,8 +55,13 @@ def test_i8i8_op(
weight_quant_type,
activ_quant_type,
smoothquant,
reference_out,
) = get_i8i8_gemm_inputs

# enforce fp16 dtype on all fp parameters for this test
x = x.to(torch.float16)
qdata = qdata.to(torch.float16)

out = torch.ops.fms_mo.i8i8_aiu(
x,
weight,
Expand All @@ -60,4 +72,7 @@ def test_i8i8_op(
smoothquant,
)

assert out.size() == torch.Size((x.size()[:-1] + (weight.size(0),)))
error_tolerance = 1e-4 # TODO: this needs adjusting
assert out.size() == x.size()[:-1] + (weight.size(0),)
assert torch.all((out - reference_out).abs() < error_tolerance)
# assert torch.linalg.norm(out - reference_out) < error_tolerance # alternative check
Binary file not shown.