Skip to content

Commit 134bbf9

Browse files
committed
fix: move hydra heads resize_token_embeddings
move hydra heads and ref_model 's resize_token_embeddings function calls to AcceleratePPOTrainer
1 parent 23cffd4 commit 134bbf9

File tree

2 files changed

+4
-5
lines changed

2 files changed

+4
-5
lines changed

trlx/trainer/accelerate_base_trainer.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,11 +73,6 @@ def __init__(self, config, **kwargs): # noqa: C901
7373
self.tokenizer.add_tokens(self.additional_tokens)
7474
# resize the model by-default
7575
self.model.base_model.resize_token_embeddings(len(self.tokenizer))
76-
if hasattr(self.model, "frozen_head"):
77-
self.model.frozen_head.resize_token_embeddings(len(self.tokenizer))
78-
else:
79-
# resize a reference model when hydra heads are not used
80-
self.ref_model.resize_token_embeddings(len(self.tokenizer))
8176

8277
self.tokenizer.padding_side = config.tokenizer.padding_side
8378
self.tokenizer.truncation_side = config.tokenizer.truncation_side

trlx/trainer/accelerate_ppo_trainer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,12 @@ def __init__(self, config: TRLConfig, **kwargs):
7070
# Setup a reference model when hydra heads are not used
7171
if not hasattr(self.model, "frozen_head"):
7272
self.ref_model = self.get_arch(self.config)
73+
self.ref_model.resize_token_embeddings(len(self.tokenizer))
7374
self.ref_model.to(self.accelerator.device)
7475
self.ref_model.eval()
76+
else:
77+
# resize hydra heads
78+
self.model.frozen_head.resize_token_embeddings(len(self.tokenizer))
7579

7680
# Setup the KL controller
7781
# This helps prevent large divergences in the controller (policy)

0 commit comments

Comments
 (0)