Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions examples/QAT_INT8/run_qa_no_trainer_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ def parse_args():
"--do_lowering",
choices=["cutlass", "triton"],
type=str,
default="triton",
default=None,
help="convert QAT model to utilize real INT8 GPU kernel, 'cutlass' or 'triton'",
)

Expand Down Expand Up @@ -1136,7 +1136,7 @@ def speedtest(model, exam_inp, Ntest=100):
logger.info(
f"\n {label} {'with' if comp_mode else 'without'} torch.compile"
)
model_copy = deepcopy(model)
model_copy = deepcopy(model).half()

if label == "int8":
qcfg = qconfig_init(recipe="qat_int8", args=args)
Expand Down Expand Up @@ -1178,7 +1178,7 @@ def speedtest(model, exam_inp, Ntest=100):

# Median runtime using fixed input (in msec)
med_runtime = speedtest(model_copy, exam_inp)
metrics = squad_eval(model_copy) if label == "int8" else {"f1": None}
metrics = squad_eval(model_copy) # if label == "int8" else {"f1": None}

summary["precision"].append(label)
summary["compile mode"].append(comp_mode)
Expand Down
5 changes: 3 additions & 2 deletions fms_mo/custom_ext_kernels/triton_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ def imatmul_kernel(
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.int32)
## ------ prepare LSB rounding/truncation masks -------
round_bit = 1 << (chunk_trun_bits - 1) if chunk_trun_bits > 0 else 0
# msb_mask = 0x00FFFFFF # only needed when simulating truncation on MSB
## ---------------------------------------------------------

for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
Expand Down Expand Up @@ -326,7 +327,7 @@ def grid(META):
kernel_config = {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_K": chunk_size,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_N": 128, # was 32
"GROUP_SIZE_M": 8,
"num_warps": 2,
"num_stages": 5,
Expand All @@ -335,7 +336,7 @@ def grid(META):
kernel_config = {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_K": chunk_size,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_N": 128, # was 64
"GROUP_SIZE_M": 8,
"num_warps": 4,
"num_stages": 4,
Expand Down
25 changes: 14 additions & 11 deletions fms_mo/modules/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,7 +752,7 @@ def from_fms_mo(cls, fms_mo_qlinear, **kwargs):
qlin_int.max_acc_bits = kwargs.get("max_acc_bits", 32)
qlin_int.accminmax = (
-(1 << (qlin_int.max_acc_bits - 1)),
1 << (qlin_int.max_acc_bits - 1) - 1,
(1 << (qlin_int.max_acc_bits - 1)) - 1,
)
qlin_int.truncate_lsb = kwargs.get("truncate_lsb", 0)
qlin_int.chunk_size = kwargs.get("chunk_size", 100000)
Expand Down Expand Up @@ -871,16 +871,16 @@ def from_torch_iW(cls, nnlin_iW, prec, a_cv, a_cvn, w_cv, zero_shift, **kwargs):

qlinear_iW.nbits_a = 8 # Only support INT8 for now
qlinear_iW.nbits_w = 8
qlinear_iW.acc_dtype = torch.float16
qlinear_iW.acc_dtype = kwargs.get("acc_dtype", torch.float)
qlinear_iW.usePTnativeQfunc = kwargs.get("use_PT_native_Qfunc", True)
qlinear_iW.use_int_kernel = True
qlinear_iW.use_int_kernel = kwargs.get("use_int_kernel", "triton")
qlinear_iW.weight = nn.Parameter(
nnlin_iW.weight.to(torch.int8), requires_grad=False
)
qlinear_iW.max_acc_bits = kwargs.get("max_acc_bits", 32)
qlinear_iW.accminmax = (
-(1 << (qlinear_iW.max_acc_bits - 1)),
1 << (qlinear_iW.max_acc_bits - 1) - 1,
(1 << (qlinear_iW.max_acc_bits - 1)) - 1,
)
qlinear_iW.truncate_lsb = kwargs.get("truncate_lsb", False)
qlinear_iW.chunk_size = kwargs.get("chunk_size", 100000)
Expand Down Expand Up @@ -1027,11 +1027,11 @@ def iaddmm_int(self, bias, m1, m2):
else:
m1 = self.qa_fmo_mo_qfunc(m1)

if m1.shape[1] > self.chunk_size:
if m1.shape[1] > self.chunk_size and self.use_int_kernel != "triton":
idx = list(range(0, m1.shape[1], self.chunk_size))
Nchunk = len(idx)
idx.append(m1.shape[1])
fp16_out = torch.zeros(
accumulator = torch.zeros(
(m1.shape[0], m2.shape[1]), dtype=torch.float16, device=m1.device
)
trun_scale = 1
Expand All @@ -1052,11 +1052,11 @@ def iaddmm_int(self, bias, m1, m2):
# could cast to smaller data type to further simulate HW behavior, for example,
# if HW truncates 8b from both sides of i32 accumulator, the remaining data can
# be cast to i16 to be more realistic. pay attention to overflow handling
fp16_out += imm_out.to(torch.float16)
accumulator += imm_out.to(torch.float16)

return (
fp16_out
* (trun_scale * self.input_scale * self.w_scale).to(torch.float16)
accumulator
* (trun_scale * self.input_scale * self.w_scale) # .to(torch.float16)
+ bias
).to(self.acc_dtype)
# The safest casting, i32 -> f32
Expand Down Expand Up @@ -1145,10 +1145,13 @@ def extra_repr(self) -> str:
"""
Returns an alternative string representation of the object
"""
return (
repr_str = (
f"in={self.in_features}, out={self.out_features}, bias={self.bias is not None}, "
f"use_int_kernel={self.use_int_kernel}"
f"int_kernel={self.use_int_kernel}"
)
if self.truncate_lsb > 0 or self.max_acc_bits < 32:
repr_str += f", acc_bits={self.max_acc_bits}, trun_lsb={self.truncate_lsb}"
return repr_str

def __getstate__(self):
"""
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,13 @@ dependencies = [
"accelerate>=0.20.3,!=0.34,<1.4",
"transformers>=4.45,<4.49",
"torch>=2.2.0,<2.5",
"triton>=3.0,<3.2",
"tqdm>=4.66.2,<5.0",
"datasets>=3.0.0,<4.0",
"ninja>=1.11.1.1,<2.0",
"tensorboard",
"notebook",
"torchvision>=0.8",
"torchvision>=0.17",
"evaluate",
"huggingface_hub",
"pandas",
Expand Down
Loading