-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add generation caching in TextEnvironment and fix bugs in TextEnvironment #2556
base: main
Are you sure you want to change the base?
Add generation caching in TextEnvironment and fix bugs in TextEnvironment #2556
Conversation
I would be very grateful for a review by: |
6a87c8d
to
3f57ee9
Compare
3f57ee9
to
ede7e81
Compare
I was unable to execute the pre-commit hook, so I manually ran the linter. |
Thanks for the PR! |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Just to be sure, as I'm unfamiliar with their implementation: The trl Trainers like PPO should not try to back propagate through the generated tokens, right? |
Co-authored-by: Quentin Gallouédec <[email protected]>
Co-authored-by: Quentin Gallouédec <[email protected]>
The CI failing for Python 3.9 seems unrelated to this PR. |
Yes that's correct. The backprop is done on the output of a forward pass |
@qgallouedec Could you run the precommit to fix the linting issues? I haven't gotten it to work. |
I'm still working on adding some more tests and cleaning up the code a bit. |
Ok, ping me when it's ready, I'll run the precommits and merge |
For future reference: Setting position_ids for generation did not appear necessary, as it seems to be handled here: https://github.com/huggingface/transformers/blob/241c04d36867259cdf11dbb4e9d9a60f9cb65ebc/src/transformers/generation/utils.py#L409-L416 |
cache_position also appears to be handled automatically here https://github.com/huggingface/transformers/blob/241c04d36867259cdf11dbb4e9d9a60f9cb65ebc/src/transformers/generation/utils.py#L799-L806 and here https://github.com/huggingface/transformers/blob/241c04d36867259cdf11dbb4e9d9a60f9cb65ebc/src/transformers/generation/utils.py#L1560 Looking at the transformers generation code, it seems like there are currently issues with torch compile in transformers generate (mentioned in comment), see https://github.com/huggingface/transformers/blob/2e752ead46a8845e8a160d2043c1336447895690/src/transformers/generation/utils.py#L1582 -> I think, I will include a warning in the docs. |
It seems the model used for testing was not in fact GPT2 and the variable name was incorrect. Documentation will be updated accordingly. |
In the previous commits, I also fixed what I believe to be an off-by-one error in the StringStoppingCriteria: an inaccurate number for generated_tokens. |
I noticed, that GPT2 seems to only support the legacy cache format, so I am adding support for this, |
It appears, that it was still not fixed. Working on a solution and on testing StringStoppingCriteria |
As the tests did not include a encoder-decoder architecture, I did not test for it either. I think, that this is out of scope for this Pull Request. Where this was of concern in _generate_batched, I mirrored the implementation already provided. |
This PR mainly affects the TextEnvironment class and adds caching in between generation calls, in order to not have to recompute all previous activations when generating the next segment. This is mainly intended for use cases where many tool calls are performed sequentially and thus the activations for the (possibly quite large) system prompt would have to be calculated at each step. For stability, caching is optional.
Bug fixes:
This issue also addresses two bugs I encountered:
I fixed the bug and also added a check at generation time to ensure, that the padded inputs also do not exceed max length.
RE testing:
I only made sure, that the tests in tests/test_environments.py were completing.
Using
make test
some tests were failing and the tests were taking a long time to run. However, the only tests, which call TextEnvironment seem to be in test_environments.py, so the rest should be unaffected as far as I know. Nevertheless, I would be grateful, if somebody else could run all the tests before merging. I suspect, that my environment may not be ideally configured. Is testing automated via a CI?