diff --git a/.github/workflows/build_documentation.yml b/.github/workflows/build_documentation.yml index 0b084fbb05e..4fb1737d188 100644 --- a/.github/workflows/build_documentation.yml +++ b/.github/workflows/build_documentation.yml @@ -12,7 +12,7 @@ env: jobs: build: - uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@90b4ee2c10b81b5c1a6367c4e6fc9e2fb510a7e3 # main + uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@2430c1ec91d04667414e2fa31ecfc36c153ea391 # main with: commit_sha: ${{ github.sha }} package: trl diff --git a/.github/workflows/build_pr_documentation.yml b/.github/workflows/build_pr_documentation.yml index 1252f516abf..d2c77968d1f 100644 --- a/.github/workflows/build_pr_documentation.yml +++ b/.github/workflows/build_pr_documentation.yml @@ -13,7 +13,7 @@ concurrency: jobs: build: if: github.event.pull_request.draft == false - uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@90b4ee2c10b81b5c1a6367c4e6fc9e2fb510a7e3 # main + uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@2430c1ec91d04667414e2fa31ecfc36c153ea391 # main with: commit_sha: ${{ github.event.pull_request.head.sha }} pr_number: ${{ github.event.number }} diff --git a/.github/workflows/tests_latest.yml b/.github/workflows/tests_latest.yml index 54debaba39d..94265905b40 100644 --- a/.github/workflows/tests_latest.yml +++ b/.github/workflows/tests_latest.yml @@ -26,7 +26,7 @@ jobs: steps: - name: Git checkout uses: actions/checkout@v6 - with: { ref: v1.2-release } + with: { ref: v1.3-release } - name: Set up Python 3.12 uses: actions/setup-python@v6 diff --git a/.github/workflows/upload_pr_documentation.yml b/.github/workflows/upload_pr_documentation.yml index db6b8d1d034..98a4931613b 100644 --- a/.github/workflows/upload_pr_documentation.yml +++ b/.github/workflows/upload_pr_documentation.yml @@ -8,7 +8,7 @@ on: jobs: build: - uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@9ad2de8582b56c017cb530c1165116d40433f1c6 # main + uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@2430c1ec91d04667414e2fa31ecfc36c153ea391 # main with: package_name: trl secrets: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6bf4090b549..017eb89f8c7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,10 +8,12 @@ repos: - id: ruff-format types_or: [ python, pyi ] - # - repo: https://github.com/codespell-project/codespell - # rev: v2.1.0 - # hooks: - # - id: codespell - # args: - # - --ignore-words-list=nd,reacher,thist,ths,magent,ba - # - --skip=docs/css/termynal.css,docs/js/termynal.js + - repo: local + hooks: + - id: doc-builder-style + name: Check style with doc-builder + language: python + entry: doc-builder style trl tests docs/source --max_len 119 + additional_dependencies: ["git+https://github.com/huggingface/doc-builder@2430c1ec91d04667414e2fa31ecfc36c153ea391", ruff] # See GH-5633 + pass_filenames: false + types_or: [python, markdown, rst] diff --git a/CITATION.cff b/CITATION.cff index c78b65d38fd..619482508c4 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -37,5 +37,5 @@ keywords: - language model alignment - post-training license: Apache-2.0 -version: '1.2' +version: '1.3' date-released: '2020-03-27' diff --git a/Makefile b/Makefile index 094f2cf256b..569a9bb584b 100644 --- a/Makefile +++ b/Makefile @@ -10,7 +10,6 @@ test: precommit: python scripts/add_copyrights.py pre-commit run --all-files - doc-builder style trl tests docs/source --max_len 119 slow_tests: pytest -m "slow" tests/ $(if $(IS_GITHUB_CI),--report-log "slow_tests.log",) diff --git a/VERSION b/VERSION index 14c65ab0d00..b58da95673d 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -1.3.0.dev0 \ No newline at end of file +1.4.0.dev0 \ No newline at end of file diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 72970949c5b..b5a7442e026 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -7,6 +7,8 @@ title: Quickstart title: Getting started - sections: + - local: chat_templates + title: Chat Templates - local: dataset_formats title: Dataset Formats - local: paper_index @@ -133,6 +135,8 @@ title: SDPO - local: ssd_trainer title: SSD + - local: tpo_trainer + title: TPO - local: xpo_trainer title: XPO title: Experimental diff --git a/docs/source/chat_template_utils.md b/docs/source/chat_template_utils.md index 2608f702560..53ee8fae46d 100644 --- a/docs/source/chat_template_utils.md +++ b/docs/source/chat_template_utils.md @@ -1,5 +1,7 @@ # Chat template utilities +For an overview of the chat templates bundled with TRL and the rationale behind the training patches, see [Chat Templates](chat_templates). + ## clone_chat_template [[autodoc]] clone_chat_template diff --git a/docs/source/chat_templates.md b/docs/source/chat_templates.md new file mode 100644 index 00000000000..34d2b3ba351 --- /dev/null +++ b/docs/source/chat_templates.md @@ -0,0 +1,113 @@ +# Chat Templates + +A [chat template](https://huggingface.co/docs/transformers/en/chat_templating) is a Jinja2 snippet that formats messages into the string a model was trained on. For example: + +```python +>>> from transformers import AutoTokenizer +>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") +>>> tokenizer.chat_template +"{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" +>>> tokenizer.apply_chat_template([{"role": "user", "content": "Hi!"}], tokenize=False) +'<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nHi!<|im_end|>\n' +``` + +In most cases you don't need to worry about chat templates: models ship their template along with the tokenizer, and TRL applies it for you. The whole thing is transparent. But some TRL recipes rely on features that most shipped templates don't include: + +- **SFT with `assistant_only_loss=True`** needs `{% generation %}` / `{% endgeneration %}` markers around assistant output, so the loss mask can target only assistant tokens. +- **GRPO with tool calls** needs the template to be *prefix-preserving*: appending a tool message must not change how earlier messages are rendered. + +TRL ships patched templates under [`trl/chat_templates/`](https://github.com/huggingface/trl/tree/main/trl/chat_templates) for common families (Qwen, Llama, DeepSeek-V3, GPT-OSS, ...) and swaps them in automatically for supported models. For any other model, you'll need to patch its template yourself. The rest of this page catalogs what's bundled. + +## Supported model families + +TRL stores reference copies of the original templates so it can identify supported models at init and swap in a training template when needed. The following families are recognized: Cohere, DeepSeek-V3, Gemma, GLM-4-MoE, GPT-OSS, Llama 3 / 3.1 / 3.2, Qwen2.5, Qwen3, Qwen3-VL, Qwen3.5, Qwen3.6. + +## Training templates + +Patched templates that fix training-specific issues. Swapped in at init when tools are enabled (GRPO) or when `assistant_only_loss=True` (SFT). + +### `cohere_training.jinja` + +Patched Cohere template. Diff vs `cohere.jinja`: + +Wrap assistant message output with `{% generation %}` / `{% endgeneration %}` so that `return_assistant_tokens_mask=True` produces correct masks for SFT assistant-only loss. + +### `deepseekv3_training.jinja` + +Patched DeepSeek-V3 template. Diff vs `deepseekv3.jinja`: + +- Uses `| tojson` on `tool['function']['arguments']` so that `arguments` can be passed as a `dict` (the documented format per [transformers docs](https://huggingface.co/docs/transformers/en/chat_extras#tool-calling-example)). The original template uses raw string concatenation, which crashes on dict inputs. +- Wraps assistant message output with `{% generation %}` / `{% endgeneration %}` markers for SFT assistant-only loss. + +### `gemma_training.jinja` + +Patched Gemma template (shared by Gemma and Gemma2, which ship identical chat templates). Diff vs `gemma.jinja`: + +Split the unified assistant output so that the `model\n` header (a prompt cue, not generated by the model) sits outside the generation block, and wrap the assistant content with `{% generation %}` / `{% endgeneration %}` markers for SFT assistant-only loss. + +### `glm4moe_training.jinja` + +Patched GLM-4-MoE template. Diff vs `glm4moe.jinja`: + +Require both `` and `` to be present before parsing, to avoid incorrect splitting when the model generates only one tag: + +```diff +- {%- if '' in content %} ++ {%- if '' in content and '' in content %} +``` + +Wrap assistant message output (including the thinking block and tool calls) with `{% generation %}` / `{% endgeneration %}` markers for SFT assistant-only loss. + +### `qwen3_training.jinja` + +Patched Qwen3 template. Diff vs `qwen3.jinja`: + +Require both `` and `` to be present before parsing, to avoid incorrect splitting when the model generates only one tag: + +```diff +- {%- if '' in content %} ++ {%- if '' in content and '' in content %} +``` + +Always include the thinking block regardless of message position. The original conditionally omits it based on `loop.last`, which changes the assistant rendering when a tool message is appended, breaking prefix-preservation: + +```diff +- {%- if loop.index0 > ns.last_query_index %} +- {%- if loop.last or (not loop.last and reasoning_content) %} +- {{- '<|im_start|>' + message.role + '\n\n' + reasoning_content.strip('\n') + '\n\n\n' + content.lstrip('\n') }} +- {%- else %} +- {{- '<|im_start|>' + message.role + '\n' + content }} +- {%- endif %} +- {%- else %} +- {{- '<|im_start|>' + message.role + '\n' + content }} +- {%- endif %} ++ {{- '<|im_start|>' + message.role + '\n\n' + reasoning_content.strip('\n') + '\n\n\n' + content.lstrip('\n') }} +``` + +Wrap assistant message output with `{% generation %}` / `{% endgeneration %}` so that `return_assistant_tokens_mask=True` produces correct masks for SFT assistant-only loss. + +### `gptoss_training.jinja` + +Patched GPT-OSS template. Diff vs `gptoss.jinja`: + +Wrap assistant message output with `{% generation %}` / `{% endgeneration %}` so that `return_assistant_tokens_mask=True` produces correct masks for SFT assistant-only loss. + +### `llama3_training.jinja` + +Patched Llama 3 template. Diff vs `llama3.jinja`: + +Wrap assistant message output with `{% generation %}` / `{% endgeneration %}` so that `return_assistant_tokens_mask=True` produces correct masks for SFT assistant-only loss. + +### `qwen2_5_training.jinja` + +Patched Qwen2.5 template. Diff vs `qwen2_5.jinja`: + +Wrap assistant message output with `{% generation %}` / `{% endgeneration %}` so that `return_assistant_tokens_mask=True` produces correct masks for SFT assistant-only loss. + +### `qwen3_6_training.jinja` + +Patched Qwen3.6 template. Diff vs `qwen3_6.jinja`: same set of changes as `qwen3_training.jinja` — require both `` and `` to be present before parsing, drop the `loop.index0 > ns.last_query_index` conditional so the thinking block is always emitted (prefix-preservation), and wrap assistant output with `{% generation %}` / `{% endgeneration %}` markers for SFT assistant-only loss. + +## Related utilities + +See [Chat Template Utilities](chat_template_utils) for the helper functions ([`clone_chat_template`], [`is_chat_template_prefix_preserving`], [`get_training_chat_template`]) that operate on these templates. diff --git a/docs/source/grpo_trainer.md b/docs/source/grpo_trainer.md index 935ac78f1d5..98133d1a441 100644 --- a/docs/source/grpo_trainer.md +++ b/docs/source/grpo_trainer.md @@ -632,6 +632,9 @@ trainer = GRPOTrainer( Each tool must be a standard Python function with **type-hinted arguments and return types**, along with a **Google-style docstring** describing its purpose, arguments, and return value. For more details, see the [Passing tools guide](https://huggingface.co/docs/transformers/en/chat_extras#passing-tools). +> [!TIP] +> The GRPO tool call loop requires the chat template to be *prefix-preserving* (appending a tool message must not change how earlier messages are rendered). For known model families (e.g. Qwen3, DeepSeek-V3), TRL automatically swaps in a patched training template when tools are enabled. See [Chat Templates](chat_templates#training-templates) for the full list. + Example: ```python @@ -748,6 +751,7 @@ Tested with: - [**Qwen3**](https://huggingface.co/collections/Qwen/qwen3) — e.g., `Qwen/Qwen3-0.6B` - [**Qwen3-VL**](https://huggingface.co/collections/Qwen/qwen3-vl) — e.g., `Qwen/Qwen3-VL-2B-Instruct` - [**Qwen3.5**](https://huggingface.co/collections/Qwen/qwen35) — e.g., `Qwen/Qwen3.5-2B` +- [**Qwen3.6**](https://huggingface.co/collections/Qwen/qwen36) — e.g., `Qwen/Qwen3.6-35B-A3B` > [!TIP] > Compatibility with all LLMs is not guaranteed. If you believe a model should be supported, feel free to open an issue on GitHub — or better yet, submit a pull request with the required changes. diff --git a/docs/source/paper_index.md b/docs/source/paper_index.md index 4c78c9d3bb5..3fdccefff30 100644 --- a/docs/source/paper_index.md +++ b/docs/source/paper_index.md @@ -1403,6 +1403,44 @@ training_args = CPOConfig( ) ``` +## Triple Preference Optimization + +Papers relating to the [`experimental.tpo.TPOTrainer`] + +### Triple Preference Optimization: Achieving Better Alignment using a Single Step Optimization + +**📜 Paper**: https://huggingface.co/papers/2405.16681 + +Introduces Triple Preference Optimization (TPO), a preference learning method that aligns an LLM with three responses per prompt — a gold (`reference`) completion, a preferred (`chosen`) completion and a dispreferred (`rejected`) completion — in a single optimization step. TPO combines a contrastive objective on the (chosen, rejected) pair with a supervised NLL term on the gold response, removing the need for a separate SFT stage and the reference model used in DPO. Used in TRL via [`experimental.tpo.TPOTrainer`]. To reproduce the paper's setting (Llama-3-Base, 5K), use this configuration: + +```python +from trl.experimental.tpo import TPOConfig + +training_args = TPOConfig( + loss_type="sigmoid", # contrastive loss between chosen and rejected (Section 3 of the paper) + tpo_alpha=1.0, # weight of the NLL term on the gold response (Section 3 of the paper) + beta=0.01, # β temperature (Table 6 of the paper) + learning_rate=5e-7, # Table 6 of the paper + num_train_epochs=1, + max_length=1024, +) +``` + +To use the TPO-L variant (length-normalized log-probabilities with a target reward margin γ), set `loss_type="tpo-l"` and `tpo_l_gamma`: + +```python +from trl.experimental.tpo import TPOConfig + +training_args = TPOConfig( + loss_type="tpo-l", # length-normalized variant (Section 3 of the paper) + tpo_alpha=1.0, + beta=0.01, + tpo_l_gamma=0.5, # γ target reward margin (Table 6 of the paper, Llama-3-Base 5K) + learning_rate=5e-7, + num_train_epochs=1, +) +``` + ## Nash Learning from Human Feedback Papers relating to the [`experimental.nash_md.NashMDTrainer`] diff --git a/docs/source/rapidfire_integration.md b/docs/source/rapidfire_integration.md index be1619d74cd..cdbab5bc8e4 100644 --- a/docs/source/rapidfire_integration.md +++ b/docs/source/rapidfire_integration.md @@ -1,29 +1,23 @@ # RapidFire AI Integration -RapidFire AI is an open-source experiment execution framework that enables concurrent training of multiple TRL configurations on the same GPU(s) through intelligent chunk-based scheduling. +RapidFire AI is an open-source experiment execution framework that integrates with TRL to turn "train one configuration at a time" into **real-time, side-by-side comparison of many configurations on the same GPU(s)** — so you can iterate on hyperparameters, LoRA settings, prompt schemes, and ablations **16–24× faster with no extra hardware**. -## Key Features +Links: [GitHub](https://github.com/RapidFireAI/rapidfireai) · [Docs](https://oss-docs.rapidfire.ai) · [Try in Colab](http://tinyurl.com/rapidfireai-colab) -- **16-24× higher experimentation throughput** compared to sequential training. -- **Almost no code changes** - drop-in configuration wrappers around TRL's and PEFT's existing configs. -- **Interactive Control Operations** - real-time control to stop, resume, clone, and modify training runs in flight -- **Automatic multi-GPU orchestration** with intelligent scheduling -- **Full compatibility** with transformers, PEFT, SFTTrainer, DPOTrainer, and GRPOTrainer -- **Full MLflow Integration**: Automatic experiment tracking and visualization -- **Production-Ready**: Already used in production environments with complete working examples. +## Why use RapidFire AI with TRL? -### Problem It Solves +When fine-tuning or post-training with TRL, you typically need to: -When fine-tuning or post-training with TRL, AI developers often need to: - Try different hyperparameter configurations - Compare different LoRA settings - Test different prompt schemes - Run ablation studies - -**Current approach**: Train each config one after another → slow and inefficient process - -**With RapidFire AI**: Train all configs in one go even on a single GPU → 16-24× faster process +| Scenario: comparing N training configs on the same GPU(s) | TRL alone | TRL + RapidFire AI | +| --- | --- | --- | +| Training strategy | Run N configs sequentially | Run N configs concurrently | +| When can you compare configs? | After all runs finish | Live, from the first chunk | +| Stop losers / clone winners mid-training | No | Yes (Interactive Control Operations) | ### How It Works @@ -37,20 +31,31 @@ Chunk 3: [Config A] → [Config B] → [Config C] → [Config D] ``` This enables: + - Early comparison of configurations on same data subsets incrementally - Efficient GPU utilization and minimizing idle times - Real-time and automated experiment metrics tracking - Dynamic control over runs in flight to incentivize more experimentation +## Key Features + +- **16-24× higher experimentation throughput** compared to sequential training. +- **Almost no code changes** - simple drop-in config APIs that just wrap around existing TRL and PEFT config APIs. +- **Interactive Control Operations** - real-time control to stop, resume, and clone-modify (with or without warm starting) training runs in flight. +- **Integration with Fully Sharded Data Parallel (FSDP)** for training large models that do not fit on a single GPU by sharding parameters, gradients, and optimizer states across multiple GPUs. +- **Full compatibility** with transformers, PEFT, SFTTrainer, DPOTrainer, and GRPOTrainer. +- **Pluggable experiment tracking**: MLflow (default), TensorBoard, and Trackio, enabled individually or in combination. +- **Zero-setup Google Colab support**: one-click tutorial notebooks for SFT, DPO, and GRPO on free T4 GPUs. +- **Production-Ready**: Already used in production environments with complete working examples. ## Installation ### Prerequisites - Python 3.12.x -- NVIDIA GPU with Compute Capability 7.x or 8.x +- NVIDIA GPU with Compute Capability 7.x or 8.x (multiple GPUs required for FSDP) - CUDA Toolkit 11.8+ -- PyTorch 2.7.1+ +- PyTorch 2.8+ ### pip install @@ -74,7 +79,7 @@ rapidfireai init rapidfireai start ``` -The dashboard will be available at `http://0.0.0.0:3000` where you can monitor and control experiments in real-time. +The dashboard will be available at `http://localhost:8853` where you can monitor and control experiments in real-time. ## Quick Start: SFT Training with Multiple Configs @@ -172,8 +177,8 @@ When you run this example: 1. **Config Expansion**: 2 base configurations × 2 PEFT configs = 4 total training runs 2. **Chunk-based Scheduling**: Training data is divided into chunks, and all 4 configs train concurrently 3. **GPU Swapping**: Models are swapped in/out of GPU memory based on chunk boundaries -4. **Real-time Tracking**: All metrics visible in the dashboard at `http://localhost:3000` -5. **Interactive Control**: Stop, resume, or clone any configuration from the dashboard +4. **Real-time Tracking**: All metrics visible in the dashboard at `http://localhost:8853` +5. **Interactive Control**: Stop, resume, or clone-modify any configuration from the dashboard This delivers **16-24× higher throughput** compared to training each configuration sequentially! @@ -195,7 +200,7 @@ training_args = RFSFTConfig( ) ``` -**Example Notebook**: [SFT for Customer Support](https://github.com/RapidFireAI/rapidfireai/blob/main/tutorial_notebooks/rf-tutorial-sft-chatqa-lite.ipynb) +**Example Notebook**: [SFT for Customer Support](https://github.com/RapidFireAI/rapidfireai/blob/main/tutorial_notebooks/fine-tuning/rf-tutorial-sft-chatqa-lite.ipynb) ### DPOTrainer @@ -213,7 +218,7 @@ training_args = RFDPOConfig( ) ``` -**Example Notebook**: [DPO for Preference Alignment](https://github.com/RapidFireAI/rapidfireai/blob/main/tutorial_notebooks/rf-tutorial-dpo-alignment-lite.ipynb) +**Example Notebook**: [DPO for Preference Alignment](https://github.com/RapidFireAI/rapidfireai/blob/main/tutorial_notebooks/post-training/rf-tutorial-dpo-alignment-lite.ipynb) ### GRPOTrainer @@ -230,7 +235,7 @@ training_args = RFGRPOConfig( ) ``` -**Example Notebook**: [GRPO for Math Reasoning](https://github.com/RapidFireAI/rapidfireai/blob/main/tutorial_notebooks/rf-tutorial-grpo-mathreasoning-lite.ipynb) +**Example Notebook**: [GRPO for Math Reasoning](https://github.com/RapidFireAI/rapidfireai/blob/main/tutorial_notebooks/post-training/rf-tutorial-grpo-mathreasoning-lite.ipynb) ## Core Concepts @@ -254,8 +259,8 @@ Through the RapidFire AI dashboard, you can dynamically control running experime - **Stop**: Pause a configuration (checkpointed automatically) - **Resume**: Continue from last checkpoint -- **Clone**: Duplicate a configuration with modifications -- **Clone & Warm Start**: Clone and initialize from parent's weights +- **Clone-Modify**: Duplicate a configuration with modifications (new run starts from scratch) +- **Clone-Modify with Warm Start**: Clone-modify and initialize from the parent's weights - **Delete**: Remove failed or unwanted runs This enables adaptive experimentation where you can stop underperforming configs early and clone promising ones with tweaked hyperparameters. @@ -320,7 +325,87 @@ config = RFModelConfig( ### Multi-GPU Support -RapidFire AI automatically detects and utilizes all available GPUs. No special configuration needed - the scheduler automatically distributes configurations across GPUs. +RapidFire AI automatically detects and utilizes all available GPUs. By default, the scheduler distributes independent configurations across GPUs (data-parallel across configs), so no special setup is required to run `N` configs on `N` GPUs concurrently. + +For models that do not fit on a single GPU, RapidFire AI also supports **Fully Sharded Data Parallel (FSDP)** to shard a single configuration across multiple GPUs — see the next section. + +### Multi-GPU Training with FSDP + +When a model is too large for a single GPU, enable FSDP directly through the training args of `RFSFTConfig` or `RFDPOConfig` — the same `fsdp` and `fsdp_config` fields exposed by Hugging Face `TrainingArguments`: + +```python +from rapidfireai.automl import RFModelConfig, RFSFTConfig, RFLoraConfig + +model_config = RFModelConfig( + model_name="meta-llama/Llama-3.1-8B-Instruct", + peft_config=RFLoraConfig( + r=16, lora_alpha=32, lora_dropout=0.05, + target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], + bias="none", + ), + training_args=RFSFTConfig( + learning_rate=2e-4, + per_device_train_batch_size=1, + gradient_accumulation_steps=8, + fsdp="full_shard auto_wrap", + fsdp_config={ + "sharding_strategy": "FULL_SHARD", + "auto_wrap_policy": "TRANSFORMER_BASED_WRAP", + "backward_prefetch": "backward_pre", + "forward_prefetch": True, + "use_orig_params": False, + "cpu_ram_efficient_loading": True, + "offload_params": True, + "sync_module_states": True, + "limit_all_gathers": True, + }, + ), + model_type="causal_lm", + model_kwargs={"torch_dtype": "auto"}, +) +``` + +Key points: + +- FSDP works transparently with RapidFire AI's chunk-based scheduling, IC Ops (stop / resume / clone-modify with or without warm-starting), and all supported metric tracking backends. +- FSDP is fully compatible with PEFT / LoRA — LoRA adapter weights are collected efficiently across shards when saving checkpoints. +- FSDP composes with grid search and random search: each expanded config gets its own sharded training run. + +**Example Notebooks**: +- [SFT with FSDP (lite, small model)](https://github.com/RapidFireAI/rapidfireai/blob/main/tutorial_notebooks/fine-tuning/rf-tutorial-sft-chatqa-fsdp-lite.ipynb) +- [SFT with FSDP (large model)](https://github.com/RapidFireAI/rapidfireai/blob/main/tutorial_notebooks/fine-tuning/rf-tutorial-sft-chatqa-fsdp-large.ipynb) +- [DPO with FSDP](https://github.com/RapidFireAI/rapidfireai/blob/main/tutorial_notebooks/post-training/rf-tutorial-dpo-alignment-fsdp-lite.ipynb) + +### Experiment Tracking Backends + +RapidFire AI supports three metric logging backends that can be used individually or together: **MLflow** (the default for local installs), **TensorBoard** (the default in Google Colab), and **Trackio**. + +Select one or more backends at server startup with the `--tracking-backends` flag: + +```bash +# MLflow only (default on local installs) +rapidfireai start --tracking-backends mlflow + +# TensorBoard only +rapidfireai start --tracking-backends tensorboard + +# Any combination +rapidfireai start --tracking-backends mlflow tensorboard trackio +``` + +Equivalent environment variables are also available: + +- `RF_MLFLOW_ENABLED` (default `true`, or `false` in Colab) +- `RF_TENSORBOARD_ENABLED` (default `false`, or `true` in Colab) +- `RF_TRACKIO_ENABLED` (default `false`) + +All three backends receive the same metrics (loss, evaluation scores, learning rate, etc.) and respect IC Ops run lifecycle events, so you can use, for example, Trackio for lightweight sharing alongside MLflow for a full local dashboard. + +### Running in Google Colab + +RapidFire AI runs on free Google Colab T4 GPUs, with tutorial notebooks for SFT, DPO, GRPO, and RAG / context-engineering workflows. In Colab, TensorBoard is the default tracking backend (MLflow is disabled for simplicity), and the usual `rapidfireai init` / `rapidfireai start` commands run directly from notebook cells — no terminal access required. + +Get started: [RapidFire AI in Google Colab](http://tinyurl.com/rapidfireai-colab). ## Best Practices diff --git a/docs/source/reducing_memory_usage.md b/docs/source/reducing_memory_usage.md index db78aded07e..c351c1d826d 100644 --- a/docs/source/reducing_memory_usage.md +++ b/docs/source/reducing_memory_usage.md @@ -180,6 +180,23 @@ training_args = GKDConfig(..., use_liger_kernel=True) +## Chunked cross-entropy for reducing peak memory usage + +At large vocabulary sizes, the `[batch × seq_len × vocab]` logits tensor produced by the LM head is one of the dominant activations held in memory across forward and backward. `loss_type="chunked_nll"` in [`SFTTrainer`] avoids materializing it all at once: positions with `labels == -100` are dropped *before* the `lm_head` matmul, and the cross-entropy is computed in chunks of tokens using gradient checkpointing, so peak activation memory scales with `chunk_size × vocab_size` instead of `(batch × seq_len) × vocab_size`. + +Same math as the default loss — this is a memory optimization, not a new loss. + +```python +from trl import SFTConfig + +training_args = SFTConfig(..., loss_type="chunked_nll") +``` + +Expect **typically ~30 % less peak VRAM, up to ~50 %** on large-vocab models (measured on `Qwen3-1.7B`, vocab ≈ 151k — ~30 % on single-GPU, up to ~50 % under FSDP2 × 4 GPUs) with wall time typically neutral or slightly faster. See the [PR #5575](https://github.com/huggingface/trl/pull/5575) for the full benchmark across single-GPU, DDP, FSDP2, packing, long-context, and fp32 configurations. +Under FSDP2, pass `--fsdp_reshard_after_forward false` to `accelerate launch` — the chunked path otherwise re-gathers `lm_head.weight` per chunk during backward, adding noticeable wall-time. + +Not compatible with `use_liger_kernel=True`, PEFT, or VLM. + ## Padding-free Padding-free batching is an alternative approach for reducing memory usage. In this method, a batch is first sampled and then flattened into a single sequence, avoiding padding. Unlike packing, which can result in incomplete sequences by combining parts of different samples, padding-free batching ensures that all sequences remain complete and intact. diff --git a/docs/source/sft_trainer.md b/docs/source/sft_trainer.md index 07ccde4f114..0f25a2e6f48 100644 --- a/docs/source/sft_trainer.md +++ b/docs/source/sft_trainer.md @@ -108,6 +108,9 @@ where \\( y_t \\) is the target token at timestep \\( t \\), and the model is > [!TIP] > The paper [On the Generalization of SFT: A Reinforcement Learning Perspective with Reward Rectification](https://huggingface.co/papers/2508.05629) proposes an alternative loss function, called **Dynamic Fine-Tuning (DFT)**, which aims to improve generalization by rectifying the reward signal. This method can be enabled by setting `loss_type="dft"` in the [`SFTConfig`]. For more details, see [Paper Index - Dynamic Fine-Tuning](paper_index#on-the-generalization-of-sft-a-reinforcement-learning-perspective-with-reward-rectification). +> [!TIP] +> For a memory-efficient variant of the standard loss, set `loss_type="chunked_nll"` in the [`SFTConfig`]. Same math as `"nll"`, but the `lm_head` projection skips ignored-label tokens and the cross-entropy is processed in chunks, so peak activation memory does not scale with the full vocab × seq_len logits tensor. See [Chunked cross-entropy for reducing peak memory usage](reducing_memory_usage#chunked-cross-entropy-for-reducing-peak-memory-usage). + ### Label shifting and masking During training, the loss is computed using a **one-token shift**: the model is trained to predict each token in the sequence based on all previous tokens. Specifically, the input sequence is shifted right by one position to form the target labels. @@ -169,7 +172,7 @@ training_args = SFTConfig(assistant_only_loss=True) ![train_on_assistant](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/train_on_assistant.png) > [!WARNING] -> This functionality requires the chat template to include `{% generation %}` and `{% endgeneration %}` keywords. For known model families (e.g. Qwen3), TRL automatically patches the template when `assistant_only_loss=True`. For other models, check that your chat template includes these keywords — see [HuggingFaceTB/SmolLM3-3B](https://huggingface.co/HuggingFaceTB/SmolLM3-3B/blob/main/chat_template.jinja#L76-L82) for an example. +> This functionality requires the chat template to include `{% generation %}` and `{% endgeneration %}` keywords. For known model families (e.g. Qwen3), TRL automatically patches the template when `assistant_only_loss=True`. See [Chat Templates](chat_templates#training-templates) for the full list of bundled training templates. For other models, check that your chat template includes these keywords. See [HuggingFaceTB/SmolLM3-3B](https://huggingface.co/HuggingFaceTB/SmolLM3-3B/blob/main/chat_template.jinja#L76-L82) for an example. ### Train on completion only diff --git a/docs/source/tpo_trainer.md b/docs/source/tpo_trainer.md new file mode 100644 index 00000000000..9b025c069b8 --- /dev/null +++ b/docs/source/tpo_trainer.md @@ -0,0 +1,135 @@ +# TPO Trainer + +[![All_models-TPO-blue](https://img.shields.io/badge/All_models-TPO-blue)](https://huggingface.co/models?other=tpo,trl) + +## Overview + +Triple Preference Optimization (TPO) was introduced in the paper [Triple Preference Optimization: Achieving Better Alignment using a Single Step Optimization](https://huggingface.co/papers/2405.16681) by Amir Saeidi, Shivanshu Verma, Aswin RRV, and Chitta Baral. TPO enhances the instruction-following and reasoning capabilities of large language models in a single training step, starting from a pre-trained or instruction-tuned model. + +The abstract from the paper is the following: + +> Reinforcement Learning with Human Feedback (RLHF) enhances the alignment of Large Language Models (LLMs). However, its limitations have led to the development of Direct Preference Optimization (DPO), an RL-free approach designed to overcome these shortcomings. While studies have shown that DPO improves instruction-following capabilities, it negatively impacts the reasoning ability of LLMs. Additionally, DPO is highly sensitive to judgment noise in preference datasets and the size of the training set. Although several modifications to DPO have been proposed, they still fail to fully resolve these issues. To address these limitations, we propose Triple Preference Optimization (TPO), a new preference learning method designed to enhance both reasoning and instruction-following abilities through one-step optimization. We compare TPO against DPO and its recent variants using state-of-the-art training setups, including both base and instructiontuned models such as Mistral and Llama 3. Our evaluation covers a comprehensive range of chat-based and reasoning benchmarks. The results demonstrate that TPO achieves significant improvements over existing methods without substantially increasing response length across different dataset sizes. Specifically, TPO outperforms DPO and SimPO by up to 7.0% and 7.3% points on Arena-Hard, 12.2% and 13.3% points on MixEval-Hard, 10.4% and 10.1% points on MMLU-Pro, and 19.0% and 19.2% points on GSM8K, respectively. Furthermore, TPO achieves these improvements while requiring less data than DPO. + +This post-training method was contributed by [Kashif Rasul](https://huggingface.co/kashif). + +## Quick start + +This example demonstrates how to train a model using the TPO method. We use the [Qwen 3 0.6B model](https://huggingface.co/Qwen/Qwen3-0.6B) as the base model. TPO requires a *triple-preference* dataset (`prompt`, `chosen`, `rejected`, `reference`) — see [Expected dataset type](#expected-dataset-type-and-format) below. + +Below is the script to train the model: + +```python +# train_tpo.py +from datasets import load_dataset +from trl.experimental.tpo import TPOConfig, TPOTrainer +from transformers import AutoModelForCausalLM, AutoTokenizer + +model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-0.6B") +tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B") +train_dataset = load_dataset("tpo-alignment/triple-preference-ultrafeedback-40K", split="train") + +training_args = TPOConfig(output_dir="Qwen3-0.6B-TPO") +trainer = TPOTrainer(model=model, args=training_args, processing_class=tokenizer, train_dataset=train_dataset) +trainer.train() +``` + +Execute the script using the following command: + +```bash +accelerate launch train_tpo.py +``` + +## Expected dataset type and format + +TPO requires a *triple-preference* dataset: each example must contain a `prompt`, a `chosen` (preferred) completion, a `rejected` (dispreferred) completion **and** a `reference` (gold) completion. The [`experimental.tpo.TPOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset formats. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset. + +```python +# Standard format +triple_preference_example = { + "prompt": "The sky is", + "reference": " a beautiful shade of blue.", # gold response (used for the NLL term) + "chosen": " blue.", + "rejected": " green.", +} + +# Conversational format +triple_preference_example = { + "prompt": [{"role": "user", "content": "What color is the sky?"}], + "reference": [{"role": "assistant", "content": "It is a beautiful shade of blue."}], + "chosen": [{"role": "assistant", "content": "It is blue."}], + "rejected": [{"role": "assistant", "content": "It is green."}], +} +``` + +The reference response is typically the highest-quality completion available for the prompt; in the original TPO paper it is taken from the response with the highest score in [UltraFeedback](https://huggingface.co/datasets/openbmb/UltraFeedback), with the second-highest used as the chosen completion and the lowest as the rejected completion. + +## Example script + +We provide an example script to train a model using the TPO method. The script is available at [`trl/experimental/tpo/tpo.py`](https://github.com/huggingface/trl/blob/main/trl/experimental/tpo/tpo.py). + +To test the TPO script with the [Qwen 3 0.6B model](https://huggingface.co/Qwen/Qwen3-0.6B) on a triple-preference dataset, run the following command: + +```bash +accelerate launch trl/experimental/tpo/tpo.py \ + --model_name_or_path Qwen/Qwen3-0.6B \ + --dataset_name tpo-alignment/triple-preference-ultrafeedback-40K \ + --beta 0.01 \ + --tpo_alpha 1.0 \ + --learning_rate 5e-7 \ + --num_train_epochs 1 \ + --output_dir Qwen3-0.6B-TPO +``` + +## Looking deeper into the TPO method + +Triple Preference Optimization (TPO) extends preference-based alignment from pairs to *triples* `(y_gold, y_chosen, y_rejected)`. The model is jointly optimized with two objectives in a single step: + +1. A **contrastive loss** between the chosen and rejected completions, similar in spirit to DPO/SimPO but computed directly from the policy log-probabilities (no separate reference policy is required). +2. A **supervised negative log-likelihood (NLL) loss** on the gold (`reference`) completion, weighted by `tpo_alpha`. This term replaces the standalone SFT stage typically required before DPO. + +The total TPO loss is: + +$$ +\mathcal{L}_{\mathrm{TPO}}(\theta) = \mathcal{L}_{\mathrm{contrast}}(\theta) + \alpha \cdot \mathcal{L}_{\mathrm{NLL}}(\theta; y_{\text{gold}}) +$$ + +where \\( \alpha \\) is `tpo_alpha` and \\( \mathcal{L}_{\mathrm{contrast}} \\) is selected via `loss_type`. + +### Loss types + +| `loss_type=` | Description | +| --- | --- | +| `"sigmoid"` (default) | Sigmoid loss on the (sum) log-probability difference between the chosen and rejected completions, as in the original [TPO](https://huggingface.co/papers/2405.16681) paper. | +| `"hinge"` | Hinge loss on the normalized likelihood from the [SLiC](https://huggingface.co/papers/2305.10425) paper. In this case, `beta` is the reciprocal of the margin. | +| `"ipo"` | IPO loss from the [IPO](https://huggingface.co/papers/2310.12036) paper, computed on length-normalized log-probabilities. | +| `"tpo-l"` | Length-normalized TPO variant: uses average per-token log-probabilities and adds a target reward margin `tpo_l_gamma` to the Bradley-Terry objective, in the spirit of [SimPO](https://huggingface.co/papers/2405.14734). | + +Setting `tpo_alpha=0.0` disables the NLL term entirely (the reference response is then unused, and the corresponding cross-entropy is skipped to save compute). + +## Logged metrics + +While training and evaluating we record the following metrics: + +* `loss`: The total TPO loss (contrastive + `tpo_alpha` × NLL) averaged over the current logging interval. +* `entropy`: The average entropy of the model's predicted token distribution over completion tokens. +* `mean_token_accuracy`: The proportion of completion tokens for which the model's top-1 prediction matches the chosen completion. +* `num_tokens`: The total number of tokens processed so far. +* `logits/chosen`: The average logit values assigned by the model to the tokens in the chosen completion. +* `logits/rejected`: The average logit values assigned by the model to the tokens in the rejected completion. +* `logps/chosen`: The average log-probability assigned by the model to the chosen completion. +* `logps/rejected`: The average log-probability assigned by the model to the rejected completion. +* `rewards/chosen`: The average implicit reward computed for the chosen completion, defined as \\( \beta \log \pi_{\theta}(y^{+}\!\mid x) \\). +* `rewards/rejected`: The average implicit reward computed for the rejected completion, defined as \\( \beta \log \pi_{\theta}(y^{-}\!\mid x) \\). +* `rewards/margins`: The average implicit reward margin between the chosen and rejected completions. +* `rewards/accuracies`: The proportion of examples where the implicit reward for the chosen completion is higher than that for the rejected completion. + +## TPOTrainer + +[[autodoc]] experimental.tpo.TPOTrainer + - train + - save_model + - push_to_hub + +## TPOConfig + +[[autodoc]] experimental.tpo.TPOConfig diff --git a/docs/source/vllm_integration.md b/docs/source/vllm_integration.md index 3b2d1e7b1b3..78ba3f1890f 100644 --- a/docs/source/vllm_integration.md +++ b/docs/source/vllm_integration.md @@ -3,7 +3,7 @@ This document will guide you through the process of using vLLM with TRL for faster generation in online methods like GRPO and Online DPO. We first summarize a tl;dr on how to use vLLM with TRL, and then we will go into the details of how it works under the hood. > [!WARNING] -> TRL currently only supports vLLM versions from `0.11.0` to `0.18.0`. Please ensure you have a version in this range installed to avoid compatibility issues. +> TRL currently only supports vLLM versions from `0.12.0` to `0.18.0`. Please ensure you have a version in this range installed to avoid compatibility issues. > [!TIP] > The following trainers currently support generation with vLLM: diff --git a/examples/scripts/async_grpo.py b/examples/scripts/async_grpo.py index 78e6c2e7253..ccd020b9d13 100644 --- a/examples/scripts/async_grpo.py +++ b/examples/scripts/async_grpo.py @@ -12,16 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" -CUDA_VISIBLE_DEVICES=1 VLLM_SERVER_DEV_MODE=1 vllm serve Qwen/Qwen3-4B \ - --weight-transfer-config '{"backend":"nccl"}' \ - --max-model-len 9216 +# /// script +# dependencies = [ +# "trl", +# "math-verify", +# "latex2sympy2_extended", +# "trackio", +# ] +# /// -LOG_LEVEL=DEBUG CUDA_VISIBLE_DEVICES=0 accelerate launch examples/scripts/async_grpo.py """ +pip install math_verify + +CUDA_VISIBLE_DEVICES=1 VLLM_SERVER_DEV_MODE=1 vllm serve Qwen/Qwen3-0.6B \ + --max-model-len 2048 \ + --logprobs-mode processed_logprobs \ + --weight-transfer-config '{"backend":"nccl"}' -import logging -import os +CUDA_VISIBLE_DEVICES=0 accelerate launch examples/scripts/async_grpo.py +""" from datasets import load_dataset @@ -29,34 +38,33 @@ from trl.rewards import accuracy_reward -logging.basicConfig( - level=getattr(logging, os.environ.get("LOG_LEVEL", "INFO").upper(), logging.INFO), - format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", -) -logging.getLogger("trl").setLevel(logging.DEBUG) - - def format_sample(sample): - return {"prompt": sample["messages"][:1], "solution": sample["answer"]} + return { + "prompt": [{"role": "user", "content": sample["question"]}], + "solution": sample["answer"].split("####")[-1].strip(), + } def main() -> None: - dataset = load_dataset("open-r1/OpenR1-Math-220k", split="train[:10000]") + dataset = load_dataset("openai/gsm8k", "main", split="train") dataset = dataset.map(format_sample, remove_columns=dataset.column_names) config = AsyncGRPOConfig( - output_dir="./results", - per_device_train_batch_size=1, - num_train_epochs=1, - max_completion_length=4096, - max_steps=10, + output_dir="async_grpo_gsm8k", + save_strategy="no", + per_device_train_batch_size=16, + gradient_accumulation_steps=2, + max_completion_length=1024, + chat_template_kwargs={"enable_thinking": False}, + max_steps=200, + learning_rate=1e-5, report_to="trackio", - trackio_space_id=None, - project="async_grpo", + trackio_space_id="async-grpo-gsm8k", + project="async-grpo-gsm8k", log_completions=True, ) trainer = AsyncGRPOTrainer( - model="Qwen/Qwen3-4B", + model="Qwen/Qwen3-0.6B", args=config, train_dataset=dataset, reward_funcs=accuracy_reward, diff --git a/pyproject.toml b/pyproject.toml index 12b60d37424..f8fd2d13b83 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -79,7 +79,7 @@ test = [ "pytest" ] vllm = [ - "vllm>=0.11.0,<=0.18.0", + "vllm>=0.12.0,<=0.18.0", "fastapi", "pydantic", "aiohttp>=3.13.3", diff --git a/scripts/generate_tiny_models.py b/scripts/generate_tiny_models.py index bf8c39e13a9..73b7d166949 100644 --- a/scripts/generate_tiny_models.py +++ b/scripts/generate_tiny_models.py @@ -74,6 +74,8 @@ Qwen2VLForConditionalGeneration, Qwen3_5Config, Qwen3_5ForConditionalGeneration, + Qwen3_5MoeConfig, + Qwen3_5MoeForConditionalGeneration, Qwen3Config, Qwen3ForCausalLM, Qwen3ForSequenceClassification, @@ -180,10 +182,13 @@ def init_weights_tiny_model(model): ("mistralai/Mistral-7B-Instruct-v0.1", MistralConfig, MistralForCausalLM, torch.bfloat16, "0.1"), ("mistralai/Mistral-7B-Instruct-v0.2", MistralConfig, MistralForCausalLM, torch.bfloat16, "0.2"), ("facebook/opt-1.3b", OPTConfig, OPTForCausalLM, torch.float16, None), - ("microsoft/Phi-3.5-mini-instruct", Phi3Config, Phi3ForCausalLM, torch.bfloat16, None), + ("microsoft/Phi-3-mini-4k-instruct", Phi3Config, Phi3ForCausalLM, torch.bfloat16, "3"), + ("microsoft/Phi-3.5-mini-instruct", Phi3Config, Phi3ForCausalLM, torch.bfloat16, "3.5"), ("Qwen/Qwen2.5-32B-Instruct", Qwen2Config, Qwen2ForCausalLM, torch.bfloat16, "2.5"), ("Qwen/Qwen2.5-Coder-0.5B", Qwen2Config, Qwen2ForCausalLM, torch.bfloat16, "2.5-Coder"), ("Qwen/Qwen3-8B", Qwen3Config, Qwen3ForCausalLM, torch.bfloat16, None), + # It's important to have Qwen3-4B-Instruct-2507 as it doesn't have the same chat template (non-thinking variant) + ("Qwen/Qwen3-4B-Instruct-2507", Qwen3Config, Qwen3ForCausalLM, torch.bfloat16, "Instruct-2507"), ]: revision = "refs/pr/14" if model_id == "Qwen/Qwen3-8B" else "main" # chat template with {% generation %} tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision) @@ -211,8 +216,10 @@ def init_weights_tiny_model(model): kwargs = {} if model_id == "zai-org/GLM-4.5": kwargs["n_routed_experts"] = 4 - elif model_id in ("Qwen/Qwen3-30B-A3B", "openai/gpt-oss-20b"): + elif model_id == "Qwen/Qwen3-30B-A3B": kwargs["num_experts"] = 4 + elif model_id == "openai/gpt-oss-20b": + kwargs["num_local_experts"] = 4 config = config_class( vocab_size=len(tokenizer.vocab), @@ -330,6 +337,7 @@ def init_weights_tiny_model(model): ("Qwen/Qwen2.5-VL-3B-Instruct", Qwen2_5_VLForConditionalGeneration, torch.bfloat16), ("Qwen/Qwen3-VL-2B-Instruct", Qwen3VLForConditionalGeneration, torch.bfloat16), ("Qwen/Qwen3.5-0.8B", Qwen3_5ForConditionalGeneration, torch.bfloat16), + ("Qwen/Qwen3.6-35B-A3B", Qwen3_5MoeForConditionalGeneration, torch.bfloat16), ]: processor = AutoProcessor.from_pretrained(model_id) generation_config = GenerationConfig.from_pretrained(model_id) if model_id != "Qwen/Qwen3.5-0.8B" else None @@ -376,7 +384,7 @@ def init_weights_tiny_model(model): vision_config["depth"] = 2 vision_config["out_hidden_size"] = 16 - if issubclass(model_class.config_class, Qwen3_5Config): + if issubclass(model_class.config_class, (Qwen3_5Config, Qwen3_5MoeConfig)): # For tiny layer counts, default `layer_types` can end up with no full-attention layers (e.g. 2 layers and # default interval 4), which breaks Qwen3.5 dynamic cache logic. Keep one full-attention layer at the end. text_config["layer_types"] = ["linear_attention", "full_attention"] @@ -391,6 +399,12 @@ def init_weights_tiny_model(model): vision_config["intermediate_size"] = 32 vision_config["out_hidden_size"] = 16 + if issubclass(model_class.config_class, Qwen3_5MoeConfig): + text_config["num_experts"] = 4 + text_config["num_experts_per_tok"] = 2 + text_config["moe_intermediate_size"] = 32 + text_config["shared_expert_intermediate_size"] = 32 + if model_id == "llava-hf/llava-v1.6-mistral-7b-hf": # Hotfix: llava-hf/llava-v1.6-mistral-7b-hf mistakesly sets text_config.dtype to "bfloat16". # See https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf/discussions/46 @@ -412,14 +426,16 @@ def init_weights_tiny_model(model): config = AutoConfig.from_pretrained(model_id, text_config=text_config, vision_config=vision_config, **kwargs) model = model_class(config).to(dtype=dtype) - if issubclass(model_class.config_class, Qwen3_5Config): + if model_id.startswith("Qwen/Qwen3.5"): # Qwen3.5 models has some weights in float32, to mirror this in the tiny model we need to convert them to float32 manually. + # Qwen3.6 reuses the Qwen3_5Moe class but stores those weights in bf16, so the cast is not needed there. for layer in model.model.language_model.layers: if hasattr(layer, "linear_attn"): # applies to linear attention layers only layer.linear_attn.A_log.data = layer.linear_attn.A_log.data.float() layer.linear_attn.norm.weight.data = layer.linear_attn.norm.weight.data.float() - push_to_hub(model, processor, generation_config, "tiny") + suffix = "3.6" if model_id == "Qwen/Qwen3.6-35B-A3B" else None + push_to_hub(model, processor, generation_config, "tiny", suffix) # PEFT models model = Qwen3ForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen3ForCausalLM", dtype="auto") diff --git a/tests/experimental/test_distillation_trainer.py b/tests/experimental/test_distillation_trainer.py new file mode 100644 index 00000000000..a2f1d6f8182 --- /dev/null +++ b/tests/experimental/test_distillation_trainer.py @@ -0,0 +1,183 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from unittest.mock import MagicMock + +import pytest +import torch +import torch.nn.functional as F +from datasets import Dataset +from transformers import AutoModelForCausalLM, AutoTokenizer + +from trl.experimental.distillation import DistillationConfig, DistillationTrainer +from trl.experimental.distillation.distillation_trainer import _add_tail_bucket, _jsd_divergence + +from ..testing_utils import TrlTestCase + + +def _ragged_server_response(): + # Two samples with completion lengths 1 and 3 respectively; matches the wire format + # of VLLMClient.get_sequence_logprobs (per-sample shape (comp_len, top_k=1)). + return { + "logprobs": [[[-2.3]], [[-1.1], [-0.4], [-3.0]]], + "logprob_token_ids": [[[90]], [[90], [9217], [100]]], + "actual_logprobs": [[[-2.3]], [[-1.1], [-0.4], [-3.0]]], + } + + +class TestGetTeacherTokenLogprobsFromServer(TrlTestCase): + def test_variable_lengths_use_neg_inf_sentinel_at_padding(self): + mock_self = MagicMock() + mock_self.teacher_client.get_sequence_logprobs = MagicMock(return_value=_ragged_server_response()) + mock_self.loss_top_k = 1 + mock_self.temperature = 1.0 + + inputs = { + "input_ids": torch.tensor([[10, 11, 90, 0, 0], [10, 11, 90, 9217, 100]]), + "attention_mask": torch.tensor([[1, 1, 1, 0, 0], [1, 1, 1, 1, 1]]), + "labels": torch.tensor([[-100, -100, 90, -100, -100], [-100, -100, 90, 9217, 100]]), + } + + out = DistillationTrainer._get_teacher_token_logprobs_from_server(mock_self, inputs, aligned_prompt_length=2) + + assert out["actual_logprobs"].shape == (2, 3) + assert out["topk_logprobs"].shape == (2, 3, 1) + + # Real completion positions preserved. + assert out["actual_logprobs"][0, 0].item() == pytest.approx(-2.3, rel=1e-5) + assert out["actual_logprobs"][1, 0].item() == pytest.approx(-1.1, rel=1e-5) + assert out["actual_logprobs"][1, 2].item() == pytest.approx(-3.0, rel=1e-5) + + # Sample 0 is 1 token long; positions 1 and 2 are padded with the -inf sentinel. + assert out["actual_logprobs"][0, 1].item() == float("-inf") + assert out["actual_logprobs"][0, 2].item() == float("-inf") + assert out["topk_logprobs"][0, 1, 0].item() == float("-inf") + + # Sample 1 is full-length and fully finite. + assert torch.isfinite(out["actual_logprobs"][1, :]).all() + + +class TestServerReverseKLPaddingMask(TrlTestCase): + def test_mask_keeps_forward_and_backward_finite(self): + # Simulates the getter's output: sample 0 has completion length 1 (positions 1-2 + # padded with -inf), sample 1 is full-length. + teacher_topk = torch.tensor( + [[[-2.3], [float("-inf")], [float("-inf")]], [[-1.1], [-0.4], [-3.0]]], + dtype=torch.float32, + ) + labels = torch.tensor([[90, -100, -100], [90, 9217, 100]]) + + # Strategy B: neutralise -inf at labels == -100 before the divergence math. + pad_mask = (labels == -100).unsqueeze(-1) + zero = torch.zeros((), dtype=teacher_topk.dtype) + teacher_topk = torch.where(pad_mask, zero, teacher_topk) + + valid_mask = torch.ones_like(teacher_topk, dtype=torch.bool) + teacher_with_tail, support_mask = _add_tail_bucket(teacher_topk, valid_mask) + assert torch.isfinite(teacher_with_tail).all() + + raw_student = torch.randn(2, 3, 2, requires_grad=True) + student_log_probs = F.log_softmax(raw_student, dim=-1) + loss = _jsd_divergence(student_log_probs, teacher_with_tail, beta=1.0, support_mask=support_mask) + assert torch.isfinite(loss).all() + + loss.sum().backward() + assert torch.isfinite(raw_student.grad).all() + + +def _canned_teacher_logprobs(**kwargs): + # Fabricate ragged per-sample logprobs matching the requested sequence shapes. + sequences = kwargs["sequences"] + prompt_lengths = kwargs["prompt_lengths"] + top_k = kwargs.get("top_logprobs", 1) + logprobs, token_ids, actual = [], [], [] + for seq, plen in zip(sequences, prompt_lengths, strict=True): + comp_len = len(seq) - plen + logprobs.append([[-1.0 - 0.05 * i] * top_k for i in range(comp_len)]) + token_ids.append([[int(seq[plen + i])] * top_k for i in range(comp_len)]) + actual.append([[-1.0 - 0.05 * i] for i in range(comp_len)]) + return {"logprobs": logprobs, "logprob_token_ids": token_ids, "actual_logprobs": actual} + + +def _variable_length_dataset(): + return Dataset.from_list( + [ + {"messages": [{"role": "user", "content": "What's 2+2?"}, {"role": "assistant", "content": "4."}]}, + { + "messages": [ + {"role": "user", "content": "Name three primary colors."}, + { + "role": "assistant", + "content": "Red, green, and blue are the three primary colors commonly used in additive color mixing.", + }, + ] + }, + ] + ) + + +class TestDistillationTrainerServerPath(TrlTestCase): + @classmethod + def setup_class(cls): + model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" + cls.device = "cuda" if torch.cuda.is_available() else "cpu" + cls.tokenizer = AutoTokenizer.from_pretrained(model_id) + cls.tokenizer.pad_token = cls.tokenizer.eos_token + cls.model_id = model_id + + def _run_one_step(self, bs, ga, monkeypatch): + from trl.generation import vllm_client as vllm_client_module + + fake_client = MagicMock() + fake_client.get_sequence_logprobs.side_effect = _canned_teacher_logprobs + monkeypatch.setattr(vllm_client_module, "VLLMClient", lambda *a, **kw: fake_client) + + config = DistillationConfig( + output_dir=self.tmp_dir, + per_device_train_batch_size=bs, + gradient_accumulation_steps=ga, + learning_rate=1e-4, + max_length=64, + max_prompt_length=32, + max_completion_length=32, + use_teacher_server=True, + teacher_model_server_url="http://fake-teacher.invalid:8000", + loss_top_k=1, + beta=1.0, + lmbda=0.0, + loss_add_tail=True, + save_strategy="no", + report_to="none", + logging_steps=1, + ) + model = AutoModelForCausalLM.from_pretrained(self.model_id, dtype=torch.float32).to(self.device) + trainer = DistillationTrainer( + model=model, + args=config, + train_dataset=_variable_length_dataset(), + processing_class=self.tokenizer, + ) + trainer.teacher_client = fake_client + trainer.train() + return [rec for rec in trainer.state.log_history if "grad_norm" in rec] + + @pytest.mark.slow + @pytest.mark.parametrize(("bs", "ga"), [(1, 2), (2, 1)]) + def test_reverse_kl_finite_grad_with_ragged_batch(self, bs, ga, monkeypatch): + records = self._run_one_step(bs=bs, ga=ga, monkeypatch=monkeypatch) + assert records, "Expected at least one grad_norm log entry during training" + for record in records: + assert math.isfinite(record["grad_norm"]), f"grad_norm={record['grad_norm']} leaked -inf into backward" + assert math.isfinite(record["loss"]) diff --git a/tests/experimental/test_kto_trainer.py b/tests/experimental/test_kto_trainer.py index 791d1808579..2f1af929a7e 100644 --- a/tests/experimental/test_kto_trainer.py +++ b/tests/experimental/test_kto_trainer.py @@ -15,11 +15,11 @@ import multiprocess import pytest import torch -from datasets import load_dataset +from datasets import Dataset, load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer from trl.experimental.kto import KTOConfig, KTOTrainer -from trl.experimental.kto.kto_trainer import _get_kl_dataset, _process_tokens, _tokenize +from trl.experimental.kto.kto_trainer import _get_kl_dataset from ..testing_utils import TrlTestCase, require_liger_kernel, require_peft @@ -117,60 +117,53 @@ def test_tokenize_and_process_tokens(self): ) dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_unpaired_preference") + train_dataset = dummy_dataset["train"] trainer = KTOTrainer( model=self.model, ref_model=self.ref_model, args=training_args, processing_class=self.tokenizer, - train_dataset=dummy_dataset["train"], + train_dataset=train_dataset, eval_dataset=dummy_dataset["test"], ) - train_dataset = dummy_dataset["train"] - tokenized_dataset = train_dataset.map( - _tokenize, - fn_kwargs={"tokenizer": trainer.processing_class}, - batched=True, - batch_size=2, + # Verify the tokenization step: dataset stores raw token IDs (aligned with DPO style). + # prompt_ids must start with the tokenized prompt text. + prompt_ids = self.tokenizer(train_dataset["prompt"][0])["input_ids"] + assert trainer.train_dataset[0]["prompt_ids"][: len(prompt_ids)] == prompt_ids + # completion_ids are the raw answer tokens (no prompt prefix, no BOS/EOS added yet). + assert len(trainer.train_dataset[0]["completion_ids"]) > 0 + + # Verify the collator output (assembly, BOS/EOS insertion, labels). + example = trainer.train_dataset[0] + batch = trainer.data_collator([example]) + # completion_input_ids ends with EOS + assert batch["completion_input_ids"][0, -1].item() == self.tokenizer.eos_token_id + # completion_labels: prompt prefix masked with -100, answer+EOS unmasked and matching input_ids + completion_input_ids = batch["completion_input_ids"][0].tolist() + completion_labels = batch["completion_labels"][0].tolist() + first_unmasked = next(i for i, lbl in enumerate(completion_labels) if lbl != -100) + assert first_unmasked > 0 # at least the prompt is masked + assert completion_labels[first_unmasked:] == completion_input_ids[first_unmasked:] + + # Test corruption of (prompt, completion) pairs for KL dataset. + # _get_kl_dataset shifts completion_ids by one within each batch; prompt_ids are unchanged. + synthetic = Dataset.from_dict( + { + "prompt_ids": [[1, 2], [3, 4], [5, 6]], + "completion_ids": [[10, 11], [20, 21], [30, 31]], + "label": [True, False, True], + } ) - assert tokenized_dataset["prompt"][:] == train_dataset["prompt"][:] - assert tokenized_dataset["completion"][:] == train_dataset["completion"][:] - assert tokenized_dataset["label"][:] == train_dataset["label"][:] - assert tokenized_dataset["prompt_input_ids"][0] == [46518, 374, 2664, 1091] - assert tokenized_dataset["prompt_attention_mask"][0] == [1, 1, 1, 1] - assert tokenized_dataset["answer_input_ids"][0] == [27261, 13] - assert tokenized_dataset["answer_attention_mask"][0] == [1, 1] - - # Test corruption of (prompt, completion) pairs for KL dataset for batch_size in [2, 3]: - tokenized_kl_dataset = tokenized_dataset.map(_get_kl_dataset, batched=True, batch_size=batch_size) - - # Verify that the "answer_input_ids" have been modified, meaning the new "answer_input_ids" differ - # from the original ones. However, when the length of the dataset modulo batch_size equals 1, - # the last batch remains unaltered. This is a rare scenario that does not impact the training - # process, so we exclude it from testing by iterating only up to len - 1. - for i in range(len(tokenized_kl_dataset["answer_input_ids"]) - 1): - assert tokenized_dataset["prompt_input_ids"][i] == tokenized_kl_dataset["prompt_input_ids"][i] - assert ( - tokenized_dataset["prompt_attention_mask"][i] == tokenized_kl_dataset["prompt_attention_mask"][i] - ) - assert tokenized_dataset["answer_input_ids"][i] != tokenized_kl_dataset["answer_input_ids"][i] - - fn_kwargs = { - "prefix": "", - "tokenizer": trainer.processing_class, - "max_length": trainer.max_length, - } - processed_dataset = tokenized_dataset.map(_process_tokens, fn_kwargs=fn_kwargs, num_proc=2) - assert processed_dataset["prompt"][:] == train_dataset["prompt"][:] - assert processed_dataset["completion"][:] == train_dataset["completion"][:] - assert processed_dataset["label"][:] == train_dataset["label"][:] - assert processed_dataset["prompt_input_ids"][0] == [46518, 374, 2664, 1091] - assert processed_dataset["prompt_attention_mask"][0] == [1, 1, 1, 1] - assert processed_dataset["completion_input_ids"][0] == [46518, 374, 2664, 1091, 27261, 13, 151645] - assert processed_dataset["completion_attention_mask"][0] == [1, 1, 1, 1, 1, 1, 1] - assert processed_dataset["completion_labels"][0] == [-100, -100, -100, -100, 27261, 13, 151645] + rotated = synthetic.map(_get_kl_dataset, batched=True, batch_size=batch_size) + + # Verify that completion_ids have been rotated (differ from original). When the dataset length + # modulo batch_size equals 1, the last batch is unaltered: exclude it from the check. + for i in range(len(rotated) - 1): + assert synthetic["prompt_ids"][i] == rotated["prompt_ids"][i] + assert synthetic["completion_ids"][i] != rotated["completion_ids"][i] def test_kto_trainer_without_providing_ref_model(self): training_args = KTOConfig( diff --git a/tests/experimental/test_ppo_trainer.py b/tests/experimental/test_ppo_trainer.py index 32d6726ff96..b3aea763ee3 100644 --- a/tests/experimental/test_ppo_trainer.py +++ b/tests/experimental/test_ppo_trainer.py @@ -61,7 +61,8 @@ "trl-internal-testing/tiny-MistralForCausalLM-0.1", "trl-internal-testing/tiny-MistralForCausalLM-0.2", "trl-internal-testing/tiny-OPTForCausalLM", - "trl-internal-testing/tiny-Phi3ForCausalLM", + "trl-internal-testing/tiny-Phi3ForCausalLM-3", + "trl-internal-testing/tiny-Phi3ForCausalLM-3.5", "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", ] diff --git a/tests/experimental/test_tpo_trainer.py b/tests/experimental/test_tpo_trainer.py new file mode 100644 index 00000000000..4220d729687 --- /dev/null +++ b/tests/experimental/test_tpo_trainer.py @@ -0,0 +1,325 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch +from datasets import load_dataset +from transformers.utils import is_peft_available + +from trl.experimental.tpo import TPOConfig, TPOTrainer +from trl.experimental.tpo.tpo_trainer import DataCollatorForTriplePreference + +from ..testing_utils import TrlTestCase, require_peft + + +if is_peft_available(): + from peft import LoraConfig + + +def _add_reference_column(example): + """Synthesize a `reference` (gold) completion for tests by reusing the chosen completion.""" + example["reference"] = example["chosen"] + return example + + +class TestDataCollatorForTriplePreference(TrlTestCase): + def test_padding_and_masks(self): + collator = DataCollatorForTriplePreference(pad_token_id=0) + examples = [ + {"prompt_ids": [1, 2, 3], "chosen_ids": [4, 5], "rejected_ids": [6], "reference_ids": [7, 8]}, + {"prompt_ids": [9, 10], "chosen_ids": [11], "rejected_ids": [12, 13], "reference_ids": [14]}, + ] + result = collator(examples) + + expected_input_ids = torch.tensor( + [ + [1, 2, 3, 4, 5], # prompt + chosen (example 1) + [9, 10, 11, 0, 0], # prompt + chosen (example 2, padded) + [1, 2, 3, 6, 0], # prompt + rejected (example 1, padded) + [9, 10, 12, 13, 0], # prompt + rejected (example 2, padded) + [1, 2, 3, 7, 8], # prompt + reference (example 1) + [9, 10, 14, 0, 0], # prompt + reference (example 2, padded) + ] + ) + expected_attention_mask = torch.tensor( + [ + [1, 1, 1, 1, 1], + [1, 1, 1, 0, 0], + [1, 1, 1, 1, 0], + [1, 1, 1, 1, 0], + [1, 1, 1, 1, 1], + [1, 1, 1, 0, 0], + ] + ) + expected_completion_mask = torch.tensor( + [ + [0, 0, 0, 1, 1], + [0, 0, 1, 0, 0], + [0, 0, 0, 1, 0], + [0, 0, 1, 1, 0], + [0, 0, 0, 1, 1], + [0, 0, 1, 0, 0], + ] + ) + + assert set(result.keys()) == {"input_ids", "attention_mask", "completion_mask"} + torch.testing.assert_close(result["input_ids"], expected_input_ids) + torch.testing.assert_close(result["attention_mask"], expected_attention_mask) + torch.testing.assert_close(result["completion_mask"], expected_completion_mask) + + def test_exclude_reference(self): + # When `include_reference=False`, the collator only emits the chosen/rejected halves so the per-step + # compute/memory cost matches DPO's `DataCollatorForPreference`. This is the layout used by + # `TPOTrainer` when `tpo_alpha=0.0`. + collator = DataCollatorForTriplePreference(pad_token_id=0, include_reference=False) + examples = [ + {"prompt_ids": [1, 2, 3], "chosen_ids": [4, 5], "rejected_ids": [6], "reference_ids": [7, 8]}, + {"prompt_ids": [9, 10], "chosen_ids": [11], "rejected_ids": [12, 13], "reference_ids": [14]}, + ] + result = collator(examples) + + expected_input_ids = torch.tensor( + [ + [1, 2, 3, 4, 5], # prompt + chosen (example 1) + [9, 10, 11, 0, 0], # prompt + chosen (example 2, padded) + [1, 2, 3, 6, 0], # prompt + rejected (example 1, padded) + [9, 10, 12, 13, 0], # prompt + rejected (example 2, padded) + ] + ) + assert result["input_ids"].shape == (4, 5) # 2 * B rows, no reference branch + torch.testing.assert_close(result["input_ids"], expected_input_ids) + assert set(result.keys()) == {"input_ids", "attention_mask", "completion_mask"} + + +class TestTPOTrainer(TrlTestCase): + def test_train(self): + # Get the dataset and synthesize a reference (gold) completion + dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train") + dataset = dataset.map(_add_reference_column) + + # Initialize the trainer + training_args = TPOConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates + report_to="none", + ) + trainer = TPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + args=training_args, + train_dataset=dataset, + ) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" + + @pytest.mark.parametrize("loss_type", ["sigmoid", "hinge", "ipo", "tpo-l"]) + def test_train_loss_types(self, loss_type): + # Get the dataset and synthesize a reference (gold) completion + dataset = load_dataset("trl-internal-testing/zen", "standard_preference") + dataset = dataset.map(_add_reference_column) + + # Initialize the trainer + training_args = TPOConfig( + output_dir=self.tmp_dir, + loss_type=loss_type, + learning_rate=0.1, # use higher lr because gradients are tiny and default lr can stall updates + report_to="none", + eval_strategy="steps", + eval_steps=3, + ) + trainer = TPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + args=training_args, + train_dataset=dataset["train"], + eval_dataset=dataset["test"], + ) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" + + def test_train_conversational(self): + # Get the dataset and synthesize a reference (gold) completion + dataset = load_dataset("trl-internal-testing/zen", "conversational_preference", split="train") + dataset = dataset.map(_add_reference_column) + + # Initialize the trainer + training_args = TPOConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, + report_to="none", + ) + trainer = TPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + args=training_args, + train_dataset=dataset, + ) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" + + def test_train_without_nll(self): + # Setting tpo_alpha=0.0 disables the NLL term, skips the corresponding cross-entropy, and also drops the + # reference branch from the collated batch so the model doesn't pay the extra forward-pass cost. + dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train") + dataset = dataset.map(_add_reference_column) + + training_args = TPOConfig( + output_dir=self.tmp_dir, + tpo_alpha=0.0, + learning_rate=0.1, + report_to="none", + ) + trainer = TPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + args=training_args, + train_dataset=dataset, + ) + + # The default collator should drop the reference branch entirely when `tpo_alpha=0.0`. + assert isinstance(trainer.data_collator, DataCollatorForTriplePreference) + assert trainer.data_collator.include_reference is False + + # Verify the collated batch is 2 * per_device_train_batch_size (chosen + rejected only), not 3 * B. + batch = trainer.data_collator(list(trainer.train_dataset.select(range(2)))) + assert batch["input_ids"].shape[0] == 4 # 2 branches * 2 examples + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" + + def test_train_implicit_prompt(self): + # Implicit-prompt variant: no `prompt` column, the prompt is embedded in `chosen`/`rejected` and (for TPO) + # also in `reference`. Regression test for the `extract_prompt` bug where the reference column was left + # untouched, silently doubling the prompt in the reference branch. + dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train") + + # Synthesize a reference column that shares the same implicit prompt as chosen/rejected + dataset = dataset.map(_add_reference_column) + + training_args = TPOConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, + report_to="none", + ) + trainer = TPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" + + def test_implicit_prompt_mismatched_reference_raises(self): + # When the dataset has no `prompt` column and the `reference` completion does not share the implicit + # prompt prefix of `chosen`/`rejected`, the trainer must raise a clear error rather than silently + # corrupting the reference branch. + dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train") + + def _set_unrelated_reference(example): + example["reference"] = "unrelated completion without the shared prompt prefix." + return example + + dataset = dataset.map(_set_unrelated_reference) + + training_args = TPOConfig(output_dir=self.tmp_dir, report_to="none") + with pytest.raises(ValueError, match="implicit prompt"): + TPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + args=training_args, + train_dataset=dataset, + ) + + def test_missing_reference_column_raises(self): + dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train") + + training_args = TPOConfig(output_dir=self.tmp_dir, report_to="none") + with pytest.raises(ValueError, match="reference"): + TPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + args=training_args, + train_dataset=dataset, + ) + + @require_peft + def test_train_with_peft(self): + dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train") + dataset = dataset.map(_add_reference_column) + + training_args = TPOConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, + report_to="none", + ) + trainer = TPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + args=training_args, + train_dataset=dataset, + peft_config=LoraConfig(), + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + + for n, param in previous_trainable_params.items(): + if "lora" in n: + new_param = trainer.model.get_parameter(n) + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" diff --git a/tests/test_chat_template_utils.py b/tests/test_chat_template_utils.py index 4ce74c0f760..79c38ba1eaf 100644 --- a/tests/test_chat_template_utils.py +++ b/tests/test_chat_template_utils.py @@ -152,6 +152,7 @@ def test_add_response_schema(self, tokenizer_name): [ pytest.param("trl-internal-testing/tiny-Qwen3VLForConditionalGeneration", id="qwen3_vl"), pytest.param("trl-internal-testing/tiny-Qwen3_5ForConditionalGeneration", id="qwen35"), + pytest.param("trl-internal-testing/tiny-Qwen3_5MoeForConditionalGeneration-3.6", id="qwen36"), ], ) def test_add_response_schema_vlm(self, processor_name): @@ -182,6 +183,16 @@ class TestSupportsToolCalling: @pytest.mark.parametrize( "model_id", [ + pytest.param("trl-internal-testing/tiny-DeepseekV3ForCausalLM", id="deepseekv3"), + pytest.param("trl-internal-testing/tiny-DeepseekV3ForCausalLM-0528", id="deepseekv3-0528"), + pytest.param( + "trl-internal-testing/tiny-Gemma4ForConditionalGeneration", + id="gemma4", + marks=pytest.mark.skipif( + Version(transformers.__version__) < Version("5.5.0"), + reason="Gemma4 models were introduced in transformers-5.5.0", + ), + ), pytest.param( "trl-internal-testing/tiny-Glm4MoeForCausalLM", id="glm4moe", @@ -195,7 +206,16 @@ class TestSupportsToolCalling: pytest.param("trl-internal-testing/tiny-LlamaForCausalLM-3.2", id="llama3.2"), pytest.param("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", id="qwen2.5"), pytest.param("trl-internal-testing/tiny-Qwen3ForCausalLM", id="qwen3"), + pytest.param("trl-internal-testing/tiny-Qwen3ForCausalLM-Instruct-2507", id="qwen3-instruct-2507"), pytest.param("trl-internal-testing/tiny-Qwen3MoeForCausalLM", id="qwen3moe"), + pytest.param( + "trl-internal-testing/tiny-Qwen3VLForConditionalGeneration", + id="qwen3_vl", + marks=pytest.mark.skipif( + Version(transformers.__version__) < Version("4.57.0"), + reason="Qwen3-VL was introduced in transformers-4.57.0", + ), + ), pytest.param( "trl-internal-testing/tiny-Qwen3_5ForConditionalGeneration", id="qwen35", @@ -204,6 +224,14 @@ class TestSupportsToolCalling: reason="Qwen3.5 tokenizer requires transformers>=5.0.0", ), ), + pytest.param( + "trl-internal-testing/tiny-Qwen3_5MoeForConditionalGeneration-3.6", + id="qwen36", + marks=pytest.mark.skipif( + Version(transformers.__version__) < Version("5.0.0"), + reason="Qwen3.5 tokenizer requires transformers>=5.0.0", + ), + ), ], ) def test_supports_tool_calling(self, model_id): @@ -214,39 +242,40 @@ def test_supports_tool_calling(self, model_id): "model_id", [ # No chat template + pytest.param("trl-internal-testing/tiny-BartModel", id="bart"), pytest.param("trl-internal-testing/tiny-BloomForCausalLM", id="bloom"), pytest.param("trl-internal-testing/tiny-GPT2LMHeadModel", id="gpt2"), pytest.param("trl-internal-testing/tiny-GPTNeoXForCausalLM", id="gptneox"), + pytest.param("trl-internal-testing/tiny-GptNeoXForSequenceClassification", id="gptneox-seq"), pytest.param("trl-internal-testing/tiny-OPTForCausalLM", id="opt"), + pytest.param("trl-internal-testing/tiny-T5ForConditionalGeneration", id="t5"), # TemplateError: rejects tool role sequence pytest.param("trl-internal-testing/tiny-CohereForCausalLM", id="cohere"), pytest.param("trl-internal-testing/tiny-FalconMambaForCausalLM", id="falconmamba"), pytest.param("trl-internal-testing/tiny-GemmaForCausalLM", id="gemma"), pytest.param("trl-internal-testing/tiny-Gemma2ForCausalLM", id="gemma2"), - # Silently ignores tool messages + pytest.param("trl-internal-testing/tiny-Gemma3ForConditionalGeneration", id="gemma3"), + pytest.param("trl-internal-testing/tiny-Idefics2ForConditionalGeneration", id="idefics2"), + pytest.param("trl-internal-testing/tiny-Idefics3ForConditionalGeneration", id="idefics3"), + pytest.param("trl-internal-testing/tiny-LlavaNextForConditionalGeneration", id="llava_next"), + pytest.param("trl-internal-testing/tiny-MistralForCausalLM-0.1", id="mistral0.1"), + pytest.param("trl-internal-testing/tiny-MistralForCausalLM-0.2", id="mistral0.2"), + pytest.param("trl-internal-testing/tiny-SmolVLMForConditionalGeneration", id="smolvlm"), + # Silently drops both tool_calls and tool messages pytest.param("trl-internal-testing/tiny-Cohere2ForCausalLM", id="cohere2"), - pytest.param("trl-internal-testing/tiny-Phi3ForCausalLM", id="phi3"), - # Silently drops assistant tool_calls (basic Llama 3 template only reads message['content']) + pytest.param("trl-internal-testing/tiny-LlavaForConditionalGeneration", id="llava"), + pytest.param("trl-internal-testing/tiny-Phi3ForCausalLM-3", id="phi3"), + pytest.param("trl-internal-testing/tiny-Phi3ForCausalLM-3.5", id="phi3.5"), + # Renders tool message content as plain text but drops assistant tool_calls pytest.param("trl-internal-testing/tiny-LlamaForCausalLM-3", id="llama3"), + pytest.param("trl-internal-testing/tiny-Qwen2VLForConditionalGeneration", id="qwen2_vl"), + pytest.param("trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration", id="qwen2.5_vl"), ], ) def test_does_not_support_tool_calling(self, model_id): tokenizer = AutoTokenizer.from_pretrained(model_id) assert supports_tool_calling(tokenizer) is False - @pytest.mark.parametrize( - "model_id", - [ - # TypeError: template concatenates arguments as string (needs template patch) - pytest.param("trl-internal-testing/tiny-DeepseekV3ForCausalLM", id="deepseekv3"), - pytest.param("trl-internal-testing/tiny-DeepseekV3ForCausalLM-0528", id="deepseekv3-0528"), - ], - ) - @pytest.mark.xfail(reason="DeepseekV3 template expects arguments as JSON string, needs patch", strict=True) - def test_deepseek_tool_calling(self, model_id): - tokenizer = AutoTokenizer.from_pretrained(model_id) - assert supports_tool_calling(tokenizer) is True - class TestIsChatTemplatePrefixPreserving: def test_prefix_preserving_template(self): @@ -409,17 +438,41 @@ def test_prefix_preserving_template_processor(self): @pytest.mark.parametrize( "tokenizer_name", [ + pytest.param("trl-internal-testing/tiny-CohereForCausalLM", id="cohere"), pytest.param("trl-internal-testing/tiny-DeepseekV3ForCausalLM", id="deepseekv3"), + pytest.param("trl-internal-testing/tiny-GemmaForCausalLM", id="gemma"), + pytest.param("trl-internal-testing/tiny-Gemma2ForCausalLM", id="gemma2"), + pytest.param( + "trl-internal-testing/tiny-Glm4MoeForCausalLM", + id="glm4moe", + marks=pytest.mark.skipif( + Version(transformers.__version__) < Version("5.0.0"), + reason="GLM4 tokenizer requires transformers>=5.0.0", + ), + ), pytest.param("trl-internal-testing/tiny-GptOssForCausalLM", id="gptoss"), pytest.param("trl-internal-testing/tiny-LlamaForCausalLM-3", id="llama3"), + pytest.param("trl-internal-testing/tiny-Phi3ForCausalLM-3", id="phi3"), pytest.param("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", id="qwen2.5"), pytest.param("trl-internal-testing/tiny-Qwen3MoeForCausalLM", id="qwen3"), + pytest.param( + "trl-internal-testing/tiny-Qwen3_5MoeForConditionalGeneration-3.6", + id="qwen36", + marks=pytest.mark.skipif( + Version(transformers.__version__) < Version("5.0.0"), + reason="Qwen3.5 tokenizer requires transformers>=5.0.0", + ), + ), ], ) class TestGetTrainingChatTemplate: def test_new_chat_template_is_prefix_preserving(self, tokenizer_name): tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) tokenizer.chat_template = get_training_chat_template(tokenizer) + # Prefix-preservation is only meaningful for templates that actually support tool messages — the check + # itself renders one. Skip the assertion for tool-less templates (e.g. Gemma). + if not supports_tool_calling(tokenizer): + pytest.skip("Template does not support tool calling; prefix-preservation check is not applicable.") assert is_chat_template_prefix_preserving(tokenizer) is True def test_behavior_unchanged_single_user_no_generation_prompt(self, tokenizer_name): @@ -528,6 +581,8 @@ def test_behavior_unchanged_with_tools_with_and_without_system_message(self, tok def test_behavior_unchanged_with_tools_with_system_message(self, tokenizer_name): tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + if not supports_tool_calling(tokenizer): + pytest.skip("Template does not support tool calling; skipping tool_calls test.") tools = [ { "type": "function", @@ -612,6 +667,7 @@ def test_assistant_masks_multi_turn(self, tokenizer_name): pytest.param("trl-internal-testing/tiny-Qwen3MoeForCausalLM", id="qwen3"), pytest.param("trl-internal-testing/tiny-Qwen3VLForConditionalGeneration", id="qwen3_vl"), pytest.param("trl-internal-testing/tiny-Qwen3_5ForConditionalGeneration", id="qwen35"), + pytest.param("trl-internal-testing/tiny-Qwen3_5MoeForConditionalGeneration-3.6", id="qwen36"), pytest.param( "trl-internal-testing/tiny-Gemma4ForConditionalGeneration", id="gemma4", diff --git a/tests/test_data_utils.py b/tests/test_data_utils.py index a2b0d4a117f..68a60b9bb03 100644 --- a/tests/test_data_utils.py +++ b/tests/test_data_utils.py @@ -542,9 +542,11 @@ class TestApplyChatTemplate(TrlTestCase): "trl-internal-testing/tiny-LlamaForCausalLM-3", "trl-internal-testing/tiny-MistralForCausalLM-0.1", "trl-internal-testing/tiny-MistralForCausalLM-0.2", - "trl-internal-testing/tiny-Phi3ForCausalLM", + "trl-internal-testing/tiny-Phi3ForCausalLM-3", + "trl-internal-testing/tiny-Phi3ForCausalLM-3.5", "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", "trl-internal-testing/tiny-Qwen3ForCausalLM", + "trl-internal-testing/tiny-Qwen3ForCausalLM-Instruct-2507", pytest.param( "trl-internal-testing/tiny-Qwen3_5ForConditionalGeneration", marks=pytest.mark.skipif( @@ -552,6 +554,13 @@ class TestApplyChatTemplate(TrlTestCase): reason="Qwen3.5 tokenizer requires transformers>=5.0.0", ), ), + pytest.param( + "trl-internal-testing/tiny-Qwen3_5MoeForConditionalGeneration-3.6", + marks=pytest.mark.skipif( + Version(transformers.__version__) < Version("5.0.0"), + reason="Qwen3.5 tokenizer requires transformers>=5.0.0", + ), + ), ] conversational_examples = [ diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py index 288dbb2ccb1..5000c0f449a 100644 --- a/tests/test_dpo_trainer.py +++ b/tests/test_dpo_trainer.py @@ -1030,22 +1030,8 @@ def test_tag_added_peft(self): ), # "trl-internal-testing/tiny-Idefics2ForConditionalGeneration", high memory peak, skipped for now # "trl-internal-testing/tiny-Idefics3ForConditionalGeneration", high memory peak, skipped for now - pytest.param( - "trl-internal-testing/tiny-LlavaForConditionalGeneration", - marks=pytest.mark.xfail( - Version(transformers.__version__).is_devrelease, - reason="Upstream issue with transformers 5.6.0.dev0, see #5497", - strict=True, - ), - ), - pytest.param( - "trl-internal-testing/tiny-LlavaNextForConditionalGeneration", - marks=pytest.mark.xfail( - Version(transformers.__version__).is_devrelease, - reason="Upstream issue with transformers 5.6.0.dev0, see #5497", - strict=True, - ), - ), + "trl-internal-testing/tiny-LlavaForConditionalGeneration", + "trl-internal-testing/tiny-LlavaNextForConditionalGeneration", "trl-internal-testing/tiny-Qwen2VLForConditionalGeneration", "trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration", # "trl-internal-testing/tiny-SmolVLMForConditionalGeneration", seems not to support bf16 properly @@ -1065,6 +1051,13 @@ def test_tag_added_peft(self): reason="Qwen3.5 models were introduced in transformers-5.2.0", ), ), + pytest.param( + "trl-internal-testing/tiny-Qwen3_5MoeForConditionalGeneration-3.6", + marks=pytest.mark.skipif( + Version(transformers.__version__) < Version("5.2.0"), + reason="Qwen3.5 models were introduced in transformers-5.2.0", + ), + ), ], ) @require_vision @@ -1098,10 +1091,14 @@ def test_train_vlm(self, model_id): # fmt: off if ( model_id == "trl-internal-testing/tiny-Gemma3ForConditionalGeneration" and "model.vision_tower.vision_model.head" in n or - model_id == "trl-internal-testing/tiny-LlavaForConditionalGeneration" and "model.vision_tower.vision_model.post_layernorm" in n or - model_id == "trl-internal-testing/tiny-LlavaForConditionalGeneration" and "vision_tower.vision_model.encoder.layers.1" in n or - model_id == "trl-internal-testing/tiny-LlavaNextForConditionalGeneration" and "model.vision_tower.vision_model.post_layernorm" in n or - model_id == "trl-internal-testing/tiny-LlavaNextForConditionalGeneration" and "vision_tower.vision_model.encoder.layers.1" in n or + model_id == "trl-internal-testing/tiny-LlavaForConditionalGeneration" and "model.vision_tower.vision_model.post_layernorm" in n or # transformers < 5.6.0 + model_id == "trl-internal-testing/tiny-LlavaForConditionalGeneration" and "vision_tower.vision_model.encoder.layers.1" in n or # transformers < 5.6.0 + model_id == "trl-internal-testing/tiny-LlavaForConditionalGeneration" and "model.vision_tower.encoder.layers.1" in n or # transformers >= 5.6.0, see #5497 + model_id == "trl-internal-testing/tiny-LlavaForConditionalGeneration" and "model.vision_tower.post_layernorm" in n or # transformers >= 5.6.0, see #5497 + model_id == "trl-internal-testing/tiny-LlavaNextForConditionalGeneration" and "model.vision_tower.vision_model.post_layernorm" in n or # transformers < 5.6.0 + model_id == "trl-internal-testing/tiny-LlavaNextForConditionalGeneration" and "vision_tower.vision_model.encoder.layers.1" in n or # transformers < 5.6.0 + model_id == "trl-internal-testing/tiny-LlavaNextForConditionalGeneration" and "model.vision_tower.encoder.layers.1" in n or # transformers >= 5.6.0, see #5497 + model_id == "trl-internal-testing/tiny-LlavaNextForConditionalGeneration" and "model.vision_tower.post_layernorm" in n or # transformers >= 5.6.0, see #5497 model_id == "trl-internal-testing/tiny-Qwen3VLForConditionalGeneration" and "model.visual.deepstack_merger_list" in n ): # fmt: on diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py index 199148a13cd..8f3e14660d8 100644 --- a/tests/test_grpo_trainer.py +++ b/tests/test_grpo_trainer.py @@ -177,9 +177,8 @@ def _make_trainer(self): trainer.processing_class = SimpleNamespace( batch_decode=MagicMock(return_value=["decoded"]), ) + trainer._tokenizer = SimpleNamespace(eos_token_id=2, pad_token_id=0) trainer.tools = None - trainer.eos_token_id = 2 - trainer.pad_token_id = 0 trainer._metrics = { "train": { "num_tokens": [], @@ -1940,6 +1939,13 @@ def test_prepare_input_called_with_correct_data(self): reason="Qwen3.5 models were introduced in transformers-5.2.0", ), ), + pytest.param( + "trl-internal-testing/tiny-Qwen3_5MoeForConditionalGeneration-3.6", + marks=pytest.mark.skipif( + Version(transformers.__version__) < Version("5.2.0"), + reason="Qwen3.5 models were introduced in transformers-5.2.0", + ), + ), # "trl-internal-testing/tiny-SmolVLMForConditionalGeneration", seems not to support bf16 properly ], ) diff --git a/tests/test_rloo_trainer.py b/tests/test_rloo_trainer.py index 39f766ed5a9..1bea5a80d7d 100644 --- a/tests/test_rloo_trainer.py +++ b/tests/test_rloo_trainer.py @@ -1331,6 +1331,13 @@ def test_prepare_input_called_with_correct_data(self): reason="Qwen3.5 models were introduced in transformers-5.2.0", ), ), + pytest.param( + "trl-internal-testing/tiny-Qwen3_5MoeForConditionalGeneration-3.6", + marks=pytest.mark.skipif( + Version(transformers.__version__) < Version("5.2.0"), + reason="Qwen3.5 models were introduced in transformers-5.2.0", + ), + ), # "trl-internal-testing/tiny-SmolVLMForConditionalGeneration", seems not to support bf16 properly ], ) diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index c9801a2111c..bafbe746b04 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import gc import json import pathlib @@ -19,6 +20,7 @@ import pytest import torch +import torch.nn.functional as F import transformers from accelerate.utils.memory import release_memory from datasets import load_dataset @@ -29,7 +31,12 @@ from transformers.utils import is_peft_available from trl import SFTConfig, SFTTrainer -from trl.trainer.sft_trainer import DataCollatorForLanguageModeling, dft_loss +from trl.trainer.sft_trainer import ( + DataCollatorForLanguageModeling, + _chunked_cross_entropy_loss, + _patch_chunked_ce_lm_head, + dft_loss, +) from .testing_utils import ( TrlTestCase, @@ -469,6 +476,30 @@ def test_train_dft_loss(self): new_param = trainer.model.get_parameter(n) assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" + def test_train_chunked_nll_loss(self): + # Get the dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") + + # Initialize the trainer + training_args = SFTConfig(output_dir=self.tmp_dir, loss_type="chunked_nll", report_to="none") + trainer = SFTTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset + ) + + # Save the initial parameters to compare them later + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + # Train the model + trainer.train() + + # Check that the training loss is not None + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" + def test_train_moe_model_with_aux_loss(self): # Get the dataset dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") @@ -1639,22 +1670,8 @@ def test_tag_added_peft(self): ), # "trl-internal-testing/tiny-Idefics2ForConditionalGeneration", high memory peak, skipped for now # "trl-internal-testing/tiny-Idefics3ForConditionalGeneration", high memory peak, skipped for now - pytest.param( - "trl-internal-testing/tiny-LlavaForConditionalGeneration", - marks=pytest.mark.xfail( - Version(transformers.__version__).is_devrelease, - reason="Upstream issue with transformers 5.6.0.dev0, see #5497", - strict=True, - ), - ), - pytest.param( - "trl-internal-testing/tiny-LlavaNextForConditionalGeneration", - marks=pytest.mark.xfail( - Version(transformers.__version__).is_devrelease, - reason="Upstream issue with transformers 5.6.0.dev0, see #5497", - strict=True, - ), - ), + "trl-internal-testing/tiny-LlavaForConditionalGeneration", + "trl-internal-testing/tiny-LlavaNextForConditionalGeneration", "trl-internal-testing/tiny-Qwen2VLForConditionalGeneration", "trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration", # "trl-internal-testing/tiny-SmolVLMForConditionalGeneration", seems not to support bf16 properly @@ -1674,6 +1691,13 @@ def test_tag_added_peft(self): reason="Qwen3.5 models were introduced in transformers-5.2.0", ), ), + pytest.param( + "trl-internal-testing/tiny-Qwen3_5MoeForConditionalGeneration-3.6", + marks=pytest.mark.skipif( + Version(transformers.__version__) < Version("5.2.0"), + reason="Qwen3.5 models were introduced in transformers-5.2.0", + ), + ), ], ) @require_vision @@ -1706,10 +1730,14 @@ def test_train_vlm(self, model_id): # fmt: off if ( model_id == "trl-internal-testing/tiny-Gemma3ForConditionalGeneration" and "model.vision_tower.vision_model.head" in n or - model_id == "trl-internal-testing/tiny-LlavaForConditionalGeneration" and "model.vision_tower.vision_model.post_layernorm" in n or - model_id == "trl-internal-testing/tiny-LlavaForConditionalGeneration" and "vision_tower.vision_model.encoder.layers.1" in n or - model_id == "trl-internal-testing/tiny-LlavaNextForConditionalGeneration" and "model.vision_tower.vision_model.post_layernorm" in n or - model_id == "trl-internal-testing/tiny-LlavaNextForConditionalGeneration" and "vision_tower.vision_model.encoder.layers.1" in n or + model_id == "trl-internal-testing/tiny-LlavaForConditionalGeneration" and "model.vision_tower.vision_model.post_layernorm" in n or # transformers < 5.6.0 + model_id == "trl-internal-testing/tiny-LlavaForConditionalGeneration" and "vision_tower.vision_model.encoder.layers.1" in n or # transformers < 5.6.0 + model_id == "trl-internal-testing/tiny-LlavaForConditionalGeneration" and "model.vision_tower.encoder.layers.1" in n or # transformers >= 5.6.0, see #5497 + model_id == "trl-internal-testing/tiny-LlavaForConditionalGeneration" and "model.vision_tower.post_layernorm" in n or # transformers >= 5.6.0, see #5497 + model_id == "trl-internal-testing/tiny-LlavaNextForConditionalGeneration" and "model.vision_tower.vision_model.post_layernorm" in n or # transformers < 5.6.0 + model_id == "trl-internal-testing/tiny-LlavaNextForConditionalGeneration" and "vision_tower.vision_model.encoder.layers.1" in n or # transformers < 5.6.0 + model_id == "trl-internal-testing/tiny-LlavaNextForConditionalGeneration" and "model.vision_tower.encoder.layers.1" in n or # transformers >= 5.6.0, see #5497 + model_id == "trl-internal-testing/tiny-LlavaNextForConditionalGeneration" and "model.vision_tower.post_layernorm" in n or # transformers >= 5.6.0, see #5497 model_id == "trl-internal-testing/tiny-Qwen3VLForConditionalGeneration" and "model.visual.deepstack_merger_list" in n ): # fmt: on @@ -2308,3 +2336,295 @@ def test_train_offloading(self, model_name, packing): assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" release_memory(trainer.model, trainer) + + +_CHUNKED_CE_MODEL_IDS = [ + "trl-internal-testing/tiny-CohereForCausalLM", + pytest.param( + "trl-internal-testing/tiny-DeepseekV3ForCausalLM", + marks=pytest.mark.skipif( + Version(transformers.__version__) < Version("5.0.0"), + reason="DeepseekV3 SDPA attention is broken in transformers < 5.0.0", + ), + ), + pytest.param( + "trl-internal-testing/tiny-DeepseekV3ForCausalLM-0528", + marks=pytest.mark.skipif( + Version(transformers.__version__) < Version("5.0.0"), + reason="DeepseekV3 SDPA attention is broken in transformers < 5.0.0", + ), + ), + "trl-internal-testing/tiny-Gemma2ForCausalLM", + "trl-internal-testing/tiny-GemmaForCausalLM", + "trl-internal-testing/tiny-Glm4MoeForCausalLM", + "trl-internal-testing/tiny-GptOssForCausalLM", + "trl-internal-testing/tiny-LlamaForCausalLM-3.1", + "trl-internal-testing/tiny-LlamaForCausalLM-3.2", + "trl-internal-testing/tiny-LlamaForCausalLM-3", + "trl-internal-testing/tiny-MistralForCausalLM-0.1", + "trl-internal-testing/tiny-MistralForCausalLM-0.2", + "trl-internal-testing/tiny-Phi3ForCausalLM", + "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + "trl-internal-testing/tiny-Qwen3ForCausalLM", + "trl-internal-testing/tiny-Qwen3MoeForCausalLM", +] + + +class TestChunkedCrossEntropyLoss: + B, S, H, V = 2, 8, 4, 16 + CHUNK_SIZE = 3 # deliberately small to force multiple chunks and a partial final chunk + + def _inputs(self, seed=0, ignore_positions=None, requires_grad=False): + torch.manual_seed(seed) + hidden = torch.randn(self.B, self.S, self.H, dtype=torch.float32, requires_grad=requires_grad) + weight = torch.randn(self.V, self.H, dtype=torch.float32, requires_grad=requires_grad) + labels = torch.randint(0, self.V, (self.B, self.S)) + if ignore_positions is not None: + labels[:, ignore_positions] = -100 + return hidden, weight, labels + + @staticmethod + def _reference(hidden, weight, labels, num_items_in_batch=None): + shift_h = hidden[..., :-1, :].reshape(-1, hidden.size(-1)) + shift_l = labels[..., 1:].reshape(-1) + logits = shift_h.float() @ weight.float().t() + if num_items_in_batch is None: + loss = F.cross_entropy(logits, shift_l, ignore_index=-100, reduction="mean") + else: + loss = F.cross_entropy(logits, shift_l, ignore_index=-100, reduction="sum") + loss = loss / num_items_in_batch + valid = shift_l != -100 + if valid.any(): + log_p = F.log_softmax(logits, dim=-1) + preds = logits.argmax(dim=-1) + accuracy = (preds[valid] == shift_l[valid]).float().mean() + entropy = -(log_p.exp() * log_p).sum(dim=-1)[valid].mean() + else: + accuracy = torch.zeros((), dtype=torch.float32) + entropy = torch.zeros((), dtype=torch.float32) + return loss, accuracy, entropy + + def test_forward_matches_cross_entropy(self): + """With no ignored tokens, chunked loss equals standard mean cross-entropy.""" + hidden, weight, labels = self._inputs() + n_valid = (labels[..., 1:] != -100).sum() + loss_c, correct_c, ent_sum_c = _chunked_cross_entropy_loss(hidden, weight, self.CHUNK_SIZE, labels) + loss_r, acc_r, ent_r = self._reference(hidden, weight, labels) + torch.testing.assert_close(loss_c, loss_r, atol=1e-5, rtol=1e-5) + torch.testing.assert_close(correct_c / n_valid, acc_r, atol=1e-5, rtol=1e-5) + torch.testing.assert_close(ent_sum_c / n_valid, ent_r, atol=1e-5, rtol=1e-5) + + def test_forward_ignore_index(self): + """Ignored labels are excluded from loss, accuracy and entropy (matches F.cross_entropy).""" + hidden, weight, labels = self._inputs(ignore_positions=slice(0, 3)) + n_valid = (labels[..., 1:] != -100).sum() + loss_c, correct_c, ent_sum_c = _chunked_cross_entropy_loss(hidden, weight, self.CHUNK_SIZE, labels) + loss_r, acc_r, ent_r = self._reference(hidden, weight, labels) + torch.testing.assert_close(loss_c, loss_r, atol=1e-5, rtol=1e-5) + torch.testing.assert_close(correct_c / n_valid, acc_r, atol=1e-5, rtol=1e-5) + torch.testing.assert_close(ent_sum_c / n_valid, ent_r, atol=1e-5, rtol=1e-5) + + def test_num_items_in_batch_reduction(self): + """When num_items_in_batch is provided, loss is sum / num_items_in_batch.""" + hidden, weight, labels = self._inputs(ignore_positions=slice(0, 3)) + num_items = 5 # arbitrary global denominator, != local valid count + loss_c, *_ = _chunked_cross_entropy_loss(hidden, weight, self.CHUNK_SIZE, labels, num_items_in_batch=num_items) + loss_r, *_ = self._reference(hidden, weight, labels, num_items_in_batch=num_items) + torch.testing.assert_close(loss_c, loss_r, atol=1e-5, rtol=1e-5) + + def test_num_items_in_batch_tensor(self): + """A tensor `num_items_in_batch` is accepted and produces the same result as the int form.""" + hidden, weight, labels = self._inputs() + num_items_tensor = torch.tensor(7, dtype=torch.float32) + loss_t, *_ = _chunked_cross_entropy_loss( + hidden, weight, self.CHUNK_SIZE, labels, num_items_in_batch=num_items_tensor + ) + loss_i, *_ = _chunked_cross_entropy_loss(hidden, weight, self.CHUNK_SIZE, labels, num_items_in_batch=7) + torch.testing.assert_close(loss_t, loss_i, atol=1e-6, rtol=1e-6) + + def test_backward_matches_reference(self): + """Gradients on hidden_states and lm_head weight match the standard CE path.""" + hidden_c, weight_c, labels = self._inputs(ignore_positions=slice(0, 3), requires_grad=True) + hidden_r = hidden_c.detach().clone().requires_grad_(True) + weight_r = weight_c.detach().clone().requires_grad_(True) + + loss_c, *_ = _chunked_cross_entropy_loss(hidden_c, weight_c, self.CHUNK_SIZE, labels) + loss_c.backward() + + loss_r, *_ = self._reference(hidden_r, weight_r, labels) + loss_r.backward() + + torch.testing.assert_close(hidden_c.grad, hidden_r.grad, atol=1e-5, rtol=1e-5) + torch.testing.assert_close(weight_c.grad, weight_r.grad, atol=1e-5, rtol=1e-5) + + def test_all_ignored_returns_zero(self): + """If every label is ignored, loss/correct/entropy_sum are all zero and backward still works. + + Every trainable parameter of the chunked path (hidden_states, lm_head_weight, and lm_head_bias when present) + must receive a gradient — otherwise DDP / FSDP synchronization hangs or errors at the all-reduce step. + """ + hidden, weight, labels = self._inputs(requires_grad=True) + bias = torch.zeros(self.V, dtype=torch.float32, requires_grad=True) + labels[:] = -100 + loss, correct, ent_sum = _chunked_cross_entropy_loss( + hidden, weight, self.CHUNK_SIZE, labels, lm_head_bias=bias + ) + assert loss.item() == 0.0 + assert correct.item() == 0.0 + assert ent_sum.item() == 0.0 + assert not torch.isnan(loss) + # Backward must succeed even when n_valid == 0 (can happen with completion-only loss + # + truncation where a whole micro-batch is masked). + loss.backward() + assert hidden.grad is not None and hidden.grad.abs().sum().item() == 0.0 + assert weight.grad is not None and weight.grad.abs().sum().item() == 0.0 + assert bias.grad is not None and bias.grad.abs().sum().item() == 0.0 + + def test_shift_labels_matches_labels(self): + """`shift_labels` path (CP/SP) must match the default `labels` path after external shifting.""" + hidden, weight, labels = self._inputs(ignore_positions=slice(0, 3)) + loss_l, correct_l, ent_l = _chunked_cross_entropy_loss(hidden, weight, self.CHUNK_SIZE, labels) + # Mimic what transformers does under CP/SP: pad labels with -100, then shift. + shift_labels = F.pad(labels, (0, 1), value=-100)[..., 1:].contiguous() + loss_s, correct_s, ent_s = _chunked_cross_entropy_loss( + hidden, weight, self.CHUNK_SIZE, shift_labels=shift_labels + ) + torch.testing.assert_close(loss_s, loss_l, atol=1e-6, rtol=1e-6) + torch.testing.assert_close(correct_s, correct_l, atol=1e-6, rtol=1e-6) + torch.testing.assert_close(ent_s, ent_l, atol=1e-6, rtol=1e-6) + + def test_requires_labels_or_shift_labels(self): + """Must provide at least one of `labels` or `shift_labels`.""" + hidden, weight, _ = self._inputs() + with pytest.raises(ValueError, match="At least one"): + _chunked_cross_entropy_loss(hidden, weight, self.CHUNK_SIZE) + + def test_shift_labels_wins_when_both_provided(self): + """When both `labels` and `shift_labels` are provided (Ulysses / CP / SP path), `shift_labels` wins.""" + hidden, weight, labels = self._inputs(ignore_positions=slice(0, 3)) + shift_labels = F.pad(labels, (0, 1), value=-100)[..., 1:].contiguous() + # Chunked result with both passed in must match the shift_labels-only path. + loss_both, *_ = _chunked_cross_entropy_loss( + hidden, weight, self.CHUNK_SIZE, labels=labels, shift_labels=shift_labels + ) + loss_shift, *_ = _chunked_cross_entropy_loss(hidden, weight, self.CHUNK_SIZE, shift_labels=shift_labels) + torch.testing.assert_close(loss_both, loss_shift, atol=1e-6, rtol=1e-6) + + def test_lm_head_bias(self): + """When `lm_head_bias` is provided, chunked loss matches `F.linear(h, w, b)` followed by CE.""" + hidden, weight, labels = self._inputs(ignore_positions=slice(0, 3)) + torch.manual_seed(1) + bias = torch.randn(self.V, dtype=torch.float32) + + loss_c, *_ = _chunked_cross_entropy_loss(hidden, weight, self.CHUNK_SIZE, labels, lm_head_bias=bias) + + # Reference: full F.linear with bias, then CE over non-ignored shifted positions. + logits_ref = F.linear(hidden[..., :-1, :], weight, bias).reshape(-1, self.V) + labels_ref = labels[..., 1:].reshape(-1) + valid = labels_ref != -100 + loss_r = F.cross_entropy(logits_ref[valid], labels_ref[valid], reduction="mean") + torch.testing.assert_close(loss_c, loss_r, atol=1e-5, rtol=1e-5) + + +@require_torch_accelerator +class TestPatchChunkedCELMHead: + """Patched `forward` must be numerically equivalent to the standard HF causal-LM loss path.""" + + CHUNK_SIZE = 5 # small, to exercise the chunk loop + + def _setup(self, model_id): + ref_model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32).to(torch_device) + chunked_model = copy.deepcopy(ref_model) + _patch_chunked_ce_lm_head(chunked_model, chunk_size=self.CHUNK_SIZE) + + B, S = 2, 16 + torch.manual_seed(42) + input_ids = torch.randint(0, ref_model.config.vocab_size, (B, S), device=torch_device) + labels = input_ids.clone() + labels[:, :4] = -100 # prompt-like mask + num_items = int((labels[..., 1:] != -100).sum()) + return ref_model, chunked_model, input_ids, labels, num_items + + @pytest.mark.parametrize("model_id", _CHUNKED_CE_MODEL_IDS) + def test_forward_matches_reference(self, model_id): + ref_model, chunked_model, input_ids, labels, num_items = self._setup(model_id) + + with torch.no_grad(): + ref_out = ref_model(input_ids=input_ids, labels=labels, num_items_in_batch=num_items) + out = chunked_model(input_ids=input_ids, labels=labels, num_items_in_batch=num_items) + + torch.testing.assert_close(out.loss, ref_out.loss, atol=1e-5, rtol=1e-5) + assert out.logits is None + assert out.num_correct_tokens is not None and out.num_correct_tokens.item() >= 0 + assert out.entropy_sum is not None and out.entropy_sum.item() >= 0.0 + + @pytest.mark.parametrize( + "model_id", + [ + "trl-internal-testing/tiny-Qwen3MoeForCausalLM", + "trl-internal-testing/tiny-GptOssForCausalLM", + ], + ) + def test_forward_matches_reference_with_aux_loss(self, model_id): + """MoE models with `output_router_logits=True` add `router_aux_loss_coef * load_balancing_loss` + to the main loss. The chunked path must match the reference loss and expose `aux_loss`.""" + ref_model = AutoModelForCausalLM.from_pretrained( + model_id, torch_dtype=torch.float32, output_router_logits=True + ).to(torch_device) + chunked_model = copy.deepcopy(ref_model) + _patch_chunked_ce_lm_head(chunked_model, chunk_size=self.CHUNK_SIZE) + + B, S = 2, 16 + torch.manual_seed(42) + input_ids = torch.randint(0, ref_model.config.vocab_size, (B, S), device=torch_device) + labels = input_ids.clone() + labels[:, :4] = -100 + num_items = int((labels[..., 1:] != -100).sum()) + + with torch.no_grad(): + ref_out = ref_model(input_ids=input_ids, labels=labels, num_items_in_batch=num_items) + out = chunked_model(input_ids=input_ids, labels=labels, num_items_in_batch=num_items) + + torch.testing.assert_close(out.loss, ref_out.loss, atol=1e-5, rtol=1e-5) + torch.testing.assert_close(out.aux_loss, ref_out.aux_loss, atol=1e-5, rtol=1e-5) + + @pytest.mark.parametrize("model_id", _CHUNKED_CE_MODEL_IDS) + def test_backward_matches_reference(self, model_id): + ref_model, chunked_model, input_ids, labels, num_items = self._setup(model_id) + + ref_out = ref_model(input_ids=input_ids, labels=labels, num_items_in_batch=num_items) + ref_out.loss.backward() + + out = chunked_model(input_ids=input_ids, labels=labels, num_items_in_batch=num_items) + out.loss.backward() + + # lm_head gradient + torch.testing.assert_close( + chunked_model.lm_head.weight.grad, ref_model.lm_head.weight.grad, atol=1e-5, rtol=1e-5 + ) + # Base decoder gradients + for name, ref_param in ref_model.model.named_parameters(): + if ref_param.grad is None: + continue + chunked_param = chunked_model.model.get_parameter(name) + torch.testing.assert_close( + chunked_param.grad, ref_param.grad, atol=1e-5, rtol=1e-5, msg=f"gradient mismatch on model.{name}" + ) + + def test_forward_without_labels_uses_original_path(self): + """With labels=None the patched forward returns real logits (for generation / eval).""" + _, chunked_model, input_ids, _, _ = self._setup("trl-internal-testing/tiny-LlamaForCausalLM-3.2") + with torch.no_grad(): + out = chunked_model(input_ids=input_ids) + assert out.logits is not None + assert out.logits.shape[-1] == chunked_model.config.vocab_size + + def test_forward_without_labels_matches_reference(self): + """labels=None logits must match the unpatched model, including per-model post-processing + (`final_logit_softcapping`, `logit_scale`, ...). This is what makes `.generate()` safe to call on a patched + model.""" + ref_model, chunked_model, input_ids, *_ = self._setup("trl-internal-testing/tiny-CohereForCausalLM") + with torch.no_grad(): + ref_out = ref_model(input_ids=input_ids) + out = chunked_model(input_ids=input_ids) + torch.testing.assert_close(out.logits, ref_out.logits, atol=1e-5, rtol=1e-5) diff --git a/tests/test_utils.py b/tests/test_utils.py index 63540f33da1..ec740f99c11 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -23,7 +23,7 @@ import torch.nn.functional as F import transformers from packaging.version import Version -from transformers import AutoModelForCausalLM, AutoModelForImageTextToText +from transformers import AutoModelForCausalLM from transformers.testing_utils import torch_device from transformers.utils import is_peft_available @@ -33,7 +33,6 @@ _ChunkedLogProbFunction, entropy_from_logits, flush_left, - forward_masked_logits, generate_model_card, get_peft_config, hash_module, @@ -1009,98 +1008,6 @@ def test_no_op_if_not_list(self): assert torch.equal(result["pixel_values"], original) -@require_torch_accelerator -class TestForwardMaskedLogits: - @pytest.mark.parametrize( - "model_id", - [ - "trl-internal-testing/tiny-CohereForCausalLM", - "trl-internal-testing/tiny-Cohere2ForCausalLM", - "trl-internal-testing/tiny-DeepseekV3ForCausalLM", - "trl-internal-testing/tiny-DeepseekV3ForCausalLM-0528", - "trl-internal-testing/tiny-Gemma2ForCausalLM", - "trl-internal-testing/tiny-GemmaForCausalLM", - "trl-internal-testing/tiny-Glm4MoeForCausalLM", - "trl-internal-testing/tiny-GptOssForCausalLM", - "trl-internal-testing/tiny-LlamaForCausalLM-3.1", - "trl-internal-testing/tiny-LlamaForCausalLM-3.2", - "trl-internal-testing/tiny-LlamaForCausalLM-3", - "trl-internal-testing/tiny-MistralForCausalLM-0.1", - "trl-internal-testing/tiny-MistralForCausalLM-0.2", - "trl-internal-testing/tiny-Phi3ForCausalLM", - "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", - "trl-internal-testing/tiny-Qwen3ForCausalLM", - ], - ) - def test_llm(self, model_id): - model = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto", device_map=torch_device) - input_ids = torch.randint(0, model.config.vocab_size, (2, 8), device=torch_device) - logits_mask = torch.tensor( - [[1, 1, 0, 0, 1, 0, 1, 0], [0, 1, 1, 0, 0, 1, 0, 1]], - device=torch_device, - ) - - full_outputs = model(input_ids=input_ids) - masked_outputs = forward_masked_logits(model, logits_mask, input_ids=input_ids) - - torch.testing.assert_close( - masked_outputs.flat_logits, - full_outputs.logits[logits_mask.bool()], - ) - - @pytest.mark.parametrize( - "model_id", - [ - "trl-internal-testing/tiny-Gemma3ForConditionalGeneration", - pytest.param( - "trl-internal-testing/tiny-Gemma4ForConditionalGeneration", - marks=pytest.mark.skipif( - Version(transformers.__version__) < Version("5.5.0"), - reason="Gemma4 models were introduced in transformers-5.5.0", - ), - ), - "trl-internal-testing/tiny-Idefics2ForConditionalGeneration", - "trl-internal-testing/tiny-Idefics3ForConditionalGeneration", - "trl-internal-testing/tiny-LlavaForConditionalGeneration", - "trl-internal-testing/tiny-LlavaNextForConditionalGeneration", - "trl-internal-testing/tiny-Qwen2VLForConditionalGeneration", - "trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration", - # "trl-internal-testing/tiny-SmolVLMForConditionalGeneration", seems not to support bf16 properly - pytest.param( - "trl-internal-testing/tiny-Qwen3VLForConditionalGeneration", - marks=[ - pytest.mark.skipif( - Version(transformers.__version__) < Version("4.57.0"), - reason="Qwen3-VL series were introduced in transformers-4.57.0", - ), - ], - ), - pytest.param( - "trl-internal-testing/tiny-Qwen3_5ForConditionalGeneration", - marks=pytest.mark.skipif( - Version(transformers.__version__) < Version("5.2.0"), - reason="Qwen3.5 models were introduced in transformers-5.2.0", - ), - ), - ], - ) - def test_vlm(self, model_id): - model = AutoModelForImageTextToText.from_pretrained(model_id, dtype="auto", device_map=torch_device) - input_ids = torch.randint(0, model.config.text_config.vocab_size, (2, 8), device=torch_device) - logits_mask = torch.tensor( - [[1, 1, 0, 0, 1, 0, 1, 0], [0, 1, 1, 0, 0, 1, 0, 1]], - device=torch_device, - ) - - full_outputs = model(input_ids=input_ids) - masked_outputs = forward_masked_logits(model, logits_mask, input_ids=input_ids) - - torch.testing.assert_close( - masked_outputs.flat_logits, - full_outputs.logits[logits_mask.bool()], - ) - - class TestChunkedLogProbFunction: N, H, V = 64, 32, 128 CHUNK_SIZE = 32 @@ -1230,9 +1137,11 @@ def forward(self, input_ids, attention_mask=None, labels=None, **kwargs): "trl-internal-testing/tiny-LlamaForCausalLM-3", "trl-internal-testing/tiny-MistralForCausalLM-0.1", "trl-internal-testing/tiny-MistralForCausalLM-0.2", - "trl-internal-testing/tiny-Phi3ForCausalLM", + "trl-internal-testing/tiny-Phi3ForCausalLM-3", + "trl-internal-testing/tiny-Phi3ForCausalLM-3.5", "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", "trl-internal-testing/tiny-Qwen3ForCausalLM", + "trl-internal-testing/tiny-Qwen3ForCausalLM-Instruct-2507", ] diff --git a/trl/_compat.py b/trl/_compat.py index 4bfed16b7fb..31086a2e912 100644 --- a/trl/_compat.py +++ b/trl/_compat.py @@ -83,90 +83,6 @@ def _patch_vllm_logging() -> None: os.environ["VLLM_LOGGING_LEVEL"] = os.getenv("VLLM_LOGGING_LEVEL", "ERROR") -def _patch_vllm_disabled_tqdm() -> None: - """ - Fix DisabledTqdm class in vLLM. - - - Bug introduced in https://github.com/vllm-project/vllm/pull/52 - - Fixed in https://github.com/vllm-project/vllm/pull/28471 (released in v0.11.1) - - Since TRL currently supports vLLM v0.11.0-0.18.0, we patch it here - - This can be removed when TRL requires vLLM>=0.11.1 - """ - if _is_package_version_below("vllm", "0.11.1"): - try: - import vllm.model_executor.model_loader.weight_utils - from tqdm import tqdm - - class DisabledTqdm(tqdm): - def __init__(self, *args, **kwargs): - kwargs["disable"] = True - super().__init__(*args, **kwargs) - - vllm.model_executor.model_loader.weight_utils.DisabledTqdm = DisabledTqdm - except (ImportError, AttributeError) as e: - warnings.warn(f"Failed to patch vLLM DisabledTqdm: {e}", stacklevel=2) - - -def _patch_vllm_cached_tokenizer() -> None: - """ - Fix get_cached_tokenizer for transformers v5 compatibility. - - - Issue: vLLM's get_cached_tokenizer accesses all_special_tokens_extended - - Removed in transformers: https://github.com/huggingface/transformers/pull/40936 (transformers>=5.0.0) - - Fixed in https://github.com/vllm-project/vllm/pull/29686 (released in v0.12.0) - - This can be removed when TRL requires vLLM>=0.12.0 - """ - if _is_package_version_at_least("transformers", "5.0.0") and _is_package_version_below("vllm", "0.12.0"): - try: - import contextlib - import copy - - import vllm.transformers_utils.tokenizer - - def get_cached_tokenizer(tokenizer): - cached_tokenizer = copy.copy(tokenizer) - tokenizer_all_special_ids = tokenizer.all_special_ids - tokenizer_all_special_tokens = tokenizer.all_special_tokens - tokenizer_vocab = tokenizer.get_vocab() - tokenizer_len = len(tokenizer) - - max_token_id = max(tokenizer_vocab.values()) - if hasattr(tokenizer, "vocab_size"): - with contextlib.suppress(NotImplementedError): - max_token_id = max(max_token_id, tokenizer.vocab_size) - - class CachedTokenizer(tokenizer.__class__): # type: ignore - @property - def all_special_ids(self) -> list[int]: - return tokenizer_all_special_ids - - @property - def all_special_tokens(self) -> list[str]: - return tokenizer_all_special_tokens - - @property - def max_token_id(self) -> int: - return max_token_id - - def get_vocab(self) -> dict[str, int]: - return tokenizer_vocab - - def __len__(self) -> int: - return tokenizer_len - - def __reduce__(self): - return get_cached_tokenizer, (tokenizer,) - - CachedTokenizer.__name__ = f"Cached{tokenizer.__class__.__name__}" - - cached_tokenizer.__class__ = CachedTokenizer - return cached_tokenizer - - vllm.transformers_utils.tokenizer.get_cached_tokenizer = get_cached_tokenizer - except (ImportError, AttributeError) as e: - warnings.warn(f"Failed to patch vLLM cached_tokenizer: {e}", stacklevel=2) - - def _patch_transformers_hybrid_cache() -> None: """ Fix HybridCache import for transformers v5 compatibility. @@ -242,8 +158,6 @@ def _patch_transformers_parallelism_config() -> None: # Apply vLLM patches _patch_vllm_logging() -_patch_vllm_disabled_tqdm() -_patch_vllm_cached_tokenizer() # Apply transformers patches _patch_transformers_hybrid_cache() diff --git a/trl/chat_template_utils.py b/trl/chat_template_utils.py index 35abd3aa696..d6079b838b7 100644 --- a/trl/chat_template_utils.py +++ b/trl/chat_template_utils.py @@ -16,7 +16,7 @@ from typing import TypeVar from jinja2 import TemplateError -from transformers import AddedToken, AutoTokenizer, PreTrainedModel, PreTrainedTokenizer, ProcessorMixin +from transformers import AddedToken, AutoTokenizer, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin from .data_utils import prepare_multimodal_messages @@ -26,10 +26,10 @@ def clone_chat_template( model: PreTrainedModel, - tokenizer: PreTrainedTokenizer, + tokenizer: PreTrainedTokenizerBase, source_tokenizer_path: str, resize_to_multiple_of: int | None = 64, -) -> tuple[PreTrainedModel, PreTrainedTokenizer, list[int]]: +) -> tuple[PreTrainedModel, PreTrainedTokenizerBase, list[int]]: """ Clones a chat template from a source tokenizer to the target tokenizer and updates the model accordingly. @@ -44,7 +44,7 @@ def clone_chat_template( Args: model ([`~transformers.PreTrainedModel`]): Model to update. - tokenizer ([`~transformers.PreTrainedTokenizer`]): + tokenizer ([`~transformers.PreTrainedTokenizerBase`]): Tokenizer to update. source_tokenizer_path (`str`): Path or identifier of the pretrained tokenizer to clone from. @@ -55,7 +55,7 @@ def clone_chat_template( Returns: model ([`~transformers.PreTrainedModel`]): Updated model with resized token embeddings and EOS token configured. - tokenizer ([`~transformers.PreTrainedTokenizer`]): + tokenizer ([`~transformers.PreTrainedTokenizerBase`]): Updated tokenizer with the chat template and special tokens applied. added_tokens (`list[int]`): List of tokens that were added to the tokenizer from the source tokenizer. @@ -306,8 +306,12 @@ def clone_chat_template( } +cohere_chat_template = (_CHAT_TEMPLATES_DIR / "cohere.jinja").read_text() + deepseekv3_chat_template = (_CHAT_TEMPLATES_DIR / "deepseekv3.jinja").read_text() +gemma_chat_template = (_CHAT_TEMPLATES_DIR / "gemma.jinja").read_text() + glm4moe_chat_template = (_CHAT_TEMPLATES_DIR / "glm4moe.jinja").read_text() gptoss_chat_template = (_CHAT_TEMPLATES_DIR / "gptoss.jinja").read_text() @@ -318,6 +322,8 @@ def clone_chat_template( llama3_2_chat_template = (_CHAT_TEMPLATES_DIR / "llama3_2.jinja").read_text() +phi3_chat_template = (_CHAT_TEMPLATES_DIR / "phi3.jinja").read_text() + qwen2_5_chat_template = (_CHAT_TEMPLATES_DIR / "qwen2_5.jinja").read_text() qwen3_chat_template = (_CHAT_TEMPLATES_DIR / "qwen3.jinja").read_text() @@ -328,8 +334,10 @@ def clone_chat_template( qwen3_5_chat_template_4b_and_above = (_CHAT_TEMPLATES_DIR / "qwen3_5_4b_and_above.jinja").read_text() +qwen3_6_chat_template = (_CHAT_TEMPLATES_DIR / "qwen3_6.jinja").read_text() + -ProcessingClassT = TypeVar("ProcessingClassT", PreTrainedTokenizer, ProcessorMixin) +ProcessingClassT = TypeVar("ProcessingClassT", PreTrainedTokenizerBase, ProcessorMixin) def add_response_schema(processing_class: ProcessingClassT) -> ProcessingClassT: @@ -344,11 +352,11 @@ def add_response_schema(processing_class: ProcessingClassT) -> ProcessingClassT: and reads `self.response_schema` from the tokenizer instance. Args: - processing_class (`PreTrainedTokenizer` or `ProcessorMixin`): + processing_class (`PreTrainedTokenizerBase` or `ProcessorMixin`): Tokenizer or VLM processor to which the response schema will be added. Returns: - `PreTrainedTokenizer` or `ProcessorMixin`: + `PreTrainedTokenizerBase` or `ProcessorMixin`: The same object that was passed in, with the response schema set on the underlying tokenizer. Examples: @@ -380,7 +388,11 @@ def add_response_schema(processing_class: ProcessingClassT) -> ProcessingClassT: tokenizer.response_schema = llama3_schema elif chat_template in [qwen3_chat_template, qwen3_vl_chat_template]: tokenizer.response_schema = qwen3_schema - elif chat_template in [qwen3_5_chat_template_2b_and_below, qwen3_5_chat_template_4b_and_above]: + elif chat_template in [ + qwen3_5_chat_template_2b_and_below, + qwen3_5_chat_template_4b_and_above, + qwen3_6_chat_template, + ]: tokenizer.response_schema = qwen3_5_schema else: raise ValueError( @@ -406,7 +418,7 @@ def supports_tool_calling(processing_class) -> bool: [`~trl.data_utils.prepare_multimodal_messages`] before rendering. Args: - processing_class (`PreTrainedTokenizer` or `ProcessorMixin`): + processing_class (`PreTrainedTokenizerBase` or `ProcessorMixin`): Tokenizer or processor instance to check. Returns: @@ -444,13 +456,21 @@ def supports_tool_calling(processing_class) -> bool: # UndefinedError (subclass): template indexes into content as a list for all roles, including tool # (Idefics2, Idefics3, LlavaNext, SmolVLM) return False + except TypeError: + # Best-effort fallback for templates that reject dict args (e.g. DeepSeek-V3). This is a chat template + # bug (see transformers#45419), and the training chat template fixes it to avoid blocking users. + tool_calls[0]["function"]["arguments"] = f'{{"{_arg_key_sentinel}": "{_arg_val_sentinel}"}}' + try: + rendered = processing_class.apply_chat_template(messages, tokenize=False) + except TemplateError: + return False # All four sentinels must survive: the tool name and arguments (assistant tool_calls) AND the tool message # content. Templates that silently drop either side (basic Llama 3 drops tool_calls; Cohere2/Phi3 drop tool # messages) will fail this check. return all(s in rendered for s in (_name_sentinel, _arg_key_sentinel, _arg_val_sentinel, _content_sentinel)) -def is_chat_template_prefix_preserving(processing_class: PreTrainedTokenizer | ProcessorMixin) -> bool: +def is_chat_template_prefix_preserving(processing_class: PreTrainedTokenizerBase | ProcessorMixin) -> bool: """ Check whether the chat template preserves prefixes when applied. @@ -459,7 +479,7 @@ def is_chat_template_prefix_preserving(processing_class: PreTrainedTokenizer | P tokenizations with and without tool messages appended. Args: - processing_class (`PreTrainedTokenizer` or `ProcessorMixin`): + processing_class (`PreTrainedTokenizerBase` or `ProcessorMixin`): Tokenizer or processor instance to check. Returns: @@ -479,7 +499,8 @@ def is_chat_template_prefix_preserving(processing_class: PreTrainedTokenizer | P ] # VLM processors expect structured list-of-blocks content, and image-token expansion only kicks in when an image # is actually present, so include a dummy image to exercise the real code path. - if isinstance(processing_class, ProcessorMixin): + is_vlm = isinstance(processing_class, ProcessorMixin) + if is_vlm: from PIL import Image dummy_image = Image.new("RGB", (8, 8)) @@ -487,41 +508,61 @@ def is_chat_template_prefix_preserving(processing_class: PreTrainedTokenizer | P messages2 = prepare_multimodal_messages(messages2, images=[dummy_image]) try: - text1 = processing_class.apply_chat_template(messages1, tokenize=False) - text2 = processing_class.apply_chat_template(messages2, tokenize=False, add_generation_prompt=True) + ids1 = processing_class.apply_chat_template(messages1, tokenize=True, return_dict=False) + ids2 = processing_class.apply_chat_template( + messages2, tokenize=True, return_dict=False, add_generation_prompt=True + ) except TypeError: # Best-effort fallback for templates that reject dict args (e.g. DeepSeek-V3). This is a chat template # bug (see transformers#45419), and the training chat template fixes it to avoid blocking users. dummy_tool_calls = [{"type": "function", "function": {"name": "dummy", "arguments": "{}"}}] messages1[1]["tool_calls"] = dummy_tool_calls messages2[1]["tool_calls"] = dummy_tool_calls - text1 = processing_class.apply_chat_template(messages1, tokenize=False) - text2 = processing_class.apply_chat_template(messages2, tokenize=False, add_generation_prompt=True) + ids1 = processing_class.apply_chat_template(messages1, tokenize=True, return_dict=False) + ids2 = processing_class.apply_chat_template( + messages2, tokenize=True, return_dict=False, add_generation_prompt=True + ) + + # VLM processors return batched output (list of lists), unbatch for single conversation + if is_vlm: + ids1 = ids1[0] + ids2 = ids2[0] + + return ids2[: len(ids1)] == ids1 - return text2.startswith(text1) +cohere_training_chat_template = (_CHAT_TEMPLATES_DIR / "cohere_training.jinja").read_text() deepseekv3_training_chat_template = (_CHAT_TEMPLATES_DIR / "deepseekv3_training.jinja").read_text() +gemma_training_chat_template = (_CHAT_TEMPLATES_DIR / "gemma_training.jinja").read_text() + +glm4moe_training_chat_template = (_CHAT_TEMPLATES_DIR / "glm4moe_training.jinja").read_text() + +gptoss_training_chat_template = (_CHAT_TEMPLATES_DIR / "gptoss_training.jinja").read_text() + llama3_training_chat_template = (_CHAT_TEMPLATES_DIR / "llama3_training.jinja").read_text() +phi3_training_chat_template = (_CHAT_TEMPLATES_DIR / "phi3_training.jinja").read_text() + qwen2_5_training_chat_template = (_CHAT_TEMPLATES_DIR / "qwen2_5_training.jinja").read_text() qwen3_training_chat_template = (_CHAT_TEMPLATES_DIR / "qwen3_training.jinja").read_text() -gptoss_training_chat_template = (_CHAT_TEMPLATES_DIR / "gptoss_training.jinja").read_text() +qwen3_6_training_chat_template = (_CHAT_TEMPLATES_DIR / "qwen3_6_training.jinja").read_text() -def get_training_chat_template(tokenizer: PreTrainedTokenizer) -> str | None: +def get_training_chat_template(tokenizer: PreTrainedTokenizerBase) -> str | None: r""" Get a training-compatible chat template, if needed. Returns a patched chat template that is prefix-preserving and includes `{%% generation %%}` / `{%% endgeneration %%}` markers for assistant-only loss masking. Returns `None` if the tokenizer's template already satisfies both - requirements. Currently DeepSeek-V3, GPT-OSS, LLaMA 3, Qwen2.5, and Qwen3 are supported. + requirements. Currently Cohere, DeepSeek-V3, Gemma, Gemma2, GLM-4-MoE, GPT-OSS, LLaMA 3, Phi-3, Qwen2.5, Qwen3, and + Qwen3.6 are supported. Args: - tokenizer (`PreTrainedTokenizer`): + tokenizer (`PreTrainedTokenizerBase`): Tokenizer instance to check. Returns: @@ -563,25 +604,42 @@ def get_training_chat_template(tokenizer: PreTrainedTokenizer) -> str | None: '<|im_start|>user\nWhat is 2 * 3?<|im_end|>\n<|im_start|>assistant\n\n\n\n\n\n{"name": "multiply", "arguments": {"a": 2, "b": 3}}\n<|im_end|>\n<|im_start|>user\n\n6\n<|im_end|>\n<|im_start|>assistant\n' ``` """ - # First check if patching is needed - if is_chat_template_prefix_preserving(tokenizer) and "{% generation %}" in tokenizer.chat_template: + # First check if patching is needed. Prefix-preservation only matters when the template actually supports tools + # (the check itself renders a tool message), so skip it otherwise. + prefix_ok = not supports_tool_calling(tokenizer) or is_chat_template_prefix_preserving(tokenizer) + if prefix_ok and "{% generation %}" in tokenizer.chat_template: return None # No patching needed + if tokenizer.chat_template == cohere_chat_template: + return cohere_training_chat_template + if tokenizer.chat_template == deepseekv3_chat_template: return deepseekv3_training_chat_template + if tokenizer.chat_template == gemma_chat_template: + return gemma_training_chat_template + + if tokenizer.chat_template == glm4moe_chat_template: + return glm4moe_training_chat_template + if tokenizer.chat_template == gptoss_chat_template: return gptoss_training_chat_template if tokenizer.chat_template == llama3_chat_template: return llama3_training_chat_template + if tokenizer.chat_template == phi3_chat_template: + return phi3_training_chat_template + if tokenizer.chat_template == qwen2_5_chat_template: return qwen2_5_training_chat_template if tokenizer.chat_template == qwen3_chat_template: return qwen3_training_chat_template + if tokenizer.chat_template == qwen3_6_chat_template: + return qwen3_6_training_chat_template + raise ValueError( "The tokenizer's chat template is not training-compatible (missing prefix-preservation or " "`{% generation %}` markers) and patching is not supported for this template. " @@ -627,7 +685,7 @@ def _validate_tool_calls(tool_calls: list | None) -> None: tool_call["arguments"] = {} -def parse_response(processing_class: PreTrainedTokenizer | ProcessorMixin, ids: list[int]) -> dict: +def parse_response(processing_class: PreTrainedTokenizerBase | ProcessorMixin, ids: list[int]) -> dict: r""" Parse a token sequence into structured response dictionaries with fallback handling. @@ -640,7 +698,7 @@ def parse_response(processing_class: PreTrainedTokenizer | ProcessorMixin, ids: For VLM processors, automatically uses the inner tokenizer for parsing. Args: - processing_class (`PreTrainedTokenizer` or VLM processor): + processing_class (`PreTrainedTokenizerBase` or VLM processor): Tokenizer or processor with a `parse_response()` method (directly or via inner tokenizer). ids (`list[int]`): List of token sequences. diff --git a/trl/chat_templates/README.md b/trl/chat_templates/README.md index 310b09228ee..00f35491772 100644 --- a/trl/chat_templates/README.md +++ b/trl/chat_templates/README.md @@ -13,10 +13,18 @@ Jinja2 chat templates stored here serve two purposes: Used for identity comparison only. +### `cohere.jinja` + +Original Cohere Command chat template (as shipped by CohereForAI/c4ai-command-r-v01 and related checkpoints). + ### `deepseekv3.jinja` Original DeepSeek-V3 chat template. +### `gemma.jinja` + +Original Gemma chat template. Used by both Gemma (v1) and Gemma2, which ship identical templates. + ### `glm4moe.jinja` Original GLM-4-MoE chat template. @@ -49,10 +57,20 @@ Original Qwen3-VL chat template. Unlike text-only Qwen3, this template is alread Original Qwen3.5 chat templates. +### `qwen3_6.jinja` + +Original Qwen3.6 chat template (shared across `Qwen3.6-27B`, `Qwen3.6-35B-A3B`, and their FP8 variants). Differs from `qwen3_5_4b_and_above.jinja` by adding a `preserve_thinking` flag and tweaking how non-string tool-call argument values are stringified. + ## Training templates Patched templates that fix training-specific issues. Swapped in at init when tools are enabled (GRPO) or when `assistant_only_loss=True` (SFT). +### `cohere_training.jinja` + +Patched Cohere template. Diff vs `cohere.jinja`: + +Wrap assistant message output with `{% generation %}` / `{% endgeneration %}` so that `return_assistant_tokens_mask=True` produces correct masks for SFT assistant-only loss. + ### `deepseekv3_training.jinja` Patched DeepSeek-V3 template. Diff vs `deepseekv3.jinja`: @@ -60,6 +78,50 @@ Patched DeepSeek-V3 template. Diff vs `deepseekv3.jinja`: - Uses `| tojson` on `tool['function']['arguments']` so that `arguments` can be passed as a `dict` (the documented format per [transformers docs](https://huggingface.co/docs/transformers/en/chat_extras#tool-calling-example)). The original template uses raw string concatenation, which crashes on dict inputs. - Wraps assistant message output with `{% generation %}` / `{% endgeneration %}` markers for SFT assistant-only loss. +### `gemma_training.jinja` + +Patched Gemma template, shared by Gemma (v1) and Gemma2 (which ship identical chat templates). Diff vs `gemma.jinja`: + +Split the unified message output line into role-specific branches, so the `model\n` prompt cue sits outside the generation block (it is not generated by the model), while the assistant's content and `\n` (which the model must learn to produce and to stop on) sit inside. Wrap the assistant content with `{% generation %}` / `{% endgeneration %}` so that `return_assistant_tokens_mask=True` produces correct masks for SFT assistant-only loss. + +### `glm4moe_training.jinja` + +Patched GLM-4-MoE template. Diff vs `glm4moe.jinja`: + +Require both `` and `` to be present before parsing, to avoid incorrect splitting when the model generates only one tag: + +```diff +- {%- if '' in content %} ++ {%- if '' in content and '' in content %} +``` + +Wrap assistant message output with `{% generation %}` / `{% endgeneration %}` so that `return_assistant_tokens_mask=True` produces correct masks for SFT assistant-only loss. + +### `gptoss_training.jinja` + +Patched GPT-OSS template. Diff vs `gptoss.jinja`: + +Wrap assistant message output with `{% generation %}` / `{% endgeneration %}` so that `return_assistant_tokens_mask=True` produces correct masks for SFT assistant-only loss. + +### `llama3_training.jinja` + +Patched Llama 3 template. Diff vs `llama3.jinja`: + +Wrap assistant message output with `{% generation %}` / `{% endgeneration %}` so that `return_assistant_tokens_mask=True` produces correct masks for SFT assistant-only loss. + +### `phi3_training.jinja` + +Patched Phi-3 template. Diff vs `phi3.jinja`: + +Wrap assistant message output with `{% generation %}` / `{% endgeneration %}` so that +`return_assistant_tokens_mask=True` produces correct masks for SFT assistant-only loss. + +### `qwen2_5_training.jinja` + +Patched Qwen2.5 template. Diff vs `qwen2_5.jinja`: + +Wrap assistant message output with `{% generation %}` / `{% endgeneration %}` so that `return_assistant_tokens_mask=True` produces correct masks for SFT assistant-only loss. + ### `qwen3_training.jinja` Patched Qwen3 template. Diff vs `qwen3.jinja`: @@ -88,20 +150,6 @@ Always include the thinking block regardless of message position. The original c Wrap assistant message output with `{% generation %}` / `{% endgeneration %}` so that `return_assistant_tokens_mask=True` produces correct masks for SFT assistant-only loss. -### `gptoss_training.jinja` - -Patched GPT-OSS template. Diff vs `gptoss.jinja`: - -Wrap assistant message output with `{% generation %}` / `{% endgeneration %}` so that `return_assistant_tokens_mask=True` produces correct masks for SFT assistant-only loss. - -### `llama3_training.jinja` - -Patched Llama 3 template. Diff vs `llama3.jinja`: +### `qwen3_6_training.jinja` -Wrap assistant message output with `{% generation %}` / `{% endgeneration %}` so that `return_assistant_tokens_mask=True` produces correct masks for SFT assistant-only loss. - -### `qwen2_5_training.jinja` - -Patched Qwen2.5 template. Diff vs `qwen2_5.jinja`: - -Wrap assistant message output with `{% generation %}` / `{% endgeneration %}` so that `return_assistant_tokens_mask=True` produces correct masks for SFT assistant-only loss. +Patched Qwen3.6 template. Same diff as `qwen3_training.jinja` (require both `` and `` before parsing, drop the `loop.index0 > ns.last_query_index` conditional so the thinking block is always emitted, wrap assistant output in `{% generation %}` / `{% endgeneration %}`), applied to the Qwen3.6 base template. diff --git a/trl/chat_templates/cohere.jinja b/trl/chat_templates/cohere.jinja new file mode 100644 index 00000000000..eea053cdb75 --- /dev/null +++ b/trl/chat_templates/cohere.jinja @@ -0,0 +1 @@ +{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Aya, a brilliant, sophisticated, multilingual AI-assistant trained to assist human users by providing thorough responses. You are able to interact and respond to questions in 23 languages and you are powered by a multilingual model built by Cohere For AI.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %} \ No newline at end of file diff --git a/trl/chat_templates/cohere_training.jinja b/trl/chat_templates/cohere_training.jinja new file mode 100644 index 00000000000..4b9082f0f9e --- /dev/null +++ b/trl/chat_templates/cohere_training.jinja @@ -0,0 +1,6 @@ +{#- Training variant of the Cohere chat template (see cohere.jinja for the original). + Modifications vs the original: + - Added {% generation %} / {% endgeneration %} around assistant message output to support + assistant-only loss masking in SFT training. +-#} +{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Aya, a brilliant, sophisticated, multilingual AI-assistant trained to assist human users by providing thorough responses. You are able to interact and respond to questions in 23 languages and you are powered by a multilingual model built by Cohere For AI.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% generation %}{{ content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endgeneration %}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %} \ No newline at end of file diff --git a/trl/chat_templates/gemma.jinja b/trl/chat_templates/gemma.jinja new file mode 100644 index 00000000000..923ec253c8d --- /dev/null +++ b/trl/chat_templates/gemma.jinja @@ -0,0 +1,4 @@ +{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '' + role + ' +' + message['content'] | trim + ' +' }}{% endfor %}{% if add_generation_prompt %}{{'model +'}}{% endif %} \ No newline at end of file diff --git a/trl/chat_templates/gemma_training.jinja b/trl/chat_templates/gemma_training.jinja new file mode 100644 index 00000000000..290e63f8ebe --- /dev/null +++ b/trl/chat_templates/gemma_training.jinja @@ -0,0 +1,29 @@ +{#- Training variant of the Gemma chat template (see gemma.jinja for the original). + This template is shared by Gemma (v1) and Gemma2, which ship identical chat templates. + Modifications vs the original: + - Split the unified output line into role-specific branches so that the + 'model\n' header (a prompt cue, not generated by the model) sits + outside the generation block. + - Added {% generation %} / {% endgeneration %} around assistant message content to + support assistant-only loss masking in SFT training. +-#} +{{- bos_token -}} +{%- if messages[0]['role'] == 'system' -%} + {{- raise_exception('System role not supported') -}} +{%- endif -%} +{%- for message in messages -%} + {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%} + {{- raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') -}} + {%- endif -%} + {%- if message['role'] == 'assistant' -%} + {{- 'model\n' -}} + {%- generation -%} + {{- message['content'] | trim + '\n' -}} + {%- endgeneration -%} + {%- else -%} + {{- '' + message['role'] + '\n' + message['content'] | trim + '\n' -}} + {%- endif -%} +{%- endfor -%} +{%- if add_generation_prompt -%} + {{- 'model\n' -}} +{%- endif -%} diff --git a/trl/chat_templates/glm4moe_training.jinja b/trl/chat_templates/glm4moe_training.jinja new file mode 100644 index 00000000000..7e02685cda7 --- /dev/null +++ b/trl/chat_templates/glm4moe_training.jinja @@ -0,0 +1,112 @@ +{#- Training variant of the GLM-4-MoE chat template (see glm4moe.jinja for the original). + Modifications vs the original: + - {%- if '' in content %} → {%- if '' in content and '' in content %} + Always check for both tags to avoid edge cases where the model generates only one tag. + - Added {% generation %} / {% endgeneration %} around assistant message output to support + assistant-only loss masking in SFT training. +-#} +[gMASK] +{%- if tools -%} +<|system|> +# Tools + +You may call one or more functions to assist with the user query. + +You are provided with function signatures within XML tags: + +{% for tool in tools %} +{{ tool | tojson(ensure_ascii=False) }} +{% endfor %} + + +For each function call, output the function name and arguments within the following XML format: +{function-name} +{arg-key-1} +{arg-value-1} +{arg-key-2} +{arg-value-2} +... +{%- endif -%} +{%- macro visible_text(content) -%} + {%- if content is string -%} + {{- content }} + {%- elif content is iterable and content is not mapping -%} + {%- for item in content -%} + {%- if item is mapping and item.type == 'text' -%} + {{- item.text }} + {%- elif item is string -%} + {{- item }} + {%- endif -%} + {%- endfor -%} + {%- else -%} + {{- content }} + {%- endif -%} +{%- endmacro -%} +{%- set ns = namespace(last_user_index=-1) %} +{%- for m in messages %} + {%- if m.role == 'user' %} + {% set ns.last_user_index = loop.index0 -%} + {%- endif %} +{%- endfor %} +{% for m in messages %} +{%- if m.role == 'user' -%}<|user|> +{{ visible_text(m.content) }} +{{- '/nothink' if (enable_thinking is defined and not enable_thinking and not visible_text(m.content).endswith("/nothink")) else '' -}} +{%- elif m.role == 'assistant' -%} +<|assistant|> +{%- set reasoning_content = '' %} +{%- set content = visible_text(m.content) %} +{%- if m.reasoning_content is string %} + {%- set reasoning_content = m.reasoning_content %} +{%- else %} + {%- if '' in content and '' in content %} + {%- set reasoning_content = content.split('')[0].rstrip('\n').split('')[-1].lstrip('\n') %} + {%- set content = content.split('')[-1].lstrip('\n') %} + {%- endif %} +{%- endif %} +{%- generation %} +{%- if loop.index0 > ns.last_user_index and reasoning_content -%} +{{ '\n' + reasoning_content.strip() + ''}} +{%- else -%} +{{ '\n' }} +{%- endif -%} +{%- if content.strip() -%} +{{ '\n' + content.strip() }} +{%- endif -%} +{% if m.tool_calls %} +{% for tc in m.tool_calls %} +{%- if tc.function %} + {%- set tc = tc.function %} +{%- endif %} +{{ '\n' + tc.name }} +{% set _args = tc.arguments %} +{% for k, v in _args.items() %} +{{ k }} +{{ v | tojson(ensure_ascii=False) if v is not string else v }} +{% endfor %} +{% endfor %} +{% endif %} +{%- endgeneration %} +{%- elif m.role == 'tool' -%} +{%- if m.content is string -%} +{%- if loop.first or (messages[loop.index0 - 1].role != "tool") %} + {{- '<|observation|>' }} +{%- endif %} +{{- '\n\n' }} +{{- m.content }} +{{- '\n' }} +{%- else -%} +<|observation|>{% for tr in m.content %} + + +{{ tr.output if tr.output is defined else tr }} +{% endfor -%} +{% endif -%} +{%- elif m.role == 'system' -%} +<|system|> +{{ visible_text(m.content) }} +{%- endif -%} +{%- endfor -%} +{%- if add_generation_prompt -%} + <|assistant|>{{- '\n' if (enable_thinking is defined and not enable_thinking) else '' -}} +{%- endif -%} \ No newline at end of file diff --git a/trl/chat_templates/phi3.jinja b/trl/chat_templates/phi3.jinja new file mode 100644 index 00000000000..ddb5006baa8 --- /dev/null +++ b/trl/chat_templates/phi3.jinja @@ -0,0 +1,8 @@ +{% for message in messages %}{% if message['role'] == 'system' %}{{'<|system|> +' + message['content'] + '<|end|> +'}}{% elif message['role'] == 'user' %}{{'<|user|> +' + message['content'] + '<|end|> +'}}{% elif message['role'] == 'assistant' %}{{'<|assistant|> +' + message['content'] + '<|end|> +'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|> +' }}{% else %}{{ eos_token }}{% endif %} \ No newline at end of file diff --git a/trl/chat_templates/phi3_training.jinja b/trl/chat_templates/phi3_training.jinja new file mode 100644 index 00000000000..1536d152f7d --- /dev/null +++ b/trl/chat_templates/phi3_training.jinja @@ -0,0 +1,26 @@ +{#- Training variant of the Phi-3 chat template (see phi3.jinja for the original). + Modifications vs the original: + - Added {% generation %} / {% endgeneration %} around assistant message output to support + assistant-only loss masking in SFT training. +-#} +{%- for message in messages %} + {%- if message['role'] == 'system' %} + {{- '<|system|>\n' + message['content'] + '<|end|>\n' }} + {%- elif message['role'] == 'user' %} + {{- '<|user|>\n' + message['content'] + '<|end|>\n' }} + {%- elif message['role'] == 'assistant' %} + {{- '<|assistant|>\n' }} + {%- generation %} + {{- message['content'] + '<|end|>\n' }} + {%- endgeneration %} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|assistant|>\n' }} +{%- elif messages[-1]['role'] == 'assistant' %} + {%- generation %} + {{- eos_token }} + {%- endgeneration %} +{%- else %} + {{- eos_token }} +{%- endif %} diff --git a/trl/chat_templates/qwen3_6.jinja b/trl/chat_templates/qwen3_6.jinja new file mode 100644 index 00000000000..a8755d827c0 --- /dev/null +++ b/trl/chat_templates/qwen3_6.jinja @@ -0,0 +1,154 @@ +{%- set image_count = namespace(value=0) %} +{%- set video_count = namespace(value=0) %} +{%- macro render_content(content, do_vision_count, is_system_content=false) %} + {%- if content is string %} + {{- content }} + {%- elif content is iterable and content is not mapping %} + {%- for item in content %} + {%- if 'image' in item or 'image_url' in item or item.type == 'image' %} + {%- if is_system_content %} + {{- raise_exception('System message cannot contain images.') }} + {%- endif %} + {%- if do_vision_count %} + {%- set image_count.value = image_count.value + 1 %} + {%- endif %} + {%- if add_vision_id %} + {{- 'Picture ' ~ image_count.value ~ ': ' }} + {%- endif %} + {{- '<|vision_start|><|image_pad|><|vision_end|>' }} + {%- elif 'video' in item or item.type == 'video' %} + {%- if is_system_content %} + {{- raise_exception('System message cannot contain videos.') }} + {%- endif %} + {%- if do_vision_count %} + {%- set video_count.value = video_count.value + 1 %} + {%- endif %} + {%- if add_vision_id %} + {{- 'Video ' ~ video_count.value ~ ': ' }} + {%- endif %} + {{- '<|vision_start|><|video_pad|><|vision_end|>' }} + {%- elif 'text' in item %} + {{- item.text }} + {%- else %} + {{- raise_exception('Unexpected item type in content.') }} + {%- endif %} + {%- endfor %} + {%- elif content is none or content is undefined %} + {{- '' }} + {%- else %} + {{- raise_exception('Unexpected content type.') }} + {%- endif %} +{%- endmacro %} +{%- if not messages %} + {{- raise_exception('No messages provided.') }} +{%- endif %} +{%- if tools and tools is iterable and tools is not mapping %} + {{- '<|im_start|>system\n' }} + {{- "# Tools\n\nYou have access to the following functions:\n\n" }} + {%- for tool in tools %} + {{- "\n" }} + {{- tool | tojson }} + {%- endfor %} + {{- "\n" }} + {{- '\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n\n\n\nvalue_1\n\n\nThis is the value for the second parameter\nthat can span\nmultiple lines\n\n\n\n\n\nReminder:\n- Function calls MUST follow the specified format: an inner block must be nested within XML tags\n- Required parameters MUST be specified\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n' }} + {%- if messages[0].role == 'system' %} + {%- set content = render_content(messages[0].content, false, true)|trim %} + {%- if content %} + {{- '\n\n' + content }} + {%- endif %} + {%- endif %} + {{- '<|im_end|>\n' }} +{%- else %} + {%- if messages[0].role == 'system' %} + {%- set content = render_content(messages[0].content, false, true)|trim %} + {{- '<|im_start|>system\n' + content + '<|im_end|>\n' }} + {%- endif %} +{%- endif %} +{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %} +{%- for message in messages[::-1] %} + {%- set index = (messages|length - 1) - loop.index0 %} + {%- if ns.multi_step_tool and message.role == "user" %} + {%- set content = render_content(message.content, false)|trim %} + {%- if not(content.startswith('') and content.endswith('')) %} + {%- set ns.multi_step_tool = false %} + {%- set ns.last_query_index = index %} + {%- endif %} + {%- endif %} +{%- endfor %} +{%- if ns.multi_step_tool %} + {{- raise_exception('No user query found in messages.') }} +{%- endif %} +{%- for message in messages %} + {%- set content = render_content(message.content, true)|trim %} + {%- if message.role == "system" %} + {%- if not loop.first %} + {{- raise_exception('System message must be at the beginning.') }} + {%- endif %} + {%- elif message.role == "user" %} + {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }} + {%- elif message.role == "assistant" %} + {%- set reasoning_content = '' %} + {%- if message.reasoning_content is string %} + {%- set reasoning_content = message.reasoning_content %} + {%- else %} + {%- if '' in content %} + {%- set reasoning_content = content.split('')[0].rstrip('\n').split('')[-1].lstrip('\n') %} + {%- set content = content.split('')[-1].lstrip('\n') %} + {%- endif %} + {%- endif %} + {%- set reasoning_content = reasoning_content|trim %} + {%- if (preserve_thinking is defined and preserve_thinking is true) or (loop.index0 > ns.last_query_index) %} + {{- '<|im_start|>' + message.role + '\n\n' + reasoning_content + '\n\n\n' + content }} + {%- else %} + {{- '<|im_start|>' + message.role + '\n' + content }} + {%- endif %} + {%- if message.tool_calls and message.tool_calls is iterable and message.tool_calls is not mapping %} + {%- for tool_call in message.tool_calls %} + {%- if tool_call.function is defined %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {%- if loop.first %} + {%- if content|trim %} + {{- '\n\n\n\n' }} + {%- else %} + {{- '\n\n' }} + {%- endif %} + {%- else %} + {{- '\n\n\n' }} + {%- endif %} + {%- if tool_call.arguments is defined %} + {%- for args_name, args_value in tool_call.arguments|items %} + {{- '\n' }} + {%- set args_value = args_value | string if args_value is string else args_value | tojson | safe %} + {{- args_value }} + {{- '\n\n' }} + {%- endfor %} + {%- endif %} + {{- '\n' }} + {%- endfor %} + {%- endif %} + {{- '<|im_end|>\n' }} + {%- elif message.role == "tool" %} + {%- if loop.previtem and loop.previtem.role != "tool" %} + {{- '<|im_start|>user' }} + {%- endif %} + {{- '\n\n' }} + {{- content }} + {{- '\n' }} + {%- if not loop.last and loop.nextitem.role != "tool" %} + {{- '<|im_end|>\n' }} + {%- elif loop.last %} + {{- '<|im_end|>\n' }} + {%- endif %} + {%- else %} + {{- raise_exception('Unexpected message role.') }} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} + {%- if enable_thinking is defined and enable_thinking is false %} + {{- '\n\n\n\n' }} + {%- else %} + {{- '\n' }} + {%- endif %} +{%- endif %} \ No newline at end of file diff --git a/trl/chat_templates/qwen3_6_training.jinja b/trl/chat_templates/qwen3_6_training.jinja new file mode 100644 index 00000000000..e4e705768d7 --- /dev/null +++ b/trl/chat_templates/qwen3_6_training.jinja @@ -0,0 +1,162 @@ +{#- Training variant of the Qwen3.6 chat template (see qwen3_6.jinja for the original). + Modifications vs the original: + - {%- if '' in content %} → {%- if '' in content and '' in content %} + Always check for both tags to avoid edge cases where the model generates only one tag. + - Removed the loop.index0 > ns.last_query_index conditional; always include thinking block. + This makes the template prefix-preserving for the [user, assistant] → [user, assistant, tool] transition. + - Added {% generation %} / {% endgeneration %} around assistant message output to support + assistant-only loss masking in SFT training. +-#} +{%- set image_count = namespace(value=0) %} +{%- set video_count = namespace(value=0) %} +{%- macro render_content(content, do_vision_count, is_system_content=false) %} + {%- if content is string %} + {{- content }} + {%- elif content is iterable and content is not mapping %} + {%- for item in content %} + {%- if 'image' in item or 'image_url' in item or item.type == 'image' %} + {%- if is_system_content %} + {{- raise_exception('System message cannot contain images.') }} + {%- endif %} + {%- if do_vision_count %} + {%- set image_count.value = image_count.value + 1 %} + {%- endif %} + {%- if add_vision_id %} + {{- 'Picture ' ~ image_count.value ~ ': ' }} + {%- endif %} + {{- '<|vision_start|><|image_pad|><|vision_end|>' }} + {%- elif 'video' in item or item.type == 'video' %} + {%- if is_system_content %} + {{- raise_exception('System message cannot contain videos.') }} + {%- endif %} + {%- if do_vision_count %} + {%- set video_count.value = video_count.value + 1 %} + {%- endif %} + {%- if add_vision_id %} + {{- 'Video ' ~ video_count.value ~ ': ' }} + {%- endif %} + {{- '<|vision_start|><|video_pad|><|vision_end|>' }} + {%- elif 'text' in item %} + {{- item.text }} + {%- else %} + {{- raise_exception('Unexpected item type in content.') }} + {%- endif %} + {%- endfor %} + {%- elif content is none or content is undefined %} + {{- '' }} + {%- else %} + {{- raise_exception('Unexpected content type.') }} + {%- endif %} +{%- endmacro %} +{%- if not messages %} + {{- raise_exception('No messages provided.') }} +{%- endif %} +{%- if tools and tools is iterable and tools is not mapping %} + {{- '<|im_start|>system\n' }} + {{- "# Tools\n\nYou have access to the following functions:\n\n" }} + {%- for tool in tools %} + {{- "\n" }} + {{- tool | tojson }} + {%- endfor %} + {{- "\n" }} + {{- '\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n\n\n\nvalue_1\n\n\nThis is the value for the second parameter\nthat can span\nmultiple lines\n\n\n\n\n\nReminder:\n- Function calls MUST follow the specified format: an inner block must be nested within XML tags\n- Required parameters MUST be specified\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n' }} + {%- if messages[0].role == 'system' %} + {%- set content = render_content(messages[0].content, false, true)|trim %} + {%- if content %} + {{- '\n\n' + content }} + {%- endif %} + {%- endif %} + {{- '<|im_end|>\n' }} +{%- else %} + {%- if messages[0].role == 'system' %} + {%- set content = render_content(messages[0].content, false, true)|trim %} + {{- '<|im_start|>system\n' + content + '<|im_end|>\n' }} + {%- endif %} +{%- endif %} +{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %} +{%- for message in messages[::-1] %} + {%- set index = (messages|length - 1) - loop.index0 %} + {%- if ns.multi_step_tool and message.role == "user" %} + {%- set content = render_content(message.content, false)|trim %} + {%- if not(content.startswith('') and content.endswith('')) %} + {%- set ns.multi_step_tool = false %} + {%- set ns.last_query_index = index %} + {%- endif %} + {%- endif %} +{%- endfor %} +{%- if ns.multi_step_tool %} + {{- raise_exception('No user query found in messages.') }} +{%- endif %} +{%- for message in messages %} + {%- set content = render_content(message.content, true)|trim %} + {%- if message.role == "system" %} + {%- if not loop.first %} + {{- raise_exception('System message must be at the beginning.') }} + {%- endif %} + {%- elif message.role == "user" %} + {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }} + {%- elif message.role == "assistant" %} + {%- set reasoning_content = '' %} + {%- if message.reasoning_content is string %} + {%- set reasoning_content = message.reasoning_content %} + {%- else %} + {%- if '' in content and '' in content %} + {%- set reasoning_content = content.split('')[0].rstrip('\n').split('')[-1].lstrip('\n') %} + {%- set content = content.split('')[-1].lstrip('\n') %} + {%- endif %} + {%- endif %} + {%- set reasoning_content = reasoning_content|trim %} + {{- '<|im_start|>' + message.role + '\n' }} + {%- generation %} + {{- '\n' + reasoning_content + '\n\n\n' + content }} + {%- if message.tool_calls and message.tool_calls is iterable and message.tool_calls is not mapping %} + {%- for tool_call in message.tool_calls %} + {%- if tool_call.function is defined %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {%- if loop.first %} + {%- if content|trim %} + {{- '\n\n\n\n' }} + {%- else %} + {{- '\n\n' }} + {%- endif %} + {%- else %} + {{- '\n\n\n' }} + {%- endif %} + {%- if tool_call.arguments is defined %} + {%- for args_name, args_value in tool_call.arguments|items %} + {{- '\n' }} + {%- set args_value = args_value | string if args_value is string else args_value | tojson | safe %} + {{- args_value }} + {{- '\n\n' }} + {%- endfor %} + {%- endif %} + {{- '\n' }} + {%- endfor %} + {%- endif %} + {{- '<|im_end|>\n' }} + {%- endgeneration %} + {%- elif message.role == "tool" %} + {%- if loop.previtem and loop.previtem.role != "tool" %} + {{- '<|im_start|>user' }} + {%- endif %} + {{- '\n\n' }} + {{- content }} + {{- '\n' }} + {%- if not loop.last and loop.nextitem.role != "tool" %} + {{- '<|im_end|>\n' }} + {%- elif loop.last %} + {{- '<|im_end|>\n' }} + {%- endif %} + {%- else %} + {{- raise_exception('Unexpected message role.') }} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} + {%- if enable_thinking is defined and enable_thinking is false %} + {{- '\n\n\n\n' }} + {%- else %} + {{- '\n' }} + {%- endif %} +{%- endif %} \ No newline at end of file diff --git a/trl/data_utils.py b/trl/data_utils.py index 89a94b1c470..c110bb98a1a 100644 --- a/trl/data_utils.py +++ b/trl/data_utils.py @@ -22,11 +22,12 @@ import pyarrow as pa import pyarrow.compute as pc import pyarrow.types -from datasets import Dataset, DatasetDict, IterableDatasetDict +from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict from transformers import PreTrainedTokenizerBase, ProcessorMixin DatasetType = TypeVar("DatasetType", Dataset, DatasetDict) +IterableDatasetType = TypeVar("IterableDatasetType", IterableDataset, IterableDatasetDict) def prepare_multimodal_messages(messages: list[dict[str, Any]], images: list | None = None) -> list[dict[str, Any]]: @@ -407,22 +408,22 @@ def _unpair_row(examples: list[dict[str, list[dict[str, str]]]]) -> list[dict[st def unpair_preference_dataset( - dataset: DatasetType, num_proc: int | None = None, desc: str | None = None -) -> DatasetType: - r""" + dataset: DatasetType | IterableDatasetType, **map_kwargs +) -> DatasetType | IterableDatasetType: + # docstyle-ignore + """ Unpair a preference dataset. Args: - dataset ([`~datasets.Dataset`] or [`~datasets.DatasetDict`]): + dataset ([`~datasets.Dataset`] or [`~datasets.DatasetDict`] or [`~datasets.IterableDataset`] or [`~datasets.IterableDatasetDict`]): Preference dataset to unpair. The dataset must have columns `"chosen"`, `"rejected"` and optionally `"prompt"`. - num_proc (`int`, *optional*): - Number of processes to use for processing the dataset. - desc (`str`, *optional*): - Meaningful description to be displayed alongside with the progress bar while mapping examples. + **map_kwargs (`dict`, *optional*): + Additional keyword arguments to pass to the dataset's map method when unpairing preferences. Returns: - [`~datasets.Dataset`]: The unpaired preference dataset. + [`~datasets.Dataset`] or [`~datasets.DatasetDict`] or [`~datasets.IterableDataset`] or [`~datasets.IterableDatasetDict`]: + The unpaired preference dataset. Example: @@ -446,7 +447,7 @@ def unpair_preference_dataset( {'prompt': 'The sky is', 'completion': ' blue.', 'label': True} ``` """ - return dataset.map(_unpair_row, batched=True, remove_columns=["chosen", "rejected"], num_proc=num_proc, desc=desc) + return dataset.map(_unpair_row, batched=True, remove_columns=["chosen", "rejected"], **map_kwargs) def maybe_unpair_preference_dataset( diff --git a/trl/experimental/async_grpo/async_grpo_trainer.py b/trl/experimental/async_grpo/async_grpo_trainer.py index a81dad5639f..25cc5082d7c 100644 --- a/trl/experimental/async_grpo/async_grpo_trainer.py +++ b/trl/experimental/async_grpo/async_grpo_trainer.py @@ -367,6 +367,7 @@ def __init__( model_name=model_name, dataset=train_dataset, reward_funcs=reward_funcs, + processing_class=processing_class, tools=tools, environment_factory=environment_factory, num_generations=self.args.num_generations, diff --git a/trl/experimental/async_grpo/async_rollout_worker.py b/trl/experimental/async_grpo/async_rollout_worker.py index 4fd11312fd2..5157c87c4dd 100644 --- a/trl/experimental/async_grpo/async_rollout_worker.py +++ b/trl/experimental/async_grpo/async_rollout_worker.py @@ -26,7 +26,7 @@ import requests from accelerate.logging import get_logger from datasets import Dataset -from transformers import AutoTokenizer +from transformers import PreTrainedTokenizerBase from trl.chat_template_utils import ( add_response_schema, @@ -90,6 +90,7 @@ def __init__( model_name: str, dataset: Dataset, reward_funcs: list[Callable[..., list[float]]], + processing_class: PreTrainedTokenizerBase, tools: list[Callable] | None = None, environment_factory: Callable[[], object] | None = None, num_generations: int = 8, @@ -165,7 +166,7 @@ def __init__( self.chat_template_kwargs = chat_template_kwargs or {} self.log_completions = log_completions self.num_completions_to_print = num_completions_to_print - self.tokenizer = AutoTokenizer.from_pretrained(model_name) + self.tokenizer = processing_class self.tokenizer = add_response_schema(self.tokenizer) # In multi-turn training, the chat template *must* be prefix-preserving. If the tokenizer's original template # isn't, we replace it at initialization with a training-safe, prefix-preserving template. diff --git a/trl/experimental/bco/bco_trainer.py b/trl/experimental/bco/bco_trainer.py index d86776dc7c1..d807e41b8f2 100644 --- a/trl/experimental/bco/bco_trainer.py +++ b/trl/experimental/bco/bco_trainer.py @@ -24,7 +24,7 @@ from dataclasses import dataclass from operator import itemgetter from pathlib import Path -from typing import TYPE_CHECKING, Any, Literal, Optional +from typing import Any, Literal import numpy as np import pandas as pd @@ -65,7 +65,7 @@ if is_peft_available(): - from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training + from peft import PeftConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training if is_wandb_available(): import wandb @@ -76,9 +76,6 @@ if is_joblib_available(): import joblib -if TYPE_CHECKING: - from transformers import PreTrainedTokenizer - logger = logging.get_logger(__name__) RUNNING_NAME = "running.json" @@ -167,8 +164,8 @@ def load_from_json(cls, accelerator: Accelerator, json_path: str): def _tokenize( batch: dict[str, list[Any]], - tokenizer: "PreTrainedTokenizer", - embedding_tokenizer: Optional["PreTrainedTokenizer"] = None, + tokenizer: PreTrainedTokenizerBase, + embedding_tokenizer: PreTrainedTokenizerBase | None = None, ) -> dict[str, list[Any]]: """Tokenize a batch from a BCO specific dataset.""" prompt_tokenized = tokenizer(batch["prompt"], add_special_tokens=False) @@ -380,7 +377,7 @@ class BCOTrainer(_BaseTrainer): The optimizer and scheduler to use for training. preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`): The function to use to preprocess the logits before computing the metrics. - peft_config (`dict`, defaults to `None`): + peft_config ([`~peft.PeftConfig`], *optional*): The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model. compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*): @@ -424,7 +421,7 @@ def __init__( callbacks: list[TrainerCallback] | None = None, optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None, - peft_config: dict | None = None, + peft_config: "PeftConfig | None" = None, compute_metrics: Callable[[EvalLoopOutput], dict] | None = None, model_adapter_name: str | None = None, ref_adapter_name: str | None = None, @@ -472,15 +469,21 @@ def __init__( if isinstance(ref_model, str): ref_model = AutoModelForCausalLM.from_pretrained(ref_model, **model_init_kwargs) + # PEFT # Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16` # has been called in order to properly call autocast if needed. self._peft_has_been_casted_to_bf16 = False - - if not is_peft_available() and peft_config is not None: - raise ValueError( - "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it with `pip install peft` to use the PEFT models" - ) - elif is_peft_available() and peft_config is not None: + if peft_config is not None: + if not is_peft_available(): + raise ImportError( + "You passed `peft_config` but the `peft` library is not installed. " + "Install it with `pip install trl[peft]`." + ) + if not isinstance(peft_config, PeftConfig): + raise TypeError( + f"`peft_config` must be a `peft.PeftConfig` instance (e.g. `peft.LoraConfig`), " + f"got {type(peft_config).__name__}." + ) if isinstance(model, PeftModel): raise ValueError( "You passed a `PeftModel` instance together with a `peft_config` to the trainer. Please first " diff --git a/trl/experimental/cpo/cpo_trainer.py b/trl/experimental/cpo/cpo_trainer.py index f3e0f39920e..6d59f8e660e 100644 --- a/trl/experimental/cpo/cpo_trainer.py +++ b/trl/experimental/cpo/cpo_trainer.py @@ -61,7 +61,7 @@ if is_peft_available(): - from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training + from peft import PeftConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training if is_wandb_available(): @@ -101,7 +101,7 @@ class CPOTrainer(_BaseTrainer): The optimizer and scheduler to use for training. preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`): The function to use to preprocess the logits before computing the metrics. - peft_config (`dict`, defaults to `None`): + peft_config ([`~peft.PeftConfig`], *optional*): The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model. compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*): @@ -142,7 +142,7 @@ def __init__( callbacks: list[TrainerCallback] | None = None, optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None, - peft_config: dict | None = None, + peft_config: "PeftConfig | None" = None, compute_metrics: Callable[[EvalLoopOutput], dict] | None = None, ): if train_dataset is None: @@ -169,15 +169,21 @@ def __init__( if isinstance(model, str): model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs) + # PEFT # Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16` # has been called in order to properly call autocast if needed. self._peft_has_been_casted_to_bf16 = False - - if not is_peft_available() and peft_config is not None: - raise ValueError( - "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models" - ) - elif is_peft_available() and peft_config is not None: + if peft_config is not None: + if not is_peft_available(): + raise ImportError( + "You passed `peft_config` but the `peft` library is not installed. " + "Install it with `pip install trl[peft]`." + ) + if not isinstance(peft_config, PeftConfig): + raise TypeError( + f"`peft_config` must be a `peft.PeftConfig` instance (e.g. `peft.LoraConfig`), " + f"got {type(peft_config).__name__}." + ) if isinstance(model, PeftModel): raise ValueError( "You passed a `PeftModel` instance together with a `peft_config` to the trainer. Please first " diff --git a/trl/experimental/distillation/distillation_trainer.py b/trl/experimental/distillation/distillation_trainer.py index 77dbcc6c08e..a4879f62b19 100644 --- a/trl/experimental/distillation/distillation_trainer.py +++ b/trl/experimental/distillation/distillation_trainer.py @@ -430,6 +430,16 @@ def __init__( # ── PEFT ── if peft_config is not None: + if not is_peft_available(): + raise ImportError( + "You passed `peft_config` but the `peft` library is not installed. " + "Install it with `pip install trl[peft]`." + ) + if not isinstance(peft_config, PeftConfig): + raise TypeError( + f"`peft_config` must be a `peft.PeftConfig` instance (e.g. `peft.LoraConfig`), " + f"got {type(peft_config).__name__}." + ) model = get_peft_model(model, peft_config) # ── Data collator ── @@ -1284,7 +1294,11 @@ def _get_teacher_token_logprobs_from_server( # Size the output tensors to tightly fit the teacher logprobs. Using the full padded # sequence length would include padding positions with -inf teacher logprobs, producing - # inf in the forward pass and NaN gradients in the backward pass (0 * inf = NaN). + # +inf in the forward pass and NaN gradients in the backward pass (0 * inf = NaN). + # Shorter samples in variable-length batches still need the -inf sentinel at the tail; + # downstream loss consumers (_compute_server_sparse_top_1_divergence_loss, + # _compute_server_forward_kl_loss) neutralise those positions before the divergence + # math runs. completion_length = max( (offset + len(lps) for offset, lps in zip(completion_offsets, result["logprobs"], strict=True)), default=0, @@ -1354,6 +1368,13 @@ def _compute_server_sparse_top_1_divergence_loss( f"{missing_count}/{total_required}." ) + # Replace -inf teacher logprobs at intra-batch padding (labels == -100) with 0 so + # reverse-KL's student_probs·(log_s - log_t) does not leak +inf into the backward pass. + pad_mask_2d = ~required + pad_mask_3d = pad_mask_2d.unsqueeze(-1) + topk_teacher_lps = torch.where(pad_mask_3d, 0.0, topk_teacher_lps) + actual_teacher_lps = torch.where(pad_mask_2d, 0.0, actual_teacher_lps) + # Server path only supports "sampled" mode — config validation enforces this, but we guard # explicitly so future relaxations of the config check don't silently change behaviour. reverse_token_ids = self._get_reverse_kl_top_1_tokens(student_log_probs, completion_tokens) diff --git a/trl/experimental/dppo/dppo_trainer.py b/trl/experimental/dppo/dppo_trainer.py index 260417a1aa6..fb566d2234b 100644 --- a/trl/experimental/dppo/dppo_trainer.py +++ b/trl/experimental/dppo/dppo_trainer.py @@ -333,7 +333,7 @@ def _generate_single_turn(self, prompt_ids, images, multimodal_fields): return prompt_ids, completion_ids, sampled_logprobs, topk_logprobs, topk_token_ids else: prompt_tensors = [torch.tensor(ids) for ids in prompt_ids] - padded_ids = pad(prompt_tensors, padding_value=self.pad_token_id, padding_side="left") + padded_ids = pad(prompt_tensors, padding_value=self._tokenizer.pad_token_id, padding_side="left") attention_mask = pad([torch.ones_like(t) for t in prompt_tensors], padding_value=0, padding_side="left") generate_inputs = {"input_ids": padded_ids, "attention_mask": attention_mask} for key, value in multimodal_fields.items(): @@ -396,7 +396,7 @@ def _generate_single_turn(self, prompt_ids, images, multimodal_fields): topk_logps_chunks.append(topk_lp_t.cpu()) # Mask everything after the first EOS token - is_eos = completion_ids == self.eos_token_id + is_eos = completion_ids == self._tokenizer.eos_token_id has_eos = is_eos.any(dim=1) eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device) eos_idx[has_eos] = is_eos.int().argmax(dim=1)[has_eos] @@ -592,9 +592,7 @@ async def _run_async_tools(async_coros): completion_ids[idx_with_tool] = pct[prompt_length:] + post_tool_ids[idx] # Decode post-tool completions - post_tool_completions = [ - parse_response(self.processing_class, ids) if ids else {} for ids in post_tool_ids - ] + post_tool_completions = [parse_response(self._tokenizer, ids) if ids else {} for ids in post_tool_ids] for idx in range(len(idxs_with_tool)): idx_with_tool = idxs_with_tool[idx] @@ -668,13 +666,12 @@ def _generate(self, prompts: list): # Decode completions. It's important to use `parse_response` when possible, because it handles tool calls. if is_conversational({"prompt": prompts[0]}): - tokenizer = self.processing_class.tokenizer if self._is_vlm else self.processing_class if ( Version(transformers.__version__) >= Version("5.0.0") # parse_response added in v5 - and hasattr(tokenizer, "response_schema") # attribute not set by default for now - and tokenizer.response_schema is not None # only works if the tokenizer has a schema + and hasattr(self._tokenizer, "response_schema") # attribute not set by default for now + and self._tokenizer.response_schema is not None # only works if the tokenizer has a schema ): - completions = [[parse_response(self.processing_class, ids)] for ids in completion_ids] + completions = [[parse_response(self._tokenizer, ids)] for ids in completion_ids] else: contents = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) completions = [[{"role": "assistant", "content": content}] for content in contents] @@ -717,7 +714,7 @@ def _generate(self, prompts: list): self._metrics[mode]["completions/min_length"].append(agg_completion_lengths.float().min().item()) self._metrics[mode]["completions/max_length"].append(agg_completion_lengths.float().max().item()) - eos_and_pad = [self.eos_token_id, self.pad_token_id] + eos_and_pad = [self._tokenizer.eos_token_id, self._tokenizer.pad_token_id] is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids], device=device) agg_is_truncated = self.accelerator.gather(is_truncated) self._metrics[mode]["completions/clipped_ratio"].append(agg_is_truncated.float().mean().item()) @@ -898,7 +895,7 @@ def _generate_and_score_completions( prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids] prompt_ids = pad( prompt_ids, - padding_value=self.pad_token_id, + padding_value=self._tokenizer.pad_token_id, padding_side="left", pad_to_multiple_of=self.pad_to_multiple_of, ).to(device=device) @@ -909,7 +906,7 @@ def _generate_and_score_completions( completion_mask = [torch.ones_like(ids, dtype=torch.long) for ids in completion_ids] completion_ids = pad( completion_ids, - padding_value=self.pad_token_id, + padding_value=self._tokenizer.pad_token_id, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of, ).to(device=device) @@ -951,7 +948,7 @@ def _generate_and_score_completions( # If mask_truncated_completions is enabled, zero out truncated completions for attention and loss masking if self.mask_truncated_completions: - eos_and_pad = [self.eos_token_id, self.pad_token_id] + eos_and_pad = [self._tokenizer.eos_token_id, self._tokenizer.pad_token_id] is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids_list], device=device) # Mask completion_mask for attention masking completion_mask = completion_mask * (~is_truncated).unsqueeze(1).int() diff --git a/trl/experimental/gfpo/gfpo_trainer.py b/trl/experimental/gfpo/gfpo_trainer.py index e5faa5fdce2..c5695be3cf0 100644 --- a/trl/experimental/gfpo/gfpo_trainer.py +++ b/trl/experimental/gfpo/gfpo_trainer.py @@ -118,7 +118,7 @@ def _generate_and_score_completions(self, inputs): prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids] prompt_ids = pad( prompt_ids, - padding_value=self.pad_token_id, + padding_value=self._tokenizer.pad_token_id, padding_side="left", pad_to_multiple_of=self.pad_to_multiple_of, ).to(device=device) @@ -132,7 +132,7 @@ def _generate_and_score_completions(self, inputs): completion_mask = [torch.ones_like(ids, dtype=torch.long) for ids in completion_ids] completion_ids = pad( completion_ids, - padding_value=self.pad_token_id, + padding_value=self._tokenizer.pad_token_id, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of, ).to(device=device) @@ -155,7 +155,7 @@ def _generate_and_score_completions(self, inputs): # If mask_truncated_completions is enabled, zero out truncated completions in completion_mask if self.mask_truncated_completions: - eos_and_pad = [self.eos_token_id, self.pad_token_id] + eos_and_pad = [self._tokenizer.eos_token_id, self._tokenizer.pad_token_id] is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids_list], device=device) completion_mask = completion_mask * (~is_truncated).unsqueeze(1).int() diff --git a/trl/experimental/gold/gold_trainer.py b/trl/experimental/gold/gold_trainer.py index bc5a23aac1c..9f36d610612 100644 --- a/trl/experimental/gold/gold_trainer.py +++ b/trl/experimental/gold/gold_trainer.py @@ -1407,7 +1407,7 @@ def _func(example): **map_kwargs, ) - # Apply the chat template if needed and preserve original text + # Add EOS token if needed: non-conversational only first_example = next(iter(dataset)) if not is_conversational(first_example): if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` diff --git a/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py b/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py index e6b26e1778c..523363d689a 100644 --- a/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py +++ b/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py @@ -116,7 +116,7 @@ def _generate_and_score_completions( prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids] prompt_ids = pad( prompt_ids, - padding_value=self.pad_token_id, + padding_value=self._tokenizer.pad_token_id, padding_side="left", pad_to_multiple_of=self.pad_to_multiple_of, ).to(device=device) @@ -130,7 +130,7 @@ def _generate_and_score_completions( completion_mask = [torch.ones_like(ids, dtype=torch.long) for ids in completion_ids] completion_ids = pad( completion_ids, - padding_value=self.pad_token_id, + padding_value=self._tokenizer.pad_token_id, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of, ).to(device=device) @@ -161,7 +161,7 @@ def _generate_and_score_completions( # If mask_truncated_completions is enabled, zero out truncated completions in completion_mask if self.mask_truncated_completions: - eos_and_pad = [self.eos_token_id, self.pad_token_id] + eos_and_pad = [self._tokenizer.eos_token_id, self._tokenizer.pad_token_id] is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids_list], device=device) completion_mask = completion_mask * (~is_truncated).unsqueeze(1).int() @@ -665,7 +665,7 @@ def update_with_replay_buffer( if target_prompt_len > current_batch_prompt_seq_len: prompt_ids = pad( list(prompt_ids.unbind(0)), - padding_value=self.pad_token_id, + padding_value=self._tokenizer.pad_token_id, pad_to_multiple_of=target_prompt_len, padding_side="left", ) @@ -676,7 +676,7 @@ def update_with_replay_buffer( if target_completion_len > current_batch_completion_seq_len: completion_ids = pad( list(completion_ids.unbind(0)), - padding_value=self.pad_token_id, + padding_value=self._tokenizer.pad_token_id, pad_to_multiple_of=target_completion_len, padding_side="right", ) @@ -711,7 +711,7 @@ def update_with_replay_buffer( if sampled_data["prompt_ids"][i].size(1) < target_prompt_len: sampled_data["prompt_ids"][i] = pad( sampled_data["prompt_ids"][i], - padding_value=self.pad_token_id, + padding_value=self._tokenizer.pad_token_id, pad_to_multiple_of=target_prompt_len, padding_side="left", ) @@ -726,7 +726,7 @@ def update_with_replay_buffer( if sampled_data["completion_ids"][i].size(1) < target_completion_len: sampled_data["completion_ids"][i] = pad( sampled_data["completion_ids"][i], - padding_value=self.pad_token_id, + padding_value=self._tokenizer.pad_token_id, pad_to_multiple_of=target_completion_len, padding_side="right", ) diff --git a/trl/experimental/kto/kto_trainer.py b/trl/experimental/kto/kto_trainer.py index 2c8fee0b3e5..5be61fcf0b4 100644 --- a/trl/experimental/kto/kto_trainer.py +++ b/trl/experimental/kto/kto_trainer.py @@ -12,24 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -import inspect import textwrap from collections import defaultdict from collections.abc import Callable -from contextlib import contextmanager, nullcontext +from contextlib import contextmanager +from dataclasses import dataclass from pathlib import Path from typing import TYPE_CHECKING, Any, Literal -import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import transformers from accelerate import PartialState, logging from accelerate.utils import is_peft_model, tqdm -from datasets import Dataset, concatenate_datasets +from datasets import Dataset, IterableDataset, IterableDatasetDict, concatenate_datasets from packaging.version import Version -from torch import autocast from torch.utils.data import DataLoader, SequentialSampler from transformers import ( AutoProcessor, @@ -39,12 +37,13 @@ ProcessorMixin, TrainerCallback, ) +from transformers.data.data_collator import DataCollatorMixin from transformers.trainer_utils import EvalLoopOutput, has_length from transformers.utils import is_peft_available from ...data_utils import ( extract_prompt, - maybe_apply_chat_template, + is_conversational, unpair_preference_dataset, ) from ...import_utils import is_liger_kernel_available @@ -54,10 +53,10 @@ create_model_from_path, disable_dropout_in_model, get_config_model_id, + pad, selective_log_softmax, use_adapter, ) -from ..utils import DPODataCollatorWithPadding, peft_module_casting_to_bf16 from .kto_config import KTOConfig @@ -65,11 +64,11 @@ from liger_kernel.chunked_loss import LigerFusedLinearKTOLoss if is_peft_available(): - from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training + from peft import PeftConfig, PeftModel, get_peft_model if TYPE_CHECKING: - from transformers import PreTrainedModel, PreTrainedTokenizer + from transformers import PreTrainedModel logger = logging.get_logger(__name__) @@ -77,167 +76,80 @@ RUNNING_NAME = "running.pt" +def get_dataset_column_names(dataset: Dataset | IterableDataset) -> list[str]: + return list(next(iter(dataset)).keys()) if dataset.column_names is None else dataset.column_names + + def _get_kl_dataset(batch: dict[str, list[Any]]) -> dict[str, list[Any]]: """ Creates mismatched pairs of prompts and completions for the KL dataset by adding a +1 offset to the order of completions. For best results, the mismatched outputs y' used to estimate the KL term for a batch should be the same set as the matched outputs y used to estimate the rewards in that batch, just paired with different x. """ - batch["answer_input_ids"] = [batch["answer_input_ids"][-1]] + batch["answer_input_ids"][:-1] - batch["answer_attention_mask"] = [batch["answer_attention_mask"][-1]] + batch["answer_attention_mask"][:-1] + batch["completion_ids"] = [batch["completion_ids"][-1]] + batch["completion_ids"][:-1] return batch -def _tokenize( - batch: dict[str, list[Any]], - tokenizer: "PreTrainedTokenizer", -) -> dict[str, list[Any]]: - """Tokenize a batch from a KTO specific dataset.""" - prompt_tokenized = tokenizer(batch["prompt"], add_special_tokens=False) - prompt_input_ids = prompt_tokenized["input_ids"] - prompt_attention_mask = prompt_tokenized["attention_mask"] - prompt_and_completion = [ - prompt + completion for prompt, completion in zip(batch["prompt"], batch["completion"], strict=True) - ] - full_tokenized = tokenizer(prompt_and_completion, add_special_tokens=False) - full_input_ids = full_tokenized["input_ids"] - full_attention_mask = full_tokenized["attention_mask"] - - answer_input_ids = [f[len(p) :] for f, p in zip(full_input_ids, prompt_input_ids, strict=True)] - answer_attention_mask = [f[len(p) :] for f, p in zip(full_attention_mask, prompt_attention_mask, strict=True)] - - # Concat tokens to form `enc(a) + enc(a + b)[len(enc(a)):]` - full_concat_input_ids = [np.concatenate([p, a]) for p, a in zip(prompt_input_ids, answer_input_ids, strict=True)] - # Prepare input tokens for token by token comparison - full_input_ids = [np.array(f) for f in full_input_ids] - for full, concat in zip(full_input_ids, full_concat_input_ids, strict=True): - if len(full) != len(concat): - raise ValueError( - "The elements in 'full_input_ids' and 'full_concat_input_ids' must have the same pairwise length." - ) - - # On some tokenizers, like Llama-2 tokenizer, there are occasions where tokens - # can be merged together when tokenizing prompt+answer. This could result - # on the last token from the prompt being different when tokenized on its own - # vs when done as prompt+answer. - response_token_ids_start_idx = [len(p) for p in prompt_input_ids] - - # If tokenized prompt is different than both prompt+answer, then it means the - # last token has changed due to merging. - for idx, (p, f, r) in enumerate(zip(prompt_input_ids, full_input_ids, response_token_ids_start_idx, strict=True)): - if not np.array_equal(p, f[:r]): - response_token_ids_start_idx[idx] -= 1 - - prompt_input_ids = [f[:r] for f, r in zip(full_input_ids, response_token_ids_start_idx, strict=True)] - prompt_attention_mask = [f[:r] for f, r in zip(full_attention_mask, response_token_ids_start_idx, strict=True)] - - for p, m in zip(prompt_input_ids, prompt_attention_mask, strict=True): - if len(p) != len(m): - raise ValueError("Prompt input ids and attention mask should have the same length.") - - answer_input_ids = [f[r:] for f, r in zip(full_input_ids, response_token_ids_start_idx, strict=True)] - answer_attention_mask = [f[r:] for f, r in zip(full_attention_mask, response_token_ids_start_idx, strict=True)] - - output = dict( - prompt_input_ids=prompt_input_ids, - prompt_attention_mask=prompt_attention_mask, - answer_input_ids=answer_input_ids, - answer_attention_mask=answer_attention_mask, - ) - - return output - - -def _process_tokens(example: dict[str, Any], model: "PreTrainedModel" = None, **kwargs) -> dict: - """Process tokens of a KTO specific dataset. - - At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation in case the prompt + - completion responses is/are too long. We truncate from the end (completion) to fit within max_length. - - We also create the labels for the completion responses, which are of length equal to the sum of the length of the - prompt and the completion response, with `-100` for the prompt tokens. +@dataclass +class DataCollatorForUnpairedPreference(DataCollatorMixin): """ - prompt = example["prompt"] - completion = example["completion"] - - batch = { - f"{kwargs['prefix']}prompt": prompt, - f"{kwargs['prefix']}completion": completion, - f"{kwargs['prefix']}label": example["label"], - } - - # Check issues below for more details - # 1. https://github.com/huggingface/trl/issues/907 - # 2. https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257 - # 3. https://github.com/LianjiaTech/BELLE/issues/337 - - if not isinstance(prompt, str): - raise ValueError(f"prompt should be an str but got {type(prompt)}") - - if not isinstance(completion, str): - raise ValueError(f"completion should be an str but got {type(completion)}") + Data collator for unpaired preference data. Assembles completions from raw token IDs and pads sequences to the + maximum length of the batch. - # keys of format prompt_* refers to just the prompt and answer_* refers to just the answer - all_tokens = { - "prompt_input_ids": example["prompt_input_ids"], - "prompt_attention_mask": example["prompt_attention_mask"], - "answer_input_ids": example["answer_input_ids"], - "answer_attention_mask": example["answer_attention_mask"], - } + Args: + pad_token_id (`int`): + Token ID to use for padding `input_ids` sequences. + max_length (`int`, *optional*): + Maximum sequence length after assembly. Sequences longer than `max_length` are truncated from the end. + return_tensors (`str`, *optional*, defaults to `"pt"`): + The tensor type to return. Currently, only `"pt"` (PyTorch tensors) is supported. + """ - # calculate max length by checking if BOS/EOS is already there - max_length = kwargs["max_length"] - bos_token_id = kwargs["tokenizer"].bos_token_id - eos_token_id = kwargs["tokenizer"].eos_token_id - if len(all_tokens["prompt_input_ids"]) > 0 and bos_token_id != all_tokens["prompt_input_ids"][0]: - max_length -= 1 - if len(all_tokens["answer_input_ids"]) > 0 and eos_token_id != all_tokens["answer_input_ids"][-1]: - max_length -= 1 - - # if combined sequence is too long, truncate the completion (answer) from the end - prompt_length = len(all_tokens["prompt_input_ids"]) - completion_length = len(all_tokens["answer_input_ids"]) - if prompt_length + completion_length > max_length: - max_completion_length = max_length - prompt_length - for k in ["answer_input_ids", "answer_attention_mask"]: - all_tokens[k] = all_tokens[k][:max_completion_length] - - # all input_ids and attention mask as is. We then check if we need to add BOS/EOS tokens - batch[f"{kwargs['prefix']}prompt_input_ids"] = all_tokens["prompt_input_ids"] - batch[f"{kwargs['prefix']}prompt_attention_mask"] = all_tokens["prompt_attention_mask"] - batch[f"{kwargs['prefix']}completion_input_ids"] = all_tokens["prompt_input_ids"] + all_tokens["answer_input_ids"] - batch[f"{kwargs['prefix']}completion_attention_mask"] = ( - all_tokens["prompt_attention_mask"] + all_tokens["answer_attention_mask"] - ) - - # add BOS, which affects both prompt and the full completion - if bos_token_id is not None: - if len(all_tokens["prompt_input_ids"]) == 0 or bos_token_id != all_tokens["prompt_input_ids"][0]: - batch[f"{kwargs['prefix']}prompt_input_ids"] = [bos_token_id] + batch[ - f"{kwargs['prefix']}prompt_input_ids" - ] - batch[f"{kwargs['prefix']}prompt_attention_mask"] = [1] + batch[f"{kwargs['prefix']}prompt_attention_mask"] - batch[f"{kwargs['prefix']}completion_input_ids"] = [bos_token_id] + batch[ - f"{kwargs['prefix']}completion_input_ids" - ] - batch[f"{kwargs['prefix']}completion_attention_mask"] = [1] + batch[ - f"{kwargs['prefix']}completion_attention_mask" - ] - # add EOS, which affects only the full completion - if len(all_tokens["answer_input_ids"]) == 0 or eos_token_id != all_tokens["answer_input_ids"][-1]: - batch[f"{kwargs['prefix']}completion_input_ids"] = batch[f"{kwargs['prefix']}completion_input_ids"] + [ - eos_token_id - ] - batch[f"{kwargs['prefix']}completion_attention_mask"] = batch[ - f"{kwargs['prefix']}completion_attention_mask" - ] + [1] - - batch[f"{kwargs['prefix']}completion_labels"] = batch[f"{kwargs['prefix']}completion_input_ids"][:] - batch[f"{kwargs['prefix']}completion_labels"][: len(batch[f"{kwargs['prefix']}prompt_input_ids"])] = [-100] * len( - batch[f"{kwargs['prefix']}prompt_input_ids"] - ) + pad_token_id: int + max_length: int | None = None + return_tensors: str = "pt" + + def torch_call(self, examples: list[dict[str, Any]]) -> dict[str, Any]: + batch = {} + for prefix, ids_key in [("completion", "completion_ids"), ("KL_completion", "KL_completion_ids")]: + if ids_key not in examples[0]: + continue + + full_ids_list = [] + labels_list = [] + for ex in examples: + prompt_ids = ex["prompt_ids"] + answer_ids = ex[ids_key] + full_ids = prompt_ids + answer_ids + labels = [-100] * len(prompt_ids) + answer_ids + if self.max_length is not None: + full_ids = full_ids[: self.max_length] + labels = labels[: self.max_length] + full_ids_list.append(full_ids) + labels_list.append(labels) + + batch[f"{prefix}_input_ids"] = pad( + [torch.tensor(ids, dtype=torch.int64) for ids in full_ids_list], + padding_value=self.pad_token_id, + padding_side="right", + ) + batch[f"{prefix}_attention_mask"] = pad( + [torch.ones(len(ids), dtype=torch.int64) for ids in full_ids_list], + padding_value=0, + padding_side="right", + ) + batch[f"{prefix}_labels"] = pad( + [torch.tensor(lbl, dtype=torch.int64) for lbl in labels_list], + padding_value=-100, + padding_side="right", + ) - return batch + if "reference_logps" in examples[0]: + batch["reference_logps"] = torch.tensor([ex["reference_logps"] for ex in examples]) + if "reference_KL_logps" in examples[0]: + batch["reference_KL_logps"] = torch.tensor([ex["reference_KL_logps"] for ex in examples]) + batch["label"] = [ex["label"] for ex in examples] + return batch class KTOTrainer(_BaseTrainer): @@ -263,9 +175,9 @@ class KTOTrainer(_BaseTrainer): state before KTO training starts. args ([`experimental.kto.KTOConfig`], *optional*): Configuration for this trainer. If `None`, a default configuration is used. - train_dataset ([`~datasets.Dataset`]): + train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]): The dataset to use for training. - eval_dataset ([`~datasets.Dataset`]): + eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Dataset | IterableDataset]`): The dataset to use for evaluation. processing_class ([`~transformers.PreTrainedTokenizerBase`] or [`~transformers.ProcessorMixin`], *optional*): Processing class used to process the data. The padding side must be set to "left". If `None`, the @@ -274,20 +186,14 @@ class KTOTrainer(_BaseTrainer): `tokenizer.eos_token` will be used as the default. data_collator ([`~transformers.DataCollator`], *optional*): The data collator to use for training. If None is specified, the default data collator - ([`experimental.utils.DPODataCollatorWithPadding`]) will be used which will pad the sequences to the - maximum length of the sequences in the batch, given a dataset of paired sequences. - model_init (`Callable[[], transformers.PreTrainedModel]`): - The model initializer to use for training. If None is specified, the default model initializer will be - used. + ([`~experimental.kto.kto_trainer.DataCollatorForUnpairedPreference`]) will be used which will pad the + sequences to the maximum length of the sequences in the batch. callbacks (`list[transformers.TrainerCallback]`): The callbacks to use for training. optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`): The optimizer and scheduler to use for training. - preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`): - The function to use to preprocess the logits before computing the metrics. - peft_config (`dict`, defaults to `None`): - The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in - a PEFT model. + peft_config ([`~peft.PeftConfig`], *optional*): + PEFT configuration used to wrap the model. If `None`, the model is not wrapped. compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*): The function to use to compute the metrics. Must take a `EvalPrediction` and return a dictionary string to metric values. @@ -313,15 +219,13 @@ def __init__( model: "str | PreTrainedModel | PeftModel", ref_model: PreTrainedModel | None = None, args: KTOConfig | None = None, - train_dataset: Dataset | None = None, - eval_dataset: Dataset | dict[str, Dataset] | None = None, + train_dataset: Dataset | IterableDataset | None = None, + eval_dataset: Dataset | IterableDataset | dict[str, Dataset | IterableDataset] | None = None, processing_class: PreTrainedTokenizerBase | ProcessorMixin | None = None, data_collator: DataCollator | None = None, - model_init: Callable[[], PreTrainedModel] | None = None, callbacks: list[TrainerCallback] | None = None, optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), - preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None, - peft_config: dict | None = None, + peft_config: "PeftConfig | None" = None, compute_metrics: Callable[[EvalLoopOutput], dict] | None = None, ): # Args @@ -332,6 +236,16 @@ def __init__( if train_dataset is None: raise ValueError("`train_dataset` is required") + elif isinstance(train_dataset, IterableDataset): + # IterableDataset requires dispatch_batches=False because Accelerate's dispatch mode may try to concatenate + # batches from multiple processes, leading to mismatch errors. + if args.accelerator_config.dispatch_batches is True: + logger.warning( + "You are using an `IterableDataset` for training with `dispatch_batches=True`. `dispatch_batches` " + "is forced to `False` when using an `IterableDataset`. To remove this warning, unset " + "`dispatch_batches` in `KTOConfig` or set it to `False`." + ) + args.accelerator_config.dispatch_batches = False # Model if isinstance(model, str): @@ -356,29 +270,36 @@ def __init__( if processing_class is None: processing_class = AutoProcessor.from_pretrained(get_config_model_id(model.config)) if isinstance(processing_class, ProcessorMixin): - tokenizer = processing_class.tokenizer + self._tokenizer = processing_class.tokenizer elif isinstance(processing_class, PreTrainedTokenizerBase): - tokenizer = processing_class + self._tokenizer = processing_class else: raise TypeError("The `processing_class` must be either a `PreTrainedTokenizerBase` or a `ProcessorMixin`") - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token + if self._tokenizer.pad_token is None: + self._tokenizer.pad_token = self._tokenizer.eos_token - # Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16` - # has been called in order to properly call autocast if needed. - self._peft_has_been_casted_to_bf16 = False + # PEFT + if peft_config is not None: + if not is_peft_available(): + raise ImportError( + "You passed `peft_config` but the `peft` library is not installed. " + "Install it with `pip install trl[peft]`." + ) + if not isinstance(peft_config, PeftConfig): + raise TypeError( + f"`peft_config` must be a `peft.PeftConfig` instance (e.g. `peft.LoraConfig`), " + f"got {type(peft_config).__name__}." + ) + if isinstance(model, PeftModel): + raise ValueError( + "You passed a `PeftModel` instance together with a `peft_config` to the trainer. Please first merge " + "and unload the existing adapter, save the resulting base model, and then pass that base model along " + "with the new `peft_config` to the trainer." + ) + # Create PEFT model + model = get_peft_model(model, peft_config) - if not is_peft_available() and peft_config is not None: - raise ValueError( - "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it with `pip install peft` to use the PEFT models" - ) - if is_peft_available() and isinstance(model, PeftModel) and peft_config is not None: - raise ValueError( - "You passed a `PeftModel` instance together with a `peft_config` to the trainer. Please first merge " - "and unload the existing adapter, save the resulting base model, and then pass that base model along " - "with the new `peft_config` to the trainer." - ) - if is_peft_available() and isinstance(model, PeftModel) and ref_model is None: + elif is_peft_available() and isinstance(model, PeftModel) and ref_model is None: # If the model is a PEFT model with a pretrained adapter, we need to create a "ref" adapter that is a copy # of the "default" adapter, so that we can use it as the reference model during KTO training. model.add_adapter("ref", model.peft_config["default"]) @@ -387,51 +308,21 @@ def __init__( ref_name = name.replace(".default.", ".ref.") ref_param = model.get_parameter(ref_name) ref_param.data.copy_(param.data) - if is_peft_available() and peft_config is not None: - if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): - _support_gc_kwargs = hasattr( - args, "gradient_checkpointing_kwargs" - ) and "gradient_checkpointing_kwargs" in list( - inspect.signature(prepare_model_for_kbit_training).parameters - ) - prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing} - - if _support_gc_kwargs: - prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs - - model = prepare_model_for_kbit_training(model, **prepare_model_kwargs) - elif args.gradient_checkpointing: - # For backward compatibility with older versions of transformers - if hasattr(model, "enable_input_require_grads"): - model.enable_input_require_grads() - else: - - def make_inputs_require_grad(module, input, output): - output.requires_grad_(True) - - model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) - - # get peft model with the given config - model = get_peft_model(model, peft_config) - if args.bf16 and getattr(model, "is_loaded_in_4bit", False): - peft_module_casting_to_bf16(model) - # If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager - self._peft_has_been_casted_to_bf16 = True - - # For models that use gradient_checkpointing, we need to attach a hook that enables input - # to explicitly have `requires_grad=True`, otherwise training will either silently - # fail or completely fail. - elif args.gradient_checkpointing: - # For backward compatibility with older versions of transformers - if hasattr(model, "enable_input_require_grads"): - model.enable_input_require_grads() - else: - - def make_inputs_require_grad(module, input, output): - output.requires_grad_(True) - - model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + # When using gradient checkpointing with PEFT, we need to enable input gradients. transformers.Trainer normally + # handles this, but a bug currently prevents it; see https://github.com/huggingface/transformers/issues/42489 + if is_peft_available() and isinstance(model, PeftModel) and args.gradient_checkpointing: + model.enable_input_require_grads() + + # When using QLoRA, the PEFT adapter weights are converted to bf16 to follow the recommendations from the + # original paper (see https://huggingface.co/papers/2305.14314, paragraph 3). Normally, this can be done by + # passing `autocast_adapter_dtype=False` to `get_peft_model`, but this option is not yet supported for + # quantized models. See: https://github.com/huggingface/peft/issues/2889 + # Non-quantized models do not have the `is_loaded_in_{8,4}bit` attributes, whereas quantized models do + if getattr(model, "is_loaded_in_4bit", False) or getattr(model, "is_loaded_in_8bit", False): + for param in model.parameters(): + if param.requires_grad: + param.data = param.data.to(torch.bfloat16) # KTO only supports causal language models, not encoder-decoder models if model is not None and hasattr(model.config, "is_encoder_decoder") and model.config.is_encoder_decoder: @@ -444,7 +335,7 @@ def make_inputs_require_grad(module, input, output): if args.max_length is None: logger.warning( - "When using DPODataCollatorWithPadding, you should set `max_length` in the KTOTrainer's init" + "When using DataCollatorForUnpairedPreference, you should set `max_length` in the KTOTrainer's init" " it will be set to `512` by default, but you should do it yourself in the future.", ) max_length = 512 @@ -452,15 +343,16 @@ def make_inputs_require_grad(module, input, output): max_length = args.max_length if data_collator is None: - data_collator = DPODataCollatorWithPadding( - pad_token_id=tokenizer.pad_token_id, + data_collator = DataCollatorForUnpairedPreference( + pad_token_id=self._tokenizer.pad_token_id, + max_length=max_length, ) if args.remove_unused_columns: args.remove_unused_columns = False # warn users logger.warning( - "When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your KTOConfig" + "When using DataCollatorForUnpairedPreference, you should set `remove_unused_columns=False` in your KTOConfig" " we have set it for you, but you should do it yourself in the future.", ) @@ -501,7 +393,13 @@ def make_inputs_require_grad(module, input, output): # Dataset train_dataset = self._prepare_dataset(train_dataset, processing_class, args, "train") if eval_dataset is not None: - eval_dataset = self._prepare_dataset(eval_dataset, processing_class, args, "eval") + if isinstance(eval_dataset, dict): + eval_dataset = { + key: self._prepare_dataset(dataset, processing_class, args, key) + for key, dataset in eval_dataset.items() + } + else: + eval_dataset = self._prepare_dataset(eval_dataset, processing_class, args, "eval") # Transformers explicitly set use_reentrant=True in the past to silence a PyTorch warning, but the default was # never updated once PyTorch switched to recommending use_reentrant=False. Until that change lands upstream @@ -518,11 +416,9 @@ def make_inputs_require_grad(module, input, output): train_dataset=train_dataset, eval_dataset=eval_dataset, processing_class=processing_class, - model_init=model_init, compute_metrics=compute_metrics, callbacks=callbacks, optimizers=optimizers, - preprocess_logits_for_metrics=preprocess_logits_for_metrics, ) # Reference model @@ -588,6 +484,13 @@ def make_inputs_require_grad(module, input, output): self.kto_loss_fn = LigerFusedLinearKTOLoss(beta=self.beta, use_ref_model=(self.ref_model is not None)) if self.precompute_ref_log_probs: + if isinstance(self.train_dataset, IterableDataset) or isinstance( + self.eval_dataset, (IterableDataset, IterableDatasetDict) + ): + raise ValueError( + "`precompute_ref_log_probs=True` is not supported with IterableDataset. Please use a map-style " + "Dataset or set `precompute_ref_log_probs=False`." + ) self.train_dataset = self._precompute_ref_logps( self.train_dataset, "train", @@ -608,15 +511,45 @@ def make_inputs_require_grad(module, input, output): self.args.precompute_ref_batch_size or self.args.per_device_eval_batch_size, ) + def _tokenize( + self, + processing_class: PreTrainedTokenizerBase | ProcessorMixin, + input: str | list, + **kwargs, + ) -> dict[str, list]: + """Tokenize a single example for dataset preprocessing. + + Dispatches to `apply_chat_template` for conversational input (list of message dicts) and to `__call__` for + non-conversational input (str). + + Args: + processing_class ([`~transformers.PreTrainedTokenizerBase`] or [`~transformers.ProcessorMixin`]): + The tokenizer or processor to use. + input (`str` or `list`): + A string for non-conversational input, or a list of message dicts for conversational input. + **kwargs: + Forwarded to `apply_chat_template` (e.g. `add_generation_prompt`, `return_assistant_tokens_mask`). + + Returns: + `dict` with at least an `"input_ids"` key mapping to a flat `list[int]`. + """ + if isinstance(input, list): # conversational: list of message dicts + result = processing_class.apply_chat_template(input, tokenize=True, return_dict=True, **kwargs) + else: # non-conversational: plain text string + result = processing_class(text=input) + return result + def _prepare_dataset( self, - dataset: Dataset, + dataset: Dataset | IterableDataset, processing_class: PreTrainedTokenizerBase | ProcessorMixin, args: KTOConfig | None, dataset_name: str, - ) -> Dataset: + ) -> Dataset | IterableDataset: # Build the kwargs for the `map` function - map_kwargs = {"num_proc": args.dataset_num_proc} + map_kwargs = {} + if isinstance(dataset, Dataset): # IterableDataset does not support num_proc + map_kwargs["num_proc"] = args.dataset_num_proc # Compute that only on the main process for faster data processing. # see: https://github.com/huggingface/trl/pull/1255 @@ -624,73 +557,96 @@ def _prepare_dataset( # Extract the prompt if needed first_example = next(iter(dataset)) if "prompt" not in first_example: - map_kwargs["desc"] = f"Extracting prompt from {dataset_name} dataset" + if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` + map_kwargs["desc"] = f"Extracting prompt from {dataset_name} dataset" dataset = dataset.map(extract_prompt, **map_kwargs) # Unpair the dataset if needed first_example = next(iter(dataset)) if "chosen" in first_example and "rejected" in first_example: - map_kwargs["desc"] = f"Unpairing {dataset_name} dataset" + if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` + map_kwargs["desc"] = f"Unpairing {dataset_name} dataset" dataset = unpair_preference_dataset(dataset, **map_kwargs) - # Apply the chat template if needed - dataset = dataset.map( - maybe_apply_chat_template, - fn_kwargs={"processing_class": processing_class}, - num_proc=args.dataset_num_proc, - desc=f"Applying chat template to {dataset_name} dataset", - ) + # Add EOS token if needed: non-conversational only + first_example = next(iter(dataset)) + if not is_conversational(first_example): + if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` + map_kwargs["desc"] = f"Adding EOS to {dataset_name} dataset" + + def add_eos(example, eos_token): + if not example["completion"].endswith(eos_token): + example["completion"] = example["completion"] + eos_token + return example + + dataset = dataset.map(add_eos, fn_kwargs={"eos_token": self._tokenizer.eos_token}, **map_kwargs) - tokenizer = getattr(processing_class, "tokenizer", processing_class) # Tokenize dataset - dataset = dataset.map( - _tokenize, - batched=True, - fn_kwargs={"tokenizer": tokenizer}, - num_proc=args.dataset_num_proc, - desc=f"Tokenizing {dataset_name} dataset", - ) - # Process dataset - dataset = dataset.map( - _process_tokens, - fn_kwargs={ - "prefix": "", - "tokenizer": tokenizer, - "max_length": self.max_length, - }, - num_proc=args.dataset_num_proc, - desc=f"Processing tokenized {dataset_name} dataset", - ) + if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` + map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset" + + def tokenize_fn(example, processing_class): + if is_conversational(example): + chat_template_kwargs = example.get("chat_template_kwargs", {}) + prompt_ids = self._tokenize( + processing_class, + example["prompt"], + add_generation_prompt=True, + **chat_template_kwargs, + )["input_ids"] + prompt_completion_ids = self._tokenize( + processing_class, + example["prompt"] + example["completion"], + **chat_template_kwargs, + )["input_ids"] + else: + prompt_ids = self._tokenize(processing_class, example["prompt"])["input_ids"] + prompt_completion_ids = self._tokenize( + processing_class, example["prompt"] + example["completion"] + )["input_ids"] + + if not prompt_completion_ids[: len(prompt_ids)] == prompt_ids: + logger.warning( + "Mismatch between tokenized prompt and the start of tokenized prompt+completion. " + "This may be due to unexpected tokenizer behavior, whitespace issues, or special " + "token handling. Verify that the tokenizer is processing text consistently." + ) + + return { + "prompt_ids": prompt_ids, + "completion_ids": prompt_completion_ids[len(prompt_ids) :], + } + + dataset = dataset.map(tokenize_fn, fn_kwargs={"processing_class": processing_class}, **map_kwargs) # Get KL datasets if needed if self.calculate_KL: + + def rename_kl_fn(example): + return {"KL_completion_ids": example["completion_ids"]} + # create pairs for estimating the KL term by flipping the matched pairs in each batch of size total_batch_size # i.e., (x_1, y_1), ..., (x_n, y_n) --> (x_1, y_n), ..., (x_n, y_1) = (x'_1, y'_1), ..., (x'_n, y'_n) + if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` + map_kwargs["desc"] = f"Extracting KL {dataset_name} dataset" kl_dataset = dataset.map( - _get_kl_dataset, - batched=True, - batch_size=args.per_device_train_batch_size, - num_proc=args.dataset_num_proc, - desc=f"Extracting KL {dataset_name} dataset", + _get_kl_dataset, batched=True, batch_size=args.per_device_train_batch_size, **map_kwargs ) + if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` + map_kwargs["desc"] = f"Assembling KL {dataset_name} dataset" + column_names = get_dataset_column_names(dataset) kl_dataset = kl_dataset.map( - _process_tokens, - fn_kwargs={ - "prefix": "KL_", - "tokenizer": tokenizer, - "max_length": self.max_length, - }, - num_proc=args.dataset_num_proc, - remove_columns=[c for c in kl_dataset.column_names if c in dataset.column_names], - desc=f"Processing tokenized {dataset_name} KL dataset", + rename_kl_fn, + remove_columns=[c for c in get_dataset_column_names(kl_dataset) if c in column_names], + **map_kwargs, ) # merge the datasets dataset = concatenate_datasets([dataset, kl_dataset], axis=1) - # calculate dataset desirability balance - if dataset_name == "train": + # Calculate dataset desirability balance + if dataset_name == "train" and isinstance(dataset, Dataset): # IterableDataset does not support len num_desirable = max(sum(dataset["label"]), 1) num_undesirable = max(len(dataset["label"]) - num_desirable, 1) # "label" is binary @@ -1196,12 +1152,7 @@ def compute_loss( return_outputs=False, num_items_in_batch=None, ) -> torch.Tensor | tuple[torch.Tensor, dict[str, torch.Tensor]]: - compute_loss_context_manager = ( - autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext() - ) - - with compute_loss_context_manager: - loss, metrics = self.get_batch_loss_metrics(model, inputs) + loss, metrics = self.get_batch_loss_metrics(model, inputs) # Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class: loss = loss.to(self.args.device) @@ -1237,10 +1188,7 @@ def prediction_step( else: ignore_keys = [] - prediction_context_manager = ( - autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext() - ) - with torch.no_grad(), prediction_context_manager: + with torch.no_grad(): loss, metrics = self.get_batch_loss_metrics(model, inputs) # force log the metrics diff --git a/trl/experimental/nash_md/nash_md_trainer.py b/trl/experimental/nash_md/nash_md_trainer.py index f8b10b3392d..fd2f7c816c3 100644 --- a/trl/experimental/nash_md/nash_md_trainer.py +++ b/trl/experimental/nash_md/nash_md_trainer.py @@ -42,7 +42,7 @@ if is_peft_available(): - from peft import PeftModel + from peft import PeftConfig, PeftModel class GeometricMixtureWrapper(GenerationMixin): @@ -133,7 +133,7 @@ class NashMDTrainer(OnlineDPOTrainer): Processing class used to process the data. If provided, will be used to automatically process the inputs for the model, and it will be saved along the model to make it easier to rerun an interrupted training or reuse the fine-tuned model. - peft_config (`dict`): + peft_config ([`~peft.PeftConfig`], *optional*): The peft config to use for training. compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*): The function to use to compute the metrics. Must take a `EvalPrediction` and return a dictionary string to @@ -177,7 +177,7 @@ def __init__( | FeatureExtractionMixin | ProcessorMixin | None = None, - peft_config: dict | None = None, + peft_config: "PeftConfig | None" = None, compute_metrics: Callable[[EvalPrediction], dict] | None = None, callbacks: list[TrainerCallback] | None = None, optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), diff --git a/trl/experimental/online_dpo/online_dpo_trainer.py b/trl/experimental/online_dpo/online_dpo_trainer.py index 4c7adef3b6d..cb57eba41fa 100644 --- a/trl/experimental/online_dpo/online_dpo_trainer.py +++ b/trl/experimental/online_dpo/online_dpo_trainer.py @@ -282,6 +282,18 @@ def __init__( self.is_encoder_decoder = model.config.is_encoder_decoder self.is_vision_model = model.config.model_type in MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES.keys() + # PEFT + if peft_config is not None: + if not is_peft_available(): + raise ImportError( + "You passed `peft_config` but the `peft` library is not installed. " + "Install it with `pip install trl[peft]`." + ) + if not isinstance(peft_config, PeftConfig): + raise TypeError( + f"`peft_config` must be a `peft.PeftConfig` instance (e.g. `peft.LoraConfig`), " + f"got {type(peft_config).__name__}." + ) if peft_config is not None or (is_peft_available() and isinstance(model, PeftModel)): model = prepare_peft_model(model, peft_config, args) @@ -349,17 +361,14 @@ def __init__( # Handle pad token for processors or tokenizers if isinstance(processing_class, ProcessorMixin): - tokenizer = processing_class.tokenizer + self._tokenizer = processing_class.tokenizer elif isinstance(processing_class, PreTrainedTokenizerBase): - tokenizer = processing_class + self._tokenizer = processing_class else: raise TypeError("The `processing_class` must be either a `PreTrainedTokenizerBase` or a `ProcessorMixin`") - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token - - self.pad_token_id = tokenizer.pad_token_id - self.eos_token_id = tokenizer.eos_token_id + if self._tokenizer.pad_token is None: + self._tokenizer.pad_token = self._tokenizer.eos_token # Vision tokens for VLM support self.image_token_id = getattr(processing_class, "image_token_id", None) @@ -368,11 +377,11 @@ def __init__( # Get the image token string for token collapsing self.image_token = None if self.image_token_id is not None: - self.image_token = tokenizer.decode([self.image_token_id]) + self.image_token = self._tokenizer.decode([self.image_token_id]) # Define the collator if not provided if data_collator is None: - data_collator = DPODataCollatorWithPadding(pad_token_id=self.pad_token_id) + data_collator = DPODataCollatorWithPadding(pad_token_id=self._tokenizer.pad_token_id) # Transformers explicitly set use_reentrant=True in the past to silence a PyTorch warning, but the default was # never updated once PyTorch switched to recommending use_reentrant=False. Until that change lands upstream @@ -505,9 +514,9 @@ def __init__( generation_kwargs = { "max_new_tokens": args.max_new_tokens, "do_sample": True, - "pad_token_id": self.pad_token_id, - "bos_token_id": tokenizer.bos_token_id, - "eos_token_id": self.eos_token_id, + "pad_token_id": self._tokenizer.pad_token_id, + "bos_token_id": self._tokenizer.bos_token_id, + "eos_token_id": self._tokenizer.eos_token_id, "temperature": self.temperature, "top_k": self.top_k, "top_p": self.top_p, @@ -583,8 +592,8 @@ def _enable_gradient_checkpointing(self, model: PreTrainedModel, args: OnlineDPO return model def _generate_vllm(self, prompts, images=None): - eos_token_id = self.eos_token_id - pad_token_id = self.pad_token_id + eos_token_id = self._tokenizer.eos_token_id + pad_token_id = self._tokenizer.pad_token_id # Generate completion_ids and prompt_ids based on mode if self.vllm_mode == "server": @@ -893,8 +902,8 @@ def process_vision_row( def _generate(self, model, prompts, images=None): """Generate completions using the model""" device = next(model.parameters()).device - eos_token_id = self.eos_token_id - pad_token_id = self.pad_token_id + eos_token_id = self._tokenizer.eos_token_id + pad_token_id = self._tokenizer.pad_token_id # Apply chat template and tokenize the input inputs = [{"prompt": prompt} for prompt in prompts] @@ -923,9 +932,7 @@ def _generate(self, model, prompts, images=None): else: # If the chat template doesn't use the image token, remove all instances if self.vision_end_token_id is not None: - escaped_eoi_token = re.escape( - self.processing_class.tokenizer.decode([self.vision_end_token_id]) - ) + escaped_eoi_token = re.escape(self._tokenizer.decode([self.vision_end_token_id])) prompts_text = [ re.sub(rf"({escaped_img_token})+{escaped_eoi_token}", "", text) for text in prompts_text ] @@ -1118,7 +1125,7 @@ def training_step( else: prompt_ids, prompt_mask, completion_ids, completion_mask = self._generate(model, prompts, images) - contain_eos_token = torch.any(completion_ids == self.eos_token_id, dim=-1) + contain_eos_token = torch.any(completion_ids == self._tokenizer.eos_token_id, dim=-1) # Extract vision inputs if available for VLM support vision_inputs = None diff --git a/trl/experimental/orpo/orpo_trainer.py b/trl/experimental/orpo/orpo_trainer.py index 59d7636efb7..eb8378b02db 100644 --- a/trl/experimental/orpo/orpo_trainer.py +++ b/trl/experimental/orpo/orpo_trainer.py @@ -62,7 +62,7 @@ if is_peft_available(): - from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training + from peft import PeftConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training if is_wandb_available(): @@ -112,7 +112,7 @@ class ORPOTrainer(_BaseTrainer): The optimizer and scheduler to use for training. preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`): The function to use to preprocess the logits before computing the metrics. - peft_config (`dict`, defaults to `None`): + peft_config ([`~peft.PeftConfig`], *optional*): The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model. compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*): @@ -151,7 +151,7 @@ def __init__( callbacks: list[TrainerCallback] | None = None, optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None, - peft_config: dict | None = None, + peft_config: "PeftConfig | None" = None, compute_metrics: Callable[[EvalLoopOutput], dict] | None = None, ): if train_dataset is None: @@ -178,15 +178,21 @@ def __init__( if isinstance(model, str): model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs) + # PEFT # Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16` # has been called in order to properly call autocast if needed. self._peft_has_been_casted_to_bf16 = False - - if not is_peft_available() and peft_config is not None: - raise ValueError( - "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models" - ) - elif is_peft_available() and peft_config is not None: + if peft_config is not None: + if not is_peft_available(): + raise ImportError( + "You passed `peft_config` but the `peft` library is not installed. " + "Install it with `pip install trl[peft]`." + ) + if not isinstance(peft_config, PeftConfig): + raise TypeError( + f"`peft_config` must be a `peft.PeftConfig` instance (e.g. `peft.LoraConfig`), " + f"got {type(peft_config).__name__}." + ) if isinstance(model, PeftModel): raise ValueError( "You passed a `PeftModel` instance together with a `peft_config` to the trainer. Please first " diff --git a/trl/experimental/ppo/ppo_trainer.py b/trl/experimental/ppo/ppo_trainer.py index 6366f987ec4..0d3e32f0a46 100644 --- a/trl/experimental/ppo/ppo_trainer.py +++ b/trl/experimental/ppo/ppo_trainer.py @@ -416,12 +416,18 @@ def __init__( "[Approximating KL Divergence](http://joschu.net/blog/kl-approx.html) for details." ) - # peft support - if not is_peft_available() and peft_config is not None: - raise ImportError( - "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models" - ) - elif is_peft_available() and peft_config is not None: + # PEFT + if peft_config is not None: + if not is_peft_available(): + raise ImportError( + "You passed `peft_config` but the `peft` library is not installed. " + "Install it with `pip install trl[peft]`." + ) + if not isinstance(peft_config, PeftConfig): + raise TypeError( + f"`peft_config` must be a `peft.PeftConfig` instance (e.g. `peft.LoraConfig`), " + f"got {type(peft_config).__name__}." + ) if isinstance(self.policy_model, PeftModel): raise ValueError( "You passed a `PeftModel` instance together with a `peft_config` to the trainer. Please first " diff --git a/trl/experimental/prm/prm_trainer.py b/trl/experimental/prm/prm_trainer.py index 7b26b69bd82..89cac4ddd44 100644 --- a/trl/experimental/prm/prm_trainer.py +++ b/trl/experimental/prm/prm_trainer.py @@ -44,7 +44,7 @@ if is_peft_available(): - from peft import PeftModel + from peft import PeftConfig, PeftModel logger = logging.get_logger(__name__) @@ -127,7 +127,7 @@ class PRMTrainer(_BaseTrainer): The optimizer and scheduler to use for training. preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`): The function to use to preprocess the logits before computing the metrics. - peft_config (`dict`, defaults to `None`): + peft_config ([`~peft.PeftConfig`], *optional*): The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model. """ @@ -167,11 +167,23 @@ def __init__( None, ), preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None, - peft_config: dict | None = None, + peft_config: "PeftConfig | None" = None, ): if train_dataset is None: raise ValueError("`train_dataset` is required") + # PEFT + if peft_config is not None: + if not is_peft_available(): + raise ImportError( + "You passed `peft_config` but the `peft` library is not installed. " + "Install it with `pip install trl[peft]`." + ) + if not isinstance(peft_config, PeftConfig): + raise TypeError( + f"`peft_config` must be a `peft.PeftConfig` instance (e.g. `peft.LoraConfig`), " + f"got {type(peft_config).__name__}." + ) if peft_config is not None or (is_peft_available() and isinstance(model, PeftModel)): model = prepare_peft_model(model, peft_config, args) diff --git a/trl/experimental/sdft/sdft_trainer.py b/trl/experimental/sdft/sdft_trainer.py index 5bf6095c2a0..e74ce16f58c 100644 --- a/trl/experimental/sdft/sdft_trainer.py +++ b/trl/experimental/sdft/sdft_trainer.py @@ -195,11 +195,23 @@ def __init__( else inspect.signature(model.get_base_model().forward).parameters.keys() ) - if is_peft_available() and is_peft_model(model) and peft_config is not None: - raise ValueError( - "You passed a `PeftModel` instance together with a `peft_config` to SDFTTrainer. Pass either a base " - "model with `peft_config`, or a pre-wrapped PEFT model." - ) + # PEFT + if peft_config is not None: + if not is_peft_available(): + raise ImportError( + "You passed `peft_config` but the `peft` library is not installed. " + "Install it with `pip install trl[peft]`." + ) + if not isinstance(peft_config, PeftConfig): + raise TypeError( + f"`peft_config` must be a `peft.PeftConfig` instance (e.g. `peft.LoraConfig`), " + f"got {type(peft_config).__name__}." + ) + if is_peft_model(model): + raise ValueError( + "You passed a `PeftModel` instance together with a `peft_config` to SDFTTrainer. Pass either a base " + "model with `peft_config`, or a pre-wrapped PEFT model." + ) if peft_config is not None or (is_peft_available() and getattr(model, "peft_config", None) is not None): model = prepare_peft_model(model, peft_config, args) @@ -209,17 +221,15 @@ def __init__( ) if isinstance(processing_class, ProcessorMixin): - tokenizer = processing_class.tokenizer + self._tokenizer = processing_class.tokenizer elif isinstance(processing_class, PreTrainedTokenizerBase): - tokenizer = processing_class + self._tokenizer = processing_class else: raise TypeError("The `processing_class` must be either a `PreTrainedTokenizerBase` or a `ProcessorMixin`") - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token + if self._tokenizer.pad_token is None: + self._tokenizer.pad_token = self._tokenizer.eos_token - self.pad_token_id = tokenizer.pad_token_id - self.eos_token_id = tokenizer.eos_token_id self.max_prompt_length = args.max_prompt_length self.max_completion_length = args.max_completion_length self.num_generations = args.num_generations @@ -239,9 +249,9 @@ def __init__( generation_kwargs = { "max_new_tokens": self.max_completion_length, "do_sample": True, - "pad_token_id": tokenizer.pad_token_id, - "bos_token_id": tokenizer.bos_token_id, - "eos_token_id": tokenizer.eos_token_id, + "pad_token_id": self._tokenizer.pad_token_id, + "bos_token_id": self._tokenizer.bos_token_id, + "eos_token_id": self._tokenizer.eos_token_id, "temperature": args.temperature, "top_p": args.top_p, "top_k": args.top_k, @@ -398,7 +408,7 @@ def _generate_completion_ids(self, prompts: list[Any]) -> tuple[torch.Tensor, to prompt_length = generate_inputs["input_ids"].size(1) completion_ids = prompt_completion_ids[:, prompt_length:] - is_eos = completion_ids == self.eos_token_id + is_eos = completion_ids == self._tokenizer.eos_token_id eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=completion_ids.device) eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] seq_idx = torch.arange(is_eos.size(1), device=completion_ids.device).expand(is_eos.size(0), -1) @@ -408,7 +418,7 @@ def _generate_completion_ids(self, prompts: list[Any]) -> tuple[torch.Tensor, to completion_ids = [torch.tensor(ids, device=self.accelerator.device) for ids in completion_ids_list] completion_mask = [torch.ones_like(ids, dtype=torch.long) for ids in completion_ids] return ( - pad(completion_ids, padding_value=self.pad_token_id, padding_side="right"), + pad(completion_ids, padding_value=self._tokenizer.pad_token_id, padding_side="right"), pad(completion_mask, padding_value=0, padding_side="right"), ) diff --git a/trl/experimental/sdpo/sdpo_trainer.py b/trl/experimental/sdpo/sdpo_trainer.py index ef84a17a44c..66195da4c58 100644 --- a/trl/experimental/sdpo/sdpo_trainer.py +++ b/trl/experimental/sdpo/sdpo_trainer.py @@ -87,7 +87,9 @@ def _tokenize_teacher_messages( teacher_prompt_ids = [ids.to(device) for ids in teacher_prompt_ids_list] teacher_prompt_mask = [torch.ones(len(ids), dtype=torch.long, device=device) for ids in teacher_prompt_ids] return TokenizedPromptBatch( - prompt_ids=pad(teacher_prompt_ids, padding_value=self.trainer.pad_token_id, padding_side="left"), + prompt_ids=pad( + teacher_prompt_ids, padding_value=self.trainer._tokenizer.pad_token_id, padding_side="left" + ), prompt_mask=pad(teacher_prompt_mask, padding_value=0, padding_side="left"), ) @@ -115,7 +117,7 @@ def build( # Use separate variables so the original completion_ids/completion_mask stay unpadded for the # teacher concat (they must match the student's sequence length for logits_to_keep alignment). padded_completion_ids = self.trainer.accelerator.pad_across_processes( - completion_ids, dim=1, pad_index=self.trainer.pad_token_id + completion_ids, dim=1, pad_index=self.trainer._tokenizer.pad_token_id ) all_completion_ids = self.trainer.accelerator.gather(padded_completion_ids) all_prompts = gather_object(prompts) @@ -193,7 +195,7 @@ def build( if demo_idx is None: raise RuntimeError("Expected a successful demonstration index for an active SDPO teacher prompt.") demo_ids = all_completion_ids[demo_idx] - demo_ids = demo_ids[demo_ids != self.trainer.processing_class.pad_token_id] + demo_ids = demo_ids[demo_ids != self.trainer._tokenizer.pad_token_id] demo_text = self.trainer.processing_class.decode(demo_ids, skip_special_tokens=True) if self.trainer.args.remove_thinking_from_demonstration: diff --git a/trl/experimental/self_distillation/base_self_distillation_trainer.py b/trl/experimental/self_distillation/base_self_distillation_trainer.py index bd9abb95164..a332c7caf8a 100644 --- a/trl/experimental/self_distillation/base_self_distillation_trainer.py +++ b/trl/experimental/self_distillation/base_self_distillation_trainer.py @@ -104,6 +104,18 @@ def __init__( else inspect.signature(model.get_base_model().forward).parameters.keys() ) + # PEFT + if peft_config is not None: + if not is_peft_available(): + raise ImportError( + "You passed `peft_config` but the `peft` library is not installed. " + "Install it with `pip install trl[peft]`." + ) + if not isinstance(peft_config, PeftConfig): + raise TypeError( + f"`peft_config` must be a `peft.PeftConfig` instance (e.g. `peft.LoraConfig`), " + f"got {type(peft_config).__name__}." + ) if peft_config is not None or (is_peft_available() and getattr(model, "peft_config", None) is not None): model = prepare_peft_model(model, peft_config, args) @@ -113,17 +125,15 @@ def __init__( ) if isinstance(processing_class, ProcessorMixin): - tokenizer = processing_class.tokenizer + self._tokenizer = processing_class.tokenizer elif isinstance(processing_class, PreTrainedTokenizerBase): - tokenizer = processing_class + self._tokenizer = processing_class else: raise TypeError("The `processing_class` must be either a `PreTrainedTokenizerBase` or a `ProcessorMixin`") - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token + if self._tokenizer.pad_token is None: + self._tokenizer.pad_token = self._tokenizer.eos_token - self.pad_token_id = tokenizer.pad_token_id - self.eos_token_id = tokenizer.eos_token_id self.temperature = args.temperature self.max_prompt_length = args.max_prompt_length self.max_completion_length = args.max_completion_length @@ -151,9 +161,9 @@ def __init__( generation_kwargs = { "max_new_tokens": self.max_completion_length, "do_sample": True, - "pad_token_id": tokenizer.pad_token_id, - "bos_token_id": tokenizer.bos_token_id, - "eos_token_id": tokenizer.eos_token_id, + "pad_token_id": self._tokenizer.pad_token_id, + "bos_token_id": self._tokenizer.bos_token_id, + "eos_token_id": self._tokenizer.eos_token_id, "temperature": args.temperature, "top_p": args.top_p, "top_k": args.top_k, diff --git a/trl/experimental/self_distillation/online_rollout_mixin.py b/trl/experimental/self_distillation/online_rollout_mixin.py index 490724582dc..93caf5e2eeb 100644 --- a/trl/experimental/self_distillation/online_rollout_mixin.py +++ b/trl/experimental/self_distillation/online_rollout_mixin.py @@ -110,7 +110,7 @@ def _generate_transformers(self, prompts): prompt_mask = generate_inputs["attention_mask"] prompt_length = prompt_ids.size(1) completion_ids = prompt_completion_ids[:, prompt_length:] - is_eos = completion_ids == self.eos_token_id + is_eos = completion_ids == self._tokenizer.eos_token_id eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=completion_ids.device) eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] seq_idx = torch.arange(is_eos.size(1), device=completion_ids.device).expand(is_eos.size(0), -1) @@ -173,15 +173,17 @@ def _generate_and_score_completions(self, inputs): prompt_ids = [torch.tensor(ids) for ids in prompt_ids_list] prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids] - prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left").to(device=device) + prompt_ids = pad(prompt_ids, padding_value=self._tokenizer.pad_token_id, padding_side="left").to(device=device) prompt_mask = pad(prompt_mask, padding_value=0, padding_side="left").to(device=device) completion_ids = [torch.tensor(ids) for ids in completion_ids_list] completion_mask = [torch.ones_like(ids, dtype=torch.long) for ids in completion_ids] - completion_ids = pad(completion_ids, padding_value=self.pad_token_id, padding_side="right").to(device=device) + completion_ids = pad(completion_ids, padding_value=self._tokenizer.pad_token_id, padding_side="right").to( + device=device + ) completion_mask = pad(completion_mask, padding_value=0, padding_side="right").to(device=device) if self.mask_truncated_completions: - eos_and_pad = [self.eos_token_id, self.pad_token_id] + eos_and_pad = [self._tokenizer.eos_token_id, self._tokenizer.pad_token_id] is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids_list], device=device) completion_mask = completion_mask * (~is_truncated).unsqueeze(1).int() @@ -240,7 +242,7 @@ def _generate_and_score_completions(self, inputs): self._metrics[mode]["completions/min_length"].append(agg_completion_lengths.float().min().item()) self._metrics[mode]["completions/max_length"].append(agg_completion_lengths.float().max().item()) - eos_and_pad = [self.eos_token_id, self.pad_token_id] + eos_and_pad = [self._tokenizer.eos_token_id, self._tokenizer.pad_token_id] is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids_list], device=device) agg_is_truncated = self.accelerator.gather(is_truncated) self._metrics[mode]["completions/clipped_ratio"].append(agg_is_truncated.float().mean().item()) diff --git a/trl/experimental/self_distillation/teacher_context.py b/trl/experimental/self_distillation/teacher_context.py index 5e1020c91a7..2448b78a712 100644 --- a/trl/experimental/self_distillation/teacher_context.py +++ b/trl/experimental/self_distillation/teacher_context.py @@ -80,6 +80,6 @@ def tokenize_prompts(self, prompts: list[Any]) -> TokenizedPromptBatch: prompt_ids = [torch.tensor(ids, device=self.trainer.accelerator.device) for ids in prompt_ids] prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids] return TokenizedPromptBatch( - prompt_ids=pad(prompt_ids, padding_value=self.trainer.pad_token_id, padding_side="left"), + prompt_ids=pad(prompt_ids, padding_value=self.trainer._tokenizer.pad_token_id, padding_side="left"), prompt_mask=pad(prompt_mask, padding_value=0, padding_side="left"), ) diff --git a/trl/experimental/ssd/ssd_trainer.py b/trl/experimental/ssd/ssd_trainer.py index ea378753653..a2b6250c0de 100644 --- a/trl/experimental/ssd/ssd_trainer.py +++ b/trl/experimental/ssd/ssd_trainer.py @@ -135,11 +135,23 @@ def __init__( else inspect.signature(model.get_base_model().forward).parameters.keys() ) - if is_peft_available() and is_peft_model(model) and peft_config is not None: - raise ValueError( - "You passed a `PeftModel` instance together with a `peft_config` to SSDTrainer. Pass either a base " - "model with `peft_config`, or a pre-wrapped PEFT model." - ) + # PEFT + if peft_config is not None: + if not is_peft_available(): + raise ImportError( + "You passed `peft_config` but the `peft` library is not installed. " + "Install it with `pip install trl[peft]`." + ) + if not isinstance(peft_config, PeftConfig): + raise TypeError( + f"`peft_config` must be a `peft.PeftConfig` instance (e.g. `peft.LoraConfig`), " + f"got {type(peft_config).__name__}." + ) + if is_peft_model(model): + raise ValueError( + "You passed a `PeftModel` instance together with a `peft_config` to SSDTrainer. Pass either a base " + "model with `peft_config`, or a pre-wrapped PEFT model." + ) if peft_config is not None or (is_peft_available() and getattr(model, "peft_config", None) is not None): model = prepare_peft_model(model, peft_config, args) @@ -149,17 +161,15 @@ def __init__( ) if isinstance(processing_class, ProcessorMixin): - tokenizer = processing_class.tokenizer + self._tokenizer = processing_class.tokenizer elif isinstance(processing_class, PreTrainedTokenizerBase): - tokenizer = processing_class + self._tokenizer = processing_class else: raise TypeError("The `processing_class` must be either a `PreTrainedTokenizerBase` or a `ProcessorMixin`") - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token + if self._tokenizer.pad_token is None: + self._tokenizer.pad_token = self._tokenizer.eos_token - self.pad_token_id = tokenizer.pad_token_id - self.eos_token_id = tokenizer.eos_token_id self.max_prompt_length = args.max_prompt_length self.max_completion_length = args.max_completion_length # SSD always samples a single completion per prompt (N=1 in the paper). @@ -177,9 +187,9 @@ def __init__( generation_kwargs = { "max_new_tokens": self.max_completion_length, "do_sample": True, - "pad_token_id": tokenizer.pad_token_id, - "bos_token_id": tokenizer.bos_token_id, - "eos_token_id": tokenizer.eos_token_id, + "pad_token_id": self._tokenizer.pad_token_id, + "bos_token_id": self._tokenizer.bos_token_id, + "eos_token_id": self._tokenizer.eos_token_id, "temperature": args.temperature, "top_p": args.top_p, "top_k": args.top_k, @@ -362,7 +372,7 @@ def _generate_completion_ids_vllm(self, prompts: list[Any]) -> tuple[torch.Tenso completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids_list] completion_mask = [torch.ones(len(ids), dtype=torch.long, device=device) for ids in completion_ids_list] return ( - pad(completion_ids, padding_value=self.pad_token_id, padding_side="right"), + pad(completion_ids, padding_value=self._tokenizer.pad_token_id, padding_side="right"), pad(completion_mask, padding_value=0, padding_side="right"), ) @@ -394,7 +404,7 @@ def _generate_completion_ids_transformers(self, prompts: list[Any]) -> tuple[tor prompt_length = generate_inputs["input_ids"].size(1) completion_ids = prompt_completion_ids[:, prompt_length:] - is_eos = completion_ids == self.eos_token_id + is_eos = completion_ids == self._tokenizer.eos_token_id eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=completion_ids.device) eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] seq_idx = torch.arange(is_eos.size(1), device=completion_ids.device).expand(is_eos.size(0), -1) @@ -404,7 +414,7 @@ def _generate_completion_ids_transformers(self, prompts: list[Any]) -> tuple[tor completion_ids = [torch.tensor(ids, device=self.accelerator.device) for ids in completion_ids_list] completion_mask = [torch.ones_like(ids, dtype=torch.long) for ids in completion_ids] return ( - pad(completion_ids, padding_value=self.pad_token_id, padding_side="right"), + pad(completion_ids, padding_value=self._tokenizer.pad_token_id, padding_side="right"), pad(completion_mask, padding_value=0, padding_side="right"), ) diff --git a/trl/experimental/tpo/__init__.py b/trl/experimental/tpo/__init__.py new file mode 100644 index 00000000000..e07ec04eb1b --- /dev/null +++ b/trl/experimental/tpo/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .tpo_config import TPOConfig +from .tpo_trainer import TPOTrainer + + +__all__ = ["TPOConfig", "TPOTrainer"] diff --git a/trl/experimental/tpo/tpo.py b/trl/experimental/tpo/tpo.py new file mode 100644 index 00000000000..e755ecf1a08 --- /dev/null +++ b/trl/experimental/tpo/tpo.py @@ -0,0 +1,215 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# /// script +# dependencies = [ +# "trl[peft]", +# "trackio", +# "kernels", +# ] +# /// + +""" +Triple Preference Optimization (TPO) training. + +TPO requires a *triple-preference* dataset where each example contains a `chosen`, a `rejected` and a `reference` +(gold) completion for the same prompt. Two dataset paths are supported out of the box: + +- Use the published + [`tpo-alignment/triple-preference-ultrafeedback-40K`](https://huggingface.co/datasets/tpo-alignment/triple-preference-ultrafeedback-40K) + dataset directly. It already has the `prompt` / `reference` / `chosen` / `rejected` schema. +- Pass `--dataset_name openbmb/UltraFeedback` and the script automatically builds the triple-preference dataset as + described in the TPO paper (Saeidi et al., 2025): the response with the highest `overall_score` becomes `reference`, + the second-highest becomes `chosen`, and the lowest becomes `rejected`. + +In both cases, if the dataset is in standard (plain-string) format it is auto-wrapped into the conversational format so +that the model's chat template is applied — this matches how Instruct models like `Qwen/Qwen3-0.6B` are trained. + +Usage: + +Full training: + +```bash +python trl/experimental/tpo/tpo.py \ + --dataset_name tpo-alignment/triple-preference-ultrafeedback-40K \ + --model_name_or_path Qwen/Qwen3-0.6B \ + --per_device_train_batch_size 2 \ + --max_steps 1000 \ + --learning_rate 5e-7 \ + --gradient_accumulation_steps 8 \ + --beta 0.01 \ + --tpo_alpha 1.0 \ + --output_dir Qwen3-0.6B-TPO \ + --no_remove_unused_columns +``` + +TPO-L (length-normalized variant with target reward margin): + +```bash +python trl/experimental/tpo/tpo.py \ + --dataset_name tpo-alignment/triple-preference-ultrafeedback-40K \ + --model_name_or_path Qwen/Qwen3-0.6B \ + --per_device_train_batch_size 2 \ + --max_steps 1000 \ + --learning_rate 5e-7 \ + --gradient_accumulation_steps 8 \ + --beta 0.01 \ + --tpo_alpha 1.0 \ + --loss_type tpo-l \ + --tpo_l_gamma 0.5 \ + --output_dir Qwen3-0.6B-TPO-L \ + --no_remove_unused_columns +``` + +LoRA: + +```bash +python trl/experimental/tpo/tpo.py \ + --dataset_name tpo-alignment/triple-preference-ultrafeedback-40K \ + --model_name_or_path Qwen/Qwen3-0.6B \ + --per_device_train_batch_size 2 \ + --max_steps 1000 \ + --learning_rate 5e-6 \ + --gradient_accumulation_steps 8 \ + --output_dir Qwen3-0.6B-TPO-LoRA \ + --no_remove_unused_columns \ + --use_peft \ + --lora_r 32 \ + --lora_alpha 16 +``` +""" + +import torch +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser + +from trl import ModelConfig, ScriptArguments, get_kbit_device_map, get_peft_config, get_quantization_config +from trl.experimental.tpo import TPOConfig, TPOTrainer + + +def build_triple_preference_from_ultrafeedback(example): + """ + Build a TPO triple-preference example from a raw UltraFeedback row. + + Following the TPO paper (Saeidi et al., 2025), completions are sorted by `overall_score` and we pick: + - the highest-scored response as the gold `reference`, + - the second-highest as `chosen`, + - the lowest as `rejected`. + + Emits the *conversational* format so that [`TPOTrainer`] applies the model's chat template automatically (see + `trl.data_utils.is_conversational`). Completions with a missing `overall_score` or `response` are filtered out; if + fewer than 3 valid completions remain, the returned example contains `None` values and should be filtered out + downstream. + """ + scored = [c for c in example["completions"] if c.get("overall_score") is not None and c.get("response")] + if len(scored) < 3: + return {"prompt": None, "reference": None, "chosen": None, "rejected": None} + scored.sort(key=lambda c: c["overall_score"], reverse=True) + return { + "prompt": [{"role": "user", "content": example["instruction"]}], + "reference": [{"role": "assistant", "content": scored[0]["response"]}], + "chosen": [{"role": "assistant", "content": scored[1]["response"]}], + "rejected": [{"role": "assistant", "content": scored[-1]["response"]}], + } + + +def to_conversational(example): + """ + Wrap a standard-format triple-preference example (plain strings) in the *conversational* format, so that + [`TPOTrainer`] applies the model's chat template automatically. This is the format expected by Instruct models; for + non-Instruct base models the standard format can be used directly. + """ + return { + "prompt": [{"role": "user", "content": example["prompt"]}], + "reference": [{"role": "assistant", "content": example["reference"]}], + "chosen": [{"role": "assistant", "content": example["chosen"]}], + "rejected": [{"role": "assistant", "content": example["rejected"]}], + } + + +if __name__ == "__main__": + parser = HfArgumentParser((ScriptArguments, TPOConfig, ModelConfig)) + script_args, training_args, model_args = parser.parse_args_into_dataclasses() + + ################ + # Model & Tokenizer + ################ + dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype) + model_kwargs = dict( + revision=model_args.model_revision, + attn_implementation=model_args.attn_implementation, + dtype=dtype, + ) + quantization_config = get_quantization_config(model_args) + if quantization_config is not None: + model_kwargs["device_map"] = get_kbit_device_map() + model_kwargs["quantization_config"] = quantization_config + + model = AutoModelForCausalLM.from_pretrained( + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs + ) + tokenizer = AutoTokenizer.from_pretrained( + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code + ) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + ################ + # Dataset + ################ + dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) + + # Auto-build the triple-preference schema from raw UltraFeedback. + first_example = next(iter(dataset[script_args.dataset_train_split])) + if "completions" in first_example and "instruction" in first_example: + dataset = dataset.map( + build_triple_preference_from_ultrafeedback, + remove_columns=list(first_example.keys()), + ) + dataset = dataset.filter(lambda ex: ex["reference"] is not None) + first_example = next(iter(dataset[script_args.dataset_train_split])) + + # Auto-wrap standard-format triple-preference data (plain strings) into conversational messages so the + # model's chat template gets applied. This matches how Instruct models are trained and is what the TPO + # paper's data preparation produces. + if {"prompt", "chosen", "rejected", "reference"}.issubset(first_example) and isinstance( + first_example["prompt"], str + ): + dataset = dataset.map(to_conversational) + + ################ + # Training + ################ + trainer = TPOTrainer( + model, + args=training_args, + train_dataset=dataset[script_args.dataset_train_split], + eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, + processing_class=tokenizer, + peft_config=get_peft_config(model_args), + ) + + # train and save the model + trainer.train() + + # Run a final evaluation pass and persist the metrics + if training_args.eval_strategy != "no": + metrics = trainer.evaluate() + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", metrics) + + # Save and push to hub + trainer.save_model(training_args.output_dir) + if training_args.push_to_hub: + trainer.push_to_hub(dataset_name=script_args.dataset_name) diff --git a/trl/experimental/tpo/tpo_config.py b/trl/experimental/tpo/tpo_config.py new file mode 100644 index 00000000000..f2b4c4a30fa --- /dev/null +++ b/trl/experimental/tpo/tpo_config.py @@ -0,0 +1,163 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import Any + +from ...trainer.base_config import _BaseConfig + + +@dataclass +class TPOConfig(_BaseConfig): + # docstyle-ignore + r""" + Configuration class for the [`experimental.tpo.TPOTrainer`]. + + This class includes only the parameters that are specific to TPO training. For a full list of training arguments, + please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this class may + differ from those in [`~transformers.TrainingArguments`]. + + Using [`~transformers.HfArgumentParser`] we can turn this class into + [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the + command line. + + Parameters: + > Parameters that control the model + + model_init_kwargs (`dict[str, Any]`, *optional*): + Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model` + argument of the [`experimental.tpo.TPOTrainer`] is provided as a string. + disable_dropout (`bool`, *optional*, defaults to `True`): + Whether to disable dropout in the model. + + > Parameters that control the data preprocessing + + dataset_num_proc (`int`, *optional*): + Number of processes to use for processing the dataset. + max_length (`int` or `None`, *optional*, defaults to `1024`): + Maximum length of the tokenized sequence. Sequences longer than `max_length` are truncated from the left or + right depending on the `truncation_mode`. If `None`, no truncation is applied. + truncation_mode (`str`, *optional*, defaults to `"keep_start"`): + Truncation mode to use when the sequence exceeds `max_length`. Possible values are `"keep_start"` and + `"keep_end"`. + pad_to_multiple_of (`int`, *optional*): + If set, the sequences will be padded to a multiple of this value. + + > Parameters that control the training + + loss_type (`str`, *optional*, defaults to `"sigmoid"`): + Type of loss to use. Possible values are: + + - `"sigmoid"`: sigmoid loss from the original [TPO](https://huggingface.co/papers/2405.16681) paper. + - `"hinge"`: hinge loss on the normalized likelihood from the + [SLiC](https://huggingface.co/papers/2305.10425) paper. + - `"ipo"`: IPO loss from the [IPO](https://huggingface.co/papers/2310.12036) paper. + - `"tpo-l"`: length-normalized TPO variant from the + [TPO](https://huggingface.co/papers/2405.16681) paper, which adds a target reward margin + `tpo_l_gamma` to the Bradley-Terry objective. + + beta (`float`, *optional*, defaults to `0.01`): + Parameter controlling the temperature of the TPO loss. For the IPO loss (`loss_type="ipo"`), β is the + regularization parameter denoted by τ in the [paper](https://huggingface.co/papers/2310.12036). + label_smoothing (`float`, *optional*, defaults to `0.0`): + Label smoothing factor. + tpo_alpha (`float`, *optional*, defaults to `1.0`): + Weight of the supervised negative log-likelihood term computed on the gold (`reference`) response in TPO + training. Setting `tpo_alpha=0.0` disables the NLL term and skips the corresponding forward pass. + tpo_l_gamma (`float`, *optional*, defaults to `0.5`): + Target reward margin γ for the TPO-L loss, used only when `loss_type="tpo-l"`. + + > [!NOTE] + > These parameters have default values different from [`~transformers.TrainingArguments`]: + > - `logging_steps`: Defaults to `10` instead of `500`. + > - `gradient_checkpointing`: Defaults to `True` instead of `False`. + > - `bf16`: Defaults to `True` if `fp16` is not set, instead of `False`. + > - `learning_rate`: Defaults to `5e-7` instead of `5e-5`. + """ + + _VALID_DICT_FIELDS = _BaseConfig._VALID_DICT_FIELDS + ["model_init_kwargs"] + + # Parameters whose default values are overridden from TrainingArguments + learning_rate: float = field( + default=5e-7, + metadata={"help": "The initial learning rate for AdamW."}, + ) + + # Parameters that control the model + model_init_kwargs: dict[str, Any] | str | None = field( + default=None, + metadata={ + "help": "Keyword arguments for `AutoModelForCausalLM.from_pretrained`, used when the `model` argument of " + "the `TPOTrainer` is provided as a string." + }, + ) + disable_dropout: bool = field( + default=True, + metadata={"help": "Whether to disable dropout in the model."}, + ) + + # Parameters that control the data preprocessing + dataset_num_proc: int | None = field( + default=None, + metadata={"help": "Number of processes to use for processing the dataset."}, + ) + max_length: int | None = field( + default=1024, + metadata={ + "help": "Maximum length of the tokenized sequence. Sequences longer than `max_length` are truncated from " + "the left or right depending on the `truncation_mode`. If `None`, no truncation is applied." + }, + ) + truncation_mode: str = field( + default="keep_start", + metadata={ + "help": "Truncation mode to use when the sequence exceeds `max_length`.", + "choices": ["keep_end", "keep_start"], + }, + ) + pad_to_multiple_of: int | None = field( + default=None, + metadata={"help": "If set, the sequences will be padded to a multiple of this value."}, + ) + + # Parameters that control the training + loss_type: str = field( + default="sigmoid", + metadata={ + "help": "Type of loss to use.", + "choices": ["sigmoid", "hinge", "ipo", "tpo-l"], + }, + ) + beta: float = field( + default=0.01, + metadata={ + "help": "Parameter controlling the temperature of the TPO loss. For the IPO loss (`loss_type='ipo'`), this " + "value is the regularization parameter denoted by τ in the IPO paper." + }, + ) + label_smoothing: float = field( + default=0.0, + metadata={"help": "Label smoothing factor."}, + ) + tpo_alpha: float = field( + default=1.0, + metadata={ + "help": "Weight of the supervised NLL term computed on the gold (`reference`) response in TPO training. " + "Setting `tpo_alpha=0.0` disables the NLL term and skips the corresponding forward pass." + }, + ) + tpo_l_gamma: float = field( + default=0.5, + metadata={"help": "Target reward margin γ for the TPO-L loss, used only when `loss_type='tpo-l'`."}, + ) diff --git a/trl/experimental/tpo/tpo_trainer.py b/trl/experimental/tpo/tpo_trainer.py new file mode 100644 index 00000000000..93bb193fa8f --- /dev/null +++ b/trl/experimental/tpo/tpo_trainer.py @@ -0,0 +1,787 @@ +# Copyright 2020-2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import textwrap +from collections import defaultdict +from collections.abc import Callable +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import torch +import torch.nn.functional as F +import transformers +from accelerate import PartialState +from accelerate.logging import get_logger +from accelerate.utils import is_peft_model +from datasets import Dataset, IterableDataset +from packaging.version import Version +from transformers import ( + AutoProcessor, + DataCollator, + PreTrainedModel, + PreTrainedTokenizerBase, +) +from transformers.data.data_collator import DataCollatorMixin +from transformers.trainer_callback import TrainerCallback +from transformers.trainer_utils import EvalPrediction +from transformers.utils import is_peft_available + +from ...data_utils import extract_prompt, is_conversational +from ...trainer.base_trainer import _BaseTrainer +from ...trainer.utils import ( + create_model_from_path, + disable_dropout_in_model, + entropy_from_logits, + get_config_model_id, + pad, + selective_log_softmax, +) +from .tpo_config import TPOConfig + + +if is_peft_available(): + from peft import PeftConfig, PeftModel, get_peft_model + + +logger = get_logger(__name__) + + +def _extract_triple_prompt(example: dict) -> dict: + """Extract the shared prompt from `chosen`/`rejected` and also strip it from `reference`. + + Wraps [`~trl.data_utils.extract_prompt`] — which only rewrites `chosen` and `rejected` — and additionally strips + the extracted prompt prefix from the `reference` (gold) completion. This is specific to TPO and assumes that the + `reference` completion shares the same implicit prompt prefix as `chosen` and `rejected`. If it does not, a + `ValueError` is raised asking the caller to provide an explicit `prompt` column. + """ + extracted = extract_prompt(example) + prompt = extracted["prompt"] + reference = example["reference"] + if reference[: len(prompt)] != prompt: + raise ValueError( + "The `reference` completion does not start with the implicit prompt extracted from `chosen`/`rejected`. " + "Either provide an explicit `prompt` column, or make sure the `reference` completion shares the same " + "prompt prefix as the `chosen` and `rejected` completions." + ) + extracted["reference"] = reference[len(prompt) :] + return extracted + + +@dataclass +class DataCollatorForTriplePreference(DataCollatorMixin): + """ + Data collator used for triple-preference data. Inputs are dynamically padded to the maximum length of a batch. + + This collator expects each example in the input list to be a dictionary containing the keys `"prompt_ids"`, + `"chosen_ids"` and `"rejected_ids"`. When `include_reference=True` (the default) each example must additionally + contain `"reference_ids"`. The collator returns a dictionary containing the following keys: + - `"input_ids"`: Tensor of input IDs, padded to the maximum length of the batch. When + `include_reference=True`, the first third of the batch corresponds to the `"chosen_ids"`, the second third to + the `"rejected_ids"` and the last third to the `"reference_ids"`. When `include_reference=False`, the first + half corresponds to the `"chosen_ids"` and the second half to the `"rejected_ids"` (matching the layout of + [`~trl.trainer.dpo_trainer.DataCollatorForPreference`]). + - `"attention_mask"`: Tensor of attention mask, padded to the maximum length of the batch. + - `"completion_mask"`: Tensor indicating the positions of the completion tokens, padded to the maximum length of + the batch. + + Args: + pad_token_id (`int`): + Token ID to use for padding. + max_length (`int`, *optional*): + Maximum length of the sequences after concatenation. Sequences longer than `max_length` are truncated + before padding, which avoids allocating oversized tensors for batches containing very long sequences. + truncation_mode (`str`, *optional*, defaults to `"keep_start"`): + Truncation mode when a concatenated sequence exceeds `max_length`. Possible values are `"keep_end"` and + `"keep_start"`. + pad_to_multiple_of (`int`, *optional*): + If set, the sequences will be padded to a multiple of this value. + return_tensors (`str`, *optional*, defaults to `"pt"`): + Type of Tensor to return. Only `"pt"` is currently supported. + include_reference (`bool`, *optional*, defaults to `True`): + Whether to include the `"reference_ids"` branch in the collated batch. When `False`, the collator emits + only the chosen/rejected halves and skips the gold-response sequences entirely, which matches the behavior + expected when `tpo_alpha=0.0` (no NLL term). + + Examples: + ```python + >>> from trl.experimental.tpo.tpo_trainer import DataCollatorForTriplePreference + + >>> collator = DataCollatorForTriplePreference(pad_token_id=0) + >>> examples = [ + ... {"prompt_ids": [1, 2, 3], "chosen_ids": [4, 5], "rejected_ids": [6], "reference_ids": [7, 8]}, + ... {"prompt_ids": [9, 10], "chosen_ids": [11], "rejected_ids": [12, 13], "reference_ids": [14]}, + ... ] + >>> collator(examples) + {'input_ids': tensor([[ 1, 2, 3, 4, 5], + [ 9, 10, 11, 0, 0], + [ 1, 2, 3, 6, 0], + [ 9, 10, 12, 13, 0], + [ 1, 2, 3, 7, 8], + [ 9, 10, 14, 0, 0]]), + 'attention_mask': tensor([[1, 1, 1, 1, 1], + [1, 1, 1, 0, 0], + [1, 1, 1, 1, 0], + [1, 1, 1, 1, 0], + [1, 1, 1, 1, 1], + [1, 1, 1, 0, 0]]), + 'completion_mask': tensor([[0, 0, 0, 1, 1], + [0, 0, 1, 0, 0], + [0, 0, 0, 1, 0], + [0, 0, 1, 1, 0], + [0, 0, 0, 1, 1], + [0, 0, 1, 0, 0]])} + ``` + """ + + pad_token_id: int + max_length: int | None = None + truncation_mode: str = "keep_start" + pad_to_multiple_of: int | None = None + return_tensors: str = "pt" + include_reference: bool = True + + def torch_call(self, examples: list[dict[str, Any]]) -> dict[str, Any]: + prompt_chosen_ids = [example["prompt_ids"] + example["chosen_ids"] for example in examples] + prompt_rejected_ids = [example["prompt_ids"] + example["rejected_ids"] for example in examples] + chosen_mask = [[0] * len(example["prompt_ids"]) + [1] * len(example["chosen_ids"]) for example in examples] + rejected_mask = [[0] * len(example["prompt_ids"]) + [1] * len(example["rejected_ids"]) for example in examples] + if self.include_reference: + prompt_reference_ids = [example["prompt_ids"] + example["reference_ids"] for example in examples] + reference_mask = [ + [0] * len(example["prompt_ids"]) + [1] * len(example["reference_ids"]) for example in examples + ] + + if self.max_length is not None: + if self.truncation_mode == "keep_start": + sl = slice(None, self.max_length) + elif self.truncation_mode == "keep_end": + sl = slice(-self.max_length, None) + else: + raise ValueError( + f"Unsupported truncation mode: {self.truncation_mode}, expected 'keep_start' or 'keep_end'" + ) + prompt_chosen_ids = [ids[sl] for ids in prompt_chosen_ids] + prompt_rejected_ids = [ids[sl] for ids in prompt_rejected_ids] + chosen_mask = [m[sl] for m in chosen_mask] + rejected_mask = [m[sl] for m in rejected_mask] + if self.include_reference: + prompt_reference_ids = [ids[sl] for ids in prompt_reference_ids] + reference_mask = [m[sl] for m in reference_mask] + + chosen_attention_mask = [[1] * len(ids) for ids in prompt_chosen_ids] + rejected_attention_mask = [[1] * len(ids) for ids in prompt_rejected_ids] + input_ids = prompt_chosen_ids + prompt_rejected_ids + attention_mask = chosen_attention_mask + rejected_attention_mask + completion_mask = chosen_mask + rejected_mask + if self.include_reference: + reference_attention_mask = [[1] * len(ids) for ids in prompt_reference_ids] + input_ids = input_ids + prompt_reference_ids + attention_mask = attention_mask + reference_attention_mask + completion_mask = completion_mask + reference_mask + + # Convert to tensor + input_ids = [torch.tensor(ids) for ids in input_ids] + attention_mask = [torch.tensor(m, dtype=torch.long) for m in attention_mask] + completion_mask = [torch.tensor(m, dtype=torch.long) for m in completion_mask] + + # Pad + output = {} + output["input_ids"] = pad( + input_ids, + padding_value=self.pad_token_id, + padding_side="right", + pad_to_multiple_of=self.pad_to_multiple_of, + ) + output["attention_mask"] = pad( + attention_mask, + padding_value=0, + padding_side="right", + pad_to_multiple_of=self.pad_to_multiple_of, + ) + output["completion_mask"] = pad( + completion_mask, + padding_value=0, + padding_side="right", + pad_to_multiple_of=self.pad_to_multiple_of, + ) + return output + + +class TPOTrainer(_BaseTrainer): + """ + Trainer for Triple Preference Optimization (TPO) method. This algorithm was initially proposed in the paper [Triple + Preference Optimization: Achieving Better Alignment using a Single Step + Optimization](https://huggingface.co/papers/2405.16681). This class is a wrapper around the + [`~transformers.Trainer`] class and inherits all of its attributes and methods. + + Args: + model (`str` or [`~transformers.PreTrainedModel`] or [`~peft.PeftModel`]): + Model to be trained. Can be either: + + - A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or a + path to a *directory* containing model weights saved using + [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded + using `.from_pretrained` (where `` is derived from the model + config) with the keyword arguments in `args.model_init_kwargs`. + - A [`~transformers.PreTrainedModel`] object. Only causal language models are supported. + - A [`~peft.PeftModel`] object. Only causal language models are supported. + args ([`experimental.tpo.TPOConfig`], *optional*): + Configuration for this trainer. If `None`, a default configuration is used. + data_collator ([`~transformers.DataCollator`], *optional*): + Function to use to form a batch from a list of elements of the processed `train_dataset` or `eval_dataset`. + Will default to [`~trl.experimental.tpo.tpo_trainer.DataCollatorForTriplePreference`]. Custom collators + must truncate sequences before padding; the trainer does not apply post-collation truncation. + train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]): + Dataset to use for training. TPO requires a *triple-preference* dataset: each sample must contain a + `"chosen"`, a `"rejected"` and a `"reference"` (gold) completion. The format of the samples can be either: + + - [Standard](dataset_formats#standard): Each sample contains plain text. + - [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role + and content). + eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Dataset | IterableDataset]`): + Dataset to use for evaluation. It must meet the same requirements as `train_dataset`. + processing_class ([`~transformers.PreTrainedTokenizerBase`], *optional*): + Processing class used to process the data. If `None`, the processing class is loaded from the model's name + with [`~transformers.AutoProcessor.from_pretrained`]. A padding token, `tokenizer.pad_token`, must be set. + If the processing class has not set a padding token, `tokenizer.eos_token` will be used as the default. + compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*): + The function that will be used to compute metrics at evaluation. Must take a + [`~transformers.EvalPrediction`] and return a dictionary string to metric values. + callbacks (list of [`~transformers.TrainerCallback`], *optional*): + List of callbacks to customize the training loop. Will add those to the list of default callbacks detailed + in [here](https://huggingface.co/docs/transformers/main_classes/callback). + + If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`] + method. + optimizers (`tuple[torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None]`, *optional*, defaults to `(None, None)`): + A tuple containing the optimizer and the scheduler to use. Will default to an instance of `AdamW` on your + model and a scheduler given by [`~transformers.get_linear_schedule_with_warmup`] controlled by `args`. + peft_config ([`~peft.PeftConfig`], *optional*): + PEFT configuration used to wrap the model. If `None`, the model is not wrapped. + """ + + _tag_names = ["trl", "tpo"] + _name = "TPO" + _paper = { + "title": "Triple Preference Optimization: Achieving Better Alignment using a Single Step Optimization", + "id": "2405.16681", + # docstyle-ignore + "citation": textwrap.dedent("""\ + @misc{saeidi2025triplepreferenceoptimizationachieving, + title = {{Triple Preference Optimization: Achieving Better Alignment using a Single Step Optimization}}, + author = {Amir Saeidi and Shivanshu Verma and Aswin RRV and Kashif Rasul and Chitta Baral}, + year = 2025, + eprint = {2405.16681}, + archivePrefix= {arXiv}, + primaryClass = {cs.CL}, + url = {https://arxiv.org/abs/2405.16681}, + }"""), + } + + def __init__( + self, + model: "str | PreTrainedModel | PeftModel", + args: TPOConfig | None = None, + data_collator: DataCollator | None = None, + train_dataset: Dataset | IterableDataset | None = None, + eval_dataset: Dataset | IterableDataset | dict[str, Dataset | IterableDataset] | None = None, + processing_class: PreTrainedTokenizerBase | None = None, + compute_metrics: Callable[[EvalPrediction], dict] | None = None, + callbacks: list[TrainerCallback] | None = None, + optimizers: tuple[torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None] = (None, None), + peft_config: "PeftConfig | None" = None, + ): + # Args + if args is None: + model_name = model if isinstance(model, str) else get_config_model_id(model.config) + model_name = model_name.split("/")[-1] + args = TPOConfig(f"{model_name}-TPO") + + if train_dataset is None: + raise ValueError("`train_dataset` is required") + elif isinstance(train_dataset, IterableDataset): + # IterableDataset requires dispatch_batches=False because Accelerate's dispatch mode may try to concatenate + # batches from multiple processes, leading to mismatch errors. + if args.accelerator_config.dispatch_batches is True: + logger.warning( + "You are using an `IterableDataset` for training with `dispatch_batches=True`. `dispatch_batches` " + "is forced to `False` when using an `IterableDataset`. To remove this warning, unset " + "`dispatch_batches` in `TPOConfig` or set it to `False`." + ) + args.accelerator_config.dispatch_batches = False + + # Model + if isinstance(model, str): + model_init_kwargs = args.model_init_kwargs or {} + # Distributed training requires device_map=None ("auto" fails) + if args.distributed_state.distributed_type in ["MULTI_GPU", "DEEPSPEED"]: + model_init_kwargs["device_map"] = None + model = create_model_from_path(model, **model_init_kwargs) + else: + if args.model_init_kwargs is not None: + logger.warning( + "You passed `model_init_kwargs` to the `TPOConfig`, but your model is already instantiated. " + "The `model_init_kwargs` will be ignored." + ) + + # Processing class + if processing_class is None: + processing_class = AutoProcessor.from_pretrained(get_config_model_id(model.config)) + if not isinstance(processing_class, PreTrainedTokenizerBase): + raise TypeError( + "The `processing_class` must be a `PreTrainedTokenizerBase`. `TPOTrainer` does not currently " + "support vision-language models." + ) + self._tokenizer = processing_class + + if self._tokenizer.pad_token is None: + self._tokenizer.pad_token = self._tokenizer.eos_token + + # PEFT + if peft_config is not None: + if not is_peft_available(): + raise ImportError( + "You passed `peft_config` but the `peft` library is not installed. " + "Install it with `pip install trl[peft]`." + ) + if not isinstance(peft_config, PeftConfig): + raise TypeError( + f"`peft_config` must be a `peft.PeftConfig` instance (e.g. `peft.LoraConfig`), " + f"got {type(peft_config).__name__}." + ) + if is_peft_model(model): + raise ValueError( + "You passed a `PeftModel` instance together with a `peft_config` to the trainer. Please first merge " + "and unload the existing adapter, save the resulting base model, and then pass that base model along " + "with the new `peft_config` to the trainer." + ) + # Create PEFT model + model = get_peft_model(model, peft_config) + + # When using gradient checkpointing with PEFT, we need to enable input gradients. transformers.Trainer normally + # handles this, but a bug currently prevents it; see https://github.com/huggingface/transformers/issues/42489 + if is_peft_available() and isinstance(model, PeftModel) and args.gradient_checkpointing: + model.enable_input_require_grads() + + # Data collator. When `tpo_alpha=0.0`, the NLL term on the gold response is disabled, so we can drop the + # reference branch from the batch entirely — this spares the model from computing logits for a third of + # each step. + if data_collator is None: + data_collator = DataCollatorForTriplePreference( + pad_token_id=self._tokenizer.pad_token_id, + max_length=args.max_length, + truncation_mode=args.truncation_mode, + pad_to_multiple_of=args.pad_to_multiple_of, + include_reference=args.tpo_alpha != 0.0, + ) + + # Training arguments + self.beta = args.beta + self.loss_type = args.loss_type + self.label_smoothing = args.label_smoothing + self.tpo_alpha = args.tpo_alpha + self.tpo_l_gamma = args.tpo_l_gamma + if self.loss_type in ["hinge", "ipo"] and self.label_smoothing > 0: + logger.warning( + f"You are using the {self.loss_type} loss type that does not support label smoothing. The " + "`label_smoothing` parameter will be ignored. Set `label_smoothing` to `0.0` to remove this warning." + ) + + # Dataset + train_dataset = self._prepare_dataset(train_dataset, processing_class, args, "train") + if eval_dataset is not None: + if isinstance(eval_dataset, dict): + eval_dataset = { + key: self._prepare_dataset(dataset, processing_class, args, key) + for key, dataset in eval_dataset.items() + } + else: + eval_dataset = self._prepare_dataset(eval_dataset, processing_class, args, "eval") + + # Transformers explicitly set use_reentrant=True in the past to silence a PyTorch warning, but the default was + # never updated once PyTorch switched to recommending use_reentrant=False. Until that change lands upstream + # (see https://github.com/huggingface/transformers/pull/43203) and is released (most likely in 5.0.0), we + # default to the recommended non-reentrant behavior here, while preserving any user-provided value. + if args.gradient_checkpointing and Version(transformers.__version__) < Version("5.0.0"): + args.gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {} + args.gradient_checkpointing_kwargs.setdefault("use_reentrant", False) + + super().__init__( + model=model, + args=args, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + processing_class=processing_class, + compute_metrics=compute_metrics, + callbacks=callbacks, + optimizers=optimizers, + ) + + # Disable dropout in the model + if args.disable_dropout: + disable_dropout_in_model(model) + + # Initialize the metrics + self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)} + self._total_train_tokens = 0 + + # Add tags to the model + self.model.add_model_tags(self._tag_names) + + def _tokenize( + self, + processing_class: PreTrainedTokenizerBase, + input: str | list, + **kwargs, + ) -> dict[str, list]: + """Tokenize a single example for dataset preprocessing. + + Dispatches to `apply_chat_template` for conversational input (list of message dicts) and to `__call__` for + non-conversational input (str). + + Args: + processing_class ([`~transformers.PreTrainedTokenizerBase`]): + The tokenizer to use. + input (`str` or `list`): + A string for non-conversational input, or a list of message dicts for conversational input. + **kwargs: + Forwarded to `apply_chat_template` (e.g. `add_generation_prompt`). + + Returns: + `dict` with at least an `"input_ids"` key mapping to a flat `list[int]`. + """ + if isinstance(input, list): # conversational: list of message dicts + return processing_class.apply_chat_template(input, tokenize=True, return_dict=True, **kwargs) + # non-conversational: plain text string + return processing_class(text=input) + + def _prepare_dataset( + self, + dataset: Dataset | IterableDataset, + processing_class: PreTrainedTokenizerBase, + args: TPOConfig, + dataset_name: str, + ) -> Dataset | IterableDataset: + # Validate that the triple-preference columns are present + first_example = next(iter(dataset)) + if "chosen" not in first_example or "rejected" not in first_example: + raise ValueError( + "TPO requires a triple-preference dataset with `chosen`, `rejected` and `reference` columns, but the " + f"dataset is missing `chosen` or `rejected`. Got columns: {list(first_example.keys())}." + ) + if "reference" not in first_example: + raise ValueError( + "TPO requires a triple-preference dataset with `chosen`, `rejected` and `reference` columns, but the " + f"dataset is missing the `reference` (gold) column. Got columns: {list(first_example.keys())}." + ) + + # Build the kwargs for the `map` function + map_kwargs = {} + if isinstance(dataset, Dataset): # IterableDataset does not support num_proc + map_kwargs["num_proc"] = args.dataset_num_proc + + with PartialState().main_process_first(): + # Extract the prompt if needed. Unlike DPO, we must also strip the extracted prompt from the reference + # column (see `_extract_triple_prompt`), which assumes the reference shares the same implicit prompt. + first_example = next(iter(dataset)) + if "prompt" not in first_example: + if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` + map_kwargs["desc"] = f"Extracting prompt from {dataset_name} dataset" + dataset = dataset.map(_extract_triple_prompt, **map_kwargs) + + # Add EOS to completions for non-conversational data + first_example = next(iter(dataset)) + if not is_conversational(first_example): + if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` + map_kwargs["desc"] = f"Adding EOS to {dataset_name} dataset" + + def add_eos(example, eos_token): + if not example["chosen"].endswith(eos_token): + example["chosen"] = example["chosen"] + eos_token + if not example["rejected"].endswith(eos_token): + example["rejected"] = example["rejected"] + eos_token + if not example["reference"].endswith(eos_token): + example["reference"] = example["reference"] + eos_token + return example + + dataset = dataset.map(add_eos, fn_kwargs={"eos_token": processing_class.eos_token}, **map_kwargs) + + # Tokenize the dataset + if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` + map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset" + + def tokenize_fn(example, processing_class): + tools = example.get("tools") + tools = json.loads(tools) if isinstance(tools, str) else tools + output = {} + if is_conversational(example): + prompt_ids = self._tokenize( + processing_class, + example["prompt"], + tools=tools, + add_generation_prompt=True, + **example.get("chat_template_kwargs", {}), + )["input_ids"] + prompt_chosen_ids = self._tokenize( + processing_class, + example["prompt"] + example["chosen"], + tools=tools, + **example.get("chat_template_kwargs", {}), + )["input_ids"] + prompt_rejected_ids = self._tokenize( + processing_class, + example["prompt"] + example["rejected"], + tools=tools, + **example.get("chat_template_kwargs", {}), + )["input_ids"] + prompt_reference_ids = self._tokenize( + processing_class, + example["prompt"] + example["reference"], + tools=tools, + **example.get("chat_template_kwargs", {}), + )["input_ids"] + else: + prompt_ids = self._tokenize(processing_class, example["prompt"])["input_ids"] + prompt_chosen_ids = self._tokenize(processing_class, example["prompt"] + example["chosen"])[ + "input_ids" + ] + prompt_rejected_ids = self._tokenize(processing_class, example["prompt"] + example["rejected"])[ + "input_ids" + ] + prompt_reference_ids = self._tokenize(processing_class, example["prompt"] + example["reference"])[ + "input_ids" + ] + + # Check if the tokenized prompt starts with the tokenized prompt+completion + if not prompt_chosen_ids[: len(prompt_ids)] == prompt_ids: + logger.warning( + "Mismatch between tokenized prompt and the start of tokenized prompt+chosen. " + "This may be due to unexpected tokenizer behavior, whitespace issues, or special " + "token handling. Verify that the tokenizer is processing text consistently." + ) + if not prompt_rejected_ids[: len(prompt_ids)] == prompt_ids: + logger.warning( + "Mismatch between tokenized prompt and the start of tokenized prompt+rejected. " + "This may be due to unexpected tokenizer behavior, whitespace issues, or special " + "token handling. Verify that the tokenizer is processing text consistently." + ) + if not prompt_reference_ids[: len(prompt_ids)] == prompt_ids: + logger.warning( + "Mismatch between tokenized prompt and the start of tokenized prompt+reference. " + "This may be due to unexpected tokenizer behavior, whitespace issues, or special " + "token handling. Verify that the tokenizer is processing text consistently." + ) + + output["prompt_ids"] = prompt_ids + output["chosen_ids"] = prompt_chosen_ids[len(prompt_ids) :] + output["rejected_ids"] = prompt_rejected_ids[len(prompt_ids) :] + output["reference_ids"] = prompt_reference_ids[len(prompt_ids) :] + return output + + dataset = dataset.map(tokenize_fn, fn_kwargs={"processing_class": processing_class}, **map_kwargs) + + return dataset + + def _set_signature_columns_if_needed(self): + # If `self.args.remove_unused_columns` is True, non-signature columns are removed. + # By default, this method sets `self._signature_columns` to the model's expected inputs (usually, "input_ids" + # and "attention_mask"). + if self._signature_columns is None: + self._signature_columns = ["prompt_ids", "chosen_ids", "rejected_ids", "reference_ids"] + + def _compute_loss(self, model, inputs, return_outputs): + mode = "train" if self.model.training else "eval" + + # When `tpo_alpha=0.0` the NLL term is disabled and the collator drops the reference branch, so the batch + # is laid out as `[chosen, rejected]` (n_branches=2). Otherwise it is `[chosen, rejected, reference]` + # (n_branches=3). + n_branches = 3 if self.tpo_alpha != 0.0 else 2 + + _non_model_keys = {"completion_mask"} + model_kwargs = {k: v for k, v in inputs.items() if k not in _non_model_keys} + model_kwargs["use_cache"] = False + outputs = model(**model_kwargs) + + input_ids = inputs["input_ids"] + completion_mask = inputs["completion_mask"] + shift_logits = outputs.logits[..., :-1, :].contiguous() + shift_labels = input_ids[..., 1:].contiguous() + shift_completion_mask = completion_mask[..., 1:].contiguous() + per_token_logps = selective_log_softmax(shift_logits, shift_labels) + per_token_logps[shift_completion_mask == 0] = 0.0 # mask out non-completion tokens + + # Length-normalized for IPO and TPO-L (matches the SimPO-style implicit reward used by the TPO paper); + # summed otherwise. + if self.loss_type in ("ipo", "tpo-l"): + completion_lengths = shift_completion_mask.sum(dim=1).clamp(min=1) + logps = per_token_logps.sum(dim=1) / completion_lengths + else: + logps = per_token_logps.sum(dim=1) + logps_chunks = logps.chunk(n_branches, dim=0) + chosen_logps, rejected_logps = logps_chunks[0], logps_chunks[1] + + # Contrastive loss between chosen and rejected. Unlike DPO, TPO does not subtract reference-model log-probs: + # the "reference" in TPO is a gold response used in the NLL term below, not a separate reference policy. + delta_score = chosen_logps - rejected_logps + + if self.loss_type == "sigmoid": + per_sequence_loss = ( + -F.logsigmoid(self.beta * delta_score) * (1 - self.label_smoothing) + - F.logsigmoid(-self.beta * delta_score) * self.label_smoothing + ) + + elif self.loss_type == "hinge": + per_sequence_loss = torch.relu(1 - self.beta * delta_score) + + elif self.loss_type == "ipo": + # (Eq. 17) of the IPO paper where beta is the regularization parameter for the IPO loss, denoted by τ. + per_sequence_loss = (delta_score - 1 / (2 * self.beta)) ** 2 + + elif self.loss_type == "tpo-l": + # Length-normalized TPO-L variant: subtract a target reward margin γ/β before the sigmoid. + gamma_logratios = self.tpo_l_gamma / self.beta + shifted_delta = delta_score - gamma_logratios + per_sequence_loss = ( + -F.logsigmoid(self.beta * shifted_delta) * (1 - self.label_smoothing) + - F.logsigmoid(-self.beta * shifted_delta) * self.label_smoothing + ) + + else: + raise ValueError( + f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'tpo-l']" + ) + + loss = per_sequence_loss.mean() + + # NLL loss on the gold (`reference`) response. Mirrors the `"sft"` loss branch of `DPOTrainer._compute_loss`: + # we restrict the cross-entropy to the completion tokens of the reference sequence and let `F.cross_entropy` + # average over them. The NLL contribution is folded into the main `loss` (matching DPO/SFT convention: the + # individual NLL term is not logged separately). + if n_branches == 3: + _, _, ref_logits = shift_logits.chunk(3, dim=0) + _, _, ref_labels = shift_labels.chunk(3, dim=0) + _, _, ref_mask = shift_completion_mask.chunk(3, dim=0) + ref_mask = ref_mask.bool() + nll_loss = F.cross_entropy(ref_logits[ref_mask], ref_labels[ref_mask]) + loss = loss + self.tpo_alpha * nll_loss + + # Log the metrics + # Entropy + per_token_entropy = entropy_from_logits(shift_logits.detach()) + entropy = per_token_entropy[shift_completion_mask.bool()].mean() + entropy = self.accelerator.gather_for_metrics(entropy).mean().item() + self._metrics[mode]["entropy"].append(entropy) + + # Number of tokens + if mode == "train": + num_tokens_in_batch = self.accelerator.gather_for_metrics(inputs["attention_mask"].sum()).sum().item() + self._total_train_tokens += num_tokens_in_batch + self._metrics[mode]["num_tokens"] = [self._total_train_tokens] + + # Average logits for chosen and rejected completions + logits_chunks = shift_logits.detach().chunk(n_branches, dim=0) + mask_chunks = shift_completion_mask.chunk(n_branches, dim=0) + labels_chunks = shift_labels.chunk(n_branches, dim=0) + chosen_logits, rejected_logits = logits_chunks[0], logits_chunks[1] + chosen_mask, rejected_mask = mask_chunks[0], mask_chunks[1] + chosen_labels = labels_chunks[0] + total_chosen_logits = chosen_logits[chosen_mask.bool()].mean(-1).sum() + total_chosen_tokens = chosen_mask.sum() + total_rejected_logits = rejected_logits[rejected_mask.bool()].mean(-1).sum() + total_rejected_tokens = rejected_mask.sum() + total_chosen_logits = self.accelerator.gather_for_metrics(total_chosen_logits).sum().item() + total_chosen_tokens = self.accelerator.gather_for_metrics(total_chosen_tokens).sum().item() + total_rejected_logits = self.accelerator.gather_for_metrics(total_rejected_logits).sum().item() + total_rejected_tokens = self.accelerator.gather_for_metrics(total_rejected_tokens).sum().item() + avg_chosen_logits = total_chosen_logits / total_chosen_tokens if total_chosen_tokens > 0 else 0.0 + avg_rejected_logits = total_rejected_logits / total_rejected_tokens if total_rejected_tokens > 0 else 0.0 + self._metrics[mode]["logits/chosen"].append(avg_chosen_logits) + self._metrics[mode]["logits/rejected"].append(avg_rejected_logits) + + # Token accuracy for the chosen completions + predictions = chosen_logits.argmax(dim=-1) + chosen_bool_mask = chosen_mask.bool() + correct_predictions = (predictions == chosen_labels) & chosen_bool_mask + total_tokens = chosen_bool_mask.sum() + correct_tokens = correct_predictions.sum() + correct_tokens = self.accelerator.gather_for_metrics(correct_tokens) + total_tokens = self.accelerator.gather_for_metrics(total_tokens) + total_sum = total_tokens.sum() + accuracy = (correct_tokens.sum() / total_sum).item() if total_sum > 0 else 0.0 + self._metrics[mode]["mean_token_accuracy"].append(accuracy) + + # Rewards for chosen and rejected completions (β · log π_θ as in the SimPO/TPO implicit reward) + chosen_rewards = self.beta * chosen_logps.detach() + rejected_rewards = self.beta * rejected_logps.detach() + agg_chosen_rewards = self.accelerator.gather(chosen_rewards) + agg_rejected_rewards = self.accelerator.gather(rejected_rewards) + self._metrics[mode]["rewards/chosen"].append(agg_chosen_rewards.mean().item()) + self._metrics[mode]["rewards/rejected"].append(agg_rejected_rewards.mean().item()) + + # Reward accuracy + reward_accuracies = (chosen_rewards > rejected_rewards).float() + agg_reward_accuracies = self.accelerator.gather(reward_accuracies) + self._metrics[mode]["rewards/accuracies"].append(agg_reward_accuracies.mean().item()) + + # Reward margins + margins = chosen_rewards - rejected_rewards + agg_margins = self.accelerator.gather(margins) + self._metrics[mode]["rewards/margins"].append(agg_margins.mean().item()) + + # Average log probabilities for chosen and rejected completions + self._metrics[mode]["logps/chosen"].append(self.accelerator.gather(chosen_logps).mean().item()) + self._metrics[mode]["logps/rejected"].append(self.accelerator.gather(rejected_logps).mean().item()) + + return (loss, outputs) if return_outputs else loss + + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): + return self._compute_loss(model, inputs, return_outputs) + + def log(self, logs: dict[str, float], start_time: float | None = None) -> None: + mode = "train" if self.model.training else "eval" + metrics = {key: sum(val) / len(val) for key, val in self._metrics[mode].items()} # average the metrics + + # This method can be called both in training and evaluation. When called in evaluation, the keys in `logs` + # start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format. + if mode == "eval": + metrics = {f"eval_{key}": val for key, val in metrics.items()} + + logs = {**logs, **metrics} + super().log(logs, start_time) + self._metrics[mode].clear() + + # During eval, Trainer calls prediction_step. If no labels are present in the inputs, it only runs forward and + # returns logits. We override prediction_step to force compute_loss, because this trainer doesn't involve labels. + def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys: list[str] | None = None): + inputs = self._prepare_inputs(inputs) + with torch.no_grad(), self.compute_loss_context_manager(): + if prediction_loss_only: + loss = self.compute_loss(model, inputs, return_outputs=False) + logits, labels = None, None + else: + loss, outputs = self.compute_loss(model, inputs, return_outputs=True) + logits, labels = outputs.logits, inputs["input_ids"] + return loss, logits, labels + + # Ensure the model card is saved along with the checkpoint + def _save_checkpoint(self, model, trial): + if self.args.hub_model_id is None: + model_name = Path(self.args.output_dir).name + else: + model_name = self.args.hub_model_id.split("/")[-1] + self.create_model_card(model_name=model_name) + super()._save_checkpoint(model, trial) diff --git a/trl/experimental/xpo/xpo_trainer.py b/trl/experimental/xpo/xpo_trainer.py index f88125e2916..29fe8d7a6a6 100644 --- a/trl/experimental/xpo/xpo_trainer.py +++ b/trl/experimental/xpo/xpo_trainer.py @@ -41,7 +41,7 @@ if is_peft_available(): - from peft import PeftModel + from peft import PeftConfig, PeftModel class XPOTrainer(OnlineDPOTrainer): @@ -74,7 +74,7 @@ class XPOTrainer(OnlineDPOTrainer): Processing class used to process the data. If provided, will be used to automatically process the inputs for the model, and it will be saved along the model to make it easier to rerun an interrupted training or reuse the fine-tuned model. - peft_config (`dict`): + peft_config ([`~peft.PeftConfig`], *optional*): The peft config to use for training. compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*): The function to use to compute the metrics. Must take a `EvalPrediction` and return a dictionary string to @@ -117,7 +117,7 @@ def __init__( | ProcessorMixin | None = None, reward_processing_classes: PreTrainedTokenizerBase | list[PreTrainedTokenizerBase] | None = None, - peft_config: dict | None = None, + peft_config: "PeftConfig | None" = None, compute_metrics: Callable[[EvalPrediction], dict] | None = None, callbacks: list[TrainerCallback] | None = None, optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), diff --git a/trl/import_utils.py b/trl/import_utils.py index f3e7c43452f..a8adca36f34 100644 --- a/trl/import_utils.py +++ b/trl/import_utils.py @@ -105,9 +105,9 @@ def is_uvicorn_available() -> bool: def is_vllm_available(min_version: str | None = None) -> bool: _vllm_available, _vllm_version = _is_package_available("vllm", return_version=True) if _vllm_available: - if not (Version("0.11.0") <= Version(_vllm_version) <= Version("0.18.0")): + if not (Version("0.12.0") <= Version(_vllm_version) <= Version("0.18.0")): warnings.warn( - f"TRL currently supports vLLM versions from 0.11.0 to 0.18.0. You have version {_vllm_version} " + f"TRL currently supports vLLM versions from 0.12.0 to 0.18.0. You have version {_vllm_version} " "installed. We recommend installing a supported version to avoid compatibility issues.", stacklevel=2, ) diff --git a/trl/scripts/vllm_serve.py b/trl/scripts/vllm_serve.py index faf9bf865bc..ccf182eeaac 100644 --- a/trl/scripts/vllm_serve.py +++ b/trl/scripts/vllm_serve.py @@ -14,6 +14,7 @@ import argparse import base64 +import json import logging import math import os @@ -214,6 +215,9 @@ class ScriptArguments: Distributed executor backend for vLLM. Set to `"ray"` to distribute tensor parallel workers across multiple nodes via a Ray cluster. Required when `tensor_parallel_size` exceeds the number of local GPUs. If not set, vLLM defaults to the multiproc backend (single-node only). + speculative_config (`str`, *optional*): + JSON string for vLLM speculative decoding config, forwarded to `LLM(speculative_config=...)`. When unset, + speculative decoding is disabled. Example: `'{"method": "qwen3_next_mtp", "num_speculative_tokens": 5}'`. """ model: str = field( @@ -318,6 +322,13 @@ class ScriptArguments: "GPUs. If not set, vLLM defaults to the multiproc backend (single-node only)." }, ) + speculative_config: str | None = field( + default=None, + metadata={ + "help": "JSON string for vLLM speculative decoding config. " + 'Example: \'{"method": "qwen3_next_mtp", "num_speculative_tokens": 5}\'' + }, + ) def llm_worker( @@ -350,6 +361,7 @@ def llm_worker( distributed_executor_backend=script_args.distributed_executor_backend, # Important so temperature scaling/logit tweaking affects the TIS log probs logprobs_mode="processed_logprobs", + speculative_config=json.loads(script_args.speculative_config) if script_args.speculative_config else None, ) # Send ready signal to parent process @@ -398,7 +410,6 @@ def chunk_list(lst: list, n: int) -> list[list]: def main(script_args: ScriptArguments): import asyncio - from packaging.version import Version from transformers import is_vision_available from trl.generation.vllm_generation import extract_logprobs @@ -428,16 +439,11 @@ def main(script_args: ScriptArguments): raise ImportError("vLLM is required to run the vLLM serve script. Please install it using `pip install vllm`.") import uvicorn - import vllm from fastapi import FastAPI from pydantic import BaseModel from vllm import SamplingParams from vllm.sampling_params import StructuredOutputsParams - - if Version(vllm.__version__) <= Version("0.11.0"): - from vllm.utils import get_open_port - else: - from vllm.utils.network_utils import get_open_port + from vllm.utils.network_utils import get_open_port if is_vision_available(): from PIL import Image diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index b40144e7206..b92a494b3c0 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -557,24 +557,39 @@ def __init__( # Handle pad token for processors or tokenizers if isinstance(processing_class, ProcessorMixin): - tokenizer = processing_class.tokenizer + self._tokenizer = processing_class.tokenizer self._is_vlm = True elif isinstance(processing_class, PreTrainedTokenizerBase): - tokenizer = processing_class + self._tokenizer = processing_class self._is_vlm = False else: raise TypeError("The `processing_class` must be either a `PreTrainedTokenizerBase` or a `ProcessorMixin`") - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token + if self._tokenizer.pad_token is None: + self._tokenizer.pad_token = self._tokenizer.eos_token - if is_peft_available() and is_peft_model(model) and peft_config is not None: - raise ValueError( - "You passed a `PeftModel` instance together with a `peft_config` to the trainer. Please first merge " - "and unload the existing adapter, save the resulting base model, and then pass that base model along " - "with the new `peft_config` to the trainer." - ) - if is_peft_available() and is_peft_model(model) and ref_model is None: + # PEFT + if peft_config is not None: + if not is_peft_available(): + raise ImportError( + "You passed `peft_config` but the `peft` library is not installed. " + "Install it with `pip install trl[peft]`." + ) + if not isinstance(peft_config, PeftConfig): + raise TypeError( + f"`peft_config` must be a `peft.PeftConfig` instance (e.g. `peft.LoraConfig`), " + f"got {type(peft_config).__name__}." + ) + if is_peft_model(model): + raise ValueError( + "You passed a `PeftModel` instance together with a `peft_config` to the trainer. Please first merge " + "and unload the existing adapter, save the resulting base model, and then pass that base model along " + "with the new `peft_config` to the trainer." + ) + # Create PEFT model + model = get_peft_model(model, peft_config) + + elif is_peft_available() and is_peft_model(model) and ref_model is None: # If the model is a PEFT model with a pretrained adapter, we need to create a "ref" adapter that is a copy # of the "default" adapter, so that we can use it as the reference model during DPO training. model.add_adapter("ref", model.peft_config["default"]) @@ -584,10 +599,6 @@ def __init__( ref_param = model.get_parameter(ref_name) ref_param.data.copy_(param.data) - # Create PEFT model - if peft_config is not None: - model = get_peft_model(model, peft_config) - # When using gradient checkpointing with PEFT, we need to enable input gradients. transformers.Trainer normally # handles this, but a bug currently prevents it; see https://github.com/huggingface/transformers/issues/42489 if is_peft_available() and isinstance(model, PeftModel) and args.gradient_checkpointing: @@ -629,16 +640,16 @@ def __init__( if data_collator is None and not self._is_vision_dataset: # Get the pad token: if not provided, use the one from the processing class or the eos token # if the processing class does not have a pad token. - pad_token = args.pad_token or tokenizer.pad_token or tokenizer.eos_token - if pad_token not in tokenizer.get_vocab(): + pad_token = args.pad_token or self._tokenizer.pad_token or self._tokenizer.eos_token + if pad_token not in self._tokenizer.get_vocab(): raise ValueError( f"The specified `pad_token` ('{pad_token}') is not found in the vocabulary of the given " f"`processing_class` ({processing_class.__class__.__name__}). Ensure that the `pad_token` exists " "in the vocabulary before using it as a padding token." ) - tokenizer.pad_token = pad_token + self._tokenizer.pad_token = pad_token data_collator = DataCollatorForPreference( - pad_token_id=tokenizer.pad_token_id, + pad_token_id=self._tokenizer.pad_token_id, max_length=args.max_length, truncation_mode=args.truncation_mode, pad_to_multiple_of=args.pad_to_multiple_of, @@ -880,7 +891,7 @@ def _prepare_dataset( map_kwargs["desc"] = f"Extracting prompt from {dataset_name} dataset" dataset = dataset.map(extract_prompt, **map_kwargs) - # Apply the chat template if needed + # Add EOS token if needed: non-conversational only first_example = next(iter(dataset)) if not is_conversational(first_example): if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` @@ -893,8 +904,7 @@ def add_eos(example, eos_token): example["rejected"] = example["rejected"] + eos_token return example - eos_token = processing_class.tokenizer.eos_token if self._is_vlm else processing_class.eos_token - dataset = dataset.map(add_eos, fn_kwargs={"eos_token": eos_token}, **map_kwargs) + dataset = dataset.map(add_eos, fn_kwargs={"eos_token": self._tokenizer.eos_token}, **map_kwargs) # Tokenize the dataset if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` @@ -1383,8 +1393,14 @@ def _compute_loss(self, model, inputs, return_outputs): # Log the metrics # Entropy per_token_entropy = entropy_from_logits(shift_logits.detach()) - entropy = per_token_entropy[shift_completion_mask.bool()].mean() - entropy = self.accelerator.gather_for_metrics(entropy).mean().item() + mask = shift_completion_mask + entropy_sum = (per_token_entropy * mask).sum() + total_tokens = mask.sum() + + # Gather counts across ranks and weight-average + entropy_sum = self.accelerator.gather_for_metrics(entropy_sum).sum() + total_tokens = self.accelerator.gather_for_metrics(total_tokens).sum() + entropy = (entropy_sum / total_tokens).item() if total_tokens > 0 else 0.0 self._metrics[mode]["entropy"].append(entropy) # Number of tokens diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 295770890b1..3f616d7508c 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -319,40 +319,53 @@ def __init__( # Handle pad token for processors or tokenizers if isinstance(processing_class, ProcessorMixin): - tokenizer = processing_class.tokenizer + self._tokenizer = processing_class.tokenizer self._is_vlm = True elif isinstance(processing_class, PreTrainedTokenizerBase): - tokenizer = processing_class + self._tokenizer = processing_class self._is_vlm = False else: raise TypeError("The `processing_class` must be either a `PreTrainedTokenizerBase` or a `ProcessorMixin`") + if self._tokenizer.pad_token is None: + self._tokenizer.pad_token = self._tokenizer.eos_token + # Resolve vision placeholder token IDs once. Used by the forward pass to rebuild mm_token_type_ids # when tool responses inject images into the completion (see _generate forward_kwargs block). self._image_pad_token_id = None self._video_pad_token_id = None if self._is_vlm: for candidate in ("<|image_pad|>", "<|image|>"): - tid = tokenizer.convert_tokens_to_ids(candidate) - if tid != tokenizer.unk_token_id: + tid = self._tokenizer.convert_tokens_to_ids(candidate) + if tid != self._tokenizer.unk_token_id: self._image_pad_token_id = tid break - tid = tokenizer.convert_tokens_to_ids("<|video_pad|>") - if tid != tokenizer.unk_token_id: + tid = self._tokenizer.convert_tokens_to_ids("<|video_pad|>") + if tid != self._tokenizer.unk_token_id: self._video_pad_token_id = tid - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token - self.pad_token_id = tokenizer.pad_token_id - self.eos_token_id = tokenizer.eos_token_id + # PEFT + if peft_config is not None: + if not is_peft_available(): + raise ImportError( + "You passed `peft_config` but the `peft` library is not installed. " + "Install it with `pip install trl[peft]`." + ) + if not isinstance(peft_config, PeftConfig): + raise TypeError( + f"`peft_config` must be a `peft.PeftConfig` instance (e.g. `peft.LoraConfig`), " + f"got {type(peft_config).__name__}." + ) + if is_peft_model(model): + raise ValueError( + "You passed a `PeftModel` instance together with a `peft_config` to the trainer. Please first merge " + "and unload the existing adapter, save the resulting base model, and then pass that base model along " + "with the new `peft_config` to the trainer." + ) + # Create PEFT model + model = get_peft_model(model, peft_config) - if is_peft_available() and is_peft_model(model) and peft_config is not None: - raise ValueError( - "You passed a `PeftModel` instance together with a `peft_config` to the trainer. Please first merge " - "and unload the existing adapter, save the resulting base model, and then pass that base model along " - "with the new `peft_config` to the trainer." - ) - if is_peft_available() and is_peft_model(model) and args.beta != 0.0: + elif is_peft_available() and is_peft_model(model) and args.beta != 0.0: # If the model is a PEFT model with a pretrained adapter, we need to create a "ref" adapter that is a copy # of the "default" adapter, so that we can use it as the reference model during GRPO training. model.add_adapter("ref", model.peft_config["default"]) @@ -362,10 +375,6 @@ def __init__( ref_param = model.get_parameter(ref_name) ref_param.data.copy_(param.data) - # Create PEFT model - if peft_config is not None: - model = get_peft_model(model, peft_config) - # When using gradient checkpointing with PEFT, we need to enable input gradients. transformers.Trainer normally # handles this, but a bug currently prevents it; see https://github.com/huggingface/transformers/issues/42489 if is_peft_available() and is_peft_model(model) and args.gradient_checkpointing: @@ -528,8 +537,7 @@ def __init__( # While waiting for broader adoption, we provide this utility function to manually set the response schema for # known chat templates. `response_schema` lives on the (inner) tokenizer, since `parse_response` is a tokenizer # method that reads `self.response_schema`. - tokenizer = processing_class.tokenizer if self._is_vlm else processing_class - if self.tools and getattr(tokenizer, "response_schema", None) is None: + if self.tools and getattr(self._tokenizer, "response_schema", None) is None: processing_class = add_response_schema(processing_class) # In multi-turn training, the chat template *must* be prefix-preserving. If the tokenizer's original template # isn't, we replace it at initialization with a training-safe, prefix-preserving template. @@ -785,9 +793,9 @@ def cast_outputs_to_original_dtype(module, args, output): generation_kwargs = { "max_new_tokens": self.max_completion_length, "do_sample": True, - "pad_token_id": tokenizer.pad_token_id, - "bos_token_id": tokenizer.bos_token_id, - "eos_token_id": tokenizer.eos_token_id, + "pad_token_id": self._tokenizer.pad_token_id, + "bos_token_id": self._tokenizer.bos_token_id, + "eos_token_id": self._tokenizer.eos_token_id, "temperature": self.temperature, "top_p": self.top_p, "top_k": self.top_k, @@ -1370,7 +1378,7 @@ def _generate_single_turn(self, prompt_ids, images, multimodal_fields): else: # Regular generation path: left-pad token IDs into tensors prompt_tensors = [torch.tensor(ids) for ids in prompt_ids] - padded_ids = pad(prompt_tensors, padding_value=self.pad_token_id, padding_side="left") + padded_ids = pad(prompt_tensors, padding_value=self._tokenizer.pad_token_id, padding_side="left") attention_mask = pad([torch.ones_like(t) for t in prompt_tensors], padding_value=0, padding_side="left") generate_inputs = {"input_ids": padded_ids, "attention_mask": attention_mask} # For VLMs, include multimodal fields as tensors (pixel_values, image_grid_thw, etc.) @@ -1403,7 +1411,7 @@ def _generate_single_turn(self, prompt_ids, images, multimodal_fields): completion_ids = prompt_completion_ids[:, prompt_length:] # Mask everything after the first EOS token - is_eos = completion_ids == self.eos_token_id + is_eos = completion_ids == self._tokenizer.eos_token_id eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device) eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1) @@ -1459,7 +1467,7 @@ def _get_tool_suffix_ids(self, tool_messages): # When we compute `suffix_ids` by slicing `full_ids`, we must align the slicing boundary to # EOS (not EOS + newline). Templates that don't use EOS as end-of-turn (e.g. Gemma uses # ) skip this trimming. - eos_positions = [i for i, tok_id in enumerate(prefix_ids) if tok_id == self.eos_token_id] + eos_positions = [i for i, tok_id in enumerate(prefix_ids) if tok_id == self._tokenizer.eos_token_id] if eos_positions: prefix_ids = prefix_ids[: eos_positions[-1] + 1] @@ -1658,9 +1666,7 @@ async def _run_async_tools(async_coros): completion_ids[idx_with_tool] = pct[prompt_length:] + post_tool_ids[idx] # Decode post-tool completions. - post_tool_completions = [ - parse_response(self.processing_class, ids) if ids else {} for ids in post_tool_ids - ] + post_tool_completions = [parse_response(self._tokenizer, ids) if ids else {} for ids in post_tool_ids] # Add post-tool completions to the existing completions for idx in range(len(idxs_with_tool)): @@ -1710,13 +1716,12 @@ def _generate(self, prompts: list): # Decode completions. It's important to use `parse_response` when possible, because it handles tool calls. if is_conversational({"prompt": prompts[0]}): - tokenizer = self.processing_class.tokenizer if self._is_vlm else self.processing_class if ( Version(transformers.__version__) >= Version("5.0.0") # parse_response added in v5 - and hasattr(tokenizer, "response_schema") # attribute not set by default for now - and tokenizer.response_schema is not None # only works if the tokenizer has a schema + and hasattr(self._tokenizer, "response_schema") # attribute not set by default for now + and self._tokenizer.response_schema is not None # only works if the tokenizer has a schema ): - completions = [[parse_response(self.processing_class, ids)] for ids in completion_ids] + completions = [[parse_response(self._tokenizer, ids)] for ids in completion_ids] else: contents = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) completions = [[{"role": "assistant", "content": content}] for content in contents] @@ -1770,7 +1775,7 @@ def _generate(self, prompts: list): self._metrics[mode]["completions/max_length"].append(agg_completion_lengths.float().max().item()) # Identify sequences that terminated with EOS and log their lengths - eos_and_pad = [self.eos_token_id, self.pad_token_id] + eos_and_pad = [self._tokenizer.eos_token_id, self._tokenizer.pad_token_id] is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids], device=device) agg_is_truncated = self.accelerator.gather(is_truncated) self._metrics[mode]["completions/clipped_ratio"].append(agg_is_truncated.float().mean().item()) @@ -1868,7 +1873,7 @@ def _generate_and_score_completions( prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids] prompt_ids = pad( prompt_ids, - padding_value=self.pad_token_id, + padding_value=self._tokenizer.pad_token_id, padding_side="left", pad_to_multiple_of=self.pad_to_multiple_of, ).to(device=device) @@ -1879,7 +1884,7 @@ def _generate_and_score_completions( completion_mask = [torch.ones_like(ids, dtype=torch.long) for ids in completion_ids] completion_ids = pad( completion_ids, - padding_value=self.pad_token_id, + padding_value=self._tokenizer.pad_token_id, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of, ).to(device=device) @@ -1906,7 +1911,7 @@ def _generate_and_score_completions( # If mask_truncated_completions is enabled, zero out truncated completions for attention and loss masking if self.mask_truncated_completions: - eos_and_pad = [self.eos_token_id, self.pad_token_id] + eos_and_pad = [self._tokenizer.eos_token_id, self._tokenizer.pad_token_id] is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids_list], device=device) # Mask completion_mask for attention masking completion_mask = completion_mask * (~is_truncated).unsqueeze(1).int() diff --git a/trl/trainer/reward_trainer.py b/trl/trainer/reward_trainer.py index e1e3ad2534c..91fba84be7c 100644 --- a/trl/trainer/reward_trainer.py +++ b/trl/trainer/reward_trainer.py @@ -404,8 +404,24 @@ def __init__( else: added_tokens = [] - # PEFT configuration and model wrapping + # PEFT if peft_config is not None: + if not is_peft_available(): + raise ImportError( + "You passed `peft_config` but the `peft` library is not installed. " + "Install it with `pip install trl[peft]`." + ) + if not isinstance(peft_config, PeftConfig): + raise TypeError( + f"`peft_config` must be a `peft.PeftConfig` instance (e.g. `peft.LoraConfig`), " + f"got {type(peft_config).__name__}." + ) + if is_peft_model(model): + raise ValueError( + "You passed a `PeftModel` instance together with a `peft_config` to the trainer. Please first merge " + "and unload the existing adapter, save the resulting base model, and then pass that base model along " + "with the new `peft_config` to the trainer." + ) if added_tokens: # Ensure that the added tokens are trainable if peft_config.trainable_token_indices is None: @@ -414,7 +430,6 @@ def __init__( peft_config.trainable_token_indices["embed_tokens"] = added_tokens else: peft_config.trainable_token_indices["embed_tokens"].extend(added_tokens) - # Ensure that the lm_head is trainable if peft_config.modules_to_save is None or "lm_head" not in peft_config.modules_to_save: logger.warning( @@ -428,16 +443,7 @@ def __init__( peft_config.modules_to_save = ["lm_head"] else: peft_config.modules_to_save.append("lm_head") - - if is_peft_available() and is_peft_model(model) and peft_config is not None: - raise ValueError( - "You passed a `PeftModel` instance together with a `peft_config` to the trainer. Please first merge " - "and unload the existing adapter, save the resulting base model, and then pass that base model along " - "with the new `peft_config` to the trainer." - ) - - # Create PEFT model - if peft_config is not None: + # Create PEFT model model = get_peft_model(model, peft_config) # When using gradient checkpointing with PEFT, we need to enable input gradients. transformers.Trainer normally @@ -559,7 +565,7 @@ def _prepare_dataset( with PartialState().main_process_first(): if not is_processed: - # Add EOS token to the end of the sequences if needed + # Add EOS token if needed: non-conversational only first_example = next(iter(dataset)) if not is_conversational(first_example): if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index 6924332a336..311c2ba81af 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -267,24 +267,37 @@ def __init__( # Handle pad token for processors or tokenizers if isinstance(processing_class, ProcessorMixin): - tokenizer = processing_class.tokenizer + self._tokenizer = processing_class.tokenizer elif isinstance(processing_class, PreTrainedTokenizerBase): - tokenizer = processing_class + self._tokenizer = processing_class else: raise TypeError("The `processing_class` must be either a `PreTrainedTokenizerBase` or a `ProcessorMixin`") - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token - self.pad_token_id = tokenizer.pad_token_id - self.eos_token_id = tokenizer.eos_token_id + if self._tokenizer.pad_token is None: + self._tokenizer.pad_token = self._tokenizer.eos_token - if is_peft_available() and is_peft_model(model) and peft_config is not None: - raise ValueError( - "You passed a `PeftModel` instance together with a `peft_config` to the trainer. Please first merge " - "and unload the existing adapter, save the resulting base model, and then pass that base model along " - "with the new `peft_config` to the trainer." - ) - if is_peft_available() and is_peft_model(model): + # PEFT + if peft_config is not None: + if not is_peft_available(): + raise ImportError( + "You passed `peft_config` but the `peft` library is not installed. " + "Install it with `pip install trl[peft]`." + ) + if not isinstance(peft_config, PeftConfig): + raise TypeError( + f"`peft_config` must be a `peft.PeftConfig` instance (e.g. `peft.LoraConfig`), " + f"got {type(peft_config).__name__}." + ) + if is_peft_model(model): + raise ValueError( + "You passed a `PeftModel` instance together with a `peft_config` to the trainer. Please first merge " + "and unload the existing adapter, save the resulting base model, and then pass that base model along " + "with the new `peft_config` to the trainer." + ) + # Create PEFT model + model = get_peft_model(model, peft_config) + + elif is_peft_available() and is_peft_model(model): # If the model is a PEFT model with a pretrained adapter, we need to create a "ref" adapter that is a copy # of the "default" adapter, so that we can use it as the reference model during the training. model.add_adapter("ref", model.peft_config["default"]) @@ -294,10 +307,6 @@ def __init__( ref_param = model.get_parameter(ref_name) ref_param.data.copy_(param.data) - # Create PEFT model - if peft_config is not None: - model = get_peft_model(model, peft_config) - # When using gradient checkpointing with PEFT, we need to enable input gradients. transformers.Trainer normally # handles this, but a bug currently prevents it; see https://github.com/huggingface/transformers/issues/42489 if is_peft_available() and is_peft_model(model) and args.gradient_checkpointing: @@ -332,14 +341,14 @@ def __init__( self.reward_func_names.append(reward_funcs[i].__name__) self.reward_funcs = reward_funcs - self._has_async_reward_funcs = any(inspect.iscoroutinefunction(func) for func in self.reward_funcs) - if self._has_async_reward_funcs: - self.async_reward_loop_thread, self.async_reward_loop, self.async_reward_loop_ready_event = ( - start_event_loop_in_daemon(name="RLOOTrainer-AsyncRewardLoop") + self._has_async_funcs = any(inspect.iscoroutinefunction(func) for func in self.reward_funcs) + if self._has_async_funcs: + self.async_loop_thread, self.async_loop, self.async_loop_ready_event = start_event_loop_in_daemon( + name="RLOOTrainer-AsyncRewardLoop" ) # wait until the event loop is running in the daemon thread - self.async_reward_loop_ready_event.wait() - atexit.register(shutdown_event_loop_in_daemon, self.async_reward_loop_thread, self.async_reward_loop) + self.async_loop_ready_event.wait() + atexit.register(shutdown_event_loop_in_daemon, self.async_loop_thread, self.async_loop) # Reward weights if args.reward_weights is not None: @@ -532,9 +541,9 @@ def __init__( generation_kwargs = { "max_new_tokens": self.max_completion_length, "do_sample": True, - "pad_token_id": tokenizer.pad_token_id, - "bos_token_id": tokenizer.bos_token_id, - "eos_token_id": tokenizer.eos_token_id, + "pad_token_id": self._tokenizer.pad_token_id, + "bos_token_id": self._tokenizer.bos_token_id, + "eos_token_id": self._tokenizer.eos_token_id, "temperature": self.temperature, "top_p": self.top_p, "top_k": self.top_k, @@ -866,7 +875,7 @@ def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list): # Execute async custom functions in parallel using asyncio.gather if async_funcs_info: - async def _invoke_async_reward(index, func, func_name): + async def _invoke_async(index, func, func_name): with profiling_context(self, func_name): output = await func( prompts=prompts, completions=completions, completion_ids=completion_ids_list, **reward_kwargs @@ -875,10 +884,10 @@ async def _invoke_async_reward(index, func, func_name): return index, output async def _run_async_funcs(): - coros = [_invoke_async_reward(i, func, func_name) for (i, func, func_name) in async_funcs_info] + coros = [_invoke_async(i, func, func_name) for (i, func, func_name) in async_funcs_info] return await asyncio.gather(*coros) - async_results = asyncio.run_coroutine_threadsafe(_run_async_funcs(), self.async_reward_loop).result() + async_results = asyncio.run_coroutine_threadsafe(_run_async_funcs(), self.async_loop).result() for idx, output_reward_func in async_results: rewards_per_func[:, idx] = torch.tensor(output_reward_func, dtype=torch.float32, device=device) @@ -993,7 +1002,7 @@ def _generate_single_turn(self, prompt_ids, images, multimodal_fields): else: # Regular generation path: left-pad token IDs into tensors prompt_tensors = [torch.tensor(ids) for ids in prompt_ids] - padded_ids = pad(prompt_tensors, padding_value=self.pad_token_id, padding_side="left") + padded_ids = pad(prompt_tensors, padding_value=self._tokenizer.pad_token_id, padding_side="left") attention_mask = pad([torch.ones_like(t) for t in prompt_tensors], padding_value=0, padding_side="left") generate_inputs = {"input_ids": padded_ids, "attention_mask": attention_mask} # For VLMs, include multimodal fields as tensors (pixel_values, image_grid_thw, etc.) @@ -1026,7 +1035,7 @@ def _generate_single_turn(self, prompt_ids, images, multimodal_fields): completion_ids = prompt_completion_ids[:, prompt_length:] # Mask everything after the first EOS token - is_eos = completion_ids == self.eos_token_id + is_eos = completion_ids == self._tokenizer.eos_token_id eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device) eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1) @@ -1073,7 +1082,7 @@ def _generate(self, prompts: list): self._metrics[mode]["completions/max_length"].append(agg_completion_lengths.float().max().item()) # Identify sequences that terminated with EOS and log their lengths - eos_and_pad = [self.eos_token_id, self.pad_token_id] + eos_and_pad = [self._tokenizer.eos_token_id, self._tokenizer.pad_token_id] is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids], device=device) agg_is_truncated = self.accelerator.gather(is_truncated) self._metrics[mode]["completions/clipped_ratio"].append(agg_is_truncated.float().mean().item()) @@ -1127,7 +1136,7 @@ def _generate_and_score_completions( prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids] prompt_ids = pad( prompt_ids, - padding_value=self.pad_token_id, + padding_value=self._tokenizer.pad_token_id, padding_side="left", pad_to_multiple_of=self.pad_to_multiple_of, ).to(device=device) @@ -1138,7 +1147,7 @@ def _generate_and_score_completions( completion_mask = [torch.ones_like(ids, dtype=torch.long) for ids in completion_ids] completion_ids = pad( completion_ids, - padding_value=self.pad_token_id, + padding_value=self._tokenizer.pad_token_id, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of, ).to(device=device) @@ -1148,7 +1157,7 @@ def _generate_and_score_completions( # If mask_truncated_completions is enabled, zero out truncated completions in completion_mask if self.mask_truncated_completions: - eos_and_pad = [self.eos_token_id, self.pad_token_id] + eos_and_pad = [self._tokenizer.eos_token_id, self._tokenizer.pad_token_id] # Mask completion_mask for attention masking is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids_list], device=device) completion_mask = completion_mask * (~is_truncated).unsqueeze(1).int() diff --git a/trl/trainer/sft_config.py b/trl/trainer/sft_config.py index 240790f7fb5..7c8da47fb29 100644 --- a/trl/trainer/sft_config.py +++ b/trl/trainer/sft_config.py @@ -99,8 +99,16 @@ class SFTConfig(_BaseConfig): on the assistant responses, which is supported only for [conversational](#conversational) datasets. If `False`, loss is computed on the entire sequence. loss_type (`str`, *optional*, defaults to `"nll"`): - Type of loss to use. Possible values are `"nll"` (negative log-likelihood, default) and `"dft"` (Dynamic - Fine-Tuning, as described in [this paper](https://huggingface.co/papers/2508.05629)). + Type of loss to use. Possible values are: + + - `"nll"`: standard negative log-likelihood (default). + - `"dft"`: Dynamic Fine-Tuning, as described in + [this paper](https://huggingface.co/papers/2508.05629). + - `"chunked_nll"`: same math as `"nll"`, but the `lm_head` projection is computed on non-ignored tokens + only (positions with `labels == -100` are dropped before the matmul) and the cross-entropy is processed + in chunks of tokens to reduce peak activation memory. Not compatible with `use_liger_kernel`, PEFT, or + VLM models. Under FSDP2, set `fsdp_reshard_after_forward false` in the accelerate config — the chunked + path otherwise re-gathers `lm_head.weight` per chunk during backward, adding noticeable wall-time. activation_offloading (`bool`, *optional*, defaults to `False`): Whether to offload the activations to the CPU. @@ -256,10 +264,10 @@ class SFTConfig(_BaseConfig): loss_type: str = field( default="nll", metadata={ - "help": ( - 'Type of loss to use. Possible values are `"nll"` (negative log-likelihood, default) and `"dft"` ' - "(Dynamic Fine-Tuning, as described in https://huggingface.co/papers/2508.05629)." - ) + "help": "Type of loss to use. Possible values are `'nll'` (negative log-likelihood, default), `'dft'` " + "(Dynamic Fine-Tuning, https://huggingface.co/papers/2508.05629), and `'chunked_nll'` (same math as " + "`'nll'` but skips the `'lm_head'` matmul on ignored tokens and chunks the CE to reduce peak memory; not " + "compatible with Liger, PEFT, or VLM)." }, ) activation_offloading: bool = field( diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index 68e78ec48bd..940dd2e4d22 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -15,6 +15,7 @@ import contextlib import json import os +import types import warnings from collections import defaultdict from collections.abc import Callable @@ -22,8 +23,10 @@ from pathlib import Path from typing import Any +import accelerate import torch import torch.nn as nn +import torch.nn.functional as F import transformers from accelerate import PartialState from accelerate.logging import get_logger @@ -39,6 +42,7 @@ TrainingArguments, ) from transformers.data.data_collator import DataCollatorMixin +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.trainer_callback import TrainerCallback from transformers.trainer_utils import EvalPrediction from transformers.utils import is_peft_available @@ -65,10 +69,243 @@ ) +_CHUNKED_LM_HEAD_CHUNK_SIZE = 256 + + if is_peft_available(): from peft import PeftConfig, PeftModel, PeftType, get_peft_model +@dataclass +class _ChunkedCELMHeadOutput(CausalLMOutputWithPast): + """`CausalLMOutputWithPast` with extra fields populated by the chunked-CE path.""" + + num_correct_tokens: torch.Tensor | None = None + entropy_sum: torch.Tensor | None = None + aux_loss: torch.Tensor | None = None + + +def _chunk(h, w, b, lbl, logit_scale, final_logit_softcapping): + logits = h.float() @ w.float().t() + if b is not None: + logits = logits + b.float() + if logit_scale != 1.0: + logits = logits * logit_scale + if final_logit_softcapping is not None: + logits = final_logit_softcapping * torch.tanh(logits / final_logit_softcapping) + log_p = F.log_softmax(logits, dim=-1) + chunk_loss = F.nll_loss(log_p, lbl, reduction="sum") + chunk_correct = (logits.argmax(dim=-1) == lbl).sum().float() + chunk_entropy = -(log_p.exp() * log_p).sum(dim=-1).sum() + return chunk_loss, chunk_correct, chunk_entropy + + +def _chunked_cross_entropy_loss( + hidden_states: torch.Tensor, + lm_head_weight: torch.Tensor, + chunk_size: int, + labels: torch.Tensor | None = None, + shift_labels: torch.Tensor | None = None, + num_items_in_batch: torch.Tensor | int | None = None, + logit_scale: float = 1.0, + final_logit_softcapping: float | None = None, + lm_head_bias: torch.Tensor | None = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Memory-efficient next-token cross-entropy over hidden states and an `lm_head` weight. + + The full `lm_head` projection is never materialized. Positions where labels equal `-100` are dropped before the + matmul, and the remaining tokens are processed in chunks of `chunk_size`. Each chunk's `[chunk_size, vocab_size]` + logits tensor is kept alive only during its own forward/backward pass via gradient checkpointing, so peak + logits-activation memory is `chunk_size * vocab_size` instead of `batch_size * seq_len * vocab_size`. + + At least one of `labels` or `shift_labels` must be provided. Passing `labels` alone is the standard path and + triggers the internal `labels[..., 1:]` / `hidden_states[..., :-1, :]` shift. Passing `shift_labels` skips the + shift and assumes the caller has already aligned labels with hidden states — this is the contract used under + context / sequence parallelism, where labels are shifted before being sharded. If both are provided, `shift_labels` + wins (matching [`~transformers.loss.ForCausalLMLoss`]). + + Args: + hidden_states (`torch.Tensor`): + Base decoder output of shape `(B, S, H)`, i.e. before the `lm_head` projection. + lm_head_weight (`torch.Tensor`): + Weight of the `lm_head` linear layer, shape `(V, H)`. + chunk_size (`int`): + Number of valid tokens processed per chunk. Peak memory scales linearly with this. + labels (`torch.Tensor`, *optional*): + Labels of shape `(B, S)`. Positions equal to `-100` are excluded from both the `lm_head` matmul and the + loss. Mutually exclusive with `shift_labels`. + shift_labels (`torch.Tensor`, *optional*): + Pre-shifted labels of shape `(B, S)`, aligned with `hidden_states` (position `i` predicts + `shift_labels[i]`). Mutually exclusive with `labels`. + num_items_in_batch (`torch.Tensor`, `int` or `None`, *optional*): + Total number of valid tokens across the global batch, as plumbed by [`~transformers.Trainer`]. When + provided, the loss is reduced as `sum / num_items_in_batch`, matching the gradient-accumulation-correct + behavior of HF's default cross-entropy. When `None`, reduction is `mean` over local valid tokens. + logit_scale (`float`, *optional*, defaults to `1.0`): + Multiplier applied to each chunk's logits before the cross-entropy, matching the `logit_scale` behavior of + Cohere-style models. + final_logit_softcapping (`float`, *optional*): + If set, applies `softcap * tanh(logits / softcap)` to each chunk's logits before the cross-entropy, + matching the `final_logit_softcapping` behavior of Gemma-style models. Applied after `logit_scale`. + lm_head_bias (`torch.Tensor`, *optional*): + Bias of the `lm_head` linear layer, shape `(V,)`. Added to each chunk's logits when provided. + + Returns: + `tuple[torch.Tensor, torch.Tensor, torch.Tensor]`: scalar loss, number of correctly-predicted tokens (count), + and sum of per-token Shannon entropy (in nats) — all over the local batch. Raw sums are returned so callers can + reduce correctly across ranks. + """ + if labels is None and shift_labels is None: + raise ValueError("At least one of `labels` or `shift_labels` must be provided.") + + if shift_labels is not None: + hidden = hidden_states.reshape(-1, hidden_states.size(-1)) + labels = shift_labels.reshape(-1) + else: + hidden = hidden_states[..., :-1, :].reshape(-1, hidden_states.size(-1)) + labels = labels[..., 1:].reshape(-1) + + valid = labels != -100 + hidden = hidden[valid] + labels = labels[valid] + n_valid = hidden.size(0) + + correct = hidden.new_zeros((), dtype=torch.float32) + entropy_sum = hidden.new_zeros((), dtype=torch.float32) + if n_valid == 0: + # Whole micro-batch masked (e.g. completion-only loss + truncation). Keep the loss connected + # to the autograd graph through every trainable parameter so `.backward()` succeeds and DDP / + # FSDP gradient sync doesn't hang on a missing param. + loss = (hidden_states.float().sum() + lm_head_weight.float().sum()) * 0.0 + if lm_head_bias is not None: + loss = loss + lm_head_bias.float().sum() * 0.0 + return loss, correct, entropy_sum + + loss = hidden.new_zeros((), dtype=torch.float32) + + for start in range(0, n_valid, chunk_size): + h_chunk = hidden[start : start + chunk_size] + lbl_chunk = labels[start : start + chunk_size] + chunk_loss, chunk_correct, chunk_entropy = torch.utils.checkpoint.checkpoint( + _chunk, + h_chunk, + lm_head_weight, + lm_head_bias, + lbl_chunk, + logit_scale, + final_logit_softcapping, + use_reentrant=False, + ) + loss = loss + chunk_loss + correct = correct + chunk_correct + entropy_sum = entropy_sum + chunk_entropy + + if num_items_in_batch is None: + loss = loss / n_valid + else: + if isinstance(num_items_in_batch, torch.Tensor): + num_items_in_batch = num_items_in_batch.to(loss.device) + loss = loss / num_items_in_batch + return loss, correct, entropy_sum + + +def _patch_chunked_ce_lm_head(model: torch.nn.Module, chunk_size: int) -> None: + """ + Patch a causal LM so its `forward` computes the language modeling loss via [`_chunked_cross_entropy_loss`] when + `labels` are provided. + + The patched forward calls the base decoder directly (`model.get_decoder()`) to obtain hidden states, skips the + `lm_head` matmul on positions with `labels == -100`, and computes the cross-entropy in chunks of `chunk_size` valid + tokens. It returns a [`_ChunkedCELMHeadOutput`] with `loss` set, `logits` set to `None`, and `token_accuracy` / + `entropy` fields set to the mean values over non-ignored tokens. Also accepts pre-shifted `shift_labels` in place + of `labels`, for the context / sequence parallelism path. When both are `None`, the original forward is invoked so + generation and labels-free evaluation preserve any per-model logits post-processing (e.g. `logit_scale`, + `final_logit_softcapping`). + + For MoE models with `output_router_logits=True`, the load-balancing auxiliary loss is added to the main loss with + the same coefficient (`router_aux_loss_coef`) and formula (`load_balancing_loss_func`) used by the model's own + forward, so the chunked path remains numerically equivalent to the reference. + + Not supported yet: VLM / multimodal models whose forward injects visual tokens outside the base decoder, and + PEFT-wrapped models. + """ + final_logit_softcapping = getattr(model.config, "final_logit_softcapping", None) + logit_scale = getattr(model.config, "logit_scale", 1.0) + original_forward = model.forward + + def _chunked_ce_forward( + self: torch.nn.Module, + input_ids: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + labels: torch.Tensor | None = None, + num_items_in_batch: torch.Tensor | int | None = None, + shift_labels: torch.Tensor | None = None, + output_router_logits: bool | None = None, + **kwargs, + ) -> CausalLMOutputWithPast: + # Without labels, fall back to the original forward so generation and labels-free evaluation + # preserve any per-model logits post-processing (e.g. Cohere `logit_scale`, Gemma + # `final_logit_softcapping`, `logits_to_keep` slicing). + if labels is None and shift_labels is None: + if output_router_logits is not None: + kwargs["output_router_logits"] = output_router_logits + return original_forward(input_ids=input_ids, attention_mask=attention_mask, **kwargs) + + if output_router_logits is None: + output_router_logits = getattr(self.config, "output_router_logits", False) + + kwargs.pop("use_cache", None) + decoder_kwargs = {} + if output_router_logits: + decoder_kwargs["output_router_logits"] = True + outputs: BaseModelOutputWithPast = self.get_decoder()( + input_ids=input_ids, attention_mask=attention_mask, use_cache=False, **decoder_kwargs, **kwargs + ) + hidden_states = outputs.last_hidden_state + + loss, num_correct_tokens, entropy_sum = _chunked_cross_entropy_loss( + hidden_states, + self.lm_head.weight, + chunk_size, + labels=labels, + shift_labels=shift_labels, + num_items_in_batch=num_items_in_batch, + logit_scale=logit_scale, + final_logit_softcapping=final_logit_softcapping, + lm_head_bias=self.lm_head.bias, + ) + + aux_loss = None + if output_router_logits: + # Mirror the per-family MoE forward: add `router_aux_loss_coef * load_balancing_loss_func(...)` to + # the main loss. Mixtral is the source of truth — every MoE family (Qwen3Moe, GptOss, OLMoE, + # Qwen2Moe, DBRX, JetMoE, PhiMoE, …) pulls this function from mixtral via the modular system, so a + # single import keeps us in lockstep with upstream for every family we test. + from transformers.models.mixtral.modeling_mixtral import load_balancing_loss_func + + aux_loss = load_balancing_loss_func( + outputs.router_logits, + self.num_experts, + self.num_experts_per_tok, + attention_mask, + ) + loss = loss + self.router_aux_loss_coef * aux_loss.to(loss.device) + + return _ChunkedCELMHeadOutput( + loss=loss, + logits=None, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + num_correct_tokens=num_correct_tokens, + entropy_sum=entropy_sum, + aux_loss=aux_loss, + ) + + model.forward = types.MethodType(_chunked_ce_forward, model) + + logger = get_logger(__name__) @@ -702,22 +939,22 @@ def __init__( # Handle pad token for processors or tokenizers if isinstance(processing_class, ProcessorMixin): - tokenizer = processing_class.tokenizer + self._tokenizer = processing_class.tokenizer self._is_vlm = True elif isinstance(processing_class, PreTrainedTokenizerBase): - tokenizer = processing_class + self._tokenizer = processing_class self._is_vlm = False else: raise TypeError("The `processing_class` must be either a `PreTrainedTokenizerBase` or a `ProcessorMixin`") if args.eos_token is not None: - if args.eos_token not in tokenizer.get_vocab(): + if args.eos_token not in self._tokenizer.get_vocab(): raise ValueError( f"The specified `eos_token` ('{args.eos_token}') is not found in the vocabulary of the given " f"`processing_class` ({processing_class.__class__.__name__}). Ensure that the `eos_token` exists " "in the vocabulary before using it as an EOS token." ) - tokenizer.eos_token = args.eos_token + self._tokenizer.eos_token = args.eos_token if args.chat_template_path is not None: if os.path.isfile(args.chat_template_path) and args.chat_template_path.endswith((".jinja", ".j2")): @@ -754,8 +991,24 @@ def __init__( "tokens in input_ids. Use truncation_mode='keep_start' (the default) or set max_length=None." ) - # PEFT configuration and model wrapping + # PEFT if peft_config is not None: + if not is_peft_available(): + raise ImportError( + "You passed `peft_config` but the `peft` library is not installed. " + "Install it with `pip install trl[peft]`." + ) + if not isinstance(peft_config, PeftConfig): + raise TypeError( + f"`peft_config` must be a `peft.PeftConfig` instance (e.g. `peft.LoraConfig`), " + f"got {type(peft_config).__name__}." + ) + if is_peft_model(model): + raise ValueError( + "You passed a `PeftModel` instance together with a `peft_config` to the trainer. Please first merge " + "and unload the existing adapter, save the resulting base model, and then pass that base model along " + "with the new `peft_config` to the trainer." + ) if added_tokens: # Ensure that the added tokens are trainable if peft_config.trainable_token_indices is None: @@ -764,7 +1017,6 @@ def __init__( peft_config.trainable_token_indices["embed_tokens"] = added_tokens else: peft_config.trainable_token_indices["embed_tokens"].extend(added_tokens) - # Ensure that the lm_head is trainable if peft_config.modules_to_save is None or "lm_head" not in peft_config.modules_to_save: logger.warning( @@ -778,16 +1030,7 @@ def __init__( peft_config.modules_to_save = ["lm_head"] else: peft_config.modules_to_save.append("lm_head") - - if is_peft_available() and is_peft_model(model) and peft_config is not None: - raise ValueError( - "You passed a `PeftModel` instance together with a `peft_config` to the trainer. Please first merge " - "and unload the existing adapter, save the resulting base model, and then pass that base model along " - "with the new `peft_config` to the trainer." - ) - - # Create PEFT model - if peft_config is not None: + # Create PEFT model model = get_peft_model(model, peft_config) # PEFT + DeepSpeed ZeRO-3 requires reentrant checkpointing. For more details, see @@ -880,16 +1123,16 @@ def __init__( if data_collator is None and not self._is_vision_dataset: # Get the pad token: if not provided, use the one from the processing class or the eos token # if the processing class does not have a pad token. - pad_token = args.pad_token or tokenizer.pad_token or tokenizer.eos_token - if pad_token not in tokenizer.get_vocab(): + pad_token = args.pad_token or self._tokenizer.pad_token or self._tokenizer.eos_token + if pad_token not in self._tokenizer.get_vocab(): raise ValueError( f"The specified `pad_token` ('{pad_token}') is not found in the vocabulary of the given " f"`processing_class` ({processing_class.__class__.__name__}). Ensure that the `pad_token` exists " "in the vocabulary before using it as a padding token." ) - tokenizer.pad_token = pad_token + self._tokenizer.pad_token = pad_token data_collator = DataCollatorForLanguageModeling( - pad_token_id=tokenizer.pad_token_id, + pad_token_id=self._tokenizer.pad_token_id, max_length=None if self.padding_free else args.max_length, truncation_mode=args.truncation_mode, completion_only_loss=self.completion_only_loss, @@ -976,8 +1219,22 @@ def __init__( "passing a `compute_loss_func` is not allowed." ) compute_loss_func = dft_loss + elif args.loss_type == "chunked_nll": + # Same math as `"nll"` but the `lm_head` matmul is skipped on ignored tokens and the CE is computed in + # chunks of tokens. Implemented by patching the model's forward before `super().__init__` so accelerate + # wraps the patched forward. + if self._is_vlm: + raise NotImplementedError("`loss_type='chunked_nll'` is not supported for VLM models yet.") + if peft_config is not None or is_peft_model(model): + raise NotImplementedError("`loss_type='chunked_nll'` is not supported with PEFT yet.") + _patch_chunked_ce_lm_head(model, chunk_size=_CHUNKED_LM_HEAD_CHUNK_SIZE) else: - raise ValueError(f"Invalid `loss_type` {args.loss_type} passed. Supported values are 'nll' and 'dft'.") + raise ValueError( + f"Invalid `loss_type` {args.loss_type} passed. Supported values are 'nll', 'dft', and " + "'chunked_nll'." + ) + elif args.loss_type == "chunked_nll": + raise ValueError("`loss_type='chunked_nll'` is not compatible with `use_liger_kernel=True`.") # Transformers explicitly set use_reentrant=True in the past to silence a PyTorch warning, but the default was # never updated once PyTorch switched to recommending use_reentrant=False. Until that change lands upstream @@ -1010,6 +1267,24 @@ def __init__( self.aux_loss_enabled = getattr(model.config, "output_router_logits", False) + # Under FSDP2 with `reshard_after_forward=True` (accelerate's default), the chunked CE path triggers a + # redundant `lm_head.weight` all-gather per chunk during backward, adding significant wall-time. Setting + # `reshard_after_forward=False` keeps the un-wrapped `lm_head` resident and closes the gap without meaningfully + # affecting peak memory. + # `AcceleratorState.is_fsdp2` was added in accelerate 1.6.0; guard so older (but still-supported) versions + # don't `AttributeError` on every SFTTrainer init. + if ( + args.loss_type == "chunked_nll" + and Version(accelerate.__version__) >= Version("1.6.0") + and self.accelerator.state.is_fsdp2 + and self.accelerator.state.fsdp_plugin.reshard_after_forward + ): + logger.warning( + "`loss_type='chunked_nll'` under FSDP2 with `reshard_after_forward=True` is significantly slower than " + "necessary due to per-chunk all-gathers of `lm_head.weight`. Consider passing " + "`--fsdp_reshard_after_forward false` to `accelerate launch` (or equivalent in your FSDP config)." + ) + # Initialize the metrics self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)} self._total_train_tokens = 0 @@ -1102,7 +1377,7 @@ def _func(example): **map_kwargs, ) - # Apply the chat template if needed + # Add EOS token if needed: non-conversational only first_example = next(iter(dataset)) if not is_conversational(first_example): if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc` @@ -1115,10 +1390,9 @@ def add_eos(example, eos_token): example["completion"] = example["completion"] + eos_token return example - eos_token = processing_class.tokenizer.eos_token if self._is_vlm else processing_class.eos_token dataset = dataset.map( add_eos, - fn_kwargs={"eos_token": eos_token}, + fn_kwargs={"eos_token": self._tokenizer.eos_token}, remove_columns="messages" if "messages" in column_names else None, # renamed to "text" **map_kwargs, ) @@ -1286,24 +1560,38 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N ) # Compute entropy - if not self.args.use_liger_kernel: # liger doesn't return logits + if self.args.loss_type == "chunked_nll": + shift_labels = inputs["shift_labels"] if "shift_labels" in inputs else labels[..., 1:] + n_valid = self.accelerator.gather_for_metrics((shift_labels != -100).sum()).sum() + entropy_sum = self.accelerator.gather_for_metrics(outputs.entropy_sum).sum() + entropy = (entropy_sum / n_valid).item() if n_valid > 0 else 0.0 + self._metrics[mode]["entropy"].append(entropy) + elif not self.args.use_liger_kernel: # liger doesn't return logits with torch.no_grad(): - per_token_entropy = entropy_from_logits(outputs.logits) - # When using Prompt Tuning, skip the virtual tokens in logits before entropy computation, since they - # do not correspond to actual input tokens. + if "shift_labels" in inputs: + # When using CP or SP, labels are pre-shifted. + shift_logits = outputs.logits.contiguous() + shift_labels = inputs["shift_labels"] + else: + shift_logits = outputs.logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + + # Prompt Tuning and P-Tuning output logits for virtual tokens but Prefix-Tuning does not. if ( self.num_virtual_tokens > 0 and model.peft_config[model.active_adapter].peft_type != PeftType.PREFIX_TUNING ): - per_token_entropy = per_token_entropy[:, self.num_virtual_tokens :] - if "attention_mask" in inputs: - attention_mask = inputs["attention_mask"] - entropy = torch.sum(per_token_entropy * attention_mask) / attention_mask.sum() - elif "position_ids" in inputs: - entropy = torch.mean(per_token_entropy) - else: - raise ValueError("Expected 'attention_mask' or 'position_ids' in inputs.") - entropy = self.accelerator.gather_for_metrics(entropy).mean().item() + shift_logits = shift_logits[:, self.num_virtual_tokens :, :] + + per_token_entropy = entropy_from_logits(shift_logits) + mask = shift_labels != -100 + entropy_sum = (per_token_entropy * mask).sum() + total_tokens = mask.sum() + + # Gather counts across ranks and weight-average + entropy_sum = self.accelerator.gather_for_metrics(entropy_sum).sum() + total_tokens = self.accelerator.gather_for_metrics(total_tokens).sum() + entropy = (entropy_sum / total_tokens).item() if total_tokens > 0 else 0.0 self._metrics[mode]["entropy"].append(entropy) if mode == "train": @@ -1319,7 +1607,13 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N self._total_train_tokens += num_tokens_in_batch self._metrics[mode]["num_tokens"] = [self._total_train_tokens] - if self.args.use_liger_kernel: + if self.args.loss_type == "chunked_nll": + shift_labels = inputs["shift_labels"] if "shift_labels" in inputs else labels[..., 1:] + n_valid = self.accelerator.gather_for_metrics((shift_labels != -100).sum()).sum() + correct = self.accelerator.gather_for_metrics(outputs.num_correct_tokens).sum() + accuracy = (correct / n_valid).item() if n_valid > 0 else 0.0 + self._metrics[mode]["mean_token_accuracy"].append(accuracy) + elif self.args.use_liger_kernel: if hasattr(outputs, "token_accuracy") and outputs.token_accuracy is not None: token_accuracy = self.accelerator.gather_for_metrics(outputs.token_accuracy).mean().item() self._metrics[mode]["mean_token_accuracy"].append(token_accuracy) diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index c93cfa76a89..2fdae036c70 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -22,7 +22,6 @@ import types from collections.abc import Mapping, Sequence, Sized from contextlib import contextmanager -from dataclasses import dataclass from importlib.metadata import version from itertools import accumulate from typing import TypeVar @@ -43,7 +42,6 @@ is_comet_available, is_trackio_available, ) -from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.models.auto.auto_factory import _BaseAutoModelClass from transformers.utils import ( is_peft_available, @@ -1054,66 +1052,6 @@ def get_config_model_id(config: PretrainedConfig) -> str: return getattr(config, "_name_or_path", "") -@dataclass -class CausalLMOutputWithPastAndFlatLogits(CausalLMOutputWithPast): - flat_logits: torch.Tensor | None = None - - -def forward_masked_logits( - model: PreTrainedModel, logits_mask: torch.LongTensor, **kwargs -) -> CausalLMOutputWithPastAndFlatLogits: - """ - Run a Causal LM forward pass while computing logits only for masked positions to reduce memory usage. - - These are always equal: - - ```python - full_outputs = model(input_ids=input_ids) - masked_outputs = forward_masked_logits(model, mask, input_ids=input_ids) - - assert torch.equal( - masked_outputs.flat_logits, - full_outputs.logits[mask.bool()], - ) - ``` - - Args: - model ([`~transformers.PreTrainedModel`]): - A causal language model. - logits_mask (`torch.LongTensor`): - Boolean-like tensor indicating which token positions should have logits computed. Shape should match the - input sequence shape in `kwargs` (typically `[batch, seq_len]`). - **kwargs: - Keyword arguments forwarded to the inner decoder (e.g., `input_ids`, `attention_mask`, `past_key_values`). - - Returns: - `CausalLMOutputWithPastAndFlatLogits`: Output containing logits only for the unmasked positions. - - Raises: - ValueError: If `logits_to_keep` or `labels` are provided in `kwargs`. - """ - if kwargs.get("logits_to_keep") is not None: - raise ValueError("`logits_to_keep` is not supported by this forward helper.") - if kwargs.get("labels") is not None: - raise ValueError("`labels` is not yet supported by this forward helper.") - - outputs: BaseModelOutputWithPast = model.get_decoder()(**kwargs) - hidden_states = outputs.last_hidden_state - - # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - flat_logits = model.lm_head(hidden_states[logits_mask.bool()]) - if hasattr(model, "logit_scale"): # CohereForCausalLM has this attribute - flat_logits = flat_logits * model.logit_scale - - return CausalLMOutputWithPastAndFlatLogits( - flat_logits=flat_logits, - # We use .get(...) because some models like FalconMambaForCausalLM don't return past_key_values or attentions - past_key_values=outputs.get("past_key_values"), - hidden_states=outputs.hidden_states, - attentions=outputs.get("attentions"), - ) - - @contextmanager def use_adapter(model: "PeftModel", adapter_name: str | None): """