Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .spellcheck-en-custom.txt
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ Inductor
inferenced
inferencing
isort
JIT
Jupyter
Kubernetes
KV
Expand Down Expand Up @@ -105,6 +106,7 @@ Tokenized
tokenizer
Tokenizer
toml
triton
Unquantized
vals
venv
Expand Down
13 changes: 9 additions & 4 deletions examples/QAT_INT8/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,16 +87,16 @@ python run_qa_no_trainer_qat.py \
--max_seq_length 384 \
--doc_stride 128 \
--attn_impl eager \
--do_lowering
--do_lowering <cutlass or triton>
```

This script uses an "external kernel" instead of the `torch.matmul` kernel to perform real `INT8` matmuls. This kernel is written for Nvidia's CUDA/CUTLASS library and is compiled once just ahead of the run. The compiled artifacts are usually stored in `~/.cache/torch_extensions/`. Remove this folder if a fresh recompile of the kernel is needed.
This script uses an "external kernel" instead of the `torch.matmul` kernel to perform real `INT8` matmuls. We have two options for INT kernel, one is written using Nvidia's CUDA/CUTLASS library and one is in Triton. Both will be compiled once just ahead of the run (i.e., just-in-time, JIT, compilation). The compiled artifacts are usually stored in `~/.cache/torch_extensions/`. Remove this folder if a fresh recompile of the kernel is needed.

Checkout [Example Test Results](#example-test-results) to compare against your results.

## Example Test Results

For comparison purposes, here are some of the results we found during testing when tested with `PyTorch 2.3.1`:
For comparison purposes, here are some of the results from an A100. CUTLASS results were obtained with `PyTorch 2.3.1` while Triton results were obtained using `PyTorch 2.4.1`:

> [!NOTE]
> Accuracy could vary ~ +-0.2 from run to run.
Expand All @@ -106,16 +106,21 @@ For comparison purposes, here are some of the results we found during testing wh
|fp16|128|eager |88.21 (as fine-tuned) |126.38|
| |128|Inductor | |71.59|
| |128|CUDAGRAPH | |71.13|
|INT8|128|eager |88.33|329.45 <sup>1</sup>|
|INT8 CUTLASS|128|eager |88.33|329.45 <sup>1</sup>|
| |128|Inductor |88.42|67.87 <sup>2</sup>|
| |128|CUDAGRAPH |-- |-- <sup>3</sup>|
|INT8 triton|128|eager |88.10|358.51|
| |128|Inductor |88.13|99.91 <sup>4</sup>|
| |128|CUDAGRAPH |88.13|100.21 <sup>4</sup>|

<sup>1</sup> `INT8` matmuls are ~2x faster than `FP16` matmuls. However, `INT8` models will have additional overhead compared to `FP16` models. For example, converting FP tensors to INT before INT matmul.

<sup>2</sup> Each of these additional quantization operations is relatively 'cheap', but the overhead of launching each job is not negligible. Using `torch.compile` can fuse the Ops and reduce the total number of jobs being launched.

<sup>3</sup> `CUDAGRAPH` is the most effective way to minimize job launching overheads and can achieve ~2X end-to-end speed-up in this case. However, there seem to be bugs associated with this option at the moment. Further investigation is still on-going.

<sup>4</sup> Unlike our CUTLASS `INT8` kernel, which is ~2x faster than `FP16` matmul, our Triton `INT8` is not as optimized and performs only comparable with `FP16` on mid-to-large tensor sizes.

## Code Walk-through

In this section, we will deep dive into what happens during the example steps.
Expand Down
15 changes: 11 additions & 4 deletions examples/QAT_INT8/run_qa_no_trainer_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,8 +388,10 @@ def parse_args():
)
parser.add_argument(
"--do_lowering",
action="store_true",
help="convert QAT model to utilize real INT8 GPU kernel",
choices=["cutlass", "triton"],
type=str,
default="triton",
help="convert QAT model to utilize real INT8 GPU kernel, 'cutlass' or 'triton'",
)

args = parser.parse_args()
Expand Down Expand Up @@ -1086,7 +1088,7 @@ def squad_eval(model, keep_model_in_eval_mode=True):
qmodel_prep(model, exam_inp, qcfg, optimizer, use_dynamo=True)

# ---- [fms_mo] the following code are performing speed tests ----
elif args.do_lowering:
elif args.do_lowering in ["cutlass", "triton"]:
# Standard
from copy import deepcopy
import time
Expand Down Expand Up @@ -1158,7 +1160,11 @@ def speedtest(model, exam_inp, Ntest=100):
parent_mod = model_copy.get_submodule(parent_name)
qmod = getattr(parent_mod, module_name)
setattr(
parent_mod, module_name, QLinearINT8Deploy.from_fms_mo(qmod)
parent_mod,
module_name,
QLinearINT8Deploy.from_fms_mo(
qmod, use_int_kernel=args.do_lowering
),
)

if comp_mode is not False:
Expand Down Expand Up @@ -1385,6 +1391,7 @@ def speedtest(model, exam_inp, Ntest=100):
)
logger.info(f"Predict metrics: {predict_metric}")

log = {}
if args.with_tracking:
log = {
"squad_v2" if args.version_2_with_negative else "squad": eval_metric,
Expand Down
Loading