Skip to content

Add fp4 support #3532

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@ repos:
- id: clang-format
types_or: [c++, c, cuda]
- repo: https://github.com/keith/pre-commit-buildifier
rev: 6.4.0
rev: 8.0.3
hooks:
- id: buildifier
args:
- --warnings=all
- id: buildifier-lint
- repo: https://github.com/abravalheri/validate-pyproject
rev: v0.23
rev: v0.24.1
hooks:
- id: validate-pyproject
- repo: https://github.com/pycqa/isort
Expand All @@ -37,17 +37,17 @@ repos:
- id: isort
name: isort (python)
- repo: https://github.com/pre-commit/mirrors-mypy
rev: "v1.9.0"
rev: "v1.15.0"
hooks:
- id: mypy
exclude: "^py/torch_tensorrt/fx|^examples|^tests|^py/torch_tensorrt/dynamo/_experimental|^tools|^docs|noxfile.py|setup.py|versions.py"
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.3.3
rev: v0.11.7
hooks:
- id: ruff
- repo: https://github.com/psf/black
rev: 24.3.0
rev: 25.1.0
hooks:
- id: black
exclude: ^examples/custom_converters/elu_converter/setup.py|^docs
Expand All @@ -57,7 +57,7 @@ repos:
- id: typos
- repo: https://github.com/astral-sh/uv-pre-commit
# uv version.
rev: 0.5.5
rev: 0.7.1
hooks:
# Update the uv lockfile
- id: uv-lock
Expand Down
19 changes: 19 additions & 0 deletions py/torch_tensorrt/_enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,12 @@ class dtype(Enum):
:meta hide-value:
"""

f4 = auto()
"""4 bit floating-point number, equivalent to ``dtype.fp4`` and ``dtype.float4``

:meta hide-value:
"""

uint8 = u8
int8 = i8

Expand All @@ -91,6 +97,9 @@ class dtype(Enum):
float8 = f8
fp8 = f8

float4 = f4
fp4 = f4

half = f16
fp16 = f16
float16 = f16
Expand Down Expand Up @@ -162,6 +171,8 @@ def _from(
return dtype.i32
elif t == torch.float8_e4m3fn:
return dtype.f8
elif t == torch.float4_e2m1fn_x2:
return dtype.f4
elif t == torch.half:
return dtype.f16
elif t == torch.float:
Expand All @@ -188,6 +199,8 @@ def _from(
return dtype.i8
elif t == trt.DataType.FP8:
return dtype.f8
elif t == trt.DataType.FP4:
return dtype.fp4
elif t == trt.DataType.INT32:
return dtype.i32
elif t == trt.DataType.INT64:
Expand Down Expand Up @@ -357,6 +370,8 @@ def to(
return torch.long
elif self == dtype.f8:
return torch.float8_e4m3fn
elif self == dtype.f4:
return torch.float4_e2m1fn_x2
elif self == dtype.f16:
return torch.half
elif self == dtype.f32:
Expand Down Expand Up @@ -394,6 +409,8 @@ def to(
return trt.DataType.BOOL
elif self == dtype.bf16:
return trt.DataType.BF16
elif self == dtype.f4:
return trt.DataType.FP4
elif use_default:
return trt.DataType.FLOAT
else:
Expand All @@ -410,6 +427,8 @@ def to(
return np.int64
elif self == dtype.f16:
return np.float16
elif self == dtype.f4:
return np.float4_e2m1fn_x2
elif self == dtype.f32:
return np.float32
elif self == dtype.f64:
Expand Down
9 changes: 8 additions & 1 deletion py/torch_tensorrt/dynamo/_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,14 @@
REQUIRE_FULL_COMPILATION = False
DRYRUN = False
HARDWARE_COMPATIBLE = False
SUPPORTED_KERNEL_PRECISIONS = {dtype.f32, dtype.f16, dtype.bf16, dtype.i8, dtype.f8}
SUPPORTED_KERNEL_PRECISIONS = {
dtype.f32,
dtype.f16,
dtype.bf16,
dtype.i8,
dtype.f8,
dtype.f4,
}
TIMING_CACHE_PATH = os.path.join(
tempfile.gettempdir(), "torch_tensorrt_engine_cache", "timing_cache.bin"
)
Expand Down
8 changes: 8 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/_ConversionContext.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from dataclasses import dataclass, field
from typing import Union

import numpy as np
import torch
from torch_tensorrt.dynamo._settings import CompilationSettings
from torch_tensorrt.dynamo.types import TRTNetwork

Expand All @@ -21,3 +23,9 @@ class ConversionContext:
)
requires_output_allocator: bool = False
mapping: dict[str, np.array] = field(default_factory=dict)
cpu_weights_reference_holder: dict[str, Union[torch.Tensor, np.array]] = field(
default_factory=dict
)

def clear_cpu_weights_reference_holder(self) -> None:
self.cpu_weights_reference_holder.clear()
2 changes: 2 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -743,6 +743,8 @@ def run(
)
_LOGGER.info(f"TRT Engine uses: {serialized_engine.nbytes} bytes of Memory")

self.ctx.clear_cpu_weights_reference_holder()

self._save_timing_cache(
builder_config, self.compilation_settings.timing_cache_path
)
Expand Down
36 changes: 36 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,42 @@ def aten_ops_quantize_op(
)


try:
import modelopt.torch.quantization as mtq # noqa: F401

assert torch.ops.tensorrt.dynamic_block_quantize_op.default
except Exception as e:
_LOGGER.warning(
"Unable to import quantize op. Please install modelopt library (https://github.com/NVIDIA/TensorRT-Model-Optimizer?tab=readme-ov-file#installation) to add support for compiling quantized models"
)
else:

@dynamo_tensorrt_converter(
torch.ops.tensorrt.dynamic_block_quantize_op.default,
supports_dynamic_shapes=True,
)
def aten_ops_dynamic_block_quantize_op(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.dynamic_block_quantize.quantize(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
args[1],
args[2],
args[3],
args[4],
args[5],
args[6],
)


@dynamo_tensorrt_converter(torch.ops.aten.squeeze.dim, supports_dynamic_shapes=True)
@dynamo_tensorrt_converter(torch.ops.aten.squeeze.dims, supports_dynamic_shapes=True)
def aten_ops_squeeze(
Expand Down
47 changes: 44 additions & 3 deletions py/torch_tensorrt/dynamo/conversion/converter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,7 @@ def create_constant(
name: str,
dtype: Optional[Union[torch.dtype, np.dtype, TRTDataType, _enums.dtype]],
min_rank: Optional[int] = 1,
target_quantized_type: Optional[TRTDataType] = None,
) -> TRTTensor:
"""
Add a TensorRT constant layer whose value is `value` to `ctx.net`.
Expand All @@ -338,6 +339,7 @@ def create_constant(
dtype (Optional[Union[torch.dtype, np.dtype, TRTDataType]]):
If a dtype is given, we will convert the type of the given `value` to this dtype.
min_rank (int): minimum rank of the constant tensor.
target_quantized_type (Optional[TRTDataType]): If a quantized type is given, we will convert the type of the given `value` to this dtype.
Returns:
A TensorRT ITensor that represents the given value.
"""
Expand All @@ -361,12 +363,48 @@ def create_constant(
shape = list(torch_value.shape)

if torch_value is not None:
if torch_value.dtype == torch.float8_e4m3fn:
weights = trt.Weights(
type=trt.DataType.FP8,
ptr=torch_value.data_ptr(),
count=torch_value.numel(),
)
constant = ctx.net.add_constant(
shape,
weights,
)
constant.name = name
ctx.cpu_weights_reference_holder[name + " FP8_CONSTANT"] = torch_value
return constant.get_output(0)

if torch_value.dtype == torch.uint8:
if (
target_quantized_type is None
or target_quantized_type != trt.DataType.FP4
):
# Iconstant layer does not support Uint8, it only support that FP4 data packed in uint8
raise ValueError(
"Currently supported target_quantized_type for uint8 is FP4, got {target_quantized_type=}"
)
shape[-1] = shape[-1] * 2
weights = trt.Weights(
type=trt.DataType.FP4,
ptr=torch_value.data_ptr(),
count=torch_value.numel() * 2,
)
constant = ctx.net.add_constant(
shape,
weights,
)
constant.name = name
ctx.cpu_weights_reference_holder[name + " FP4_CONSTANT"] = torch_value
return constant.get_output(0)

if torch_value.dtype == torch.bfloat16:
torch_value_fp32 = torch_value.to(torch.float32)
numpy_value = torch_value_fp32.numpy()
else:
numpy_value = torch_value.numpy()

ctx.mapping[name + " CONSTANT"] = numpy_value.reshape(-1)
constant = ctx.net.add_constant(
shape,
Expand All @@ -381,7 +419,6 @@ def create_constant(
trt.DataType.BF16,
name + "_bf16_cast",
)

return constant.get_output(0)
else:
raise ValueError(
Expand All @@ -395,6 +432,7 @@ def get_trt_tensor(
name: str,
dtype: Optional[Union[torch.dtype, np.dtype, TRTDataType, _enums.dtype]] = None,
min_rank: int = 1,
target_quantized_type: Optional[TRTDataType] = None,
) -> TRTTensor:
"""
Given a value of random type, we try to convert it to a TensorRT ITensor.
Expand All @@ -408,6 +446,7 @@ def get_trt_tensor(
dtype (Optional[Union[torch.dtype, np.dtype, TRTDataType]]):
If dtype is provided, the given value will be converted to this dtype.
min_rank (int): minimum rank of the constant tensor.
target_quantized_type (Optional[TRTDataType]): If a quantized type is given, we will convert the type of the given `value` to this dtype.
Returns:
A TensorRT ITensor that represents the given value.
"""
Expand All @@ -420,7 +459,9 @@ def get_trt_tensor(
input_val = input_val.astype(np.float32)

if isinstance(input_val, (torch.Tensor, np.ndarray, int, float, bool)):
return create_constant(ctx, input_val, name, dtype, min_rank)
return create_constant(
ctx, input_val, name, dtype, min_rank, target_quantized_type
)
elif isinstance(input_val, TRTTensor):
return input_val
else:
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
condition,
conv,
deconv,
dynamic_block_quantize,
elementwise,
embedding,
full,
Expand Down
Loading
Loading