Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add generation caching in TextEnvironment and fix bugs in TextEnvironment #2556

Open
wants to merge 32 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
ab86162
feat: add caching for TextEnvironment and fix bugs
Jan 10, 2025
d09ec63
feat: make TextEnvironment caching optional and add documentation
Jan 10, 2025
b7885cc
fix: failing TextEnvironment tests
Jan 10, 2025
034c5f7
test: add tests for TextEnvironment caching and fix cache combining bug
Jan 10, 2025
18eb106
test: remove unnecessary parametrized class decorator
Jan 10, 2025
44fd184
docs: update TextEnvironmentDocs with caching
Jan 10, 2025
28601c2
fix: run linter on TextEnvironment and TextEnvironment tests
Jan 10, 2025
2a7ec4e
fix: comment
Jan 10, 2025
af06d63
fix: Args comment
Jan 10, 2025
f6f12b5
fix: TextEnvironment cache combination and batching issue
Jan 10, 2025
ede7e81
tests: make caching test more complex
Jan 10, 2025
acddaa7
fix: combine caches of different sequence lengths
Jan 11, 2025
e38940e
docs: update caching warning
Jan 12, 2025
66d0ce4
fix: prevent bos tokens in tool response
Jan 12, 2025
a051e46
docs: Update docs/source/text_environments.md
konrad-gerlach Jan 12, 2025
9ea9287
Update trl/environment/base_environment.py
konrad-gerlach Jan 12, 2025
ae1233a
Merge branch 'main' into text_environment_caching
konrad-gerlach Jan 12, 2025
a2860bc
fix: code cleanup
Jan 12, 2025
23014fb
fix: attended to invalid last generated token and off-by-one in Strin…
Jan 14, 2025
bdaa922
Merge branch 'main' into text_environment_caching
konrad-gerlach Jan 15, 2025
a097c5b
Merge branch 'main' into text_environment_caching
konrad-gerlach Jan 17, 2025
7324ee1
fix: off by one error in StringStoppingCriteria
Jan 21, 2025
9b6a6ec
Merge branch 'main' into text_environment_caching
konrad-gerlach Jan 21, 2025
39763b1
feat: test logits are same with and without caching
Jan 22, 2025
b70f51c
fix: model and tokenizer were called gpt2 but were another model
Jan 22, 2025
7b2169d
docs: add warning for torch.compile with TextEnvironment use_cache
Jan 22, 2025
c4b5400
Merge branch 'text_environment_caching' of https://github.com/konrad-…
Jan 22, 2025
5725b18
fix: StringStoppingCriteria and add test
Jan 23, 2025
589dcb7
refactor: move StoppingCriteria test
Jan 23, 2025
5e1a7dd
feat: add support for models without cache class support
Jan 23, 2025
cc99580
refactor: make caching code optional in TextEnvironment
Jan 23, 2025
50119a8
docs: TextEnvironment use_cache note untested Encoder-Decoder archite…
Jan 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/text_environments.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ Let's decompose the settings:
| `max_tool_response`| The tool response is truncated to this number to avoid running out of model context.|
| `max_length` | The maximum number of tokens to allow in an episode. |
| `generation_kwargs`| Generation settings used by the language model. |
| `use_cache` | Cache keys and values between segment generation. Warning: This feature is experimental! When using caching, TextEnvironment is not suited for training use, i.e. backpropagation through the generated graph. Use with trl Trainers is of course possible. Furthermore, caching requires, that there be no calculation dependencies between examples at inference time. When using BatchNorm, the model should thus be in eval model. Caching is not guaranteed to produce identical results compared to not using caching and you should test for yourself, if it is suited to your needs, model and generation_kwargs. Cache use has been tested for GPT-2 with greedy search. |
konrad-gerlach marked this conversation as resolved.
Show resolved Hide resolved

You can customize the environment to your needs and add custom tools and settings. Let's see how you can use the environment to have the model interact with the available tools!

Expand Down
217 changes: 212 additions & 5 deletions tests/test_environments.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from unittest.mock import patch

import torch
from transformers import AutoTokenizer
from transformers import AutoTokenizer, DynamicCache

from trl import AutoModelForCausalLMWithValueHead, TextEnvironment, TextHistory

Expand All @@ -26,10 +26,22 @@ def __call__(self, text):
return text


def dummy_generate(histories):
def dummy_generate(
histories, past_key_values=None, past_attention_masks=None, past_input_ids=None, last_active_histories=None
):
for i in range(len(histories)):
histories[i].append_segment("<request><DummyTool>test<call>", torch.tensor([1, 2, 3]), system=False)
return histories
return histories, None, None, None, None


def reshape_cache(cache):
new_cache = []
for layer in cache:
keys, values = layer
keys = keys.reshape((-1, 1, 1, 1))
values = values.reshape((-1, 1, 1, 1))
new_cache.append((keys, values))
return tuple(new_cache)


class TextHistoryTest(unittest.TestCase):
Expand Down Expand Up @@ -79,6 +91,7 @@ def test_text_history_last_segment(self):
history.append_segment("General Kenobi!", torch.tensor([4, 5, 6]))
history.append_segment("You are a bold one!", torch.tensor([7, 8, 9]))
self.assertEqual(history.last_text_segment, "You are a bold one!")
self.assertTrue(torch.all(history.last_token_segment == torch.tensor([7, 8, 9])).item())

def test_text_history_split_query_response(self):
text = "Hello there!"
Expand Down Expand Up @@ -131,10 +144,10 @@ def test_text_environment_generate(self):

model_inputs = [self.gpt2_tokenizer(txt, return_tensors="pt").input_ids.squeeze() for txt in input_texts]

generations_batched = env._generate_batched(model_inputs, batch_size=2)
generations_batched, _, _, _, _ = env._generate_batched(model_inputs, batch_size=2)
generations_batched = self.gpt2_tokenizer.batch_decode(generations_batched)

generations_single = [env._generate_batched([inputs], batch_size=1)[0] for inputs in model_inputs]
generations_single = [env._generate_batched([inputs], batch_size=1)[0][0] for inputs in model_inputs]
generations_single = self.gpt2_tokenizer.batch_decode(generations_single)

self.assertEqual(generations_single, generations_batched)
Expand Down Expand Up @@ -276,3 +289,197 @@ def test_text_environment_run(self, mock_generate):
("I am a prompt!\n" + "Hello there! General Kenobi!")
+ (2 * "<request><DummyTool>test<call>test<response>"),
)

def test_combine_cache(self):
env = TextEnvironment(
self.gpt2_model,
self.gpt2_tokenizer,
tools={"DummyTool": DummyTool()},
reward_fn=lambda x: [torch.tensor(i) for i, _ in enumerate(x)],
prompt="I am a prompt!\n",
max_turns=2,
)

caches = [
(
(torch.tensor([[1], [2]]), torch.tensor([[3], [4]])),
(torch.tensor([[7], [8]]), torch.tensor([[9], [10]])),
),
(
(torch.tensor([[5]]), torch.tensor([[6]])),
(torch.tensor([[11]]), torch.tensor([[12]])),
),
]
caches = [DynamicCache().from_legacy_cache(reshape_cache(cache)) for cache in caches]
attention_masks = [torch.tensor([[0, 1], [1, 0]]), torch.tensor([[2, 4]])]
input_ids = [torch.tensor([[1, 4], [2, 5]]), torch.tensor([[3, 6]])]
example_mask = [True, False, True]

expected_cache = reshape_cache(
(
(torch.tensor([[1], [5]]), torch.tensor([[3], [6]])),
(torch.tensor([[7], [11]]), torch.tensor([[9], [12]])),
)
)
expected_attention_mask = torch.tensor([[0, 1], [2, 4]])
expected_input_ids = torch.tensor([[1, 4], [3, 6]])

combined_cache, combined_attention_masks, combined_input_ids = env._combine_cache(
example_mask, caches, attention_masks, input_ids
)

self.assertEqual(len(combined_cache), len(expected_cache))
self.assertEqual(len(combined_cache[0]), len(expected_cache[0]))
self.assertTrue(torch.all(combined_cache[0][0] == expected_cache[0][0]))
self.assertTrue(torch.all(combined_cache[0][1] == expected_cache[0][1]))
self.assertEqual(len(combined_cache[1]), len(expected_cache[1]))
self.assertTrue(torch.all(combined_cache[1][0] == expected_cache[1][0]))
self.assertTrue(torch.all(combined_cache[1][1] == expected_cache[1][1]))
self.assertTrue(torch.all(combined_attention_masks == expected_attention_mask))
self.assertTrue(torch.all(combined_input_ids == expected_input_ids))

def test_get_batched_cache(self):
env = TextEnvironment(
self.gpt2_model,
self.gpt2_tokenizer,
tools={"DummyTool": DummyTool()},
reward_fn=lambda x: [torch.tensor(i) for i, _ in enumerate(x)],
prompt="I am a prompt!\n",
max_turns=2,
)

cache = reshape_cache(
(
(torch.tensor([[1], [2], [3]]), torch.tensor([[4], [5], [6]])),
(torch.tensor([[7], [8], [9]]), torch.tensor([[10], [11], [12]])),
)
)
attention_masks = torch.tensor([[1], [2], [3]])
input_ids = torch.tensor([[4], [5], [6]])
batched_cache, batched_attention_masks, batched_input_ids = env._get_batched_cache(
1, 3, cache, attention_masks, input_ids
)
batched_cache = batched_cache.to_legacy_cache()
expected_cache = reshape_cache(
(
(torch.tensor([[2], [3]]), torch.tensor([[5], [6]])),
(torch.tensor([[8], [9]]), torch.tensor([[11], [12]])),
)
)

self.assertEqual(len(batched_cache), len(expected_cache))
self.assertEqual(len(batched_cache[0]), len(expected_cache[0]))
self.assertTrue(torch.all(batched_cache[0][0] == expected_cache[0][0]))
self.assertTrue(torch.all(batched_cache[0][1] == expected_cache[0][1]))
self.assertEqual(len(batched_cache[1]), len(expected_cache[1]))
self.assertTrue(torch.all(batched_cache[1][0] == expected_cache[1][0]))
self.assertTrue(torch.all(batched_cache[1][1] == expected_cache[1][1]))

expected_attention_mask = torch.tensor([[2], [3]])
self.assertTrue(torch.all(batched_attention_masks == expected_attention_mask))

expected_input_ids = torch.tensor([[5], [6]])
self.assertTrue(torch.all(batched_input_ids == expected_input_ids))

def test_cached_generate_batched(self):
generation_kwargs = {"do_sample": False, "max_new_tokens": 4, "pad_token_id": self.gpt2_tokenizer.eos_token_id}
env = TextEnvironment(
self.gpt2_model,
self.gpt2_tokenizer,
tools=[DummyTool()],
reward_fn=lambda x: torch.tensor(1),
prompt="I am a prompt!\n",
generation_kwargs=generation_kwargs,
)

input_texts = ["this is a test", "this is another, longer test", "some other batch", "something unnecessary"]
model_inputs = [self.gpt2_tokenizer(txt, return_tensors="pt").input_ids.squeeze() for txt in input_texts]
outputs, past_key_values, past_attention_masks, past_input_ids, _ = env._generate_batched(
model_inputs, batch_size=2
)

past_key_values, past_attention_masks, past_input_ids = env._combine_cache(
[True, True, True, False], past_key_values, past_attention_masks, past_input_ids
)

input_texts2 = [" short interim", " a somewhat longer section in between"]
model_inputs2 = [self.gpt2_tokenizer(txt, return_tensors="pt").input_ids.squeeze() for txt in input_texts2]
# for single token query
model_inputs2.append(
torch.tensor([self.gpt2_tokenizer(" a", return_tensors="pt").input_ids], dtype=model_inputs2[0].dtype)
)

outputs_cached, _, _, _, _ = env._generate_batched(
model_inputs2,
batch_size=2,
combined_past_key_values=past_key_values,
combined_past_attention_masks=past_attention_masks,
combined_past_input_ids=past_input_ids,
)

model_inputs2_full = [
torch.concat([in1, out1, in2], dim=0) for in1, out1, in2 in zip(model_inputs[:-1], outputs, model_inputs2)
]
outputs_uncached, _, _, _, _ = env._generate_batched(model_inputs2_full, batch_size=2)
for cached, uncached in zip(outputs_cached, outputs_uncached):
self.assertTrue(torch.all(cached == uncached))

def test_different_sequence_lengths(self):
generation_kwargs = {"do_sample": False, "max_new_tokens": 4, "pad_token_id": self.gpt2_tokenizer.eos_token_id}
env = TextEnvironment(
self.gpt2_model,
self.gpt2_tokenizer,
tools=[DummyTool()],
reward_fn=lambda x: torch.tensor(1),
prompt="I am a prompt!\n",
generation_kwargs=generation_kwargs,
)

input_texts = ["this is a test", "this is another, longer test", "some other batch"]
model_inputs = [self.gpt2_tokenizer(txt, return_tensors="pt").input_ids.squeeze() for txt in input_texts]
outputs, past_key_values, past_attention_masks, past_input_ids, _ = env._generate_batched(
model_inputs, batch_size=2
)
# remove the last two tokens from the second batch to pretend they were never generated
second_cache = past_key_values[1].to_legacy_cache()
edited_cache = []
for layer in second_cache:
keys, values = layer
new_keys = keys[:, :, :-2, :]
new_values = values[:, :, :-2, :]
edited_cache.append((new_keys, new_values))

past_key_values[1] = DynamicCache().from_legacy_cache(tuple(edited_cache))
past_attention_masks[1] = past_attention_masks[1][:, :-2]
past_input_ids[1] = past_input_ids[1][:, :-2]

# ensure this actually removes generated tokens and not skipped tokens / padding
self.assertEqual(len(outputs[2]), 4)

past_key_values, past_attention_masks, past_input_ids = env._combine_cache(
[True, True, True], past_key_values, past_attention_masks, past_input_ids
)

self.assertEqual(past_attention_masks.shape, past_input_ids.shape)
self.assertEqual(past_key_values[0][0].shape[2], past_attention_masks.shape[1] - 1)
self.assertEqual(past_key_values[0][0].shape[0], past_attention_masks.shape[0])
input_texts2 = [" short interim", " a somewhat longer section in between"]
model_inputs2 = [self.gpt2_tokenizer(txt, return_tensors="pt").input_ids.squeeze() for txt in input_texts2]
# for single token query
model_inputs2.append(
torch.tensor([self.gpt2_tokenizer(" a", return_tensors="pt").input_ids], dtype=model_inputs2[0].dtype)
)
outputs_cached, _, _, _, _ = env._generate_batched(
model_inputs2,
batch_size=2,
combined_past_key_values=past_key_values,
combined_past_attention_masks=past_attention_masks,
combined_past_input_ids=past_input_ids,
)
outputs[2] = outputs[2][:-2] # remove last two generated tokens from input
model_inputs2_full = [
torch.concat([in1, out1, in2], dim=0) for in1, out1, in2 in zip(model_inputs, outputs, model_inputs2)
]
outputs_uncached, _, _, _, _ = env._generate_batched(model_inputs2_full, batch_size=2)
for cached, uncached in zip(outputs_cached, outputs_uncached):
self.assertTrue(torch.all(cached == uncached))
Loading
Loading