Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add generation caching in TextEnvironment and fix bugs in TextEnvironment #2556

Open
wants to merge 32 commits into
base: main
Choose a base branch
from

Conversation

konrad-gerlach
Copy link

This PR mainly affects the TextEnvironment class and adds caching in between generation calls, in order to not have to recompute all previous activations when generating the next segment. This is mainly intended for use cases where many tool calls are performed sequentially and thus the activations for the (possibly quite large) system prompt would have to be calculated at each step. For stability, caching is optional.

Bug fixes:
This issue also addresses two bugs I encountered:

  1. max_length checking in TextEnvironment class threw an error, as it assumed batching was present, when no batching existed.
    I fixed the bug and also added a check at generation time to ensure, that the padded inputs also do not exceed max length.
  2. The StringStoppingCriteria did not take generated eos tokens into account, which I have now fixed.

RE testing:
I only made sure, that the tests in tests/test_environments.py were completing.
Using make test some tests were failing and the tests were taking a long time to run. However, the only tests, which call TextEnvironment seem to be in test_environments.py, so the rest should be unaffected as far as I know. Nevertheless, I would be grateful, if somebody else could run all the tests before merging. I suspect, that my environment may not be ideally configured. Is testing automated via a CI?

@konrad-gerlach konrad-gerlach marked this pull request as draft January 10, 2025 15:16
@konrad-gerlach konrad-gerlach marked this pull request as ready for review January 10, 2025 15:57
@konrad-gerlach
Copy link
Author

I would be very grateful for a review by:
@lvwerra
@vwxyzjn
@younesbelkada
@qgallouedec
or any others, that feel up to the task.

@konrad-gerlach konrad-gerlach force-pushed the text_environment_caching branch from 6a87c8d to 3f57ee9 Compare January 10, 2025 16:27
@konrad-gerlach konrad-gerlach force-pushed the text_environment_caching branch from 3f57ee9 to ede7e81 Compare January 10, 2025 16:33
@konrad-gerlach
Copy link
Author

I was unable to execute the pre-commit hook, so I manually ran the linter.

@konrad-gerlach konrad-gerlach marked this pull request as draft January 11, 2025 10:47
@konrad-gerlach konrad-gerlach marked this pull request as ready for review January 11, 2025 12:58
@qgallouedec
Copy link
Member

Thanks for the PR!
Let's see what's the CI outputs.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@konrad-gerlach
Copy link
Author

konrad-gerlach commented Jan 12, 2025

Just to be sure, as I'm unfamiliar with their implementation: The trl Trainers like PPO should not try to back propagate through the generated tokens, right?

@konrad-gerlach
Copy link
Author

The CI failing for Python 3.9 seems unrelated to this PR.

@qgallouedec
Copy link
Member

The trl Trainers like PPO should not try to back propagate through the generated tokens, right?

Yes that's correct. The backprop is done on the output of a forward pass

@konrad-gerlach
Copy link
Author

@qgallouedec Could you run the precommit to fix the linting issues? I haven't gotten it to work.

@konrad-gerlach
Copy link
Author

I'm still working on adding some more tests and cleaning up the code a bit.

@qgallouedec
Copy link
Member

Ok, ping me when it's ready, I'll run the precommits and merge

@konrad-gerlach
Copy link
Author

For future reference: Setting position_ids for generation did not appear necessary, as it seems to be handled here: https://github.com/huggingface/transformers/blob/241c04d36867259cdf11dbb4e9d9a60f9cb65ebc/src/transformers/generation/utils.py#L409-L416

@konrad-gerlach
Copy link
Author

cache_position also appears to be handled automatically here https://github.com/huggingface/transformers/blob/241c04d36867259cdf11dbb4e9d9a60f9cb65ebc/src/transformers/generation/utils.py#L799-L806 and here https://github.com/huggingface/transformers/blob/241c04d36867259cdf11dbb4e9d9a60f9cb65ebc/src/transformers/generation/utils.py#L1560

Looking at the transformers generation code, it seems like there are currently issues with torch compile in transformers generate (mentioned in comment), see https://github.com/huggingface/transformers/blob/2e752ead46a8845e8a160d2043c1336447895690/src/transformers/generation/utils.py#L1582 -> I think, I will include a warning in the docs.

@konrad-gerlach
Copy link
Author

It seems the model used for testing was not in fact GPT2 and the variable name was incorrect. Documentation will be updated accordingly.

@konrad-gerlach
Copy link
Author

In the previous commits, I also fixed what I believe to be an off-by-one error in the StringStoppingCriteria: an inaccurate number for generated_tokens.

@konrad-gerlach
Copy link
Author

I noticed, that GPT2 seems to only support the legacy cache format, so I am adding support for this,

@konrad-gerlach
Copy link
Author

In the previous commits, I also fixed what I believe to be an off-by-one error in the StringStoppingCriteria: an inaccurate number for generated_tokens.

It appears, that it was still not fixed. Working on a solution and on testing StringStoppingCriteria

@konrad-gerlach
Copy link
Author

As the tests did not include a encoder-decoder architecture, I did not test for it either. I think, that this is out of scope for this Pull Request. Where this was of concern in _generate_batched, I mirrored the implementation already provided.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants