diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f31305568d..a7b91eec34 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -21,14 +21,14 @@ repos: - id: clang-format types_or: [c++, c, cuda] - repo: https://github.com/keith/pre-commit-buildifier - rev: 6.4.0 + rev: 8.0.3 hooks: - id: buildifier args: - --warnings=all - id: buildifier-lint - repo: https://github.com/abravalheri/validate-pyproject - rev: v0.23 + rev: v0.24.1 hooks: - id: validate-pyproject - repo: https://github.com/pycqa/isort @@ -37,17 +37,17 @@ repos: - id: isort name: isort (python) - repo: https://github.com/pre-commit/mirrors-mypy - rev: "v1.9.0" + rev: "v1.15.0" hooks: - id: mypy exclude: "^py/torch_tensorrt/fx|^examples|^tests|^py/torch_tensorrt/dynamo/_experimental|^tools|^docs|noxfile.py|setup.py|versions.py" - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.3.3 + rev: v0.11.7 hooks: - id: ruff - repo: https://github.com/psf/black - rev: 24.3.0 + rev: 25.1.0 hooks: - id: black exclude: ^examples/custom_converters/elu_converter/setup.py|^docs @@ -57,7 +57,7 @@ repos: - id: typos - repo: https://github.com/astral-sh/uv-pre-commit # uv version. - rev: 0.5.5 + rev: 0.7.1 hooks: # Update the uv lockfile - id: uv-lock diff --git a/py/torch_tensorrt/_enums.py b/py/torch_tensorrt/_enums.py index c706c345d6..e0a78e1a0b 100644 --- a/py/torch_tensorrt/_enums.py +++ b/py/torch_tensorrt/_enums.py @@ -80,6 +80,12 @@ class dtype(Enum): :meta hide-value: """ + f4 = auto() + """4 bit floating-point number, equivalent to ``dtype.fp4`` and ``dtype.float4`` + + :meta hide-value: + """ + uint8 = u8 int8 = i8 @@ -91,6 +97,9 @@ class dtype(Enum): float8 = f8 fp8 = f8 + float4 = f4 + fp4 = f4 + half = f16 fp16 = f16 float16 = f16 @@ -162,6 +171,8 @@ def _from( return dtype.i32 elif t == torch.float8_e4m3fn: return dtype.f8 + elif t == torch.float4_e2m1fn_x2: + return dtype.f4 elif t == torch.half: return dtype.f16 elif t == torch.float: @@ -188,6 +199,8 @@ def _from( return dtype.i8 elif t == trt.DataType.FP8: return dtype.f8 + elif t == trt.DataType.FP4: + return dtype.fp4 elif t == trt.DataType.INT32: return dtype.i32 elif t == trt.DataType.INT64: @@ -357,6 +370,8 @@ def to( return torch.long elif self == dtype.f8: return torch.float8_e4m3fn + elif self == dtype.f4: + return torch.float4_e2m1fn_x2 elif self == dtype.f16: return torch.half elif self == dtype.f32: @@ -394,6 +409,8 @@ def to( return trt.DataType.BOOL elif self == dtype.bf16: return trt.DataType.BF16 + elif self == dtype.f4: + return trt.DataType.FP4 elif use_default: return trt.DataType.FLOAT else: @@ -410,6 +427,8 @@ def to( return np.int64 elif self == dtype.f16: return np.float16 + elif self == dtype.f4: + return np.float4_e2m1fn_x2 elif self == dtype.f32: return np.float32 elif self == dtype.f64: diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index aafd1072f4..921cb37646 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -29,7 +29,14 @@ REQUIRE_FULL_COMPILATION = False DRYRUN = False HARDWARE_COMPATIBLE = False -SUPPORTED_KERNEL_PRECISIONS = {dtype.f32, dtype.f16, dtype.bf16, dtype.i8, dtype.f8} +SUPPORTED_KERNEL_PRECISIONS = { + dtype.f32, + dtype.f16, + dtype.bf16, + dtype.i8, + dtype.f8, + dtype.f4, +} TIMING_CACHE_PATH = os.path.join( tempfile.gettempdir(), "torch_tensorrt_engine_cache", "timing_cache.bin" ) diff --git a/py/torch_tensorrt/dynamo/conversion/_ConversionContext.py b/py/torch_tensorrt/dynamo/conversion/_ConversionContext.py index 141b68f3e7..1c4926bcfa 100644 --- a/py/torch_tensorrt/dynamo/conversion/_ConversionContext.py +++ b/py/torch_tensorrt/dynamo/conversion/_ConversionContext.py @@ -1,6 +1,8 @@ from dataclasses import dataclass, field +from typing import Union import numpy as np +import torch from torch_tensorrt.dynamo._settings import CompilationSettings from torch_tensorrt.dynamo.types import TRTNetwork @@ -21,3 +23,9 @@ class ConversionContext: ) requires_output_allocator: bool = False mapping: dict[str, np.array] = field(default_factory=dict) + cpu_weights_reference_holder: dict[str, Union[torch.Tensor, np.array]] = field( + default_factory=dict + ) + + def clear_cpu_weights_reference_holder(self) -> None: + self.cpu_weights_reference_holder.clear() diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 39a1ed957d..bb1a77b4eb 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -743,6 +743,8 @@ def run( ) _LOGGER.info(f"TRT Engine uses: {serialized_engine.nbytes} bytes of Memory") + self.ctx.clear_cpu_weights_reference_holder() + self._save_timing_cache( builder_config, self.compilation_settings.timing_cache_path ) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 9d6602ddca..e542f1d417 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -619,6 +619,42 @@ def aten_ops_quantize_op( ) +try: + import modelopt.torch.quantization as mtq # noqa: F401 + + assert torch.ops.tensorrt.dynamic_block_quantize_op.default +except Exception as e: + _LOGGER.warning( + "Unable to import quantize op. Please install modelopt library (https://github.com/NVIDIA/TensorRT-Model-Optimizer?tab=readme-ov-file#installation) to add support for compiling quantized models" + ) +else: + + @dynamo_tensorrt_converter( + torch.ops.tensorrt.dynamic_block_quantize_op.default, + supports_dynamic_shapes=True, + ) + def aten_ops_dynamic_block_quantize_op( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, + ) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.dynamic_block_quantize.quantize( + ctx, + target, + SourceIR.ATEN, + name, + args[0], + args[1], + args[2], + args[3], + args[4], + args[5], + args[6], + ) + + @dynamo_tensorrt_converter(torch.ops.aten.squeeze.dim, supports_dynamic_shapes=True) @dynamo_tensorrt_converter(torch.ops.aten.squeeze.dims, supports_dynamic_shapes=True) def aten_ops_squeeze( diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index 685f40b254..b5b7cce868 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -326,6 +326,7 @@ def create_constant( name: str, dtype: Optional[Union[torch.dtype, np.dtype, TRTDataType, _enums.dtype]], min_rank: Optional[int] = 1, + target_quantized_type: Optional[TRTDataType] = None, ) -> TRTTensor: """ Add a TensorRT constant layer whose value is `value` to `ctx.net`. @@ -338,6 +339,7 @@ def create_constant( dtype (Optional[Union[torch.dtype, np.dtype, TRTDataType]]): If a dtype is given, we will convert the type of the given `value` to this dtype. min_rank (int): minimum rank of the constant tensor. + target_quantized_type (Optional[TRTDataType]): If a quantized type is given, we will convert the type of the given `value` to this dtype. Returns: A TensorRT ITensor that represents the given value. """ @@ -361,12 +363,48 @@ def create_constant( shape = list(torch_value.shape) if torch_value is not None: + if torch_value.dtype == torch.float8_e4m3fn: + weights = trt.Weights( + type=trt.DataType.FP8, + ptr=torch_value.data_ptr(), + count=torch_value.numel(), + ) + constant = ctx.net.add_constant( + shape, + weights, + ) + constant.name = name + ctx.cpu_weights_reference_holder[name + " FP8_CONSTANT"] = torch_value + return constant.get_output(0) + + if torch_value.dtype == torch.uint8: + if ( + target_quantized_type is None + or target_quantized_type != trt.DataType.FP4 + ): + # Iconstant layer does not support Uint8, it only support that FP4 data packed in uint8 + raise ValueError( + "Currently supported target_quantized_type for uint8 is FP4, got {target_quantized_type=}" + ) + shape[-1] = shape[-1] * 2 + weights = trt.Weights( + type=trt.DataType.FP4, + ptr=torch_value.data_ptr(), + count=torch_value.numel() * 2, + ) + constant = ctx.net.add_constant( + shape, + weights, + ) + constant.name = name + ctx.cpu_weights_reference_holder[name + " FP4_CONSTANT"] = torch_value + return constant.get_output(0) + if torch_value.dtype == torch.bfloat16: torch_value_fp32 = torch_value.to(torch.float32) numpy_value = torch_value_fp32.numpy() else: numpy_value = torch_value.numpy() - ctx.mapping[name + " CONSTANT"] = numpy_value.reshape(-1) constant = ctx.net.add_constant( shape, @@ -381,7 +419,6 @@ def create_constant( trt.DataType.BF16, name + "_bf16_cast", ) - return constant.get_output(0) else: raise ValueError( @@ -395,6 +432,7 @@ def get_trt_tensor( name: str, dtype: Optional[Union[torch.dtype, np.dtype, TRTDataType, _enums.dtype]] = None, min_rank: int = 1, + target_quantized_type: Optional[TRTDataType] = None, ) -> TRTTensor: """ Given a value of random type, we try to convert it to a TensorRT ITensor. @@ -408,6 +446,7 @@ def get_trt_tensor( dtype (Optional[Union[torch.dtype, np.dtype, TRTDataType]]): If dtype is provided, the given value will be converted to this dtype. min_rank (int): minimum rank of the constant tensor. + target_quantized_type (Optional[TRTDataType]): If a quantized type is given, we will convert the type of the given `value` to this dtype. Returns: A TensorRT ITensor that represents the given value. """ @@ -420,7 +459,9 @@ def get_trt_tensor( input_val = input_val.astype(np.float32) if isinstance(input_val, (torch.Tensor, np.ndarray, int, float, bool)): - return create_constant(ctx, input_val, name, dtype, min_rank) + return create_constant( + ctx, input_val, name, dtype, min_rank, target_quantized_type + ) elif isinstance(input_val, TRTTensor): return input_val else: diff --git a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py index df580b1516..10af2ad892 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py @@ -7,6 +7,7 @@ condition, conv, deconv, + dynamic_block_quantize, elementwise, embedding, full, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/dynamic_block_quantize.py b/py/torch_tensorrt/dynamo/conversion/impl/dynamic_block_quantize.py new file mode 100644 index 0000000000..f76a84dea5 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/dynamic_block_quantize.py @@ -0,0 +1,272 @@ +from typing import Optional, Union + +import numpy as np +import tensorrt as trt +import torch +from torch.fx.experimental.proxy_tensor import unset_fake_temporarily +from torch.fx.node import Target +from torch_tensorrt.dynamo._SourceIR import SourceIR +from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext +from torch_tensorrt.dynamo.conversion.converter_utils import ( + get_trt_tensor, +) +from torch_tensorrt.fx.converters.converter_utils import set_layer_name +from torch_tensorrt.fx.types import TRTTensor + + +def quantize( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input_tensor: TRTTensor, + block_size: int, + amax: Union[np.ndarray, torch.Tensor], + num_bits: int, + exponent_bits: int, + scale_num_bits: int, + scale_exponent_bits: int, +) -> TRTTensor: + """ + Adds quantize and dequantize ops (QDQ) which quantize to FP4 based + on the output_type set and dequantizes them back. + """ + if len(input_tensor.shape) not in (2, 3): + raise ValueError( + f"dynamic_block_quantize converter received an input of {input_tensor.shape} shape. Supported shapes: 2D or 3D" + ) + with unset_fake_temporarily(): + axis = -1 + global_scale = _calculate_global_scale(ctx, name, amax) + if ".weight_quantizer" in name: + output = _static_double_quantize( + ctx, + target, + source_ir, + name, + input_tensor, + global_scale, + axis, + ) + elif ".input_quantizer" in name: + output = _dynamic_double_quantize( + ctx, + target, + source_ir, + name, + input_tensor, + global_scale, + axis, + ) + else: + raise ValueError( + f"quantizer received an input of {name}. Supported values: weight_quantizer | input_quantizer" + ) + return output + + +def _dynamic_double_quantize( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input_tensor: TRTTensor, + global_scale: torch.Tensor, + axis: int = -1, + block_size: int = 16, + output_type: trt.DataType = trt.DataType.FP4, + scale_type: trt.DataType = trt.DataType.FP8, +) -> TRTTensor: + """ + quantize input tensor to fp4 + Parameters: + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR] + name: str + input_tensor : TRTTensor (On GPU) + The input TRTTensor. + global_scale : Tensor (On GPU) + The global per-tensor scaling factor. It should contain only 1 element. + axis : int + The axis to quantize. Default is -1 (the last axis). + block_size : int + The block size for quantization. Default is 16. + output_type : trt.DataType + The data type for quantized data. Default is FP4. + scale_type : trt.DataType + The data type for block scale. Default is FP8. + + """ + global_scale = get_trt_tensor(ctx, global_scale, name + "_global_scale") + + if input_tensor.dtype not in [ + trt.DataType.HALF, + trt.DataType.FLOAT, + trt.DataType.BF16, + ]: + raise ValueError( + f"Currently supported input tensor type is float16 | float32 | bfloat16, got Unsupported dtype: {input_tensor.dtype}" + ) + # dynamic quantize input tensor to fp4 + dynamic_quantize_layer = ctx.net.add_dynamic_quantize( + input_tensor, + axis, + block_size, + output_type, + scale_type, + ) + dynamic_quantize_layer.set_input(1, global_scale) + set_layer_name( + dynamic_quantize_layer, target, name + "_dynamic_quantize", source_ir + ) + quantized_data_in_fp4 = dynamic_quantize_layer.get_output(0) + quantized_scale_in_fp8 = dynamic_quantize_layer.get_output(1) + + return _double_dequantize( + ctx, + target, + source_ir, + name, + quantized_data_in_fp4, + quantized_scale_in_fp8, + global_scale, + axis, + input_tensor.dtype, + ) + + +def _double_dequantize( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + quantized_data_in_fp4: TRTTensor, + quantized_scale_in_fp8: TRTTensor, + global_scale: torch.Tensor, + axis: int = -1, + output_type: trt.DataType = trt.DataType.FLOAT, +) -> TRTTensor: + """ + double dequantize will first dequantize scale from fp8 to orignal dtype(default is float32) + and then dequantize data from fp4 to orignal dtype(default is float32) + Parameters: + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR] + name: str + quantized_data_in_fp4: TRTTensor + quantized_scale_in_fp8: TRTTensor + global_scale: torch.Tensor + axis: int + output_type: trt.DataType + """ + # dequantize scale from fp8 to orignal dtype(default is float32) + dequantize_scale_layer = ctx.net.add_dequantize( + quantized_scale_in_fp8, global_scale, output_type + ) + dequantize_scale_layer.axis = axis + dequantize_scale_layer.to_type = output_type + set_layer_name( + dequantize_scale_layer, target, name + "_dequantize_scale", source_ir + ) + dequantized_scale = dequantize_scale_layer.get_output(0) + + # dequantize quantized_data_in_fp4 from fp4 to orignal dtype(default is float32) + dequantize_data_layer = ctx.net.add_dequantize( + quantized_data_in_fp4, dequantized_scale, output_type + ) + dequantize_data_layer.axis = axis + dequantize_data_layer.to_type = output_type + set_layer_name(dequantize_data_layer, target, name + "_dequantize_data", source_ir) + dequantized_data = dequantize_data_layer.get_output(0) + return dequantized_data + + +def _static_double_quantize( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + weights_tensor: torch.Tensor, + global_scale: torch.Tensor, + axis: int, +) -> TRTTensor: + """ + Parameters: + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + weights_tensor : Tensor (On GPU) + The input tensor for weights. + global_scale : Tensor (On GPU) + The global per-tensor scaling factor. It should contain only 1 element. + axis: int + The axis to quantize. Default is -1 (the last axis). + Returns: + quantized data tensor in fp4 + """ + + import modelopt.core.torch.quantization.qtensor.nvfp4_tensor as nvfp4_tensor + + if weights_tensor.dtype == torch.float16: + original_dtype = trt.DataType.HALF + elif weights_tensor.dtype == torch.float32: + original_dtype = trt.DataType.FLOAT + elif weights_tensor.dtype == torch.bfloat16: + original_dtype = trt.DataType.BF16 + else: + raise ValueError( + f"Currently supported weights tensor type is float16 | float32 | bfloat16, got Unsupported dtype: {weights_tensor.dtype}" + ) + block_scale_fp8 = nvfp4_tensor.NVFP4QTensor.get_weights_scaling_factor( + weights_tensor, + 16, + global_scale, + )[0] + weights_tensor_fp4 = nvfp4_tensor.NVFP4QTensor.quantize( + weights_tensor, + 16, + block_scale_fp8, + global_scale, + )[0]._quantized_data + + block_scale_fp8 = get_trt_tensor( + ctx, + block_scale_fp8, + name + "_block_scale_fp8", + target_quantized_type=trt.DataType.FP8, + ) + global_scale = get_trt_tensor(ctx, global_scale, name + "_global_scale") + weights_tensor_fp4 = get_trt_tensor( + ctx, + weights_tensor_fp4, + name + "_weights_fp4", + target_quantized_type=trt.DataType.FP4, + ) + + dequantized_data = _double_dequantize( + ctx, + target, + source_ir, + name, + weights_tensor_fp4, + block_scale_fp8, + global_scale, + axis, + original_dtype, + ) + return dequantized_data + + +def _calculate_global_scale( + ctx: ConversionContext, + name: str, + amax: torch.Tensor, +) -> torch.Tensor: + # calculate global scale (the global per-tensor scaling factor, should only contain 1 element) + assert len(amax.shape) == 0, "amax should be a scalar" + global_scale = amax / 6 / 448 + global_scale.masked_fill_(global_scale == 0, 1.0) + return global_scale diff --git a/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py b/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py index 0feec63316..190b6752b4 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py @@ -101,6 +101,9 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: # TODO: Update this function when quantization is added def is_impure(self, node: torch.fx.node.Node) -> bool: - if node.target in (torch.ops.tensorrt.quantize_op.default,): + if node.target in ( + torch.ops.tensorrt.quantize_op.default, + torch.ops.tensorrt.dynamic_block_quantize_op.default, + ): return True return False diff --git a/tests/py/dynamo/models/test_models_export.py b/tests/py/dynamo/models/test_models_export.py index 9e22fef929..189da962b5 100644 --- a/tests/py/dynamo/models/test_models_export.py +++ b/tests/py/dynamo/models/test_models_export.py @@ -199,6 +199,127 @@ def test_resnet18_half(ir): torch._dynamo.reset() +@unittest.skipIf( + torch.cuda.get_device_capability() < (10, 0), + "FP4 quantization requires compute capability 10.0 or later", +) +@unittest.skipIf( + not importlib.util.find_spec("modelopt"), + "ModelOpt is required to run this test", +) +@pytest.mark.unit +def test_base_fp4_dynamic_shapes(ir): + import modelopt.torch.quantization as mtq + from modelopt.torch.quantization.utils import export_torch_mode + + dtype = torch.float16 + + class SimpleNetwork(torch.nn.Module): + def __init__(self): + super(SimpleNetwork, self).__init__() + self.linear1 = torch.nn.Linear( + in_features=64, out_features=32, bias=True, dtype=dtype + ) + + def forward(self, x): + x = self.linear1(x) + return x + + def calibrate_loop(model): + """Simple calibration function for testing.""" + model(dummy_inputs) + + BATCH_SIZE = torch.export.Dim("BATCH_SIZE", min=16, max=128) + batch_size = 64 + dummy_inputs = torch.ones(batch_size, 64, dtype=dtype).cuda() + + model = SimpleNetwork().eval().cuda() + + quant_cfg = mtq.NVFP4_DEFAULT_CFG + mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) + # model has qdq nodes at this point + with torch.no_grad(): + with export_torch_mode(): + exp_program = torch.export.export( + model, (dummy_inputs,), strict=False, dynamic_shapes=({0: BATCH_SIZE},) + ) + + trt_model = torchtrt.dynamo.compile( + exp_program, + inputs=[dummy_inputs], + min_block_size=1, + debug=True, + cache_built_engines=False, + reuse_cached_engines=False, + use_explicit_typing=True, + ) + batch_size = 128 + input_tensor = torch.ones(batch_size, 64, dtype=dtype).cuda() + expected_output = model(input_tensor) + outputs_trt = trt_model(input_tensor) + abs_diff = torch.abs(expected_output - outputs_trt) + print(f"max/mean abs_diff: {abs_diff.max().item()=} {abs_diff.mean()=}") + assert torch.allclose(expected_output, outputs_trt, rtol=0.3, atol=0.3) + + +@unittest.skipIf( + torch.cuda.get_device_capability() < (10, 0), + "FP4 quantization requires compute capability 10.0 or later", +) +@unittest.skipIf( + not importlib.util.find_spec("modelopt"), + "ModelOpt is required to run this test", +) +@pytest.mark.unit +def test_base_fp4_static_shapes(ir): + import modelopt.torch.quantization as mtq + from modelopt.torch.quantization.utils import export_torch_mode + + dtype = torch.bfloat16 + + class SimpleNetwork(torch.nn.Module): + def __init__(self): + super(SimpleNetwork, self).__init__() + self.linear1 = torch.nn.Linear( + in_features=64, out_features=32, bias=True, dtype=dtype + ) + + def forward(self, x): + x = self.linear1(x) + return x + + def calibrate_loop(model): + """Simple calibration function for testing.""" + model(input_tensor) + + input_tensor = torch.randn(128, 64, dtype=dtype).cuda() + + model = SimpleNetwork().eval().cuda() + expected_output = model(input_tensor) + + quant_cfg = mtq.NVFP4_DEFAULT_CFG + mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) + # model has qdq nodes at this point + with torch.no_grad(): + with export_torch_mode(): + exp_program = torch.export.export(model, (input_tensor,), strict=False) + from torch.fx import passes + + trt_model = torchtrt.dynamo.compile( + exp_program, + inputs=[input_tensor], + min_block_size=1, + debug=True, + cache_built_engines=False, + reuse_cached_engines=False, + use_explicit_typing=True, + ) + outputs_trt = trt_model(input_tensor) + abs_diff = torch.abs(expected_output - outputs_trt) + print(f"max/mean abs_diff: {abs_diff.max().item()=} {abs_diff.mean()=}") + assert torch.allclose(expected_output, outputs_trt, rtol=0.3, atol=0.3) + + @unittest.skipIf( torch.cuda.get_device_capability() < (8, 9), "FP8 quantization requires compute capability 8.9 or later", @@ -230,8 +351,8 @@ def calibrate_loop(model): input_tensor = torch.randn(1, 10).cuda() model = SimpleNetwork().eval().cuda() - quant_cfg = mtq.FP8_DEFAULT_CFG + mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) # model has FP8 qdq nodes at this point output_pyt = model(input_tensor)