diff --git a/hls4ml/backends/fpga/fpga_backend.py b/hls4ml/backends/fpga/fpga_backend.py index bd85937d89..e52bb85494 100644 --- a/hls4ml/backends/fpga/fpga_backend.py +++ b/hls4ml/backends/fpga/fpga_backend.py @@ -26,6 +26,7 @@ GlobalPooling2D, MatMul, Merge, + MultiHeadAttention, Pooling1D, Pooling2D, Quant, @@ -71,6 +72,7 @@ def __init__(self, name): Dot, Conv, MatMul, + MultiHeadAttention, ] for layer in accum_layers: diff --git a/hls4ml/backends/vitis/passes/transformer_templates.py b/hls4ml/backends/vitis/passes/transformer_templates.py new file mode 100644 index 0000000000..9b91d3b081 --- /dev/null +++ b/hls4ml/backends/vitis/passes/transformer_templates.py @@ -0,0 +1,146 @@ +from hls4ml.backends.backend import get_backend +from hls4ml.backends.template import FunctionCallTemplate, LayerConfigTemplate +from hls4ml.model.layers import MultiHeadAttention + +# dense layer template +mult_config_template = """struct config{index}_{mNum} : nnet::dense_config {{ + static const unsigned n_in = {n_in}; + static const unsigned n_out = {n_out}; + static const unsigned io_type = nnet::{iotype}; + static const unsigned strategy = nnet::{strategy}; + static const unsigned reuse_factor = {reuse}; + static const unsigned n_zeros = {nzeros}; + static const unsigned n_nonzeros = {nonzeros}; + static const unsigned multiplier_limit = DIV_ROUNDUP(n_in * n_out, reuse_factor) - n_zeros / reuse_factor; + static const bool store_weights_in_bram = false; + typedef {accum_t.name} accum_t; + typedef {attention_output_bias_t.name} bias_t; + typedef {attention_output_weight_t.name} weight_t; + typedef ap_{index_t} index_t; + template + using kernel = nnet::{dense_function}; + template + using product = nnet::product::{product_type}; +}};\n""" + +# activation template +softmax_config_template = """struct {type}_config{index} : nnet::activ_config {{ + static const unsigned n_in = {n_in}; + static const unsigned table_size = {table_size}; + static const unsigned io_type = nnet::{iotype}; + static const unsigned reuse_factor = {reuse}; + static const nnet::softmax_implementation implementation = nnet::softmax_implementation::{implementation}; + typedef {table_t.name} exp_table_t; + typedef {table_t.name} inv_table_t; +}};\n""" + +mha_config_template = """struct config{index} : nnet::multiheadattention_config {{ + typedef {accum_t.name} accum_t; + typedef {attention_output_bias_t.name} bias_t; + typedef {attention_output_weight_t.name} weight_t; + typedef {config_mult_t1} config_mult1; + typedef {config_mult_t2} config_mult2; + typedef {config_activ_t1} softmax_config1; + + static const unsigned num_heads = {num_heads}; + static const unsigned head_dim_key = {head_dim_key}; + static const unsigned head_dim_value = {head_dim_value}; + static const unsigned feature_dim = {feature_dim}; + static const unsigned seq_len = {seq_len}; + + static const unsigned io_type = nnet::{iotype}; + static const unsigned reuse_factor = {reuse}; + static const bool store_weights_in_bram = false; +}};\n""" + +mha_function_template = """nnet::multiheadattention<{input_t}, {output_t}, {config}>({input_q}, {input_kv}, + {output}, {w_o}, {b_o}, {w_k}, {b_k}, {w_q}, {b_q}, {w_v}, {b_v});""" + +mha_include_list = ['nnet_utils/nnet_multiheadattention.h'] + + +class MhaConfigTemplate(LayerConfigTemplate): + def __init__(self): + super().__init__(MultiHeadAttention) + self.template = mha_config_template + self.mult1_template = mult_config_template + self.mult2_template = mult_config_template + self.activ1_template = softmax_config_template + + def format(self, node): + params = self._default_config_params(node) + params['num_heads'] = node.get_attr('num_heads') + params['head_dim_key'] = node.get_attr('head_dim_key') + params['head_dim_value'] = node.get_attr('head_dim_value') + params['feature_dim'] = node.get_attr('feature_dim') + params['seq_len'] = node.get_attr('seq_len') + params['config_mult_t1'] = f'config{node.index}_1' + params['config_mult_t2'] = f'config{node.index}_2' + params['config_activ_t1'] = '{}_config{}'.format("softmax", node.index) + params['strategy'] = node.get_attr('strategy') + mha_config = self.template.format(**params) + + mult_params1 = self._default_config_params(node) + mult_params1['strategy'] = 'latency' + mult_params1['mNum'] = '1' + mult_params1['n_in'] = node.get_attr('feature_dim') + mult_params1['n_out'] = node.get_attr('head_dim_key') + mult_params1['product_type'] = get_backend('vivado').product_type( + node.get_input_variable().type.precision, node.get_weights('query_weight').type.precision + ) + mult_params1['reuse'] = params['reuse'] + mult_params1['index'] = str(node.index) + mult_params1['nzeros'] = 0 + mult_params1['nonzeros'] = params['feature_dim'] * params['num_heads'] * params['head_dim_key'] + mult_params1['dense_function'] = 'DenseLatency' + mult_config1 = self.mult1_template.format(**mult_params1) + + mult_params2 = self._default_config_params(node) + mult_params2['strategy'] = 'latency' + mult_params2['mNum'] = '2' + mult_params2['n_in'] = node.get_attr('head_dim_value') * node.get_attr('num_heads') + mult_params2['n_out'] = node.get_attr('feature_dim') + mult_params2['product_type'] = get_backend('vivado').product_type( + node.get_input_variable().type.precision, node.get_weights('attention_output_weight').type.precision + ) + mult_params2['reuse'] = params['reuse'] + mult_params2['index'] = str(node.index) + mult_params2['nzeros'] = 0 + mult_params2['nonzeros'] = params['feature_dim'] * params['num_heads'] * params['head_dim_key'] + mult_params2['dense_function'] = 'DenseLatency' + mult_config2 = self.mult2_template.format(**mult_params2) + + act_params = self._default_config_params(node) + act_params['n_in'] = node.get_attr('seq_len') + act_params['type'] = 'softmax' + act_params['implementation'] = 'legacy' # in MHA: latency,stable not work, legacy works + act_config = self.activ1_template.format(**act_params) + + return mult_config1 + '\n' + mult_config2 + '\n' + act_config + '\n' + mha_config + + +class MhaFunctionTemplate(FunctionCallTemplate): + def __init__(self): + super().__init__(MultiHeadAttention, include_header=mha_include_list) + self.template = mha_function_template + + def format(self, node): + params = {} + params.update(node.attributes) + params['config'] = f'config{node.index}' + params['input_t'] = node.get_input_variable().type.name + params['output_t'] = node.get_output_variable().type.name + + params['input_q'] = node.model.get_layer_output_variable(node.inputs[0]).name + params['input_kv'] = node.model.get_layer_output_variable(node.inputs[1]).name + params['output'] = node.get_output_variable().name + params['w_o'] = node.get_weights('attention_output_weight').name + params['b_o'] = node.get_weights('attention_output_bias').name + params['w_k'] = node.get_weights('key_weight').name + params['b_k'] = node.get_weights('key_bias').name + params['w_q'] = node.get_weights('query_weight').name + params['b_q'] = node.get_weights('query_bias').name + params['w_v'] = node.get_weights('value_weight').name + params['b_v'] = node.get_weights('value_bias').name + + return self.template.format(**params) diff --git a/hls4ml/backends/vitis/vitis_backend.py b/hls4ml/backends/vitis/vitis_backend.py index cf623bb19a..7126fd5cc7 100644 --- a/hls4ml/backends/vitis/vitis_backend.py +++ b/hls4ml/backends/vitis/vitis_backend.py @@ -3,6 +3,9 @@ from hls4ml.backends import VivadoBackend from hls4ml.model.flow import get_flow, register_flow +from hls4ml.model.layers import MultiHeadAttention +from hls4ml.model.optimizer import layer_optimizer +from hls4ml.model.types import FixedPrecisionType, IntegerPrecisionType, NamedType from hls4ml.report import parse_vivado_report @@ -13,6 +16,9 @@ def __init__(self): self._register_flows() def _register_flows(self): + initializers = self._get_layer_initializers() + init_flow = register_flow('init_layers', initializers, requires=['optimize'], backend=self.name) + validation_passes = [ 'vitis:validate_conv_implementation', 'vitis:validate_resource_strategy', @@ -30,6 +36,7 @@ def _register_flows(self): ip_flow_requirements = get_flow('vivado:ip').requires.copy() ip_flow_requirements.insert(ip_flow_requirements.index('vivado:init_layers'), validation_flow) + ip_flow_requirements.insert(ip_flow_requirements.index('vivado:streaming'), init_flow) ip_flow_requirements.insert(ip_flow_requirements.index('vivado:apply_templates'), template_flow) self._default_flow = register_flow('ip', None, requires=ip_flow_requirements, backend=self.name) @@ -120,3 +127,24 @@ def build( os.chdir(curr_dir) return parse_vivado_report(model.config.get_output_dir()) + + @layer_optimizer(MultiHeadAttention) + def init_mha(self, layer): + # TODO Allow getting recurrent reuse factor from the config + reuse_factor = layer.model.config.get_reuse_factor(layer) + layer.set_attr('reuse_factor', reuse_factor) + index_t = IntegerPrecisionType(width=1, signed=False) + layer.set_attr('index_t', index_t) + if 'table_t' not in layer.attributes: + layer.set_attr( + 'table_t', NamedType(name=layer.name + '_table_t', precision=FixedPrecisionType(width=24, integer=8)) + ) + if 'table_size' not in layer.attributes: + layer.set_attr('table_size', 2048) + if 'accum_t' not in layer.attributes: + layer.set_attr('accum_t', FixedPrecisionType(width=24, integer=8)) + if 'inv_range' not in layer.attributes: + layer.set_attr('inv_range', 128) + if 'exp_range' not in layer.attributes: + layer.set_attr('exp_range', 8) + layer.set_attr('strategy', 'resource') # latency diff --git a/hls4ml/converters/keras/multiheadattention.py b/hls4ml/converters/keras/multiheadattention.py new file mode 100644 index 0000000000..c295236561 --- /dev/null +++ b/hls4ml/converters/keras/multiheadattention.py @@ -0,0 +1,60 @@ +from hls4ml.converters.keras_to_hls import get_weights_data, keras_handler, parse_default_keras_layer + + +@keras_handler('MultiHeadAttention') +def parse_mutiheadattention_layer(keras_layer, input_names, input_shapes, data_reader): + # assume input_shapes is: [[None, seq, dim]] + assert 'MultiHeadAttention' in keras_layer['class_name'] + assert input_shapes[0] == keras_layer['config']['query_shape'] + + layer = parse_default_keras_layer(keras_layer, input_names) + + layer['num_heads'] = keras_layer['config']['num_heads'] + layer['head_dim_key'] = keras_layer['config']['key_dim'] + layer['head_dim_value'] = keras_layer['config']['value_dim'] + layer['query_shape'] = keras_layer['config']['query_shape'] + layer['key_shape'] = keras_layer['config']['key_shape'] + layer['value_shape'] = keras_layer['config']['value_shape'] + layer['feature_dim'] = layer['query_shape'][-1] + layer['seq_len'] = layer['query_shape'][-2] + + if keras_layer['config']['output_shape']: + raise Exception('hls4ml does not support a defined output shape, the output shape must equal to the query shape') + else: + output_shape = layer['query_shape'] + + layer['attention_axes'] = ( + keras_layer['config']['attention_axes'] if (keras_layer['config']['attention_axes'][0] == 1) else False + ) + if layer['attention_axes'] is False: + raise Exception('assigning the attention_axes is not currently supported by hls4ml') + + if not (len(layer['query_shape']) == 3 and len(layer['key_shape']) == 3 and len(layer['value_shape']) == 3): + raise Exception('only 3D shapes for query, key, and value are currently supported by hls4ml') + + attn_scores_rank = 4 + layer['softmax_axis'] = list(range(attn_scores_rank - len(layer['attention_axes']), attn_scores_rank)) + + weights_sources = [ + ('attention_output', 'kernel'), + ('attention_output', 'bias'), + ('key', 'kernel'), + ('key', 'bias'), + ('query', 'kernel'), + ('query', 'bias'), + ('value', 'kernel'), + ('value', 'bias'), + ] + + for lname, wtype in weights_sources: + data = get_weights_data(data_reader, layer['name'], f'{lname}/{wtype}') + if wtype == 'kernel': + vtype = 'weight' + if lname in ['key', 'query', 'value']: + data = data.transpose((1, 0, 2)) + else: + vtype = 'bias' + + layer[f'{lname}_{vtype}_data'] = data + + return layer, output_shape diff --git a/hls4ml/converters/keras_to_hls.py b/hls4ml/converters/keras_to_hls.py index 00561e6ba8..07b8193674 100644 --- a/hls4ml/converters/keras_to_hls.py +++ b/hls4ml/converters/keras_to_hls.py @@ -284,6 +284,10 @@ def parse_keras_model(model_arch, reader): # Extract inbound nodes if 'inbound_nodes' in keras_layer and len(keras_layer['inbound_nodes']) > 0: input_names = [inputs_map.get(inp[0], inp[0]) for inp in keras_layer['inbound_nodes'][0]] + if keras_layer['inbound_nodes'][0][0][-1]: + # multi_head_attention has inbound: [[['input_3', 0, 0, {'value': ['dense_3', 0, 0]}]]] + inputname2 = list(keras_layer['inbound_nodes'][0][0][-1].values()) + input_names += [inp[0] for inp in inputname2] else: input_names = None diff --git a/hls4ml/converters/pytorch/multiheadattention.py b/hls4ml/converters/pytorch/multiheadattention.py new file mode 100644 index 0000000000..7c53aeeb54 --- /dev/null +++ b/hls4ml/converters/pytorch/multiheadattention.py @@ -0,0 +1,54 @@ +import numpy as np + +from hls4ml.converters.pytorch_to_hls import pytorch_handler + + +@pytorch_handler('MultiheadAttention') +def parse_multiheadattention_layer( + operation, layer_name, input_names, input_shapes, node, class_object, data_reader, config +): + assert 'MultiheadAttention' in operation + assert len(input_shapes) == 3 + + layer = {} + + layer['class_name'] = 'MultiHeadAttention' + layer['name'] = layer_name + layer['inputs'] = input_names + + layer['num_heads'] = class_object.num_heads + layer['head_dim_key'] = class_object.kdim // layer['num_heads'] + layer['head_dim_value'] = class_object.vdim // layer['num_heads'] + layer['query_shape'] = input_shapes[0] + layer['key_shape'] = input_shapes[1] + layer['value_shape'] = input_shapes[2] + + if not (len(layer['query_shape']) == len(layer['key_shape']) == len(layer['value_shape']) == 3): + raise Exception('only 3D shapes for query, key, and value are currently supported by hls4ml') + + layer['feature_dim'] = class_object.embed_dim + layer['seq_len'] = layer['query_shape'][-2] + + output_shape = layer['query_shape'] + + layer['attention_axes'] = [1] + layer['softmax_axis'] = [3] + + in_proj_weights = class_object.in_proj_weight.data.numpy() + in_proj_bias = class_object.in_proj_bias.data.numpy() + + weight_data = np.split(in_proj_weights, [class_object.embed_dim, class_object.embed_dim + class_object.kdim], axis=0) + bias_data = np.split(in_proj_bias, [class_object.embed_dim, class_object.embed_dim + class_object.kdim], axis=0) + + for weight_type, weight, bias in zip(['query', 'key', 'value'], weight_data, bias_data): + layer[f'{weight_type}_weight_data'] = weight.T.reshape( + layer['feature_dim'], layer['num_heads'], layer['head_dim_key'] + ).transpose(1, 0, 2) + layer[f'{weight_type}_bias_data'] = bias.reshape(layer['num_heads'], layer['head_dim_key']) + + layer['attention_output_weight_data'] = class_object.out_proj.weight.data.numpy().T.reshape( + layer['num_heads'], layer['head_dim_key'], layer['feature_dim'] + ) + layer['attention_output_bias_data'] = class_object.out_proj.bias.data.numpy() + + return layer, output_shape diff --git a/hls4ml/model/layers.py b/hls4ml/model/layers.py index 03e3d9ce8a..d2b995de3d 100644 --- a/hls4ml/model/layers.py +++ b/hls4ml/model/layers.py @@ -1621,6 +1621,54 @@ def initialize(self): self.add_output_variable([len(self.get_attr('expression'))], [f'N_OUTPUTS_{self.index}'], var_name='y') +class MultiHeadAttention(Layer): + _expected_attributes = [ + Attribute('num_heads'), + Attribute('head_dim_key'), + Attribute('head_dim_value'), + Attribute('feature_dim'), + Attribute('seq_len'), + WeightAttribute('attention_output_weight'), + WeightAttribute('attention_output_bias'), + WeightAttribute('key_weight'), + WeightAttribute('key_bias'), + WeightAttribute('query_weight'), + WeightAttribute('query_bias'), + WeightAttribute('value_weight'), + WeightAttribute('value_bias'), + TypeAttribute('attention_output_weight'), + TypeAttribute('attention_output_bias'), + TypeAttribute('key_weight'), + TypeAttribute('key_bias'), + TypeAttribute('query_weight'), + TypeAttribute('query_bias'), + TypeAttribute('value_weight'), + TypeAttribute('value_bias'), + ] + + def initialize(self): + weights = [ + 'attention_output_weight', + 'attention_output_bias', + 'key_weight', + 'key_bias', + 'query_weight', + 'query_bias', + 'value_weight', + 'value_bias', + ] + + for w in weights: + data_name = f'{w}_data' + var_name = f'{w}{{index}}' + data = self.get_attr(data_name) + self.add_weights_variable(name=w, var_name=var_name, data=data) + + shape = self.attributes['query_shape'][1:] + dims = [f'seq_out_{self.index}', f'feature_out_{self.index}'] + self.add_output_variable(shape, dims) + + layer_map = { 'Input': Input, 'InputLayer': Input, @@ -1687,6 +1735,7 @@ def initialize(self): 'BatchNormOnnx': BatchNormOnnx, 'LayerGroup': LayerGroup, 'SymbolicExpression': SymbolicExpression, + 'MultiHeadAttention': MultiHeadAttention, # TensorFlow-specific layers: 'BiasAdd': BiasAdd, } diff --git a/hls4ml/model/optimizer/passes/infer_precision.py b/hls4ml/model/optimizer/passes/infer_precision.py index bd439e4a0f..635fd582ae 100644 --- a/hls4ml/model/optimizer/passes/infer_precision.py +++ b/hls4ml/model/optimizer/passes/infer_precision.py @@ -87,6 +87,9 @@ def _infer_precision(self, node, types_to_infer): if node_class in ['ParametrizedActivation']: return self._infer_par_act_precision(node, types_to_infer) + if node_class in ['MultiHeadAttention']: + return self._infer_mha_precision(node, types_to_infer) + # What about quantized activation layer? Setting it to 'auto' manually will break it here. We should prevent # this in config_from_* functions @@ -573,3 +576,58 @@ def _infer_par_act_precision(self, node, types_to_infer): inferred_types.append('param_t') return inferred_types + + def _infer_mha_precision(self, node, types_to_infer): + inferred_types = [] + + for weightvar in ( + 'attention_output_weight', + 'attention_output_bias', + 'key_weight', + 'key_bias', + 'query_weight', + 'query_bias', + 'value_weight', + 'value_bias', + ): + if f'{weightvar}_t' in types_to_infer: + self._infer_default_type(node, f'{weightvar}_t') + node.weights[weightvar].update_precision(node.types[f'{weightvar}_t'].precision) + inferred_types.append(f'{weightvar}_t') + + if 'result_t' in types_to_infer: + input_precision = node.get_input_variable().type.precision + weight_precision = node.types['attention_output_weight_t'].precision + bias_precision = node.types['attention_output_bias_t'].precision + + if self._all_supported_types((input_precision, weight_precision, bias_precision)): + + after_weight_width = input_precision.width + weight_precision.width + after_weight_integer = input_precision.integer + weight_precision.integer + after_weight_signed = input_precision.signed or weight_precision.signed + + out_signed = after_weight_signed or bias_precision.signed + out_integer = ( + max( + after_weight_integer + (bias_precision.signed and not after_weight_signed), + bias_precision.integer + (after_weight_signed and not bias_precision.signed), + ) + + 1 + ) + out_width = out_integer + max(after_weight_width - after_weight_integer, bias_precision.fractional) + + # Apply max precision constraints if specified in model config + max_precision = self._get_maximum_precision(node) + if max_precision is not None: + out_width = min(out_width, max_precision.width) + out_integer = min(out_integer, max_precision.integer) + + out_precision = FixedPrecisionType(out_width, out_integer, out_signed) + else: + out_precision = self._get_default_precision(node) + + node.types['result_t'].name = f'{node.name}_result_t' + node.types['result_t'].precision = out_precision + inferred_types.append('result_t') + + return inferred_types diff --git a/hls4ml/templates/vitis/nnet_utils/nnet_multiheadattention.h b/hls4ml/templates/vitis/nnet_utils/nnet_multiheadattention.h new file mode 100644 index 0000000000..f05ba10821 --- /dev/null +++ b/hls4ml/templates/vitis/nnet_utils/nnet_multiheadattention.h @@ -0,0 +1,328 @@ +#ifndef NNET_MHT_H_ +#define NNET_MHT_H_ + +#include "hls_stream.h" +#include "nnet_activation.h" +#include "nnet_common.h" +#include "nnet_dense.h" +#include "nnet_mult.h" +#include + +namespace nnet { + +struct multiheadattention_config { + // Internal data type definitions + typedef float bias_t; + typedef float weight_t; + typedef float accum_t; + typedef ap_fixed<16, 8> multi_t; + + // Layer Sizes + static const unsigned num_heads = 10; + static const unsigned head_dim_key = 10; + static const unsigned head_dim_value = 10; + static const unsigned feature_dim = 20; + static const unsigned seq_len = 500; + + // Resource reuse info + static const unsigned io_type = io_parallel; + static const unsigned strategy = latency; + static const unsigned reuse_factor = 1; + static const bool store_weights_in_bram = false; + + template using product = nnet::product::mult; +}; + +template struct datapack { data_T data[PackSize]; }; + +template void read_stream_array(hls::stream data_in[size], data_T out[size]) { + for (int k = 0; k < size; ++k) { + #pragma HLS UNROLL + out[k] = data_in[k].read(); + } +} + +template +void matrixmul_transpose(hls::stream> &Q, + hls::stream> &K, + res_T QK[CONFIG_T::seq_len][CONFIG_T::seq_len]) // seq_Q, seq_K +{ + const data_T dk = 1.0 / sqrt(CONFIG_T::head_dim_key); + data_T QK_1; + typename CONFIG_T::accum_t QKij; + data_T Qi[CONFIG_T::head_dim_key]; + data_T Product[CONFIG_T::seq_len]; // seq_Q, seq_K + res_T qk_smout[CONFIG_T::seq_len]; + data_T krow[CONFIG_T::seq_len * CONFIG_T::head_dim_key]; + #pragma HLS ARRAY_PARTITION variable=Qi complete + #pragma HLS ARRAY_PARTITION variable=Product complete + #pragma HLS ARRAY_PARTITION variable=qk_smout complete + #pragma HLS ARRAY_PARTITION variable=QK complete dim=2 + #pragma HLS ARRAY_PARTITION variable=krow complete + + datapack datak_pack, dataq_pack; + #pragma HLS DATA_PACK variable=Q + #pragma HLS DATA_PACK variable=K + #pragma HLS DATA_PACK variable=datak_pack + #pragma HLS DATA_PACK variable=dataq_pack + + // int multiplier_limit = ceil(float(CONFIG_T::seq_len * CONFIG_T::head_dim_key) / float(CONFIG_T::reuse_factor)); + // CONFIG_T::template product::limit(multiplier_limit); + +prep_k: + for (int i = 0; i < CONFIG_T::seq_len; ++i) { + #pragma HLS PIPELINE II=CONFIG_T::reuse_factor + datak_pack = K.read(); + for (int j = 0; j < CONFIG_T::head_dim_key; ++j) { + #pragma HLS UNROLL + krow[i * CONFIG_T::head_dim_key + j] = datak_pack.data[j]; + } + } + +row: + for (int i = 0; i < CONFIG_T::seq_len; ++i) { + #pragma HLS PIPELINE II=CONFIG_T::reuse_factor + dataq_pack = Q.read(); + + q: + for (int q_i = 0; q_i < CONFIG_T::head_dim_key; ++q_i) { + #pragma HLS UNROLL + Qi[q_i] = dataq_pack.data[q_i]; + } + col: + for (int j = 0; j < CONFIG_T::seq_len; ++j) { + QKij = 0; + product: + for (int k = 0; k < CONFIG_T::head_dim_key; ++k) { + QK_1 = CONFIG_T::template product::product(Qi[k], krow[j * CONFIG_T::head_dim_key + k]); + QKij += QK_1; + } + Product[j] = QKij * dk; + } + softmax(Product, qk_smout); + for (int n = 0; n < CONFIG_T::seq_len; ++n) { + #pragma HLS UNROLL + QK[i][n] = qk_smout[n]; + } + } +} + +template +void matrixmul(data_T QK[CONFIG_T::seq_len][CONFIG_T::seq_len], hls::stream> &V, + hls::stream S[CONFIG_T::head_dim_value]) // S: attention score +{ + #pragma HLS DATA_PACK variable=V + #pragma HLS ARRAY_PARTITION variable=QK complete dim=2 + #pragma HLS ARRAY_PARTITION variable=S complete dim=1 + + datapack datav_pack; + #pragma HLS DATA_PACK variable=datav_pack + + // int multiplier_limit = ceil(float(CONFIG_T::seq_len * CONFIG_T::head_dim_value) / float(CONFIG_T::reuse_factor)); + // CONFIG_T::template product::limit(multiplier_limit); + + data_T dataV[CONFIG_T::seq_len * CONFIG_T::head_dim_value]; + #pragma HLS ARRAY_PARTITION variable = dataV complete dim = 1 + + for (int j = 0; j < CONFIG_T::seq_len; ++j) { + #pragma HLS PIPELINE II=CONFIG_T::reuse_factor + datav_pack = V.read(); + for (int i = 0; i < CONFIG_T::head_dim_value; ++i) { + #pragma HLS UNROLL + dataV[CONFIG_T::seq_len * i + j] = datav_pack.data[i]; + } + } + + data_T Sij, S_1; + data_T QKi[CONFIG_T::seq_len]; +#pragma HLS ARRAY_Partition variable=QKi complete +row: + for (int i = 0; i < CONFIG_T::seq_len; ++i) { + #pragma HLS PIPELINE II=CONFIG_T::reuse_factor + qk: + for (int q_i = 0; q_i < CONFIG_T::seq_len; ++q_i) { + #pragma HLS UNROLL + QKi[q_i] = QK[i][q_i]; + } + col: + for (int j = 0; j < CONFIG_T::head_dim_value; ++j) { + Sij = 0; + product: + for (int k = 0; k < CONFIG_T::seq_len; ++k) { + S_1 = CONFIG_T::template product::product(QKi[k], dataV[j * CONFIG_T::seq_len + k]); + Sij += S_1; + } + S[j].write(Sij); + } + } +} + +template +void lin_projection(hls::stream data_q[CONFIG_T::feature_dim], hls::stream data_vk[CONFIG_T::feature_dim], + hls::stream> &k_proj, + hls::stream> &q_proj, + hls::stream> &v_proj, + typename CONFIG_T::weight_t key_weight[CONFIG_T::feature_dim * CONFIG_T::head_dim_key], + typename CONFIG_T::bias_t key_bias[CONFIG_T::head_dim_key], + typename CONFIG_T::weight_t query_weight[CONFIG_T::feature_dim * CONFIG_T::head_dim_key], + typename CONFIG_T::bias_t query_bias[CONFIG_T::head_dim_key], + typename CONFIG_T::weight_t value_weight[CONFIG_T::feature_dim * CONFIG_T::head_dim_value], + typename CONFIG_T::bias_t value_bias[CONFIG_T::head_dim_value]) { + #pragma HLS DATA_PACK variable=k_proj + #pragma HLS DATA_PACK variable=q_proj + #pragma HLS DATA_PACK variable=v_proj + + #pragma HLS ARRAY_PARTITION variable=data_q complete dim=1 + #pragma HLS ARRAY_PARTITION variable=data_vk complete dim=1 + +k_h: + for (int j = 0; j < CONFIG_T::seq_len; ++j) { + #pragma HLS PIPELINE + + data_T proj_k[CONFIG_T::head_dim_key]; + data_T proj_q[CONFIG_T::head_dim_key]; + data_T proj_v[CONFIG_T::head_dim_value]; + data_T in_q[CONFIG_T::feature_dim]; + data_T in_v[CONFIG_T::feature_dim]; + #pragma HLS ARRAY_PARTITION variable=proj_k complete dim=1 + #pragma HLS ARRAY_PARTITION variable=proj_q complete dim=1 + #pragma HLS ARRAY_PARTITION variable=proj_v complete dim=1 + #pragma HLS ARRAY_PARTITION variable=in_q complete dim=1 + #pragma HLS ARRAY_PARTITION variable=in_v complete dim=1 + + datapack proj_k_pack; + datapack proj_q_pack; + datapack proj_v_pack; + #pragma HLS DATA_PACK variable=proj_k_pack + #pragma HLS DATA_PACK variable=proj_q_pack + #pragma HLS DATA_PACK variable=proj_v_pack + + read_stream_array(data_q, in_q); + read_stream_array(data_vk, in_v); + + dense(in_v, proj_k_pack.data, key_weight, key_bias); + dense(in_q, proj_q_pack.data, query_weight, query_bias); + dense(in_v, proj_v_pack.data, value_weight, value_bias); + + k_proj.write(proj_k_pack); + q_proj.write(proj_q_pack); + v_proj.write(proj_v_pack); + } +} + +template +void dense_out(hls::stream data_in[CONFIG_T::num_heads][CONFIG_T::head_dim_value], + res_T res[CONFIG_T::seq_len * CONFIG_T::feature_dim], + typename CONFIG_T::weight_t + attention_output_weight[CONFIG_T::num_heads * CONFIG_T::head_dim_value * CONFIG_T::feature_dim], + typename CONFIG_T::bias_t attention_output_bias[CONFIG_T::feature_dim]) { + data_T mat_res_con[CONFIG_T::num_heads * CONFIG_T::head_dim_value]; + res_T dense_out[CONFIG_T::feature_dim]; +#pragma HLS ARRAY_PARTITION variable=mat_res_con complete dim=1 +#pragma HLS ARRAY_PARTITION variable=dense_out complete dim=1 +output_dense: + for (int k = 0; k < CONFIG_T::seq_len; ++k) { + + #pragma HLS PIPELINE + for (int i = 0; i < CONFIG_T::num_heads; ++i) { + #pragma HLS UNROLL + for (int j = 0; j < CONFIG_T::head_dim_value; ++j) { + #pragma HLS UNROLL + mat_res_con[CONFIG_T::head_dim_value * i + j] = data_in[i][j].read(); + } + } + dense(mat_res_con, dense_out, attention_output_weight, + attention_output_bias); + for (int i = 0; i < CONFIG_T::feature_dim; ++i) { + #pragma HLS UNROLL + res[CONFIG_T::feature_dim * k + i] = dense_out[i]; + } + } +} + +template +void data_prep(data_T data[CONFIG_T::seq_len * CONFIG_T::feature_dim], hls::stream d[CONFIG_T::feature_dim]) { + #pragma HLS ARRAY_PARTITION variable=d complete dim=1 + for (int j = 0; j < CONFIG_T::seq_len; ++j) { + for (int k = 0; k < CONFIG_T::feature_dim; ++k) { + #pragma HLS UNROLL + d[k].write(data[j * CONFIG_T::feature_dim + k]); + } + } +} + +template +void multiheadattention( + data_T data_q[CONFIG_T::seq_len * CONFIG_T::feature_dim], data_T data_vk[CONFIG_T::seq_len * CONFIG_T::feature_dim], + res_T res[CONFIG_T::seq_len * CONFIG_T::feature_dim], + typename CONFIG_T::weight_t attention_output_weight[CONFIG_T::num_heads * CONFIG_T::head_dim_value * + CONFIG_T::feature_dim], // num_heads,head_size_v,dim + typename CONFIG_T::bias_t attention_output_bias[CONFIG_T::feature_dim], + typename CONFIG_T::weight_t + key_weight[CONFIG_T::feature_dim * CONFIG_T::num_heads * CONFIG_T::head_dim_key], // n_head,dim,head_dim + typename CONFIG_T::bias_t key_bias[CONFIG_T::num_heads * CONFIG_T::head_dim_key], + typename CONFIG_T::weight_t + query_weight[CONFIG_T::feature_dim * CONFIG_T::num_heads * CONFIG_T::head_dim_key], // same shape as key + typename CONFIG_T::bias_t query_bias[CONFIG_T::num_heads * CONFIG_T::head_dim_key], + typename CONFIG_T::weight_t value_weight[CONFIG_T::feature_dim * CONFIG_T::num_heads * CONFIG_T::head_dim_value], + typename CONFIG_T::bias_t value_bias[CONFIG_T::num_heads * CONFIG_T::head_dim_value]) { + hls::stream d_value[CONFIG_T::num_heads][CONFIG_T::feature_dim]; + hls::stream d_query[CONFIG_T::num_heads][CONFIG_T::feature_dim]; + hls::stream> q_proj[CONFIG_T::num_heads]; + hls::stream> k_proj[CONFIG_T::num_heads]; + hls::stream> v_proj[CONFIG_T::num_heads]; + res_T qk_mul[CONFIG_T::num_heads][CONFIG_T::seq_len][CONFIG_T::seq_len]; + hls::stream matr_out[CONFIG_T::num_heads][CONFIG_T::head_dim_value]; + #pragma HLS stream variable=d_value type=fifo depth=CONFIG_T::feature_dim + #pragma HLS stream variable=d_query type=fifo depth=CONFIG_T::feature_dim + #pragma HLS stream variable=q_proj type=fifo depth=CONFIG_T::seq_len + #pragma HLS stream variable=k_proj type=fifo depth=CONFIG_T::seq_len + #pragma HLS stream variable=v_proj type=fifo depth=CONFIG_T::seq_len + #pragma HLS stream variable=matr_out type=fifo depth=CONFIG_T::head_dim_value + + #pragma HLS DATAFLOW + #pragma HLS ARRAY_PARTITION variable=d_query complete dim=1 + #pragma HLS ARRAY_PARTITION variable=v_proj complete dim=1 + #pragma HLS ARRAY_PARTITION variable=q_proj complete dim=1 + #pragma HLS ARRAY_PARTITION variable=k_proj complete dim=1 + #pragma HLS ARRAY_PARTITION variable=qk_mul complete dim=1 + #pragma HLS ARRAY_PARTITION variable=matr_out complete dim=1 +prepq: + for (int i = 0; i < CONFIG_T::num_heads; ++i) { + #pragma HLS UNROLL + nnet::data_prep(data_q, d_query[i]); + } +prepvk: + for (int i = 0; i < CONFIG_T::num_heads; ++i) { + #pragma HLS UNROLL + nnet::data_prep(data_vk, d_value[i]); + } + +lin_proj: + for (int i = 0; i < CONFIG_T::num_heads; ++i) { + #pragma HLS UNROLL + nnet::lin_projection( + d_query[i], d_value[i], k_proj[i], q_proj[i], v_proj[i], + key_weight + (CONFIG_T::head_dim_key * CONFIG_T::feature_dim * i), key_bias + (CONFIG_T::head_dim_key * i), + query_weight + (CONFIG_T::head_dim_key * CONFIG_T::feature_dim * i), query_bias + (CONFIG_T::head_dim_key * i), + value_weight + (CONFIG_T::head_dim_value * CONFIG_T::feature_dim * i), + value_bias + (CONFIG_T::head_dim_value * i)); + } + +maxtrixmul1: + for (int i = 0; i < CONFIG_T::num_heads; ++i) { + #pragma HLS UNROLL + nnet::matrixmul_transpose(q_proj[i], k_proj[i], qk_mul[i]); + } + +maxtrixmul2: + for (int i = 0; i < CONFIG_T::num_heads; ++i) { + #pragma HLS UNROLL + nnet::matrixmul(qk_mul[i], v_proj[i], matr_out[i]); // stream + } + + nnet::dense_out(matr_out, res, attention_output_weight, attention_output_bias); +} +} // namespace nnet + +#endif diff --git a/test/pytest/test_multiheadattention.py b/test/pytest/test_multiheadattention.py new file mode 100644 index 0000000000..e446a22f70 --- /dev/null +++ b/test/pytest/test_multiheadattention.py @@ -0,0 +1,52 @@ +from pathlib import Path + +import numpy as np +import pytest +from tensorflow.keras import Model +from tensorflow.keras.layers import Input, MultiHeadAttention + +import hls4ml + +test_root_path = Path(__file__).parent + +batch_size = 100 +seq_len = 10 +num_heads = 2 +key_dim = 4 + +atol = 2e-2 + + +@pytest.fixture(scope='module') +def query_data(): + return np.random.rand(batch_size, seq_len, num_heads * key_dim) + + +@pytest.fixture(scope='module') +def key_value_data(): + return np.random.rand(batch_size, seq_len, num_heads * key_dim) + + +@pytest.fixture(scope='module') +def model(): + query_input = Input(shape=(seq_len, num_heads * key_dim)) + key_value_input = Input(shape=(seq_len, num_heads * key_dim)) + mha_layer = MultiHeadAttention(num_heads=num_heads, key_dim=key_dim)(query_input, key_value_input) + model = Model(inputs=[query_input, key_value_input], outputs=mha_layer) + model.compile() + return model + + +# Currently only Vitis in io_parallel mode is supported +def test_multiheadattention(model, query_data, key_value_data): + config = hls4ml.utils.config_from_keras_model(model, granularity='name', backend='Vitis') + output_dir = str(test_root_path / 'hls4mlprj_multiheadattention_Vitis_io_parallel') + hls_model = hls4ml.converters.convert_from_keras_model( + model, backend='Vitis', hls_config=config, io_type='io_parallel', output_dir=output_dir + ) + hls_model.compile() + + # Predict + y_keras = model.predict([query_data, key_value_data]).flatten() + y_hls = hls_model.predict([query_data, key_value_data]).flatten() + np.testing.assert_allclose(y_keras, y_hls, rtol=0, atol=atol, verbose=True) diff --git a/test/pytest/test_multiheadattention_pytorch.py b/test/pytest/test_multiheadattention_pytorch.py new file mode 100644 index 0000000000..a89ea5ac5a --- /dev/null +++ b/test/pytest/test_multiheadattention_pytorch.py @@ -0,0 +1,67 @@ +from pathlib import Path + +import numpy as np +import pytest +import torch +from torch import nn + +import hls4ml + +test_root_path = Path(__file__).parent + +batch_size = 100 +seq_len = 10 +num_heads = 2 +embed_dim = 8 + +atol = 2e-2 + + +@pytest.fixture(scope='module') +def query_data(): + return np.random.rand(batch_size, seq_len, embed_dim) + + +@pytest.fixture(scope='module') +def key_value_data(): + return np.random.rand(batch_size, seq_len, embed_dim) + + +class MultiHeadAttentionModel(nn.Module): + def __init__(self): + super().__init__() + self.mha = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True) + + def forward(self, query, key, value): + output, _ = self.mha(query, key, value) + return output + + +# Currently only Vitis in io_parallel mode is supported +def test_multiheadattention(query_data, key_value_data): + model = MultiHeadAttentionModel() + model.eval() + + config = hls4ml.utils.config_from_pytorch_model( + model, + [(seq_len, embed_dim), (seq_len, embed_dim), (seq_len, embed_dim)], + granularity='name', + backend='Vitis', + channels_last_conversion='off', + transpose_outputs=False, + ) + output_dir = str(test_root_path / 'hls4mlprj_multiheadattention_pytorch_Vitis_io_parallel') + hls_model = hls4ml.converters.convert_from_pytorch_model( + model, backend='Vitis', hls_config=config, io_type='io_parallel', output_dir=output_dir + ) + hls_model.compile() + + # Predict + y_pytorch = ( + model(torch.Tensor(query_data), torch.Tensor(key_value_data), torch.Tensor(key_value_data)) + .detach() + .numpy() + .flatten() + ) + y_hls = hls_model.predict([query_data, key_value_data, key_value_data]).flatten() + np.testing.assert_allclose(y_pytorch, y_hls, rtol=0, atol=atol, verbose=True)