Skip to content

Commit 92d65ca

Browse files
Update extending vocab docs (#2669)
- Recommends trainable tokens as first measure - Clarifies a few things about saving embeddings - Adds full-finetuning as an option of last resort --------- Co-authored-by: Benjamin Bossan <[email protected]>
1 parent 4346513 commit 92d65ca

File tree

2 files changed

+46
-3
lines changed

2 files changed

+46
-3
lines changed

docs/source/developer_guides/troubleshooting.md

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -145,17 +145,45 @@ As an example, when loading a model that is using the DeBERTa architecture for s
145145

146146
### Extending the vocabulary
147147

148-
For many language fine-tuning tasks, extending the model's vocabulary is necessary since new tokens are being introduced. This requires extending the embedding layer to account for the new tokens and also storing the embedding layer in addition to the adapter weights when saving the adapter.
148+
For many language fine-tuning tasks, extending the model's vocabulary is necessary since new tokens are being introduced. This requires extending the embedding layer to account for the new tokens and, depending on the fine-tuning method, also storing the embedding layer in addition to the adapter weights when saving the adapter. There are a few ways of achieving this ordered by parameter effectiveness:
149149

150-
Save the embedding layer by adding it to the `target_modules` of the config. The embedding layer name must follow the standard naming scheme from Transformers. For example, the Mistral config could look like this:
150+
- [trainable tokens](../package_reference/trainable_tokens), train only the specified tokens, optionally store only the updated values
151+
- training an adapter on the embedding matrix, optionally store only the updated values
152+
- full-finetuning of the embedding layer
153+
154+
#### Using trainable tokens
155+
156+
Let's start with trainable tokens, in this case its [LoRA integration](../developer_guides/lora#efficiently-train-tokens-alongside-lora). If you're interested in only training the new embeddings and nothing else, refer to the [standalone documentation](../package_reference/trainable_tokens).
157+
158+
To enable selective token training of the embedding layer, you'll need to supply the token ids of your newly added tokens via the `trainable_token_indices` parameter. Optionally you can specify which layer to target if there is more than one embedding layer. For a Mistral model this could look like this:
159+
160+
```python
161+
new_tokens = ['<think>', '</think>']
162+
tokenizer.add_tokens(new_tokens)
163+
base_model.resize_token_embeddings(len(tokenizer))
164+
165+
lora_config = LoraConfig(
166+
...,
167+
trainable_token_indices={'embed_tokens': tokenizer.convert_tokens_to_ids(new_tokens)},
168+
)
169+
```
170+
171+
If your model uses tied weights (such as the `lm_head`), trainable tokens will try to resolve those and keep them updated as well, so in that case there should be no need for adding `modules_to_save=["lm_head"]`. This only works if the model uses the Transformers convention for tying weights.
172+
173+
Saving the model with `model.save_pretrained` may save the full embedding matrix instead of
174+
only the difference as a precaution because the embedding matrix was resized. To save space you can disable this behavior by setting `save_embedding_layers=False` when calling `save_pretrained`. This is safe to do as long as you don't modify the embedding matrix through other means as well, as such changes will be not tracked by trainable tokens.
175+
176+
#### Using an adapter, e.g. LoRA
177+
178+
Prepare the embedding layer by adding it to the `target_modules` of your adapter config. For example, the Mistral config could look like this:
151179

152180
```python
153181
config = LoraConfig(..., target_modules=["embed_tokens", "lm_head", "q_proj", "v_proj"])
154182
```
155183

156184
Once added to `target_modules`, PEFT automatically stores the embedding layer when saving the adapter if the model has the [`~transformers.PreTrainedModel.get_input_embeddings`] and [`~transformers.PreTrainedModel.get_output_embeddings`]. This is generally the case for Transformers models.
157185

158-
If the model's embedding layer doesn't follow the Transformer's naming scheme, you can still save it by manually passing `save_embedding_layers=True` when saving the adapter:
186+
If the model's embedding layer doesn't follow the Transformer's naming scheme but nevertheless implements `get_input_embeddings`, you can still save it by manually passing `save_embedding_layers=True` when saving the adapter:
159187

160188
```python
161189
model = get_peft_model(...)
@@ -167,6 +195,14 @@ For inference, load the base model first and resize it the same way you did befo
167195

168196
For a complete example, please check out [this notebook](https://github.com/huggingface/peft/blob/main/examples/causal_language_modeling/peft_lora_clm_with_additional_tokens.ipynb).
169197

198+
#### Full fine-tuning
199+
200+
Full fine-tuning is more costly in terms of VRAM or storage space but if all else fails, you can fall back to this and see if it works for you. Achieve it by adding the name of the embedding layer to `modules_to_save`. Note that you need to add tied layers as well, e.g. `lm_head`. Example for a Mistral model with LoRA:
201+
202+
```python
203+
config = LoraConfig(..., modules_to_save=["embed_tokens", "lm_head"], target_modules=["q_proj", "v_proj"])
204+
```
205+
170206
### Getting a warning about "weights not being initialized from the model checkpoint"
171207

172208
When you load your PEFT model which has been trained on a task (for example, classification), you may get a warning like:

docs/source/package_reference/trainable_tokens.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,13 @@ Note that this method does not add tokens for you, you have to add tokens to the
3333
embedding matrix of the model accordingly. This method will only re-train the embeddings for the tokens you specify.
3434
This method can also be used in conjunction with LoRA layers! See [the LoRA developer guide](../developer_guides/lora#efficiently-train-tokens-alongside-lora).
3535

36+
> [!TIP]
37+
> Saving the model with [`~PeftModel.save_pretrained`] or retrieving the state dict using
38+
> [`get_peft_model_state_dict`] when adding new tokens may save the full embedding matrix instead of only the difference
39+
> as a precaution because the embedding matrix was resized. To save space you can disable this behavior by setting
40+
> `save_embedding_layers=False` when calling `save_pretrained`. This is safe to do as long as you don't modify the
41+
> embedding matrix through other means as well, as such changes will be not tracked by trainable tokens.
42+
3643
## TrainableTokensConfig
3744

3845
[[autodoc]] tuners.trainable_tokens.config.TrainableTokensConfig

0 commit comments

Comments
 (0)