-
Notifications
You must be signed in to change notification settings - Fork 488
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add BetterTransformer
integration for Detr
#1065
base: main
Are you sure you want to change the base?
Conversation
e06cdbe
to
24209cb
Compare
Requesting review. @younesbelkada @fxmarty |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you very much @awinml for your hard work!
Another test is failing:
FAILED bettertransformer/test_common.py::BetterTransformerIntegrationTests::test_dict_class_consistency - AttributeError: 'tuple' object has no attribute 'keys'
FAILED bettertransformer/test_common.py::BetterTransformerIntegrationTests::test_raise_activation_fun_10_detr - AttributeError: 'tuple' object has no attribute 'keys'
Could you please double check? 🙏 After that we can merge I think :D
@younesbelkada The following test to keeps failing. Any suggestions on how to fix it? ============================================= FAILURES =============================================
________________ BetterTransformerIntegrationTests.test_raise_activation_fun_10_detr ________________
a = (<test_common.BetterTransformerIntegrationTests testMethod=test_raise_activation_fun_10_detr>,), kw = {}
@wraps(func)
def standalone_func(*a, **kw):
> return func(*(a + p.args), **p.kwargs, **kw)
../../../../environments/env12-opt/lib/python3.8/site-packages/parameterized/parameterized.py:620:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ __ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
tests/bettertransformer/test_common.py:161: in test_raise_activation_fun
self.assertTrue("Activation function" in str(cm.exception))
E AssertionError: False is not true
====================================== short test summary info ======================================
FAILED tests/bettertransformer/test_common.py::BetterTransformerIntegrationTests::test_raise_activation_fun_10_detr - AssertionError: False is not true |
hidden_states = torch._transformer_encoder_layer_fwd( | ||
hidden_states, | ||
self.embed_dim, | ||
self.num_heads, | ||
self.in_proj_weight, | ||
self.in_proj_bias, | ||
self.out_proj_weight, | ||
self.out_proj_bias, | ||
self.norm_first, | ||
self.norm1_eps, | ||
self.norm1_weight, | ||
self.norm1_bias, | ||
self.norm2_eps, | ||
self.norm2_weight, | ||
self.norm2_bias, | ||
self.linear1_weight, | ||
self.linear1_bias, | ||
self.linear2_weight, | ||
self.linear2_bias, | ||
attention_mask, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@awinml you forgot self.use_gelu
, I guess that's why your test is failing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@IlyasMoutawwakil I ran the tests after adding self.use_gelu
. The above test (#1065 (comment)) still fails.
self.use_gelu
was removed because I don't think it is relevant for this model. Please correct me if I am wrong.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the order in torch._transformer_encoder_layer_fwd
is very important because there are no default values (almost).
use_gelu
tells the cpp kernel whether to use gelu or relu (true, false). any other activation function should raise an error here:
https://github.com/huggingface/optimum/blob/eb05de1f4abe34ebc23c638780934ebd2fe3bb02/optimum/bettertransformer/models/base.py#LL116-L119
but apparently it doesn't in your case and that's why the test doesn't pass (it's a sanity check).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@IlyasMoutawwakil Makes sense! I updated the order to match the link. The test still seems to fail. Can you have a look?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@younesbelkada Any suggestions for fixing the failing test? #1065 (comment)
@IlyasMoutawwakil The CI tests fail giving the following error: FAILED bettertransformer/test_common.py::BetterTransformerIntegrationTests::test_raise_activation_fun_10_detr - ImportError:
DetrConvEncoder requires the timm library but it was not found in your environment. You can install it with pip:
`pip install timm`. Please note that you may need to restart your runtime after installation. I have Am I missing something here? |
3010883
to
a25e2d5
Compare
@fxmarty I have added |
What does this PR do?
Add
BetterTransformer
integration forDetr
.Fixes huggingface/transformers#20372
Fixes #488
Closes #684
Closes #1022
This PR completes the stalled PR #684.
Who can review?
@younesbelkada @fxmarty