diff --git a/hls4ml/converters/pytorch/convolution.py b/hls4ml/converters/pytorch/convolution.py index f000286c72..0294eb8ab1 100644 --- a/hls4ml/converters/pytorch/convolution.py +++ b/hls4ml/converters/pytorch/convolution.py @@ -1,8 +1,12 @@ -from hls4ml.converters.pytorch_to_hls import pytorch_handler +import numpy as np + +from hls4ml.converters.pytorch_to_hls import addQuantizationParameters, convert_uaq_to_apfixed, pytorch_handler from hls4ml.converters.utils import compute_padding_1d_pytorch, compute_padding_2d_pytorch, parse_data_format +from hls4ml.model.quantizers import BrevitasQuantizer +from hls4ml.model.types import FixedPrecisionType -@pytorch_handler('Conv1d') +@pytorch_handler('Conv1d', 'QuantConv1d') def parse_conv1d_layer(operation, layer_name, input_names, input_shapes, node, class_object, data_reader, config): assert 'Conv1d' in operation @@ -13,12 +17,50 @@ def parse_conv1d_layer(operation, layer_name, input_names, input_shapes, node, c layer['class_name'] = 'Conv1D' layer['data_format'] = 'channels_first' # Pytorch default (can't change) - layer['weight_data'] = class_object.weight.data.numpy() - if class_object.bias is not None: - layer['bias_data'] = class_object.bias.data.numpy() - else: - layer['bias_data'] = None + if "Quant" in operation: + if class_object.weight_quant.is_quant_enabled: + width = int(class_object.quant_weight().bit_width) + scale = class_object.quant_weight().scale.detach().numpy() + mantissa, _ = np.frexp(scale) + # if scale is power of 2 we can simply use hls4ml FixedPrecisionType and directly + # use the already quantized tensor from brevitas + if mantissa == 0.5: + ap_fixed_params = convert_uaq_to_apfixed(width, float(class_object.quant_weight().scale)) + layer['weight_data'] = class_object.quant_weight().detach().value.numpy() + layer['weight_quantizer'] = BrevitasQuantizer( + width, FixedPrecisionType(width=width, integer=int(ap_fixed_params[1]), signed=True) + ) + else: + raise Exception( + '''Non-power of 2 quantization of weights not supported when injecting brevitas models. + Please used QONNX instead.''' + ) + else: + layer['weight_data'] = class_object.weight.data.numpy() + + if class_object.bias_quant.is_quant_enabled: + width = int(class_object.quant_bias().bit_width) + ap_fixed_params = convert_uaq_to_apfixed(width, float(class_object.quant_bias().scale)) + layer['bias_data'] = class_object.quant_bias().detach().value.numpy() + layer['bias_quantizer'] = BrevitasQuantizer( + width, FixedPrecisionType(width=width, integer=int(ap_fixed_params[1]), signed=True) + ) + else: + if class_object.bias is not None: + layer['bias_data'] = class_object.bias.data.numpy() + else: + layer['bias_data'] = None + if class_object.input_quant.is_quant_enabled: + layer = addQuantizationParameters(layer, class_object.input_quant, 'input', act=True) + if class_object.output_quant.is_quant_enabled: + layer = addQuantizationParameters(layer, class_object.input_quant, 'output', act=True) + else: + layer['weight_data'] = class_object.weight.data.numpy() + if class_object.bias is not None: + layer['bias_data'] = class_object.bias.data.numpy() + else: + layer['bias_data'] = None # Input info (*_, layer['in_width'], layer['n_chan']) = parse_data_format( input_shapes[0], 'channels_first' @@ -47,7 +89,7 @@ def parse_conv1d_layer(operation, layer_name, input_names, input_shapes, node, c return layer, output_shape -@pytorch_handler('Conv2d') +@pytorch_handler('Conv2d', 'QuantConv2d') def parse_conv2d_layer(operation, layer_name, input_names, input_shapes, node, class_object, data_reader, config): assert 'Conv2d' in operation @@ -58,11 +100,52 @@ def parse_conv2d_layer(operation, layer_name, input_names, input_shapes, node, c layer['class_name'] = 'Conv2D' layer['data_format'] = 'channels_first' # Pytorch default (can't change) - layer['weight_data'] = class_object.weight.data.numpy() - if class_object.bias is not None: - layer['bias_data'] = class_object.bias.data.numpy() + if "Quant" in operation: + if class_object.weight_quant.is_quant_enabled: + width = int(class_object.quant_weight().bit_width) + scale = class_object.quant_weight().scale.detach().numpy() + mantissa, _ = np.frexp(scale) + # if scale is power of 2 we can simply use hls4ml FixedPrecisionType and directly + # use the already quantized tensor from brevitas + if mantissa == 0.5: + ap_fixed_params = convert_uaq_to_apfixed(width, float(class_object.quant_weight().scale)) + layer['weight_data'] = class_object.quant_weight().detach().value.numpy() + layer['weight_quantizer'] = BrevitasQuantizer( + width, FixedPrecisionType(width=width, integer=int(ap_fixed_params[1]), signed=True) + ) + else: + raise Exception( + '''Non-power of 2 quantization of weights not supported when injecting brevitas models. + Please used QONNX instead.''' + ) + # layer = addQuantizationParameters(layer, class_object.quant_weight(), 'weight') + # layer['weight_data'] = class_object.quant_weight().detach().value.numpy() + else: + layer['weight_data'] = class_object.weight.data.numpy() + + if class_object.bias_quant.is_quant_enabled: + width = int(class_object.quant_bias().bit_width) + ap_fixed_params = convert_uaq_to_apfixed(width, float(class_object.quant_bias().scale)) + layer['bias_data'] = class_object.quant_bias().detach().value.numpy() + layer['bias_quantizer'] = BrevitasQuantizer( + width, FixedPrecisionType(width=width, integer=int(ap_fixed_params[1]), signed=True) + ) + else: + if class_object.bias is not None: + layer['bias_data'] = class_object.bias.data.numpy() + else: + layer['bias_data'] = None + if class_object.input_quant.is_quant_enabled: + layer = addQuantizationParameters(layer, class_object.input_quant, 'input', act=True) + if class_object.output_quant.is_quant_enabled: + layer = addQuantizationParameters(layer, class_object.input_quant, 'output', act=True) + else: - layer['bias_data'] = None + layer['weight_data'] = class_object.weight.data.numpy() + if class_object.bias is not None: + layer['bias_data'] = class_object.bias.data.numpy() + else: + layer['bias_data'] = None # Input info (*_, layer['in_height'], layer['in_width'], layer['n_chan']) = parse_data_format( diff --git a/hls4ml/converters/pytorch/core.py b/hls4ml/converters/pytorch/core.py index 57c42f401f..958a4aae55 100644 --- a/hls4ml/converters/pytorch/core.py +++ b/hls4ml/converters/pytorch/core.py @@ -1,6 +1,8 @@ import numpy as np -from hls4ml.converters.pytorch_to_hls import pytorch_handler +from hls4ml.converters.pytorch_to_hls import addQuantizationParameters, convert_uaq_to_apfixed, pytorch_handler +from hls4ml.model.quantizers import BrevitasQuantizer +from hls4ml.model.types import FixedPrecisionType @pytorch_handler('Constant') @@ -20,7 +22,33 @@ def parse_constant_layer(operation, layer_name, node): return layer, output_shape -@pytorch_handler('Linear') +# A QuantIdentity layer does nothing but quantize its inputs. Insert `Quant` node to be processed by QONNX optimizers +@pytorch_handler('QuantIdentity') +def parse_quantidentity_layer(operation, layer_name, input_names, input_shapes, node, class_object, data_reader, config): + assert 'QuantIdentity' in operation + + layer = {} + layer['inputs'] = input_names + + layer['class_name'] = 'Quant' + layer['name'] = layer_name + + if class_object.act_quant.is_quant_enabled: + layer['bitwidth'] = int(class_object.act_quant.bit_width()) + layer['signed'] = class_object.act_quant.is_signed + layer['scale'] = np.full(np.array(input_shapes[0][1:]), class_object.act_quant.scale()) + layer['zeropt'] = float(class_object.act_quant.zero_point()) + layer['narrow'] = class_object.act_quant.is_narrow_range + layer['rounding_mode'] = class_object.act_quant.rounding_mode + + else: + raise Exception('''QuantIdentify layer without act quant does nothing, please remove from model.''') + output_shape = input_shapes[0] + + return layer, output_shape + + +@pytorch_handler('Linear', 'QuantLinear') def parse_linear_layer(operation, layer_name, input_names, input_shapes, node, class_object, data_reader, config): assert 'Linear' in operation @@ -36,6 +64,44 @@ def parse_linear_layer(operation, layer_name, input_names, input_shapes, node, c else: layer['bias_data'] = None + if "Quant" in operation: + if class_object.weight_quant.is_quant_enabled: + width = int(class_object.quant_weight().bit_width) + scale = class_object.quant_weight().scale.detach().numpy() + mantissa, _ = np.frexp(scale) + # if scale is power of 2 we can simply use hls4ml FixedPrecisionType and directly + # use the already quantized tensor from brevitas + if mantissa == 0.5: + ap_fixed_params = convert_uaq_to_apfixed(width, float(class_object.quant_weight().scale)) + layer['weight_data'] = class_object.quant_weight().detach().value.numpy() + layer['weight_quantizer'] = BrevitasQuantizer( + width, FixedPrecisionType(width=width, integer=int(ap_fixed_params[1]), signed=True) + ) + else: + raise Exception( + '''Non-power of 2 quantization of weights not supported when injecting brevitas models. + Please used QONNX instead.''' + ) + else: + layer['weight_data'] = class_object.weight.data.numpy() + + if class_object.bias_quant.is_quant_enabled: + width = int(class_object.quant_bias().bit_width) + ap_fixed_params = convert_uaq_to_apfixed(width, float(class_object.quant_bias().scale)) + layer['bias_data'] = class_object.quant_bias().detach().value.numpy() + layer['bias_quantizer'] = BrevitasQuantizer( + width, FixedPrecisionType(width=width, integer=int(ap_fixed_params[1]), signed=True) + ) + else: + if class_object.bias is not None: + layer['bias_data'] = class_object.bias.data.numpy() + else: + layer['bias_data'] = None + if class_object.input_quant.is_quant_enabled: + layer = addQuantizationParameters(layer, class_object.input_quant, 'input', act=True) + if class_object.output_quant.is_quant_enabled: + layer = addQuantizationParameters(layer, class_object.input_quant, 'output', act=True) + if class_object is not None: layer['n_in'] = class_object.in_features layer['n_out'] = class_object.out_features @@ -54,7 +120,19 @@ def parse_linear_layer(operation, layer_name, input_names, input_shapes, node, c return layer, output_shape -activation_layers = ['Softmax', 'ReLU', 'LeakyReLU', 'Threshold', 'ELU', 'PReLU', 'Sigmoid', 'Tanh'] +activation_layers = [ + 'Softmax', + 'ReLU', + 'LeakyReLU', + 'Threshold', + 'ELU', + 'PReLU', + 'Sigmoid', + 'Tanh', + 'QuantReLU', + 'QuantSigmoid', + 'QuantTanh', +] @pytorch_handler(*activation_layers) @@ -66,6 +144,12 @@ def parse_activation_layer(operation, layer_name, input_names, input_shapes, nod layer['name'] = layer_name layer['inputs'] = input_names + if "Quant" in operation: + layer['class_name'] = operation.split('Quant')[-1] + layer['activation'] = layer['class_name'] + if class_object.act_quant.is_quant_enabled: + layer = addQuantizationParameters(layer, class_object.act_quant, 'output', act=True) + if node.op == 'call_module': if layer['class_name'] in ['ReLU', 'Sigmoid', 'Tanh']: layer['class_name'] = 'Activation' diff --git a/hls4ml/converters/pytorch/merge.py b/hls4ml/converters/pytorch/merge.py index 1f1e11dcb7..faae193ac9 100644 --- a/hls4ml/converters/pytorch/merge.py +++ b/hls4ml/converters/pytorch/merge.py @@ -1,4 +1,4 @@ -from hls4ml.converters.pytorch_to_hls import pytorch_handler +from hls4ml.converters.pytorch_to_hls import addQuantizationParameters, pytorch_handler concat_layers = ['cat', 'concat', 'concatenate'] @@ -28,7 +28,7 @@ def parse_concat_layer(operation, layer_name, input_names, input_shapes, node, c return layer, output_shape -add_layers = ['add'] +add_layers = ['add', 'QuantEltwiseAdd'] multiply_layers = ['mul', 'multiply'] subtract_layers = ['sub', 'subtract'] min_layers = ['fmin', 'minimum'] @@ -56,6 +56,12 @@ def parse_merge_layer(operation, layer_name, input_names, input_shapes, node, cl layer['inputs'] = input_names + if 'Quant' in operation: + if class_object.input_quant.is_quant_enabled: + layer = addQuantizationParameters(layer, class_object.input_quant, 'input', act=True) + if class_object.output_quant.is_quant_enabled: + layer = addQuantizationParameters(layer, class_object.input_quant, 'output', act=True, scale_up=True) + output_shape = input_shapes[0][:] return layer, output_shape diff --git a/hls4ml/converters/pytorch/pooling.py b/hls4ml/converters/pytorch/pooling.py index 54e840cacb..8433a438b6 100644 --- a/hls4ml/converters/pytorch/pooling.py +++ b/hls4ml/converters/pytorch/pooling.py @@ -1,7 +1,12 @@ from hls4ml.converters.pytorch_to_hls import pytorch_handler from hls4ml.converters.utils import compute_padding_1d_pytorch, compute_padding_2d_pytorch, parse_data_format -pooling_layers = ['MaxPool1d', 'MaxPool2d', 'AvgPool1d', 'AvgPool2d'] +pooling_layers = [ + 'MaxPool1d', + 'MaxPool2d', + 'AvgPool1d', + 'AvgPool2d', +] # TODO add support for special quantized average pool layers @pytorch_handler(*pooling_layers) @@ -10,9 +15,9 @@ def parse_pooling_layer(operation, layer_name, input_names, input_shapes, node, layer = {} - if operation == 'MaxPool1d': + if 'MaxPool1d' in operation: layer['class_name'] = 'MaxPooling1D' - if operation == 'MaxPool2d': + if 'MaxPool2d' in operation: layer['class_name'] = 'MaxPooling2D' if operation == 'AvgPool1d': layer['class_name'] = 'AveragePooling1D' diff --git a/hls4ml/converters/pytorch/recurrent.py b/hls4ml/converters/pytorch/recurrent.py index 5d8f6a58bd..a4bac46d0c 100644 --- a/hls4ml/converters/pytorch/recurrent.py +++ b/hls4ml/converters/pytorch/recurrent.py @@ -1,6 +1,8 @@ import numpy as np -from hls4ml.converters.pytorch_to_hls import pytorch_handler +from hls4ml.converters.pytorch_to_hls import addQuantizationParameters, convert_uaq_to_apfixed, pytorch_handler +from hls4ml.model.quantizers import BrevitasQuantizer +from hls4ml.model.types import FixedPrecisionType, NamedType rnn_layers = ['RNN', 'LSTM', 'GRU'] @@ -72,3 +74,151 @@ def parse_rnn_layer(operation, layer_name, input_names, input_shapes, node, clas layer['pass_initial_states'] = True return layer, output_shape + + +quant_rnn_layers = ['QuantRNN'] # QuantLSTM very complex, might come later. No QuantGRU in brevitas at this point + + +@pytorch_handler(*quant_rnn_layers) +def parse_quant_rnn_layer(operation, layer_name, input_names, input_shapes, node, class_object, data_reader, config): + assert operation in quant_rnn_layers + operation = operation.split('Quant')[-1] + + if len(class_object._modules['layers']) > 1: + raise Exception('hls4ml does not support num_layers > 1') + + if class_object.num_directions > 1: + raise Exception('hls4ml does not support birectional RNNs') + + layer = {} + + layer["name"] = layer_name + + layer['inputs'] = input_names + if 'IOType' in config.keys(): + if len(input_names) > 1 and config['IOType'] == 'io_stream': + raise Exception('Passing initial values for the hidden state is not supported for io_stream input type.') + + layer['class_name'] = operation + if operation == 'RNN': + layer['class_name'] = 'SimpleRNN' + + layer['return_sequences'] = False # parameter does not exist in pytorch + layer['return_state'] = False # parameter does not exist in pytorch + + if layer['class_name'] == 'SimpleRNN': + layer['activation'] = 'tanh' if 'Tanh' in str(class_object._modules['layers'][0][0].cell.act_fn) else 'ReLU' + else: + layer['activation'] = 'tanh' # GRU and LSTM are hard-coded to use tanh in pytorch + + if layer['class_name'] == 'GRU' or layer['class_name'] == 'LSTM': + layer['recurrent_activation'] = 'sigmoid' # GRU and LSTM are hard-coded to use sigmoid in pytorch + + layer['time_major'] = not class_object._modules['layers'][0][0].cell.batch_first + # TODO Should we handle time_major? + if layer['time_major']: + raise Exception('hls4ml only supports "batch-first == True"') + + layer['n_timesteps'] = input_shapes[0][1] + layer['n_in'] = input_shapes[0][2] + + layer['n_out'] = class_object._modules['layers'][0][0].hidden_size + + RNNObject = class_object._modules['layers'][0][0] + + if RNNObject.gate_params.input_weight.weight_quant.is_quant_enabled: + width = int(RNNObject.gate_params.input_weight.quant_weight().bit_width) + scale = RNNObject.gate_params.input_weight.quant_weight().scale.detach().numpy() + signed = RNNObject.gate_params.input_weight.quant_weight().signed + mantissa, _ = np.frexp(scale) + # if scale is power of 2 we can simply use hls4ml FixedPrecisionType and directly + # use the already quantized tensor from brevitas + if mantissa == 0.5: + ap_fixed_params = convert_uaq_to_apfixed(width, float(RNNObject.gate_params.input_weight.quant_weight().scale)) + layer['weight_data'] = RNNObject.gate_params.input_weight.quant_weight().detach().value.numpy() + layer['weight_quantizer'] = BrevitasQuantizer( + width, FixedPrecisionType(width=width, integer=int(ap_fixed_params[1]), signed=signed) + ) + else: + raise Exception( + '''Non-power of 2 quantization of weights not supported when injecting brevitas models. + Please used QONNX instead.''' + ) + + if RNNObject.gate_params.hidden_weight.weight_quant.is_quant_enabled: + width = int(RNNObject.gate_params.hidden_weight.quant_weight().bit_width) + scale = RNNObject.gate_params.hidden_weight.quant_weight().scale.detach().numpy() + signed = RNNObject.gate_params.input_weight.quant_weight().signed + mantissa, _ = np.frexp(scale) + # if scale is power of 2 we can simply use hls4ml FixedPrecisionType and directly + # use the already quantized tensor from brevitas + if mantissa == 0.5: + ap_fixed_params = convert_uaq_to_apfixed(width, float(RNNObject.gate_params.hidden_weight.quant_weight().scale)) + layer['recurrent_weight_data'] = RNNObject.gate_params.hidden_weight.quant_weight().detach().value.numpy() + layer['recurrent_weight_quantizer'] = BrevitasQuantizer( + width, FixedPrecisionType(width=width, integer=int(ap_fixed_params[1]), signed=signed) + ) + else: + raise Exception( + '''Non-power of 2 quantization of weights not supported when injecting brevitas models. + Please used QONNX instead.''' + ) + + input_bias = RNNObject.gate_params.quant_bias() + if input_bias is not None: + width = int(input_bias.bit_width) + scale = input_bias.scale.detach().numpy() + mantissa, _ = np.frexp(scale) + # if scale is power of 2 we can simply use hls4ml FixedPrecisionType and directly + # use the already quantized tensor from brevitas + if mantissa == 0.5: + ap_fixed_params = convert_uaq_to_apfixed(width, scale) + + layer['bias_data'] = input_bias.detach().value.numpy() + layer['bias_quantizer'] = BrevitasQuantizer( + width, FixedPrecisionType(width=width, integer=int(ap_fixed_params[1]), signed=True) + ) + else: + raise Exception( + '''Non-power of 2 quantization of weights not supported when injecting brevitas models. + Please used QONNX instead.''' + ) + else: + layer['bias_data'] = np.zeros(layer['weight_data'].shape[0]) + layer['bias_quantizer'] = layer['weight_quantizer'] + + layer['recurrent_bias_data'] = np.zeros(layer['recurrent_weight_data'].shape[0]) + layer['recurrent_bias_quantizer'] = layer['weight_quantizer'] + + acc_scale = RNNObject.cell.gate_acc_quant.scale() + acc_bitwdith = int(RNNObject.cell.gate_acc_quant.bit_width()) + mantissa, _ = np.frexp(acc_scale) + # if scale is power of 2 we can simply use hls4ml FixedPrecisionType and directly + # use the already quantized tensor from brevitas + if mantissa == 0.5: + ap_fixed_params = convert_uaq_to_apfixed(acc_bitwdith, acc_scale) + precision = FixedPrecisionType(width=width, integer=int(ap_fixed_params[1]), signed=True) + layer['accum_t'] = NamedType(layer["name"] + '_accum_t', precision) + + else: + raise Exception( + '''Non-power of 2 quantization of weights not supported when injecting brevitas models. + Please used QONNX instead.''' + ) + + if RNNObject.cell.output_quant.is_quant_enabled: + layer = addQuantizationParameters(layer, RNNObject.cell.output_quant, 'output', act=True) + layer = addQuantizationParameters(layer, RNNObject.cell.output_quant, 'input', act=True) + + if layer['class_name'] == 'GRU': + layer['apply_reset_gate'] = 'after' # Might be true for pytorch? It's not a free parameter + + output_shape = [input_shapes[0][0], layer['n_out']] + + layer['pytorch'] = True # need to switch some behaviors to match pytorch implementations + if len(input_names) == 1: + layer['pass_initial_states'] = False + else: + layer['pass_initial_states'] = True + + return layer, output_shape diff --git a/hls4ml/converters/pytorch/reshape.py b/hls4ml/converters/pytorch/reshape.py index f7392ab8da..8bf7a6d5cd 100644 --- a/hls4ml/converters/pytorch/reshape.py +++ b/hls4ml/converters/pytorch/reshape.py @@ -120,10 +120,25 @@ def parse_flatten_layer(operation, layer_name, input_names, input_shapes, node, return layer, output_shape -@pytorch_handler('Upsample', 'UpsamplingNearest2d', 'UpsamplingBilinear2d') +@pytorch_handler( + 'Upsample', + 'UpsamplingNearest2d', + 'UpsamplingBilinear2d', + 'QuantUpsample', + 'QuantUpsamplingNearest2d', + 'QuantUpsamplingBilinear2d', +) def handle_upsample(operation, layer_name, input_names, input_shapes, node, class_object, data_reader, config): - assert operation in ['Upsample', 'UpsamplingNearest2d', 'UpsamplingBilinear2d'] + assert operation in [ + 'Upsample', + 'UpsamplingNearest2d', + 'UpsamplingBilinear2d', + 'QuantUpsample', + 'QuantUpsamplingNearest2d', + 'QuantUpsamplingBilinear2d', + ] + layer = {} layer['name'] = layer_name layer['inputs'] = input_names diff --git a/hls4ml/converters/pytorch_to_hls.py b/hls4ml/converters/pytorch_to_hls.py index a36ff5eb67..a8146c7d47 100644 --- a/hls4ml/converters/pytorch_to_hls.py +++ b/hls4ml/converters/pytorch_to_hls.py @@ -1,3 +1,5 @@ +import math + import numpy as np from hls4ml.model import ModelGraph @@ -61,6 +63,59 @@ def get_weights_data(data_reader, layer_name, var_name): return (*data,) +def convert_uaq_to_apfixed(bitwidth, scale_factor): + """ + parameters: + bitwidth: int + scale_factor: float + zero_point: float + + return: + int_bitwidth: int + fract_bitwidth: int + """ + fract_bitwidth = -math.log2(scale_factor) + int_bitwidth = bitwidth - fract_bitwidth + + return (fract_bitwidth, int_bitwidth) + + +# embed quantization information into the layer dictionary for a Quant layer +# so that this layer can be added to the model +def addQuantizationParameters(layer, quant_object, quant_type, act=False, scale_up=False): + if not act: + # currently not used, might be use later for non-power-of-2 scales + bit_width = int(quant_object.bit_width) + signed = quant_object.signed + scale = float(quant_object.scale) + zeropoint = float(quant_object.zero_point) + if signed: + narrow = True + else: + narrow = False + rounding_mode = 'ROUND' + else: + bit_width = int(quant_object.bit_width()) + signed = quant_object.is_signed + scale = float(quant_object.scale()) + # bit of a hack to make adding operations with QuantEltWiseAdd work + if scale_up: + scale = 2 ** (math.log2(scale) + 1) + zeropoint = float(quant_object.zero_point()) + narrow = quant_object.is_narrow_range + rounding_mode = quant_object.rounding_mode + + layer[f'{quant_type}_quantization'] = { + 'bit_width': bit_width, + 'signed': signed, + 'scale': scale, + 'zeropoint': zeropoint, + 'narrow': narrow, + 'rounding_mode': rounding_mode, + } + return layer + + # ----------------------Layer handling--------------------- # layer_handlers = {} @@ -144,7 +199,7 @@ def parse_pytorch_model(config, verbose=True): tracer = CustomFXTracer() traced_model = tracer.trace(model) # Define layers to skip for conversion to HLS - skip_layers = ['Dropout', 'Sequential'] + skip_layers = ['Dropout', 'QuantDropout', 'Sequential'] # All supported layers supported_layers = get_supported_pytorch_layers() + skip_layers @@ -200,6 +255,10 @@ def parse_pytorch_model(config, verbose=True): if pytorch_class not in supported_layers: raise Exception(f'Unsupported layer {pytorch_class}') + if 'IOType' in config.keys(): + if "QuantUpsampl" in pytorch_class and config['IOType'] == 'io_stream': + raise Exception('Quant upsampling layers currently not supported with io_stream') + if layer_counter != 0: input_shapes = [output_shape] # In case there are multiple inputs @@ -224,7 +283,7 @@ def parse_pytorch_model(config, verbose=True): # parse info from class object input_names = [inputs_map.get(str(i), str(i)) for i in node.args] - if pytorch_class in ["RNN", "GRU", "LSTM"]: + if pytorch_class in ['RNN', 'GRU', 'LSTM', 'QuantRNN']: input_shapes = [] input_names = [] for arg in node.args: diff --git a/hls4ml/model/graph.py b/hls4ml/model/graph.py index 76c621a1a2..b770e0a5bf 100644 --- a/hls4ml/model/graph.py +++ b/hls4ml/model/graph.py @@ -495,7 +495,9 @@ def insert_node(self, node, before=None, input_idx=0): next_nodes.append(x) if before is None: - next_node = next((x for x in self.graph.values() if x.inputs and x.inputs[0] in prev_node.outputs), None) + next_node = next( + (x for x in self.graph.values() if x.inputs and set(x.inputs).intersection(prev_node.outputs)), None + ) else: if before not in next_nodes: raise Exception( @@ -591,7 +593,6 @@ def replace_node(self, old_node, new_node): for i, n in enumerate(node.outputs): if n in repl: node.outputs[i] = repl[n] - self.graph = OrderedDict((new_node.name, new_node) if k == old_node.name else (k, v) for k, v in self.graph.items()) old_name = old_node.name diff --git a/hls4ml/model/layers.py b/hls4ml/model/layers.py index 0efeaafa3d..3d773838b1 100644 --- a/hls4ml/model/layers.py +++ b/hls4ml/model/layers.py @@ -386,8 +386,8 @@ def initialize(self): class Quant(Layer): # The QONNX quantization layer """ - This is a QONNX quantization layer. Optimizations should convert it - before HLS is produced. + This is a QONNX quantization layer. Can also be inserted in direct brevitas parsing. + Optimizations should convert it before HLS is produced. """ _expected_attributes = [ @@ -452,6 +452,8 @@ class Dense(Layer): WeightAttribute('bias'), TypeAttribute('weight'), TypeAttribute('bias'), + Attribute('input_quantization', value_type=dict, default={}), + Attribute('output_quantization', value_type=dict, default={}), ] def initialize(self): @@ -500,6 +502,8 @@ class Conv1D(Layer): WeightAttribute('bias'), TypeAttribute('weight'), TypeAttribute('bias'), + Attribute('input_quantization', value_type=dict, default={}), + Attribute('output_quantization', value_type=dict, default={}), ] def initialize(self): @@ -611,6 +615,8 @@ class Conv2D(Layer): WeightAttribute('bias'), TypeAttribute('weight'), TypeAttribute('bias'), + Attribute('input_quantization', value_type=dict, default={}), + Attribute('output_quantization', value_type=dict, default={}), ] def initialize(self): diff --git a/hls4ml/model/optimizer/__init__.py b/hls4ml/model/optimizer/__init__.py index c474970448..b7498eeafc 100644 --- a/hls4ml/model/optimizer/__init__.py +++ b/hls4ml/model/optimizer/__init__.py @@ -30,6 +30,13 @@ del module_path del optimizers +register_flow( + 'parse_brevitas', + [ + 'brevitas_input_output_optimizer', + ], +) + register_flow( 'parse_qonnx', [ @@ -53,6 +60,7 @@ 'conv_to_conv_x_d', 'conv_to_depthwise_conv_x_d', ], + requires=['parse_brevitas'], ) register_flow( @@ -77,6 +85,7 @@ 'merge_linear_activation', # many of the above optimzers need to be done before this 'infer_precision_types', + 'adjust_resize_input_precision', ], requires=['parse_qonnx'], ) # TODO Maybe not all QKeras optmizers belong here? diff --git a/hls4ml/model/optimizer/passes/brevitas_optimizer.py b/hls4ml/model/optimizer/passes/brevitas_optimizer.py new file mode 100644 index 0000000000..a786b9c60f --- /dev/null +++ b/hls4ml/model/optimizer/passes/brevitas_optimizer.py @@ -0,0 +1,63 @@ +# Inserts Quant nodes into the model as needed for input/output quantization of layers in brevitas +import numpy as np + +from hls4ml.model.optimizer import OptimizerPass + + +class BrevitasInputOutputOptimizer(OptimizerPass): + '''Takes nodes parsed from brevitas and inserts Quant nodes into the model if necessary''' + + def match(self, node): + if ('output_quantization' in node.attributes.keys() and not len(node.attributes['output_quantization']) == 0) or ( + 'input_quantization' in node.attributes.keys() and not len(node.attributes['input_quantization']) == 0 + ): + return True + else: + return False + + def transform(self, model, node): + + # See if Quant layer needs to be added for the output + if 'output_quantization' in node.attributes.keys() and not len(node.attributes['output_quantization']) == 0: + print(node.attributes['output_quantization']) + attributes = {} + + input = node.name + # Other attributes + attributes['narrow'] = node.attributes['output_quantization']['narrow'] + attributes['rounding_mode'] = node.attributes['output_quantization']['rounding_mode'] + attributes['signed'] = node.attributes['output_quantization']['signed'] + attributes['bitwidth'] = node.attributes['output_quantization']['bit_width'] + attributes['zeropt'] = node.attributes['output_quantization']['zeropoint'] + attributes['scale'] = np.array([node.attributes['output_quantization']['scale']]) + + quant_node = model.make_node('Quant', f'quant_output_for_{node.get_attr("name")}', attributes, [input]) + quant_node.set_attr('name', f'quant_output_for_{node.get_attr("name")}') + + model.insert_node(quant_node) + + node.attributes['output_quantization'] = {} + + elif 'input_quantization' in node.attributes.keys() and not len(node.attributes['input_quantization']) == 0: + + attributes = {} + + # Other attributes + attributes['narrow'] = node.attributes['input_quantization']['narrow'] + attributes['rounding_mode'] = node.attributes['input_quantization']['rounding_mode'] + attributes['signed'] = node.attributes['input_quantization']['signed'] + attributes['bitwidth'] = node.attributes['input_quantization']['bit_width'] + attributes['zeropt'] = node.attributes['input_quantization']['zeropoint'] + attributes['scale'] = np.array([node.attributes['input_quantization']['scale']]) + + for i, input in enumerate(node.inputs): + quant_node = model.make_node( + 'Quant', f'quant_input_for_{node.get_attr("name")}_input_{i}', attributes, [input] + ) + quant_node.set_attr('name', f'quant_input_for_{node.get_attr("name")}_input_{i}') + + model.insert_node(quant_node, input_idx=i) + + node.attributes['input_quantization'] = {} + + return True diff --git a/hls4ml/model/optimizer/passes/convert_to_channels_last.py b/hls4ml/model/optimizer/passes/convert_to_channels_last.py index 6511a6967b..24abf34e48 100644 --- a/hls4ml/model/optimizer/passes/convert_to_channels_last.py +++ b/hls4ml/model/optimizer/passes/convert_to_channels_last.py @@ -116,7 +116,7 @@ def transform(self, model, node): # Add transpose for output layer elif ( - node.get_attr('name') in model.outputs + node.name in model.outputs and len(outshape) > 1 and model.config.config['HLSConfig']['Model']['TransposeOutputs'] ): diff --git a/hls4ml/model/optimizer/passes/quant_opt.py b/hls4ml/model/optimizer/passes/quant_opt.py index 6c9badd832..bf6854ae61 100644 --- a/hls4ml/model/optimizer/passes/quant_opt.py +++ b/hls4ml/model/optimizer/passes/quant_opt.py @@ -88,7 +88,6 @@ class QuantToActivation(OptimizerPass): def match(self, node): # only matches after the other inputs are already folded - is_match = ( isinstance(node, Quant) and len(node.inputs) == 1 @@ -105,8 +104,8 @@ def match(self, node): scale_unit_or_po2 = (scale == np.ones_like(scale)).all() if not scale_unit_or_po2 and _ALSO_MATCH_PO2: # This optimization only works if all scales are the same - if np.all(scale[0] == scale): - mantissa, _ = np.frexp(scale[0]) + if np.all(next(iter(scale.flat)) == scale): + mantissa, _ = np.frexp(next(iter(scale.flat))) scale_unit_or_po2 = mantissa == 0.5 is_match = scale_unit_or_po2 @@ -125,9 +124,8 @@ def transform(self, model, node): integer = bitwidth scale = node.get_attr('scale') if _ALSO_MATCH_PO2 and not (scale == np.ones_like(scale)).all(): - _, exp = np.frexp(scale[0]) + _, exp = np.frexp(next(iter(scale.flat))) integer = bitwidth + exp - 1 - precision, quantizer = _calculate_precision_quantizer(bitwidth, integer, signed, narrow, rounding_mode) attributes = {'activation': 'linear', 'quantizer': quantizer} @@ -139,8 +137,7 @@ def transform(self, model, node): new_name = f'{node.name}_act' model.config.set_name_config(new_name, config) model.config.parse_name_config(new_name, config) - - new_node = model.make_node(Activation, new_name, attributes, [node.inputs[0]], [x for x in node.outputs]) + new_node = model.make_node(Activation, new_name, attributes, [node.inputs[0]], [f'{x}_act' for x in node.outputs]) model.replace_node(node, new_node) return True @@ -268,7 +265,6 @@ def transform(self, model, node): rescale_name = f'{node.name}_rescale' model.config.set_name_config(rescale_name, rescale_config) model.config.parse_name_config(rescale_name, rescale_config) - firstscale = 1 / scale firstbias = bias attributes_scale['scale_data'] = np.broadcast_to(firstscale, inshape) diff --git a/hls4ml/model/optimizer/passes/resize_remove_constants.py b/hls4ml/model/optimizer/passes/resize_remove_constants.py index fd2f1cfadd..70a68923b5 100644 --- a/hls4ml/model/optimizer/passes/resize_remove_constants.py +++ b/hls4ml/model/optimizer/passes/resize_remove_constants.py @@ -36,3 +36,20 @@ def transform(self, model, node): # Clean all the '' inputs node.inputs = list(filter(None, node.inputs)) return True + + +class AdjustResizeInputPrecision(OptimizerPass): + """ + This optimizer makes sure that the input data type of a Resize layer matches the output data type of the previous layer. + """ + + def match(self, node): + is_match = isinstance(node, Resize) and not ( + node.get_input_node().types['result_t'].precision == node.get_output_variable().type.precision + ) + return is_match + + def transform(self, model, node): + node.get_output_variable().type.precision = node.get_input_node().types['result_t'].precision + + return True diff --git a/hls4ml/model/quantizers.py b/hls4ml/model/quantizers.py index b445c70af3..2dc542569a 100644 --- a/hls4ml/model/quantizers.py +++ b/hls4ml/model/quantizers.py @@ -171,6 +171,22 @@ def __call__(self, data): return y +class BrevitasQuantizer(Quantizer): + """Wrapper around brevitas quantizers. Since we can get the already quantized tensors + directly from the brevitas QuantTensor objects, nothing needs to be done + + Args: + bits: bitwidth of the quantized tensor + hls_type: hls_type of the quantized tensor + """ + + def __init__(self, bits, hls_type): + super().__init__(bits, hls_type) + + def __call__(self, data): + return data + + class QuantNodeQuantizer(Quantizer): """ This implements a quantizer for a FixedPrecisionType with width==integer diff --git a/hls4ml/templates/oneapi/firmware/nnet_utils/nnet_recurrent.h b/hls4ml/templates/oneapi/firmware/nnet_utils/nnet_recurrent.h index 678161006f..3367213167 100644 --- a/hls4ml/templates/oneapi/firmware/nnet_utils/nnet_recurrent.h +++ b/hls4ml/templates/oneapi/firmware/nnet_utils/nnet_recurrent.h @@ -493,7 +493,7 @@ void simple_rnn_pytorch_init_state(const data_T &data, const h_T &hin, res_T &re } // Do SimpleRNN - simple_rnn_pytorch_cell(in, hidden_state_temp, h, kernel, rec_kernel, bias, rec_bias); + simple_rnn_pytorch_cell(in, hidden_state_temp, h, kernel, rec_kernel, bias, rec_bias); // Write result #pragma unroll diff --git a/hls4ml/templates/quartus/firmware/nnet_utils/nnet_recurrent.h b/hls4ml/templates/quartus/firmware/nnet_utils/nnet_recurrent.h index d3411f351b..794e46972e 100644 --- a/hls4ml/templates/quartus/firmware/nnet_utils/nnet_recurrent.h +++ b/hls4ml/templates/quartus/firmware/nnet_utils/nnet_recurrent.h @@ -501,7 +501,7 @@ void simple_rnn_pytorch(data_T data[CONFIG_T::n_timesteps * CONFIG_T::n_in], h_T } // Do SimpleRNN - simple_rnn_pytorch_cell(in, hidden_state_temp, h, kernel, rec_kernel, bias, rec_bias); + simple_rnn_pytorch_cell(in, hidden_state_temp, h, kernel, rec_kernel, bias, rec_bias); // Write result #pragma unroll diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_recurrent.h b/hls4ml/templates/vivado/nnet_utils/nnet_recurrent.h index 618767dcb5..5a2783c2c1 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_recurrent.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_recurrent.h @@ -73,8 +73,9 @@ void lstm(bool reset_state, data_T data[CONFIG_T::n_in], res_T h_newstate[CONFIG #pragma HLS ARRAY_PARTITION variable=inputacc_c complete #pragma HLS ARRAY_PARTITION variable=s_actstate complete - nnet::dense(data, tmpres, param, param_b); - nnet::dense(h_newstate, tmpres_state, param_r, param_br); + nnet::dense(data, tmpres, param, param_b); + nnet::dense(h_newstate, tmpres_state, param_r, + param_br); for (int iacc = 0; iacc < (3 * CONFIG_T::n_state); iacc++) { #pragma HLS UNROLL @@ -254,7 +255,7 @@ void lstm_stack(data_T data[CONFIG_T::n_sequence * CONFIG_T::n_in], h_T h_newsta data_in[j] = data[j + iloop * CONFIG_T::n_in]; } - nnet::lstm(reset_state, data_in, h_newstate, s_newstate, param, param_r, param_b, param_br); + nnet::lstm(reset_state, data_in, h_newstate, s_newstate, param, param_r, param_b, param_br); if (CONFIG_T::n_sequence_out > 1) for (int i = CONFIG_T::n_state * iloop, j = 0; i < (CONFIG_T::n_state * (iloop + 1)); i++, j++) { #pragma HLS UNROLL diff --git a/pyproject.toml b/pyproject.toml index 041428ea9f..895de19e88 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,6 +49,7 @@ optional-dependencies.qkeras = [ optional-dependencies.quartus-report = [ "calmjs-parse", "tabulate" ] optional-dependencies.sr = [ "sympy>=1.13.1" ] optional-dependencies.testing = [ + "brevitas", "calmjs-parse", "hgq>=0.2.3", "onnx>=1.4", diff --git a/test/pytest/test_brevitas_parsing.py b/test/pytest/test_brevitas_parsing.py new file mode 100644 index 0000000000..c8989aaed2 --- /dev/null +++ b/test/pytest/test_brevitas_parsing.py @@ -0,0 +1,355 @@ +from pathlib import Path + +import brevitas.nn as qnn +import numpy as np +import pytest +import torch +from brevitas.quant import Int8ActPerTensorFixedPoint, Int8WeightPerTensorFixedPoint, Int8WeightPerTensorFloat +from torch import nn +from torch.nn import Module + +import hls4ml +from hls4ml.converters import convert_from_pytorch_model +from hls4ml.utils.config import config_from_pytorch_model + +test_root_path = Path(__file__).parent + +quants = { + 'Int8WeightPerTensorFixedPoint': Int8WeightPerTensorFixedPoint, + 'Int8ActPerTensorFixedPoint': Int8ActPerTensorFixedPoint, + 'Int8WeightPerTensorFloat': Int8WeightPerTensorFloat, +} + + +class QuantModelConv2d(Module): + def __init__(self): + super().__init__() + self.conv1 = qnn.QuantConv2d(3, 6, 5, bias=True, weight_quant=Int8WeightPerTensorFixedPoint) + self.relu1 = nn.ReLU() + + def forward(self, x): + out = self.relu1(self.conv1(x)) + return out + + +class QuantModelConv1d(Module): + def __init__(self): + super().__init__() + self.conv1 = qnn.QuantConv1d(3, 6, 4, bias=True, weight_quant=Int8WeightPerTensorFixedPoint) + self.relu1 = nn.ReLU() + + def forward(self, x): + out = self.relu1(self.conv1(x)) + return out + + +class QuantModelLinear(Module): + def __init__(self, weight_quant, input_quant): + super().__init__() + self.lin1 = qnn.QuantLinear(4, 4, bias=False, weight_quant=quants[weight_quant], input_quant=quants[input_quant]) + self.relu1 = qnn.QuantReLU(act_quant=quants[input_quant]) + + def forward(self, x): + out = self.relu1(self.lin1(x)) + return out + + +@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus', 'oneAPI']) +@pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream']) +@pytest.mark.parametrize('weight_quant', ['Int8WeightPerTensorFixedPoint']) +@pytest.mark.parametrize('io_quant', ['Int8ActPerTensorFixedPoint']) +def test_quantlinear(backend, io_type, weight_quant, io_quant): + model = QuantModelLinear(weight_quant, io_quant) + + x = torch.rand(1, 4) + pytorch_prediction = model(x).detach().numpy() + config = config_from_pytorch_model(model, input_shape=(None, 4)) + output_dir = str(test_root_path / f'hls4mlprj_brevitas_linear_{backend}_{io_type}_{weight_quant}_{io_quant}') + + hls_model = convert_from_pytorch_model( + model, + hls_config=config, + output_dir=output_dir, + backend=backend, + io_type=io_type, + ) + hls_model.compile() + + hls_prediction = np.reshape(hls_model.predict(x.detach().numpy()), pytorch_prediction.shape) + + np.testing.assert_allclose(hls_prediction, pytorch_prediction, rtol=0.0, atol=0.05) + + +@pytest.mark.parametrize('backend', ['Vivado', 'Quartus']) +@pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream']) +def test_quantconv1d(backend, io_type): + model = QuantModelConv1d() + + n_in = 3 + n_out = 6 + size_in = 5 + + x = torch.randn(1, n_in, size_in) + + pytorch_prediction = model(x).detach().numpy() + if io_type == 'io_stream': + x = np.ascontiguousarray(x.permute(0, 2, 1)) + config = config_from_pytorch_model( + model, (None, n_in, size_in), channels_last_conversion="internal", transpose_outputs=False + ) + else: + config = config_from_pytorch_model( + model, (None, n_in, size_in), channels_last_conversion="full", transpose_outputs=True + ) + + output_dir = str(test_root_path / f'hls4mlprj_brevitas_conv1d_{backend}_{io_type}') + + from hls4ml.utils.torch import CustomFXTracer + + tracer = CustomFXTracer() + traced_model = tracer.trace(model) + nNodes = 0 + convNode = None + for _node in traced_model.nodes: + nNodes += 1 + if nNodes == 2: + convNode = _node + + children = {c[0]: c[1] for c in model.named_children()} + class_object_conv = children[convNode.target] + + out_width = int( + ( + size_in + + 2 * class_object_conv.padding[0] + - class_object_conv.dilation[0] * (class_object_conv.kernel_size[0] - 1) + - 1 + ) + / class_object_conv.stride[0] + + 1 + ) # following https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html + + hls_model = convert_from_pytorch_model(model, hls_config=config, output_dir=output_dir, backend=backend, io_type=io_type) + hls_model.compile() + + if io_type == 'io_stream': + hls_prediction = np.transpose(np.reshape(hls_model.predict(x), (1, out_width, n_out)), (0, 2, 1)) + else: + hls_prediction = np.reshape(hls_model.predict(x.detach().numpy()), pytorch_prediction.shape) + + np.testing.assert_allclose(hls_prediction, pytorch_prediction, rtol=1e-2, atol=0.01) + + +@pytest.mark.parametrize('backend', ['Vivado', 'Quartus']) +@pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream']) +def test_quantconv2d(backend, io_type): + model = QuantModelConv2d() + + n_in = 3 + n_out = 6 + size_in_width = 5 + size_in_height = 6 + + x = torch.randn(1, n_in, size_in_height, size_in_width) + + pytorch_prediction = model(x).detach().numpy() + if io_type == 'io_stream': + x = np.ascontiguousarray(x.permute(0, 2, 3, 1)) + config = config_from_pytorch_model( + model, (None, n_in, size_in_height, size_in_width), channels_last_conversion="internal", transpose_outputs=False + ) + else: + config = config_from_pytorch_model( + model, (None, n_in, size_in_height, size_in_width), channels_last_conversion="full", transpose_outputs=True + ) + + output_dir = str(test_root_path / f'hls4mlprj_brevitas_conv2d_{backend}_{io_type}') + + from hls4ml.utils.torch import CustomFXTracer + + tracer = CustomFXTracer() + traced_model = tracer.trace(model) + + nNodes = 0 + convNode = None + for _node in traced_model.nodes: + nNodes += 1 + if nNodes == 2: + convNode = _node + + children = {c[0]: c[1] for c in model.named_children()} + class_object_conv = children[convNode.target] + + out_width = int( + ( + size_in_width + + 2 * class_object_conv.padding[1] + - class_object_conv.dilation[1] * (class_object_conv.kernel_size[1] - 1) + - 1 + ) + / class_object_conv.stride[1] + + 1 + ) # following https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html + out_height = int( + ( + size_in_height + + 2 * class_object_conv.padding[0] + - class_object_conv.dilation[0] * (class_object_conv.kernel_size[0] - 1) + - 1 + ) + / class_object_conv.stride[0] + + 1 + ) # following https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html + + hls_model = convert_from_pytorch_model( + model, + hls_config=config, + output_dir=output_dir, + backend=backend, + io_type=io_type, + ) + hls_model.compile() + + if io_type == 'io_stream': + hls_prediction = np.transpose(np.reshape(hls_model.predict(x), (1, out_height, out_width, n_out)), (0, 3, 1, 2)) + else: + hls_prediction = np.reshape(hls_model.predict(x.detach().numpy()), pytorch_prediction.shape) + + np.testing.assert_allclose(hls_prediction, pytorch_prediction, rtol=0.0, atol=0.05) + + +in_height = 6 +in_width = 8 +in_feat = 4 + +size = 2 +atol = 5e-3 + + +@pytest.fixture(scope='module') +def data_1d(): + X = np.random.rand(100, in_feat, in_width) + return X + + +@pytest.fixture(scope='module') +def data_2d(): + X = np.random.rand(100, in_feat, in_height, in_width) + return X + + +class QuantUpsample1DModel(nn.Module): + def __init__(self): + super().__init__() + self.identity = qnn.QuantIdentity(act_quant=Int8ActPerTensorFixedPoint, return_quant_tensor=True) + self.upsample = qnn.QuantUpsample(scale_factor=2) + self.relu = nn.ReLU() + + def forward(self, x): + return self.relu(self.upsample(self.identity(x))) + + +class QuantUpsample2DModel(nn.Module): + def __init__(self): + super().__init__() + # this scale_factor tests proper output shape calculation with fractional scaling and parsing per-axis scales + self.identity = qnn.QuantIdentity(act_quant=Int8ActPerTensorFixedPoint, return_quant_tensor=True) + self.upsample = qnn.QuantUpsamplingNearest2d(scale_factor=(1, 2.4)) # Would also work with Upsample(mode='nearest') + self.relu = nn.ReLU() + + def forward(self, x): + return self.relu(self.upsample(self.identity(x))) + + +@pytest.mark.parametrize('io_type', ['io_parallel']) # Quant upsampling layers currently not supported in io_stream +@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus']) +def test_pytorch_upsampling1d(data_1d, io_type, backend): + model = QuantUpsample1DModel() + + config = hls4ml.utils.config_from_pytorch_model( + model, + (None, in_feat, in_width), + default_precision='ap_fixed<16,6>', + channels_last_conversion="internal", + transpose_outputs=False, + ) + odir = str(test_root_path / f'hls4mlprj_pytorch_upsampling_1d_{backend}_{io_type}') + hls_model = hls4ml.converters.convert_from_pytorch_model( + model, hls_config=config, io_type=io_type, output_dir=odir, backend=backend + ) + hls_model.compile() + + data_1d_t = np.ascontiguousarray(data_1d.transpose([0, 2, 1])) + + pytorch_prediction = model(torch.Tensor(data_1d)).value.detach().numpy() + hls_prediction = hls_model.predict(data_1d_t) + + pred_shape = list(pytorch_prediction.shape) + pred_shape.append(pred_shape.pop(1)) # Transpose shape to channels_last + hls_prediction = hls_prediction.reshape(pred_shape).transpose([0, 2, 1]) # Transpose back + + np.testing.assert_allclose(hls_prediction, pytorch_prediction, rtol=1e-2, atol=0.01) + + +@pytest.mark.parametrize('io_type', ['io_parallel']) # Fractional scaling doesn't work with io_stream +@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus']) +def test_pytorch_upsampling2d(data_2d, io_type, backend): + model = QuantUpsample2DModel() + + config = hls4ml.utils.config_from_pytorch_model( + model, + (in_feat, in_height, in_width), + default_precision='ap_fixed<16,6>', + channels_last_conversion="full", # With conversion to channels_last + transpose_outputs=True, + ) + odir = str(test_root_path / f'hls4mlprj_pytorch_upsampling_2d_{backend}_{io_type}') + hls_model = hls4ml.converters.convert_from_pytorch_model( + model, hls_config=config, io_type=io_type, output_dir=odir, backend=backend + ) + hls_model.compile() + + pytorch_prediction = model(torch.Tensor(data_2d)).value.detach().numpy().flatten() + hls_prediction = hls_model.predict(data_2d).flatten() + + np.testing.assert_allclose(hls_prediction, pytorch_prediction, rtol=1e-2, atol=0.01) + + +class QuantEltwiseAddModel(nn.Module): + def __init__(self): + super().__init__() + self.add = qnn.QuantEltwiseAdd(input_quant=Int8ActPerTensorFixedPoint, output_quant=Int8ActPerTensorFixedPoint) + + def forward(self, x, y): + return self.add(x, y) + + +@pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream']) +@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus']) +def test_brevitas_quanteltwiseadd(io_type, backend): + model = QuantEltwiseAddModel() + + x = torch.rand(1, 4, 4) + y = torch.rand(1, 4, 4) + + pytorch_prediction = model(torch.Tensor(x), torch.Tensor(y)).detach().numpy() + + config = hls4ml.utils.config_from_pytorch_model( + model, + [(None, 4, 4), (None, 4, 4)], + default_precision='ap_fixed<16,6>', + channels_last_conversion="off", + transpose_outputs=False, + ) + odir = str(test_root_path / f'hls4mlprj_brevitas_quanteltwiseadd_{backend}_{io_type}') + hls_model = hls4ml.converters.convert_from_pytorch_model( + model, hls_config=config, io_type=io_type, output_dir=odir, backend=backend + ) + hls_model.compile() + + hls_prediction = hls_model.predict([x.detach().numpy(), y.detach().numpy()]) + + pred_shape = pytorch_prediction.shape + hls_prediction = hls_prediction.reshape(pred_shape) + + np.testing.assert_allclose(hls_prediction, pytorch_prediction, rtol=5e-2, atol=0.05) diff --git a/test/pytest/test_recurrent_brevitas.py b/test/pytest/test_recurrent_brevitas.py new file mode 100644 index 0000000000..fa6dd1c4e7 --- /dev/null +++ b/test/pytest/test_recurrent_brevitas.py @@ -0,0 +1,70 @@ +from pathlib import Path + +import brevitas.nn as qnn +import numpy as np +import pytest +import torch +from brevitas.quant import ( + Int8ActPerTensorFixedPoint, + Int8BiasPerTensorFixedPointInternalScaling, + Int8WeightPerTensorFixedPoint, +) +from torch import nn + +from hls4ml.converters import convert_from_pytorch_model +from hls4ml.utils.config import config_from_pytorch_model + +test_root_path = Path(__file__).parent + + +class QuantRNNModel(nn.Module): + def __init__(self): + super().__init__() + self.rnn = qnn.QuantRNN( + input_size=10, + hidden_size=20, + bidirectional=False, + shared_input_hidden_weights=False, + batch_first=True, + weight_quant=Int8WeightPerTensorFixedPoint, + bias_quant=Int8BiasPerTensorFixedPointInternalScaling, + io_quant=Int8ActPerTensorFixedPoint, + gate_acc_quant=Int8ActPerTensorFixedPoint, + return_quant_tensor=True, + bias=True, + ) + + def forward(self, x, h0): + output, _ = self.rnn(x, (h0)) + return output + + +@pytest.mark.parametrize('backend', ['Quartus', 'oneAPI']) +@pytest.mark.parametrize('io_type', ['io_parallel']) +def test_rnn(backend, io_type): + model = QuantRNNModel() + model.eval() + + X_input = torch.randn(1, 1, 10) + X_input = np.round(X_input * 2**16) * 2**-16 # make it exact ap_fixed<32,16> + h0 = torch.randn(1, 1, 20) + h0 = np.round(h0 * 2**16) * 2**-16 + + pytorch_prediction = model(torch.Tensor(X_input), torch.Tensor(h0)).detach().value.numpy() + + config = config_from_pytorch_model( + model, + [(None, 1, 10), (None, 1, 20)], + channels_last_conversion="off", + transpose_outputs=False, + default_precision='fixed<32,16>', + ) + output_dir = str(test_root_path / f'hls4mlprj_brevitas_rnn_{backend}_{io_type}') + + hls_model = convert_from_pytorch_model(model, hls_config=config, output_dir=output_dir, backend=backend, io_type=io_type) + + hls_model.compile() + + hls_prediction = np.reshape(hls_model.predict([X_input.detach().numpy(), h0.detach().numpy()]), pytorch_prediction.shape) + + np.testing.assert_allclose(hls_prediction, pytorch_prediction, atol=2) # quite bad accuracy so far