diff --git a/.spellcheck-en-custom.txt b/.spellcheck-en-custom.txt index 4f569522..eb138cfa 100644 --- a/.spellcheck-en-custom.txt +++ b/.spellcheck-en-custom.txt @@ -38,6 +38,7 @@ Inductor inferenced inferencing isort +JIT Jupyter Kubernetes KV @@ -105,6 +106,7 @@ Tokenized tokenizer Tokenizer toml +triton Unquantized vals venv diff --git a/examples/QAT_INT8/README.md b/examples/QAT_INT8/README.md index 758d263b..b6abbb12 100644 --- a/examples/QAT_INT8/README.md +++ b/examples/QAT_INT8/README.md @@ -87,16 +87,16 @@ python run_qa_no_trainer_qat.py \ --max_seq_length 384 \ --doc_stride 128 \ --attn_impl eager \ - --do_lowering + --do_lowering ``` -This script uses an "external kernel" instead of the `torch.matmul` kernel to perform real `INT8` matmuls. This kernel is written for Nvidia's CUDA/CUTLASS library and is compiled once just ahead of the run. The compiled artifacts are usually stored in `~/.cache/torch_extensions/`. Remove this folder if a fresh recompile of the kernel is needed. +This script uses an "external kernel" instead of the `torch.matmul` kernel to perform real `INT8` matmuls. We have two options for INT kernel, one is written using Nvidia's CUDA/CUTLASS library and one is in Triton. Both will be compiled once just ahead of the run (i.e., just-in-time, JIT, compilation). The compiled artifacts are usually stored in `~/.cache/torch_extensions/`. Remove this folder if a fresh recompile of the kernel is needed. Checkout [Example Test Results](#example-test-results) to compare against your results. ## Example Test Results -For comparison purposes, here are some of the results we found during testing when tested with `PyTorch 2.3.1`: +For comparison purposes, here are some of the results from an A100. CUTLASS results were obtained with `PyTorch 2.3.1` while Triton results were obtained using `PyTorch 2.4.1`: > [!NOTE] > Accuracy could vary ~ +-0.2 from run to run. @@ -106,9 +106,12 @@ For comparison purposes, here are some of the results we found during testing wh |fp16|128|eager |88.21 (as fine-tuned) |126.38| | |128|Inductor | |71.59| | |128|CUDAGRAPH | |71.13| -|INT8|128|eager |88.33|329.45 1| +|INT8 CUTLASS|128|eager |88.33|329.45 1| | |128|Inductor |88.42|67.87 2| | |128|CUDAGRAPH |-- |-- 3| +|INT8 triton|128|eager |88.10|358.51| +| |128|Inductor |88.13|99.91 4| +| |128|CUDAGRAPH |88.13|100.21 4| 1 `INT8` matmuls are ~2x faster than `FP16` matmuls. However, `INT8` models will have additional overhead compared to `FP16` models. For example, converting FP tensors to INT before INT matmul. @@ -116,6 +119,8 @@ For comparison purposes, here are some of the results we found during testing wh 3 `CUDAGRAPH` is the most effective way to minimize job launching overheads and can achieve ~2X end-to-end speed-up in this case. However, there seem to be bugs associated with this option at the moment. Further investigation is still on-going. +4 Unlike our CUTLASS `INT8` kernel, which is ~2x faster than `FP16` matmul, our Triton `INT8` is not as optimized and performs only comparable with `FP16` on mid-to-large tensor sizes. + ## Code Walk-through In this section, we will deep dive into what happens during the example steps. diff --git a/examples/QAT_INT8/run_qa_no_trainer_qat.py b/examples/QAT_INT8/run_qa_no_trainer_qat.py index 8fd7c991..04dc1070 100644 --- a/examples/QAT_INT8/run_qa_no_trainer_qat.py +++ b/examples/QAT_INT8/run_qa_no_trainer_qat.py @@ -388,8 +388,10 @@ def parse_args(): ) parser.add_argument( "--do_lowering", - action="store_true", - help="convert QAT model to utilize real INT8 GPU kernel", + choices=["cutlass", "triton"], + type=str, + default="triton", + help="convert QAT model to utilize real INT8 GPU kernel, 'cutlass' or 'triton'", ) args = parser.parse_args() @@ -1086,7 +1088,7 @@ def squad_eval(model, keep_model_in_eval_mode=True): qmodel_prep(model, exam_inp, qcfg, optimizer, use_dynamo=True) # ---- [fms_mo] the following code are performing speed tests ---- - elif args.do_lowering: + elif args.do_lowering in ["cutlass", "triton"]: # Standard from copy import deepcopy import time @@ -1158,7 +1160,11 @@ def speedtest(model, exam_inp, Ntest=100): parent_mod = model_copy.get_submodule(parent_name) qmod = getattr(parent_mod, module_name) setattr( - parent_mod, module_name, QLinearINT8Deploy.from_fms_mo(qmod) + parent_mod, + module_name, + QLinearINT8Deploy.from_fms_mo( + qmod, use_int_kernel=args.do_lowering + ), ) if comp_mode is not False: @@ -1385,6 +1391,7 @@ def speedtest(model, exam_inp, Ntest=100): ) logger.info(f"Predict metrics: {predict_metric}") + log = {} if args.with_tracking: log = { "squad_v2" if args.version_2_with_negative else "squad": eval_metric, diff --git a/fms_mo/custom_ext_kernels/triton_kernels.py b/fms_mo/custom_ext_kernels/triton_kernels.py new file mode 100644 index 00000000..bc4e4780 --- /dev/null +++ b/fms_mo/custom_ext_kernels/triton_kernels.py @@ -0,0 +1,361 @@ +# Copyright The FMS Model Optimizer Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This file contains external kernels for FP and INT8 matmul written in triton.""" + +# Third Party +from triton.language.extra import libdevice +import torch +import triton +import triton.language as tl + + +def get_cuda_autotune_config(chunk_size=None): + """Basic use of triton.Config() is like: + triton.Config( + {'BLOCK_SIZE_M': i, + 'BLOCK_SIZE_N': j, + 'BLOCK_SIZE_K': k, + 'GROUP_SIZE_M': l + }, + num_stages=m, + num_warps=n + ) + User could override BLOCK_SIZE_K to a certain chunk_size (must >16). + """ + test_combinations = [ + (128, 256, 64, 8, 3, 8), + (64, 256, 32, 8, 4, 4), + (128, 128, 32, 8, 4, 4), + (128, 64, 32, 8, 4, 4), + (64, 128, 32, 8, 4, 4), + (128, 32, 32, 8, 4, 4), + (64, 32, 32, 8, 5, 2), + (32, 64, 32, 8, 5, 2), + # Good config for fp8 inputs. + (128, 256, 128, 8, 3, 8), + (256, 128, 128, 8, 3, 8), + (256, 64, 128, 8, 4, 4), + (64, 256, 128, 8, 4, 4), + (128, 128, 128, 8, 4, 4), + (128, 64, 64, 8, 4, 4), + (64, 128, 64, 8, 4, 4), + (128, 32, 64, 8, 4, 4), + ] + return [ + triton.Config( + { + "BLOCK_SIZE_M": i, + "BLOCK_SIZE_N": j, + "BLOCK_SIZE_K": chunk_size if chunk_size else k, + "GROUP_SIZE_M": l, + }, + num_stages=m, + num_warps=n, + ) + for i, j, k, l, m, n in test_combinations + ] + + +# `triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, +# which consumes: +# - A list of `triton.Config` objects that define different configurations of +# meta-parameters (e.g., `BLOCK_SIZE_M`) and compilation options (e.g., `num_warps`) to try +# - An auto-tuning *key* whose change in values will trigger evaluation of all the +# provided configs +# => Need to avoid using auto-tune for real model inference! But for micro-benchmarking purpose, we +# could enable the decorator below +# @triton.autotune(configs=get_cuda_autotune_config(), key=['M', 'N', 'K']) +@triton.jit +def matmul_kernel( + # Pointers to matrices + a_ptr, + b_ptr, + c_ptr, + # Matrix dimensions + M, + N, + K, + # The stride variables represent how much to increase the ptr by when moving by 1 + # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr` + # by to get the element one row down (A has M rows). + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + chunk_trun_bits, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + ACTIVATION: tl.constexpr, +): + """Kernel for computing the matmul C = A x B that include LSB truncation. + A has shape (M, K), B has shape (K, N) and C has shape (M, N) + Args: + chunk_trun_bits (int): number of LSB to truncate/round. [0 to 23] + """ + # ----------------------------------------------------------- + # Map program ids `pid` to the block of C it should compute. + # This is done in a grouped ordering to promote L2 data reuse. + # See above `L2 Cache Optimizations` section for details. + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # ---------------------------------------------------------- + # Create pointers for the first blocks of A and B. + # We will advance this pointer as we move in the K direction + # and accumulate + # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers + # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers + # See above `Pointer Arithmetic` section for details + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + # ----------------------------------------------------------- + # Iterate to compute a block of the C matrix. + # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block + # of fp32 values for higher accuracy. + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + ## ------ prepare LSB rounding/truncation masks ------- + # NOTE mask will be applied on accumulator, which is alway FP32, so we may truncate up to 23b + # e.g., 20b -> trun_mask = 0xFFF00000, round_bit = 0x00080000 + # 8b -> trun_mask = 0xFFFFFF00, round_bit = 0x00000080 + full_32b_mask = 0xFFFFFFFF + trun_mask = (full_32b_mask << chunk_trun_bits) & full_32b_mask + round_bit = 1 << (chunk_trun_bits - 1) if chunk_trun_bits > 0 else 0 + ## --------------------------------------------------------- + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # Load the next block of A and B, generate a mask by checking the K dimension. + # If it is out of bounds, set it to 0. + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + # We accumulate along the K dimension. + accumulator = tl.dot(a, b, accumulator, input_precision="ieee") + # tl.dot() default is using TF32 approximation, not good enough for LSB truncation exp + + ## ------ add chunky LSB rounding/masking -------- + if chunk_trun_bits != 0: + accumulator = libdevice.uint_as_float( + (libdevice.float_as_uint(accumulator) + round_bit) & trun_mask + ) + ## --------------------------------------------------------- + + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + # You can fuse arbitrary activation functions here + # while the accumulator is still in FP32! + if ACTIVATION == "leaky_relu": + accumulator = leaky_relu(accumulator) + + c = accumulator # do not cast to (tl.float16) just yet + + # ----------------------------------------------------------- + # Write back the block of the output matrix C with masks. + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + +# Reminder: avoid auto-tune for real model inference! But for micro-benchmarking purpose, could +# enable the decorator below +# @triton.autotune(configs=get_cuda_autotune_config(),key=['M', 'N', 'K'],) +@triton.jit +def imatmul_kernel( + a_ptr, + b_ptr, + c_ptr, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + chunk_trun_bits, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + ACTIVATION: tl.constexpr, +): + """Kernel for computing the INT matmul C = A x B that include LSB truncation. A and B should be + INT8, C should be INT32. (Pretty much the same code as float version.) + A has shape (M, K), B has shape (K, N) and C has shape (M, N) + Args: + chunk_trun_bits (int): number of LSBs to truncate/round. + """ + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.int32) + ## ------ prepare LSB rounding/truncation masks ------- + round_bit = 1 << (chunk_trun_bits - 1) if chunk_trun_bits > 0 else 0 + ## --------------------------------------------------------- + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, b, accumulator, input_precision="ieee") + + ## ------ add chunky LSB rounding/masking -------- + if chunk_trun_bits != 0: + accumulator = (accumulator + round_bit) >> chunk_trun_bits + accumulator = accumulator << chunk_trun_bits + ## --------------------------------------------------------- + + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + if ACTIVATION == "leaky_relu": + accumulator = leaky_relu(accumulator) + + c = accumulator + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + +@triton.jit +def leaky_relu(x): + """Activation function that could be fused into matmul kernel""" + return tl.where(x >= 0, x, 0.01 * x) + + +def tl_matmul_chunk_truncate(a, b, activation="", chunk_trun_bits=0, chunk_size=16): + """Triton matmul for HW behavior simulation. Supports float and int8. + a. variable chunk size (i.e., BLOCK_SIZE_K) + b. LSB truncation, must <23 if using float. + + Args: + a, b: input tensors. FloatX, X in [32, 16, 8] or INT8. + activation (str, optional): activation func to be fused, see relu example. + chunk_trun_bits (int, optional): number of LSBs to be truncated/rounded. + chunk_size (int, optional): BLOCK_SIZE_K, some HW has specific chunk size. must >= 16. + + Returns: + _type_: _description_ + + NOTE: + use empirical way to determine BLOCK sizes, may not be optimal. But need to avoid autotune for + real model inference. otherwise auto-tune will be triggered in every forward call. + """ + # Check constraints. + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + assert a.is_contiguous(), "Matrix A must be contiguous" + assert a.dtype == b.dtype, "Input dtypes inconsistent" + + allowed_dtypes = [torch.float, torch.bfloat16, torch.float16] + cuda_cc = torch.cuda.get_device_capability() + if cuda_cc[0] >= 8: + allowed_dtypes.append(torch.int8) + if cuda_cc[0] >= 9 or cuda_cc == (8, 9): + allowed_dtypes += [torch.float8_e4m3fn, torch.float8_e5m2] + assert a.dtype in allowed_dtypes, "Input dtype is not supported" + M, K = a.shape + K, N = b.shape + + # Allocates output, always accumulate in FP32/INT32 then cast (if floats) + def isPowerofTwo(x): + """triton-specific limitation: block size needs to be power of 2.""" + return (x & (x - 1)) == 0 + + if a.dtype == torch.int8: + mm_kernel = imatmul_kernel + chunk_size = max(chunk_size, 32) if isPowerofTwo(chunk_size) else 32 + c = torch.zeros((M, N), device=a.device, dtype=torch.int32) + else: + assert chunk_trun_bits < 23, "FP32 accumulator only has 23 mantissa bits" + mm_kernel = matmul_kernel + chunk_size = max(chunk_size, 16) if isPowerofTwo(chunk_size) else 16 + c = torch.zeros((M, N), device=a.device, dtype=torch.float32) + + # 1D launch kernel where each block gets its own program. + def grid(META): + return ( + triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), + ) + + if M < 1024 or N < 1024: + kernel_config = { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_K": chunk_size, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 8, + "num_warps": 2, + "num_stages": 5, + } + else: + kernel_config = { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_K": chunk_size, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 4, + } + + mm_kernel[grid]( + a, + b, + c, + M, + N, + K, + a.stride(0), + a.stride(1), + b.stride(0), + b.stride(1), + c.stride(0), + c.stride(1), + chunk_trun_bits=chunk_trun_bits, + ACTIVATION=activation, + **kernel_config, # if using auto-tune, comment this line out. + ) + return c.to(a.dtype) if a.dtype != torch.int8 else c diff --git a/fms_mo/modules/bmm.py b/fms_mo/modules/bmm.py index 86dcea5c..aa16deff 100644 --- a/fms_mo/modules/bmm.py +++ b/fms_mo/modules/bmm.py @@ -364,7 +364,7 @@ def from_fms_mo(cls, fms_mo_qbmm, **kwargs): qbmm_int.num_bits_m1 = fms_mo_qbmm.num_bits_m1 qbmm_int.num_bits_m2 = fms_mo_qbmm.num_bits_m2 qcfg = getattr(fms_mo_qbmm, "qcfg", None) - qbmm_int.useINTkernel = False # always False until int kernel is implemented + qbmm_int.use_int_kernel = False # always False until int kernel is implemented qbmm_int.use_PT_native_Qfunc = qcfg["use_PT_native_Qfunc"] if qcfg else False with torch.no_grad(): @@ -438,7 +438,7 @@ def extra_repr(self) -> str: """ return ( f"nbits_m1,m2={self.num_bits_m1},{self.num_bits_m2}, " - f"useINTkernel={self.useINTkernel}" + f"use_int_kernel={self.use_int_kernel}" ) def forward(self, m1: torch.Tensor, m2: torch.Tensor) -> torch.Tensor: diff --git a/fms_mo/modules/linear.py b/fms_mo/modules/linear.py index 9243f127..1a56e9e1 100644 --- a/fms_mo/modules/linear.py +++ b/fms_mo/modules/linear.py @@ -27,6 +27,9 @@ import torch.nn.functional as F # Local +from fms_mo.custom_ext_kernels.triton_kernels import ( + tl_matmul_chunk_truncate as tl_matmul, +) from fms_mo.custom_ext_kernels.utils import pack_vectorized from fms_mo.quant.quantizers import ( HardPrune, @@ -708,6 +711,17 @@ def from_fms_mo(cls, fms_mo_qlinear, **kwargs): Args: cls: The class of the QLinearModule to be created. fms_mo_qlinear: The QLinear module to be converted. + (experimental) + use_int_kernel: choose from ['cutlass', 'triton', False], "cutlass" kernel is faster, + "triton" supports chunky truncation, "False" fallbacks to torch.matmul + max_acc_bits: usually INT matmul accumulate in INT32, but some HW could have different + design, such as using INT24 accumulator, which will saturate at + (-2**(acc_bit-1) +1, 2**(acc_bit-1) ) + truncate_lsb: some HW may apply truncation on least-significant bits (LSBs) of the + accumulated partial sum + chunk_size: some HW may have specific chunk size (BLOCK SIZE, especially in k-dim) for + the reason to avoid overflow/underflow problem. This can be simulated using + PyTorch (break a matmul into serial smaller matmuls, slow) or Triton kernel Returns: A QLinearINT8Deploy object initialized with the weights and biases from the @@ -731,10 +745,17 @@ def from_fms_mo(cls, fms_mo_qlinear, **kwargs): ) # Make sure to register an Op for integer matmul, could be real INT matmul or emulation qcfg = getattr(fms_mo_qlinear, "qcfg", {}) - qlin_int.useINTkernel = qcfg.get("useINTkernel", True) + qlin_int.use_int_kernel = kwargs.get( + "use_int_kernel", qcfg.get("use_int_kernel", "cutlass") + ) qlin_int.usePTnativeQfunc = kwargs.get("use_PT_native_Qfunc", False) - qlin_int.acc24minmax = (int(-(2**24) / 2 + 1), int(2**24 / 2)) - qlin_int.simi24toi16 = kwargs.get("simi24toi16", False) + qlin_int.max_acc_bits = kwargs.get("max_acc_bits", 32) + qlin_int.accminmax = ( + -(1 << (qlin_int.max_acc_bits - 1)), + 1 << (qlin_int.max_acc_bits - 1) - 1, + ) + qlin_int.truncate_lsb = kwargs.get("truncate_lsb", 0) + qlin_int.chunk_size = kwargs.get("chunk_size", 100000) qlin_int.acc_dtype = torch.float16 qlin_int.nbits_a = fms_mo_qlinear.num_bits_feature # only support INT8 for now qlin_int.nbits_w = fms_mo_qlinear.num_bits_weight @@ -751,7 +772,7 @@ def from_fms_mo(cls, fms_mo_qlinear, **kwargs): ) # Qw.clipval should have been updated after this qlin_int.weight = nn.Parameter( w_int8.to(torch.int8), requires_grad=False - ) # NOTE: Needs INT W stored as FP... + ) # NOTE: may need INT W stored as FP in some cases if qlin_int.usePTnativeQfunc: input_scale = torch.tensor( @@ -852,12 +873,17 @@ def from_torch_iW(cls, nnlin_iW, prec, a_cv, a_cvn, w_cv, zero_shift, **kwargs): qlinear_iW.nbits_w = 8 qlinear_iW.acc_dtype = torch.float16 qlinear_iW.usePTnativeQfunc = kwargs.get("use_PT_native_Qfunc", True) - qlinear_iW.useINTkernel = True + qlinear_iW.use_int_kernel = True qlinear_iW.weight = nn.Parameter( nnlin_iW.weight.to(torch.int8), requires_grad=False ) - qlinear_iW.acc24minmax = (int(-(2**24) / 2 + 1), int(2**24 / 2)) - qlinear_iW.simi24toi16 = kwargs.get("simi24toi16", False) + qlinear_iW.max_acc_bits = kwargs.get("max_acc_bits", 32) + qlinear_iW.accminmax = ( + -(1 << (qlinear_iW.max_acc_bits - 1)), + 1 << (qlinear_iW.max_acc_bits - 1) - 1, + ) + qlinear_iW.truncate_lsb = kwargs.get("truncate_lsb", False) + qlinear_iW.chunk_size = kwargs.get("chunk_size", 100000) with torch.no_grad(): if qlinear_iW.usePTnativeQfunc: @@ -1001,25 +1027,36 @@ def iaddmm_int(self, bias, m1, m2): else: m1 = self.qa_fmo_mo_qfunc(m1) - if self.simi24toi16: - chunk_size = 99999 - idx = list(range(0, m1.shape[1], chunk_size)) + if m1.shape[1] > self.chunk_size: + idx = list(range(0, m1.shape[1], self.chunk_size)) Nchunk = len(idx) idx.append(m1.shape[1]) fp16_out = torch.zeros( (m1.shape[0], m2.shape[1]), dtype=torch.float16, device=m1.device ) + trun_scale = 1 + if self.truncate_lsb > 0: + round_bit = 1 << (self.truncate_lsb - 1) + trun_scale = 1 << self.truncate_lsb + for i in range(Nchunk): imm_out = torch.ops.fms_mo.imatmul( m1[:, idx[i] : idx[i + 1]], m2[idx[i] : idx[i + 1], :] ) - imm_out = imm_out.clamp(self.acc24minmax[0], self.acc24minmax[1]) - imm_out = torch.bitwise_right_shift(imm_out + 128, 8) - imm_out = imm_out.to(torch.int16) + if self.max_acc_bits < 32: + imm_out = imm_out.clamp(self.accminmax[0], self.accminmax[1]) + if self.truncate_lsb > 0: + imm_out = torch.bitwise_right_shift( + imm_out + round_bit, self.truncate_lsb + ) + # could cast to smaller data type to further simulate HW behavior, for example, + # if HW truncates 8b from both sides of i32 accumulator, the remaining data can + # be cast to i16 to be more realistic. pay attention to overflow handling fp16_out += imm_out.to(torch.float16) return ( - fp16_out * (256 * self.input_scale * self.w_scale).to(torch.float16) + fp16_out + * (trun_scale * self.input_scale * self.w_scale).to(torch.float16) + bias ).to(self.acc_dtype) # The safest casting, i32 -> f32 @@ -1049,31 +1086,43 @@ def set_matmul_op(self): """ Sets the matmul operator for the quantized linear module. - If `useINTkernel` is True and CUDA is available, it will use the INT kernel + If `use_int_kernel` is True and CUDA is available, it will use the INT kernel for integer matrix multiplication. Otherwise, it will use the FP kernel. If the operator has already been set, it will do nothing. """ - if self.useINTkernel and not torch.cuda.is_available(): + if self.use_int_kernel and not torch.cuda.is_available(): logger.warning( - "Cannot set useINTkernel=True when CUDA is not available. " - "Fallback to useINTkernel=False" + "Cannot set use_int_kernel=True when CUDA is not available. " + "Fallback to use_int_kernel=False" ) - self.useINTkernel = False + self.use_int_kernel = False if hasattr(torch.ops, "fms_mo") and hasattr(torch.ops.fms_mo, "imatmul"): # imatmul already registered, e.g. when swapping the 2nd QLinear self.imatmul = torch.ops.fms_mo.imatmul - self.iaddmm = self.iaddmm_int if self.useINTkernel else self.iaddmm_FP + self.iaddmm = self.iaddmm_int if self.use_int_kernel else self.iaddmm_FP else: # When swapping the first QLinear, need to register our custom Op and choose the kernel + # Standard + from functools import partial + # Local from fms_mo.custom_ext_kernels.utils import ( cutlass_ops_load_and_reg, imatmul_ops_reg, ) - if self.useINTkernel: # will use real imatmul + if self.use_int_kernel == "triton": + # will use real imatmul written in triton + imm_func = partial( + tl_matmul, + chunk_trun_bits=self.truncate_lsb, + chunk_size=self.chunk_size, + ) + + elif self.use_int_kernel == "cutlass": + # will use real imatmul written in cutlass cutlass_ops_load_and_reg() # Third Party import cutlass_mm # this module will only be available after calling reg() @@ -1082,9 +1131,9 @@ def set_matmul_op(self): else: imm_func = torch.matmul - imatmul_ops_reg(self.useINTkernel, imm_func) + imatmul_ops_reg(self.use_int_kernel, imm_func) self.imatmul = torch.ops.fms_mo.imatmul - self.iaddmm = self.iaddmm_int if self.useINTkernel else self.iaddmm_FP + self.iaddmm = self.iaddmm_int if self.use_int_kernel else self.iaddmm_FP def _get_name(self): """ @@ -1098,7 +1147,7 @@ def extra_repr(self) -> str: """ return ( f"in={self.in_features}, out={self.out_features}, bias={self.bias is not None}, " - f"useINTkernel={self.useINTkernel}" + f"use_int_kernel={self.use_int_kernel}" ) def __getstate__(self): @@ -1727,3 +1776,141 @@ def isinstance_qlinear(module): bool: True if the module is a quantized linear class, False otherwise. """ return isinstance(module, QLinear_modules) + + +class LinearFuncFPxFwdBwd(torch.autograd.Function): + """Linear function using FP24 accumulation, experimental only. + Input and weights can be fp16, bf16, or fp32. W.shape = [out, in]. + W and bias could be of different dtype from input, will cast before calling + triton kernel. This triton kernel will always use fp32 accumulation, then + truncate/rounded last 8 or 16 or 20 bits (from LSB side). + Modified from microxcaling Linear. + """ + + @staticmethod + def forward(ctx, x, weight, bias=None, trun_bits=0, chunk_size=16): + assert x.dtype in [torch.float, torch.bfloat16, torch.float16] + # input can be 2D or 3D, need to reshape before tl_matmul + org_dtype = x.dtype + target_shape_output = x.shape[:-1] + (weight.shape[0],) + x = x.reshape(-1, x.shape[-1]) + + if bias is not None: + ctx.has_bias = True + ctx.bias_dtype = bias.dtype + else: + ctx.has_bias = False + + ctx.save_for_backward(x, weight) # x, W are saved in their original dtype + ctx.trun_bits = trun_bits + ctx.chunk_size = chunk_size + + # triton kernel assumes 2D inputs and cast the return to input.dtype + output = tl_matmul( + x, + weight.t().to(org_dtype), + chunk_trun_bits=trun_bits, + chunk_size=chunk_size, + ).reshape(target_shape_output) + + if bias is not None: + output = output + bias.to(org_dtype) + + return output + + @staticmethod + def backward(ctx, grad_output): + # load context, should be bf16 already, x should be 2D already + x, weight = ctx.saved_tensors # x, W are saved in original dtype + trun_bits = ctx.trun_bits + chunk_size = ctx.chunk_size + out_dim = weight.shape[0] + in_dim = weight.shape[1] + dtype_input = x.dtype + # input and output could be 3D tl_matmul only takes 2D. + target_shape_grad_input = grad_output.shape[:-1] + (in_dim,) + grad_output_2D = grad_output.reshape(-1, out_dim).to(dtype_input) + + # Compute grad_weight, shape = [out, in] + # NOTE: this triton kernel requires A matrix to be contiguous + grad_weight = tl_matmul( + grad_output_2D.transpose(0, 1).contiguous(), + x, + chunk_trun_bits=trun_bits, + chunk_size=chunk_size, + ).to(weight.dtype) + # Compute grad_input in 2D then reshape to target shape, could be 3D or 2D + grad_input = ( + tl_matmul( + grad_output_2D, + weight.to(dtype_input), + chunk_trun_bits=trun_bits, + chunk_size=chunk_size, + ) + .reshape(target_shape_grad_input) + .to(dtype_input) + ) + + if not ctx.has_bias: + grad_bias = None + else: + grad_bias = grad_output_2D.sum(0).to(ctx.bias_dtype) + + return grad_input, grad_weight, grad_bias, None + + +class LinearFPxAcc(torch.nn.Linear): + """Linear layer wrapper that can simulate the HW behavior of LSB truncation on FP accumulation. + Some HW may have options to allow FP matmul engine to accumulate in precision lower than FP32, + such as accumulate in TF32 or even BF16. According to Nvidia doc, ~7-10x speed up with minor + accuracy trade-off. This supports both FWD and BWD. + Ref: + 1. https://developer.nvidia.com/blog/accelerating-ai-training-with-tf32-tensor-cores/ + 2. PyTorch's "torch.backends.cuda.matmul.allow_tf32" + """ + + @classmethod + def from_nn(cls, nnlin, trun_bits=0, **kwargs): + """Converts a torch.nn.Linear module to a LinearFPxAcc, which supports accumulation in + reduced precision FPx, where x < 32. + + Args: + cls (class): The class to be created. + nnlin (torch.nn.Linear): The original torch.nn.Linear module. + trun_bits (int): truncate [0 to 22] LSBs from FP32 accumulation. + **kwargs: Additional keyword arguments. + + Returns: + LinearFPxAcc: The converted linear layer. + """ + + target_device = kwargs.get( + "target_device", kwargs.get("device", next(nnlin.parameters()).device) + ) + + lin24acc = cls( + nnlin.in_features, + nnlin.out_features, + bias=nnlin.bias is not None, + device=target_device, + ) + + lin24acc.weight = nnlin.weight + lin24acc.trun_bits = trun_bits + + if nnlin.bias is not None: + lin24acc.bias = nnlin.bias + return lin24acc.to(target_device) + + def forward(self, inputs): + # This Linear Class will cast to BF16 before matmul and return FP32 + return LinearFuncFPxFwdBwd.apply(inputs, self.weight, self.bias, self.trun_bits) + + def extra_repr(self) -> str: + """ + Returns an alternative string representation of the object. + """ + return ( + f"in={self.in_features}, out={self.out_features}, bias={self.bias is not None}, " + f"trun_bits={self.trun_bits}" + ) diff --git a/pyproject.toml b/pyproject.toml index dea6534a..d8eeeaf2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ dependencies = [ "numpy>=1.26.4,<2.3.0", "accelerate>=0.20.3,!=0.34,<1.4", "transformers>=4.45,<4.49", -"torch>=2.2.0,<2.4", +"torch>=2.2.0,<2.5", "tqdm>=4.66.2,<5.0", "datasets>=3.0.0,<4.0", "ninja>=1.11.1.1,<2.0", diff --git a/tests/triton_kernels/__init__.py b/tests/triton_kernels/__init__.py new file mode 100644 index 00000000..094b6434 --- /dev/null +++ b/tests/triton_kernels/__init__.py @@ -0,0 +1,13 @@ +# Copyright The FMS Model Optimizer Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/triton_kernels/test_triton_mm.py b/tests/triton_kernels/test_triton_mm.py new file mode 100644 index 00000000..efedb1b8 --- /dev/null +++ b/tests/triton_kernels/test_triton_mm.py @@ -0,0 +1,120 @@ +# Copyright The FMS Model Optimizer Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Pytest configuration file with fixtures for triton kernel functionality test""" + +# Third Party +import pytest +import torch + +# Local +from fms_mo.custom_ext_kernels.triton_kernels import ( + tl_matmul_chunk_truncate as tl_matmul, +) +from fms_mo.modules.linear import LinearFPxAcc + + +@pytest.mark.parametrize("mkn", [64, 256, 1024, 4096]) +@pytest.mark.parametrize( + "dtype_to_test", + [ + torch.float, + torch.float16, + torch.bfloat16, + torch.float8_e4m3fn, + torch.float8_e5m2, + ], +) +def test_triton_matmul_fp(mkn, dtype_to_test): + """Parametric tests for triton matmul kernel using variety of tensor sizes and dtypes.""" + if not torch.cuda.is_available(): + # only run the test when GPU is available + return + + torch.manual_seed(23) + m = n = k = mkn + a = torch.randn((m, k), device="cuda", dtype=torch.float) + b = torch.randn((k, n), device="cuda", dtype=torch.float) + + torch_mm_device = "cuda" + if dtype_to_test in [torch.float8_e4m3fn, torch.float8_e5m2]: + cuda_cc = torch.cuda.get_device_capability() + if cuda_cc[0] < 9 and cuda_cc != (8, 9): + return + # torch.matmul does not support fp8 x fp8 on cuda + torch_mm_device = "cpu" + + a = a.to(dtype_to_test) + b = b.to(dtype_to_test) + torch_output = ( + torch.matmul(a.to(torch_mm_device), b.to(torch_mm_device)) + .to("cuda") + .to(torch.float) + ) + tl_output_no_trun = tl_matmul(a, b).to(torch.float) + tl_output_trun_8b = tl_matmul(a, b, chunk_trun_bits=8).to(torch.float) + + diff_no_trun = torch_output - tl_output_no_trun + diff_trun_8b = torch_output - tl_output_trun_8b + + assert torch.norm(diff_no_trun) / torch.norm(torch_output) < 1e-5 + assert torch.norm(diff_trun_8b) / torch.norm(torch_output) < 1e-3 + + +@pytest.mark.parametrize("mkn", [64, 256, 1024, 4096]) +def test_triton_matmul_int8(mkn): + """Parametric tests for triton imatmul kernel using variety of tensor sizes.""" + if not torch.cuda.is_available(): + # only run the test when GPU is available + return + + torch.manual_seed(23) + m = n = k = mkn + a = torch.randint(-128, 127, (m, k), device="cuda", dtype=torch.int8) + b = torch.randint(-128, 127, (k, n), device="cuda", dtype=torch.int8) + + torch_output = torch.matmul(a.to(torch.float), b.to(torch.float)) + # cast tl_matmul results to float because torch.norm only supports float + tl_output_no_trun = tl_matmul(a, b).to(torch.float) + tl_output_trun_8b = tl_matmul(a, b, chunk_trun_bits=8).to(torch.float) + + diff_no_trun = torch_output - tl_output_no_trun + diff_trun_8b = torch_output - tl_output_trun_8b + + assert torch.norm(diff_no_trun) / torch.norm(torch_output) < 1e-5 + assert torch.norm(diff_trun_8b) / torch.norm(torch_output) < 1e-2 + + +@pytest.mark.parametrize("feat_in_out", [(64, 128), (256, 1024), (1024, 4096)]) +@pytest.mark.parametrize("trun_bits", [0, 8, 12, 16]) +def test_linear_fpx_acc(feat_in_out, trun_bits): + """Parametric tests for LinearFPxAcc. This Linear utilizes triton kernel hence can only be run + on CUDA. + """ + if not torch.cuda.is_available(): + # only run the test when GPU is available + return + + torch.manual_seed(23) + feat_in, feat_out = feat_in_out + lin = torch.nn.Linear(feat_in, feat_out, device="cuda") + lin_fpx = LinearFPxAcc.from_nn(lin, trun_bits=trun_bits) + inputs = torch.randn((512, feat_in), device="cuda") + + with torch.no_grad(): + baseline = lin(inputs) + diff = lin_fpx(inputs) - baseline + rel_err = torch.norm(diff) / torch.norm(baseline) + + rel_tol = 1e-2 if trun_bits > 10 else 1e-4 + assert rel_err < rel_tol