Skip to content

Add Multi-Head Attention support for Vitis #1163

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 68 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
68 commits
Select commit Hold shift + click to select a range
c4c818b
paser_mht
Ethan0Jiang Jul 13, 2022
3ee64d1
change parser and modify keras_to_hls
Ethan0Jiang Jul 13, 2022
5626a1a
IR_mutihead_attention
Ethan0Jiang Jul 14, 2022
d51f8a9
IR done
Ethan0Jiang Jul 15, 2022
89025a2
create mha file in template
Ethan0Jiang Jul 19, 2022
d76cf60
mha .h file dummy algo
Ethan0Jiang Jul 19, 2022
56811de
config of mha
Ethan0Jiang Jul 21, 2022
45cd493
update mha config
Ethan0Jiang Jul 21, 2022
1402f48
dummy mha
Ethan0Jiang Jul 21, 2022
430b9ea
add transpose into mha
Ethan0Jiang Jul 23, 2022
97f3e8d
projection_of_qkv_in_mha
Ethan0Jiang Jul 27, 2022
52cc7e8
mha_first_draft
Ethan0Jiang Aug 4, 2022
3961f97
able to predict model correct
Ethan0Jiang Aug 11, 2022
3533999
delete some unnassary comments
Ethan0Jiang Aug 11, 2022
d2f0df6
delete comments
Ethan0Jiang Aug 11, 2022
6aaa5ed
resource strategy of transformer
Ethan0Jiang Sep 16, 2022
3b7a288
change sm lagacy
Ethan0Jiang Oct 1, 2022
130092d
update MHA, optimized
Ethan0Jiang Oct 12, 2022
09b0ba0
support resource
Ethan0Jiang Oct 23, 2022
b49fffd
update
Ethan0Jiang Nov 27, 2022
5324a11
dense_muti_dim_support
Ethan0Jiang Dec 30, 2022
bf8c788
parallel execute dense
Ethan0Jiang Jan 1, 2023
b6be2c4
updates
Ethan0Jiang Jan 27, 2023
2472b7d
add_layerNorm_support
Ethan0Jiang Feb 15, 2023
97e71e9
MHA updated
Ethan0Jiang Feb 27, 2023
5ed4a76
LayerNorm_bug_fix
Ethan0Jiang Apr 4, 2023
5d28f58
update bit precision
Ethan0Jiang Apr 15, 2023
2fc68d0
config update
Ethan0Jiang Apr 17, 2023
b5c95cf
add some comment
Ethan0Jiang Apr 21, 2023
3b8aa8d
run pre-commit
JanFSchulte Sep 13, 2024
d28b24c
Added support on QMultiHeadAttention, QLayerNormalization, and quanti…
LostEcho365 Aug 7, 2023
de79bb9
updated on hls4ml transformer
LostEcho365 Nov 12, 2023
6c23326
trying to clean the diff
JanFSchulte Sep 13, 2024
20a0199
trying to clean the diff
JanFSchulte Sep 13, 2024
ddccde2
trying to clean the diff
Sep 17, 2024
afbe00b
trying to clean the diff
Sep 17, 2024
dedf96c
trying to clean the diff
Sep 17, 2024
a9de9cb
undo vhdl -> verilog change
Sep 18, 2024
49313d3
halfway working layernorm + test
Sep 18, 2024
1156ba5
layernorm is now pretty functional
Sep 18, 2024
17e0048
layernorm on pytorch also
Sep 19, 2024
63891fd
minor cleanup
Sep 19, 2024
8dccac6
more cleanup, pre-commit
Sep 19, 2024
595cc71
test for mha which kinda works maybe if you squint
Sep 19, 2024
5f3ec00
multihead attention working on keras and pytorch
Sep 20, 2024
5697334
fiddly precision / accuracy changes for layernorm
Sep 25, 2024
d2e27b8
Merge remote-tracking branch 'upstream/main' into transformer
rianbrooksflynn Oct 11, 2024
a149f2e
fix lookup table and label loops
rianbrooksflynn Oct 22, 2024
552fa83
remove dense_seq
rianbrooksflynn Oct 23, 2024
69f26bc
Merge remote-tracking branch 'upstream/main' into transformer
rianbrooksflynn Oct 23, 2024
be5f5a4
undo qkeras changes
rianbrooksflynn Oct 23, 2024
adf7356
fix merge conflict residue
rianbrooksflynn Oct 24, 2024
8437581
Merge remote-tracking branch 'upstream/main' into transformer
rianbrooksflynn Nov 4, 2024
6139f54
fix multiplier_limit config for mha
JanFSchulte Nov 13, 2024
a041095
remove extraneous seq_len in dense config
rianbrooksflynn Nov 13, 2024
0c8cd71
add hls stream pragmas
JanFSchulte Nov 14, 2024
1d5f8bb
Merge branch 'transformer' of https://github.com/JanFSchulte/hls4ml i…
JanFSchulte Nov 14, 2024
8006825
fix stream size
JanFSchulte Nov 14, 2024
097b6a0
Merge remote-tracking branch 'upstream/main' into multi-head-attention
rianbrooksflynn Jan 6, 2025
0580360
remove layernorm changes
rianbrooksflynn Jan 6, 2025
a0b9390
remove softmax changes
rianbrooksflynn Jan 13, 2025
a82a6aa
Merge remote-tracking branch 'upstream/main' into multi-head-attention
rianbrooksflynn Jan 13, 2025
88beb79
port to vitis
rianbrooksflynn Jan 13, 2025
9ed1eec
[pre-commit.ci] auto fixes from pre-commit hooks
pre-commit-ci[bot] Jan 14, 2025
20f2729
delete extraneous print statement
rianbrooksflynn Jan 14, 2025
702d13a
Merge branch 'main' into multi-head-attention
rianbrooksflynn Feb 20, 2025
4c90dd4
Merge branch 'main' into multi-head-attention
JanFSchulte Mar 18, 2025
c7449bd
revert changes that break Conv1D tests
JanFSchulte Mar 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions hls4ml/backends/fpga/fpga_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
GlobalPooling2D,
MatMul,
Merge,
MultiHeadAttention,
Pooling1D,
Pooling2D,
Quant,
Expand Down Expand Up @@ -71,6 +72,7 @@ def __init__(self, name):
Dot,
Conv,
MatMul,
MultiHeadAttention,
]

for layer in accum_layers:
Expand Down
146 changes: 146 additions & 0 deletions hls4ml/backends/vitis/passes/transformer_templates.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
from hls4ml.backends.backend import get_backend
from hls4ml.backends.template import FunctionCallTemplate, LayerConfigTemplate
from hls4ml.model.layers import MultiHeadAttention

# dense layer template
mult_config_template = """struct config{index}_{mNum} : nnet::dense_config {{
static const unsigned n_in = {n_in};
static const unsigned n_out = {n_out};
static const unsigned io_type = nnet::{iotype};
static const unsigned strategy = nnet::{strategy};
static const unsigned reuse_factor = {reuse};
static const unsigned n_zeros = {nzeros};
static const unsigned n_nonzeros = {nonzeros};
static const unsigned multiplier_limit = DIV_ROUNDUP(n_in * n_out, reuse_factor) - n_zeros / reuse_factor;
static const bool store_weights_in_bram = false;
typedef {accum_t.name} accum_t;
typedef {attention_output_bias_t.name} bias_t;
typedef {attention_output_weight_t.name} weight_t;
typedef ap_{index_t} index_t;
template<class data_T, class res_T, class CONFIG_T>
using kernel = nnet::{dense_function}<data_T, res_T, CONFIG_T>;
template<class x_T, class y_T>
using product = nnet::product::{product_type}<x_T, y_T>;
}};\n"""

# activation template
softmax_config_template = """struct {type}_config{index} : nnet::activ_config {{
static const unsigned n_in = {n_in};
static const unsigned table_size = {table_size};
static const unsigned io_type = nnet::{iotype};
static const unsigned reuse_factor = {reuse};
static const nnet::softmax_implementation implementation = nnet::softmax_implementation::{implementation};
typedef {table_t.name} exp_table_t;
typedef {table_t.name} inv_table_t;
}};\n"""

mha_config_template = """struct config{index} : nnet::multiheadattention_config {{
typedef {accum_t.name} accum_t;
typedef {attention_output_bias_t.name} bias_t;
typedef {attention_output_weight_t.name} weight_t;
typedef {config_mult_t1} config_mult1;
typedef {config_mult_t2} config_mult2;
typedef {config_activ_t1} softmax_config1;

static const unsigned num_heads = {num_heads};
static const unsigned head_dim_key = {head_dim_key};
static const unsigned head_dim_value = {head_dim_value};
static const unsigned feature_dim = {feature_dim};
static const unsigned seq_len = {seq_len};

static const unsigned io_type = nnet::{iotype};
static const unsigned reuse_factor = {reuse};
static const bool store_weights_in_bram = false;
}};\n"""

mha_function_template = """nnet::multiheadattention<{input_t}, {output_t}, {config}>({input_q}, {input_kv},
{output}, {w_o}, {b_o}, {w_k}, {b_k}, {w_q}, {b_q}, {w_v}, {b_v});"""

mha_include_list = ['nnet_utils/nnet_multiheadattention.h']


class MhaConfigTemplate(LayerConfigTemplate):
def __init__(self):
super().__init__(MultiHeadAttention)
self.template = mha_config_template
self.mult1_template = mult_config_template
self.mult2_template = mult_config_template
self.activ1_template = softmax_config_template

def format(self, node):
params = self._default_config_params(node)
params['num_heads'] = node.get_attr('num_heads')
params['head_dim_key'] = node.get_attr('head_dim_key')
params['head_dim_value'] = node.get_attr('head_dim_value')
params['feature_dim'] = node.get_attr('feature_dim')
params['seq_len'] = node.get_attr('seq_len')
params['config_mult_t1'] = f'config{node.index}_1'
params['config_mult_t2'] = f'config{node.index}_2'
params['config_activ_t1'] = '{}_config{}'.format("softmax", node.index)
params['strategy'] = node.get_attr('strategy')
mha_config = self.template.format(**params)

mult_params1 = self._default_config_params(node)
mult_params1['strategy'] = 'latency'
mult_params1['mNum'] = '1'
mult_params1['n_in'] = node.get_attr('feature_dim')
mult_params1['n_out'] = node.get_attr('head_dim_key')
mult_params1['product_type'] = get_backend('vivado').product_type(
node.get_input_variable().type.precision, node.get_weights('query_weight').type.precision
)
mult_params1['reuse'] = params['reuse']
mult_params1['index'] = str(node.index)
mult_params1['nzeros'] = 0
mult_params1['nonzeros'] = params['feature_dim'] * params['num_heads'] * params['head_dim_key']
mult_params1['dense_function'] = 'DenseLatency'
mult_config1 = self.mult1_template.format(**mult_params1)

mult_params2 = self._default_config_params(node)
mult_params2['strategy'] = 'latency'
mult_params2['mNum'] = '2'
mult_params2['n_in'] = node.get_attr('head_dim_value') * node.get_attr('num_heads')
mult_params2['n_out'] = node.get_attr('feature_dim')
mult_params2['product_type'] = get_backend('vivado').product_type(
node.get_input_variable().type.precision, node.get_weights('attention_output_weight').type.precision
)
mult_params2['reuse'] = params['reuse']
mult_params2['index'] = str(node.index)
mult_params2['nzeros'] = 0
mult_params2['nonzeros'] = params['feature_dim'] * params['num_heads'] * params['head_dim_key']
mult_params2['dense_function'] = 'DenseLatency'
mult_config2 = self.mult2_template.format(**mult_params2)

act_params = self._default_config_params(node)
act_params['n_in'] = node.get_attr('seq_len')
act_params['type'] = 'softmax'
act_params['implementation'] = 'legacy' # in MHA: latency,stable not work, legacy works
act_config = self.activ1_template.format(**act_params)

return mult_config1 + '\n' + mult_config2 + '\n' + act_config + '\n' + mha_config


class MhaFunctionTemplate(FunctionCallTemplate):
def __init__(self):
super().__init__(MultiHeadAttention, include_header=mha_include_list)
self.template = mha_function_template

def format(self, node):
params = {}
params.update(node.attributes)
params['config'] = f'config{node.index}'
params['input_t'] = node.get_input_variable().type.name
params['output_t'] = node.get_output_variable().type.name

params['input_q'] = node.model.get_layer_output_variable(node.inputs[0]).name
params['input_kv'] = node.model.get_layer_output_variable(node.inputs[1]).name
params['output'] = node.get_output_variable().name
params['w_o'] = node.get_weights('attention_output_weight').name
params['b_o'] = node.get_weights('attention_output_bias').name
params['w_k'] = node.get_weights('key_weight').name
params['b_k'] = node.get_weights('key_bias').name
params['w_q'] = node.get_weights('query_weight').name
params['b_q'] = node.get_weights('query_bias').name
params['w_v'] = node.get_weights('value_weight').name
params['b_v'] = node.get_weights('value_bias').name

return self.template.format(**params)
28 changes: 28 additions & 0 deletions hls4ml/backends/vitis/vitis_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@

from hls4ml.backends import VivadoBackend
from hls4ml.model.flow import get_flow, register_flow
from hls4ml.model.layers import MultiHeadAttention
from hls4ml.model.optimizer import layer_optimizer
from hls4ml.model.types import FixedPrecisionType, IntegerPrecisionType, NamedType
from hls4ml.report import parse_vivado_report


Expand All @@ -13,6 +16,9 @@ def __init__(self):
self._register_flows()

def _register_flows(self):
initializers = self._get_layer_initializers()
init_flow = register_flow('init_layers', initializers, requires=['optimize'], backend=self.name)

validation_passes = [
'vitis:validate_conv_implementation',
'vitis:validate_resource_strategy',
Expand All @@ -30,6 +36,7 @@ def _register_flows(self):

ip_flow_requirements = get_flow('vivado:ip').requires.copy()
ip_flow_requirements.insert(ip_flow_requirements.index('vivado:init_layers'), validation_flow)
ip_flow_requirements.insert(ip_flow_requirements.index('vivado:streaming'), init_flow)
ip_flow_requirements.insert(ip_flow_requirements.index('vivado:apply_templates'), template_flow)

self._default_flow = register_flow('ip', None, requires=ip_flow_requirements, backend=self.name)
Expand Down Expand Up @@ -120,3 +127,24 @@ def build(
os.chdir(curr_dir)

return parse_vivado_report(model.config.get_output_dir())

@layer_optimizer(MultiHeadAttention)
def init_mha(self, layer):
# TODO Allow getting recurrent reuse factor from the config
reuse_factor = layer.model.config.get_reuse_factor(layer)
layer.set_attr('reuse_factor', reuse_factor)
index_t = IntegerPrecisionType(width=1, signed=False)
layer.set_attr('index_t', index_t)
if 'table_t' not in layer.attributes:
layer.set_attr(
'table_t', NamedType(name=layer.name + '_table_t', precision=FixedPrecisionType(width=24, integer=8))
)
if 'table_size' not in layer.attributes:
layer.set_attr('table_size', 2048)
if 'accum_t' not in layer.attributes:
layer.set_attr('accum_t', FixedPrecisionType(width=24, integer=8))
if 'inv_range' not in layer.attributes:
layer.set_attr('inv_range', 128)
if 'exp_range' not in layer.attributes:
layer.set_attr('exp_range', 8)
layer.set_attr('strategy', 'resource') # latency
60 changes: 60 additions & 0 deletions hls4ml/converters/keras/multiheadattention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from hls4ml.converters.keras_to_hls import get_weights_data, keras_handler, parse_default_keras_layer


@keras_handler('MultiHeadAttention')
def parse_mutiheadattention_layer(keras_layer, input_names, input_shapes, data_reader):
# assume input_shapes is: [[None, seq, dim]]
assert 'MultiHeadAttention' in keras_layer['class_name']
assert input_shapes[0] == keras_layer['config']['query_shape']

layer = parse_default_keras_layer(keras_layer, input_names)

layer['num_heads'] = keras_layer['config']['num_heads']
layer['head_dim_key'] = keras_layer['config']['key_dim']
layer['head_dim_value'] = keras_layer['config']['value_dim']
layer['query_shape'] = keras_layer['config']['query_shape']
layer['key_shape'] = keras_layer['config']['key_shape']
layer['value_shape'] = keras_layer['config']['value_shape']
layer['feature_dim'] = layer['query_shape'][-1]
layer['seq_len'] = layer['query_shape'][-2]

if keras_layer['config']['output_shape']:
raise Exception('hls4ml does not support a defined output shape, the output shape must equal to the query shape')
else:
output_shape = layer['query_shape']

layer['attention_axes'] = (
keras_layer['config']['attention_axes'] if (keras_layer['config']['attention_axes'][0] == 1) else False
)
if layer['attention_axes'] is False:
raise Exception('assigning the attention_axes is not currently supported by hls4ml')

if not (len(layer['query_shape']) == 3 and len(layer['key_shape']) == 3 and len(layer['value_shape']) == 3):
raise Exception('only 3D shapes for query, key, and value are currently supported by hls4ml')

attn_scores_rank = 4
layer['softmax_axis'] = list(range(attn_scores_rank - len(layer['attention_axes']), attn_scores_rank))

weights_sources = [
('attention_output', 'kernel'),
('attention_output', 'bias'),
('key', 'kernel'),
('key', 'bias'),
('query', 'kernel'),
('query', 'bias'),
('value', 'kernel'),
('value', 'bias'),
]

for lname, wtype in weights_sources:
data = get_weights_data(data_reader, layer['name'], f'{lname}/{wtype}')
if wtype == 'kernel':
vtype = 'weight'
if lname in ['key', 'query', 'value']:
data = data.transpose((1, 0, 2))
else:
vtype = 'bias'

layer[f'{lname}_{vtype}_data'] = data

return layer, output_shape
4 changes: 4 additions & 0 deletions hls4ml/converters/keras_to_hls.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,10 @@ def parse_keras_model(model_arch, reader):
# Extract inbound nodes
if 'inbound_nodes' in keras_layer and len(keras_layer['inbound_nodes']) > 0:
input_names = [inputs_map.get(inp[0], inp[0]) for inp in keras_layer['inbound_nodes'][0]]
if keras_layer['inbound_nodes'][0][0][-1]:
# multi_head_attention has inbound: [[['input_3', 0, 0, {'value': ['dense_3', 0, 0]}]]]
inputname2 = list(keras_layer['inbound_nodes'][0][0][-1].values())
input_names += [inp[0] for inp in inputname2]
else:
input_names = None

Expand Down
54 changes: 54 additions & 0 deletions hls4ml/converters/pytorch/multiheadattention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import numpy as np

from hls4ml.converters.pytorch_to_hls import pytorch_handler


@pytorch_handler('MultiheadAttention')
def parse_multiheadattention_layer(
operation, layer_name, input_names, input_shapes, node, class_object, data_reader, config
):
assert 'MultiheadAttention' in operation
assert len(input_shapes) == 3

layer = {}

layer['class_name'] = 'MultiHeadAttention'
layer['name'] = layer_name
layer['inputs'] = input_names

layer['num_heads'] = class_object.num_heads
layer['head_dim_key'] = class_object.kdim // layer['num_heads']
layer['head_dim_value'] = class_object.vdim // layer['num_heads']
layer['query_shape'] = input_shapes[0]
layer['key_shape'] = input_shapes[1]
layer['value_shape'] = input_shapes[2]

if not (len(layer['query_shape']) == len(layer['key_shape']) == len(layer['value_shape']) == 3):
raise Exception('only 3D shapes for query, key, and value are currently supported by hls4ml')

layer['feature_dim'] = class_object.embed_dim
layer['seq_len'] = layer['query_shape'][-2]

output_shape = layer['query_shape']

layer['attention_axes'] = [1]
layer['softmax_axis'] = [3]

in_proj_weights = class_object.in_proj_weight.data.numpy()
in_proj_bias = class_object.in_proj_bias.data.numpy()

weight_data = np.split(in_proj_weights, [class_object.embed_dim, class_object.embed_dim + class_object.kdim], axis=0)
bias_data = np.split(in_proj_bias, [class_object.embed_dim, class_object.embed_dim + class_object.kdim], axis=0)

for weight_type, weight, bias in zip(['query', 'key', 'value'], weight_data, bias_data):
layer[f'{weight_type}_weight_data'] = weight.T.reshape(
layer['feature_dim'], layer['num_heads'], layer['head_dim_key']
).transpose(1, 0, 2)
layer[f'{weight_type}_bias_data'] = bias.reshape(layer['num_heads'], layer['head_dim_key'])

layer['attention_output_weight_data'] = class_object.out_proj.weight.data.numpy().T.reshape(
layer['num_heads'], layer['head_dim_key'], layer['feature_dim']
)
layer['attention_output_bias_data'] = class_object.out_proj.bias.data.numpy()

return layer, output_shape
Loading
Loading