diff --git a/examples/QAT_INT8/run_qa_no_trainer_qat.py b/examples/QAT_INT8/run_qa_no_trainer_qat.py index 04dc1070..c26e16bf 100644 --- a/examples/QAT_INT8/run_qa_no_trainer_qat.py +++ b/examples/QAT_INT8/run_qa_no_trainer_qat.py @@ -390,7 +390,7 @@ def parse_args(): "--do_lowering", choices=["cutlass", "triton"], type=str, - default="triton", + default=None, help="convert QAT model to utilize real INT8 GPU kernel, 'cutlass' or 'triton'", ) @@ -1136,7 +1136,7 @@ def speedtest(model, exam_inp, Ntest=100): logger.info( f"\n {label} {'with' if comp_mode else 'without'} torch.compile" ) - model_copy = deepcopy(model) + model_copy = deepcopy(model).half() if label == "int8": qcfg = qconfig_init(recipe="qat_int8", args=args) @@ -1178,7 +1178,7 @@ def speedtest(model, exam_inp, Ntest=100): # Median runtime using fixed input (in msec) med_runtime = speedtest(model_copy, exam_inp) - metrics = squad_eval(model_copy) if label == "int8" else {"f1": None} + metrics = squad_eval(model_copy) # if label == "int8" else {"f1": None} summary["precision"].append(label) summary["compile mode"].append(comp_mode) diff --git a/fms_mo/custom_ext_kernels/triton_kernels.py b/fms_mo/custom_ext_kernels/triton_kernels.py index bc4e4780..7b1bd58b 100644 --- a/fms_mo/custom_ext_kernels/triton_kernels.py +++ b/fms_mo/custom_ext_kernels/triton_kernels.py @@ -235,6 +235,7 @@ def imatmul_kernel( accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.int32) ## ------ prepare LSB rounding/truncation masks ------- round_bit = 1 << (chunk_trun_bits - 1) if chunk_trun_bits > 0 else 0 + # msb_mask = 0x00FFFFFF # only needed when simulating truncation on MSB ## --------------------------------------------------------- for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): @@ -326,7 +327,7 @@ def grid(META): kernel_config = { "BLOCK_SIZE_M": 128, "BLOCK_SIZE_K": chunk_size, - "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_N": 128, # was 32 "GROUP_SIZE_M": 8, "num_warps": 2, "num_stages": 5, @@ -335,7 +336,7 @@ def grid(META): kernel_config = { "BLOCK_SIZE_M": 128, "BLOCK_SIZE_K": chunk_size, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, # was 64 "GROUP_SIZE_M": 8, "num_warps": 4, "num_stages": 4, diff --git a/fms_mo/modules/linear.py b/fms_mo/modules/linear.py index 1a56e9e1..2c2e6527 100644 --- a/fms_mo/modules/linear.py +++ b/fms_mo/modules/linear.py @@ -752,7 +752,7 @@ def from_fms_mo(cls, fms_mo_qlinear, **kwargs): qlin_int.max_acc_bits = kwargs.get("max_acc_bits", 32) qlin_int.accminmax = ( -(1 << (qlin_int.max_acc_bits - 1)), - 1 << (qlin_int.max_acc_bits - 1) - 1, + (1 << (qlin_int.max_acc_bits - 1)) - 1, ) qlin_int.truncate_lsb = kwargs.get("truncate_lsb", 0) qlin_int.chunk_size = kwargs.get("chunk_size", 100000) @@ -871,16 +871,16 @@ def from_torch_iW(cls, nnlin_iW, prec, a_cv, a_cvn, w_cv, zero_shift, **kwargs): qlinear_iW.nbits_a = 8 # Only support INT8 for now qlinear_iW.nbits_w = 8 - qlinear_iW.acc_dtype = torch.float16 + qlinear_iW.acc_dtype = kwargs.get("acc_dtype", torch.float) qlinear_iW.usePTnativeQfunc = kwargs.get("use_PT_native_Qfunc", True) - qlinear_iW.use_int_kernel = True + qlinear_iW.use_int_kernel = kwargs.get("use_int_kernel", "triton") qlinear_iW.weight = nn.Parameter( nnlin_iW.weight.to(torch.int8), requires_grad=False ) qlinear_iW.max_acc_bits = kwargs.get("max_acc_bits", 32) qlinear_iW.accminmax = ( -(1 << (qlinear_iW.max_acc_bits - 1)), - 1 << (qlinear_iW.max_acc_bits - 1) - 1, + (1 << (qlinear_iW.max_acc_bits - 1)) - 1, ) qlinear_iW.truncate_lsb = kwargs.get("truncate_lsb", False) qlinear_iW.chunk_size = kwargs.get("chunk_size", 100000) @@ -1027,11 +1027,11 @@ def iaddmm_int(self, bias, m1, m2): else: m1 = self.qa_fmo_mo_qfunc(m1) - if m1.shape[1] > self.chunk_size: + if m1.shape[1] > self.chunk_size and self.use_int_kernel != "triton": idx = list(range(0, m1.shape[1], self.chunk_size)) Nchunk = len(idx) idx.append(m1.shape[1]) - fp16_out = torch.zeros( + accumulator = torch.zeros( (m1.shape[0], m2.shape[1]), dtype=torch.float16, device=m1.device ) trun_scale = 1 @@ -1052,11 +1052,11 @@ def iaddmm_int(self, bias, m1, m2): # could cast to smaller data type to further simulate HW behavior, for example, # if HW truncates 8b from both sides of i32 accumulator, the remaining data can # be cast to i16 to be more realistic. pay attention to overflow handling - fp16_out += imm_out.to(torch.float16) + accumulator += imm_out.to(torch.float16) return ( - fp16_out - * (trun_scale * self.input_scale * self.w_scale).to(torch.float16) + accumulator + * (trun_scale * self.input_scale * self.w_scale) # .to(torch.float16) + bias ).to(self.acc_dtype) # The safest casting, i32 -> f32 @@ -1145,10 +1145,13 @@ def extra_repr(self) -> str: """ Returns an alternative string representation of the object """ - return ( + repr_str = ( f"in={self.in_features}, out={self.out_features}, bias={self.bias is not None}, " - f"use_int_kernel={self.use_int_kernel}" + f"int_kernel={self.use_int_kernel}" ) + if self.truncate_lsb > 0 or self.max_acc_bits < 32: + repr_str += f", acc_bits={self.max_acc_bits}, trun_lsb={self.truncate_lsb}" + return repr_str def __getstate__(self): """ diff --git a/pyproject.toml b/pyproject.toml index d8eeeaf2..fc760362 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,12 +26,13 @@ dependencies = [ "accelerate>=0.20.3,!=0.34,<1.4", "transformers>=4.45,<4.49", "torch>=2.2.0,<2.5", +"triton>=3.0,<3.2", "tqdm>=4.66.2,<5.0", "datasets>=3.0.0,<4.0", "ninja>=1.11.1.1,<2.0", "tensorboard", "notebook", -"torchvision>=0.8", +"torchvision>=0.17", "evaluate", "huggingface_hub", "pandas",