Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BT] add BetterTransformer support for ProphetNet #923

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/bettertransformer/overview.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ The list of supported model below:
- [Marian](https://arxiv.org/abs/1804.00344)
- [MBart](https://arxiv.org/abs/2001.08210)
- [M2M100](https://arxiv.org/abs/2010.11125)
- [ProphetNet](https://arxiv.org/abs/2001.04063)
- [RemBERT](https://arxiv.org/abs/2010.12821)
- [RoBERTa](https://arxiv.org/abs/1907.11692)
- [RoCBert](https://aclanthology.org/2022.acl-long.65.pdf)
Expand Down
2 changes: 2 additions & 0 deletions optimum/bettertransformer/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
DistilBertLayerBetterTransformer,
FSMTEncoderLayerBetterTransformer,
MBartEncoderLayerBetterTransformer,
ProphetNetEncoderLayerBetterTransformer,
ViltLayerBetterTransformer,
ViTLayerBetterTransformer,
Wav2Vec2EncoderLayerBetterTransformer,
Expand Down Expand Up @@ -74,6 +75,7 @@ class BetterTransformerManager:
"opt": {"OPTAttention": OPTAttentionLayerBetterTransformer},
"pegasus": {"PegasusAttention": BartAttentionLayerBetterTransformer},
"rembert": {"RemBertLayer": BertLayerBetterTransformer},
"prophetnet": {"ProphetNetEncoderLayer": ProphetNetEncoderLayerBetterTransformer},
"roberta": {"RobertaLayer": BertLayerBetterTransformer},
"roc_bert": {"RoCBertLayer": BertLayerBetterTransformer},
"roformer": {"RoFormerLayer": BertLayerBetterTransformer},
Expand Down
133 changes: 133 additions & 0 deletions optimum/bettertransformer/models/encoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1275,6 +1275,139 @@ def forward(self, hidden_states, attention_mask, position_bias=None, *_, **__):
return (hidden_states, attention_mask)


class ProphetNetEncoderLayerBetterTransformer(BetterTransformerBaseLayer):
def __init__(self, prophetnet_layer, config):
r"""
A simple conversion of the ProphetNet Encoder layer to its `BetterTransformer` implementation.

Args:
prophet_net_layer (`torch.nn.Module`):
The original ProphetNet Layer where the weights needs to be retrieved.
"""
super().__init__(config)
self.config = config
# In_proj layer
self.in_proj_weight = nn.Parameter(
torch.cat(
[
prophetnet_layer.self_attn.query_proj.weight,
prophetnet_layer.self_attn.key_proj.weight,
prophetnet_layer.self_attn.value_proj.weight,
]
)
)
self.in_proj_bias = nn.Parameter(
torch.cat(
[
prophetnet_layer.self_attn.query_proj.bias,
prophetnet_layer.self_attn.key_proj.bias,
prophetnet_layer.self_attn.value_proj.bias,
]
)
)

# Out proj layer
self.out_proj_weight = prophetnet_layer.self_attn.out_proj.weight
self.out_proj_bias = prophetnet_layer.self_attn.out_proj.bias

# Linear layer 1
self.linear1_weight = prophetnet_layer.feed_forward.intermediate.weight
self.linear1_bias = prophetnet_layer.feed_forward.intermediate.bias

# Linear layer 2
self.linear2_weight = prophetnet_layer.feed_forward.output.weight
self.linear2_bias = prophetnet_layer.feed_forward.output.bias

# Layer norm 1
self.norm1_eps = prophetnet_layer.self_attn_layer_norm.eps
self.norm1_weight = prophetnet_layer.self_attn_layer_norm.weight
self.norm1_bias = prophetnet_layer.self_attn_layer_norm.bias

# Layer norm 2
self.norm2_eps = prophetnet_layer.feed_forward_layer_norm.eps
self.norm2_weight = prophetnet_layer.feed_forward_layer_norm.weight
self.norm2_bias = prophetnet_layer.feed_forward_layer_norm.bias

# Model hyper parameters
self.num_heads = prophetnet_layer.self_attn.num_attn_heads
self.embed_dim = prophetnet_layer.self_attn.head_dim * self.num_heads

# Last step: set the last layer to `False` -> this will be set to `True` when converting the model
self.is_last_layer = False

self.original_layers_mapping = {
"in_proj_weight": [
"self_attn.query_proj.weight",
"self_attn.key_proj.weight",
"self_attn.value_proj.weight",
],
"in_proj_bias": ["self_attn.query_proj.bias", "self_attn.key_proj.bias", "self_attn.value_proj.bias"],
"out_proj_weight": "self_attn.out_proj.weight",
"out_proj_bias": "self_attn.out_proj.bias",
"linear1_weight": "feed_forward.intermediate.weight",
"linear1_bias": "feed_forward.intermediate.bias",
"linear2_weight": "feed_forward.output.weight",
"linear2_bias": "feed_forward.output.bias",
"norm1_weight": "self_attn_layer_norm.weight",
"norm1_bias": "self_attn_layer_norm.bias",
"norm2_weight": "feed_forward_layer_norm.weight",
"norm2_bias": "feed_forward_layer_norm.bias",
}

self.validate_bettertransformer()

def forward(self, hidden_states, attention_mask, *_, **__):
r"""
This is just a wrapper around the forward function proposed in:
https://github.com/huggingface/transformers/pull/19553
"""
super().forward_checker()

if not hasattr(hidden_states, "original_shape"):
original_shape = hidden_states.shape
else:
original_shape = hidden_states.original_shape

if hidden_states.is_nested:
attention_mask = None

if attention_mask is not None:
# attention mask comes in with values 0 and -inf. we convert to torch.nn.TransformerEncoder style bool mask
# 0->false->keep this token -inf->true->mask this token
attention_mask = attention_mask.squeeze(1)[:, 0]
attention_mask = attention_mask.bool()
attention_mask = torch.reshape(attention_mask, (attention_mask.shape[0], attention_mask.shape[-1]))
hidden_states = torch._nested_tensor_from_mask(hidden_states, ~attention_mask)
attention_mask = None

hidden_states = torch._transformer_encoder_layer_fwd(
hidden_states,
self.embed_dim,
self.num_heads,
self.in_proj_weight,
self.in_proj_bias,
self.out_proj_weight,
self.out_proj_bias,
self.use_gelu,
self.norm_first,
self.norm1_eps,
self.norm1_weight,
self.norm1_bias,
self.norm2_weight,
self.norm2_bias,
self.linear1_weight,
self.linear1_bias,
self.linear2_weight,
self.linear2_bias,
attention_mask,
)
if not self.is_last_layer:
hidden_states.original_shape = original_shape
elif hidden_states.is_nested and self.is_last_layer:
hidden_states = hidden_states.to_padded_tensor(0.0, original_shape)
return (hidden_states,)


class CLIPLayerBetterTransformer(BetterTransformerBaseLayer):
def __init__(self, layer, config):
r"""
Expand Down
4 changes: 2 additions & 2 deletions optimum/utils/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@
# limitations under the License.


import collections
import importlib.util
import itertools
import os
import subprocess
import sys
import unittest
from collections.abc import MutableMapping
from typing import Any, Callable, Dict, Iterable, Optional, Tuple

import torch
Expand All @@ -35,7 +35,7 @@ def flatten_dict(dictionary: Dict):
items = []
for k, v in dictionary.items():
new_key = k
if isinstance(v, collections.MutableMapping):
if isinstance(v, MutableMapping):
items.extend(flatten_dict(v).items())
else:
items.append((new_key, v))
Expand Down
1 change: 1 addition & 0 deletions tests/bettertransformer/test_encoder_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class BetterTransformersEncoderDecoderTest(BetterTransformersTestMixin, unittest
"marian",
"mbart",
"pegasus",
"prophetnet",
"t5",
]

Expand Down
1 change: 1 addition & 0 deletions tests/bettertransformer/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
"mbart": "hf-internal-testing/tiny-random-mbart",
"opt": "hf-internal-testing/tiny-random-OPTModel",
"pegasus": "hf-internal-testing/tiny-random-PegasusModel",
"prophetnet": "hirotasoshu/tiny-random-prophetnet", # the other tiny ones have a too small max_position_embeddings
"rembert": "hf-internal-testing/tiny-random-rembert",
"roberta": "hf-internal-testing/tiny-random-RobertaModel",
"rocbert": "hf-internal-testing/tiny-random-RoCBertModel",
Expand Down