diff --git a/tests/test_utils.py b/tests/test_utils.py index 0061dd5e5e..9fd7ed9e0f 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -26,6 +26,7 @@ from trl.trainer.utils import ( DataCollatorForChatML, batch_generation, + compute_token_accuracy, decode_and_strip_padding, flush_left, generate_model_card, @@ -451,3 +452,57 @@ def test_no_tensors(self): expected_mask = torch.tensor([[1, 1, 1], [1, 1, 0]]) self.assertTrue(torch.equal(new_mask, expected_mask)) + + +class TestComputeTokenAccuracy(unittest.TestCase): + def test_basic_accuracy(self): + # Test basic accuracy computation + logits = torch.tensor([[[0.9, 0.1], [0.8, 0.2]], [[0.3, 0.7], [0.6, 0.4]]]) # Shape: [2, 2, 2] + labels = torch.tensor([[1, 0], [1, 0]]) # Shape: [2, 2] + accuracy = compute_token_accuracy(logits, labels) + self.assertAlmostEqual(accuracy, 0.75) # 3 correct out of 4 tokens + + def test_with_ignore_index(self): + # Test accuracy computation with ignored tokens + logits = torch.tensor([[[0.9, 0.1], [0.8, 0.2]], [[0.3, 0.7], [0.6, 0.4]]]) + labels = torch.tensor([[1, -100], [1, 0]]) # -100 is ignored + accuracy = compute_token_accuracy(logits, labels, ignore_index=-100) + self.assertAlmostEqual(accuracy, 2 / 3) # 2 correct out of 3 non-ignored tokens + + def test_all_ignored(self): + # Test case where all tokens are ignored + logits = torch.tensor([[[0.1, 0.9], [0.8, 0.2]]]) + labels = torch.tensor([[-100, -100]]) + accuracy = compute_token_accuracy(logits, labels) + self.assertEqual(accuracy, 0.0) # No valid tokens to compute accuracy + + def test_perfect_accuracy(self): + # Test case with 100% accuracy + logits = torch.tensor([[[0.1, 0.9], [0.8, 0.2]]]) + labels = torch.tensor([[1, 0]]) + accuracy = compute_token_accuracy(logits, labels) + self.assertEqual(accuracy, 1.0) # All predictions correct + + def test_zero_accuracy(self): + # Test case with 0% accuracy + logits = torch.tensor([[[0.1, 0.9], [0.8, 0.2]]]) + labels = torch.tensor([[0, 1]]) + accuracy = compute_token_accuracy(logits, labels) + self.assertEqual(accuracy, 0.0) # All predictions wrong + + def test_batch_accuracy(self): + # Test accuracy computation across multiple batches + logits = torch.tensor( + [ + [[0.9, 0.1], [0.8, 0.2], [0.3, 0.7]], # Batch 1 + [[0.2, 0.8], [0.7, 0.3], [0.6, 0.4]], # Batch 2 + ] + ) + labels = torch.tensor( + [ + [1, 0, 1], # Batch 1 + [1, 0, -100], # Batch 2 (last token ignored) + ] + ) + accuracy = compute_token_accuracy(logits, labels) + self.assertAlmostEqual(accuracy, 0.8) diff --git a/trl/__init__.py b/trl/__init__.py index 4d5e9e041e..44a4333d53 100644 --- a/trl/__init__.py +++ b/trl/__init__.py @@ -100,7 +100,7 @@ "XPOTrainer", ], "trainer.callbacks": ["MergeModelCallback", "RichProgressCallback", "SyncRefModelCallback"], - "trainer.utils": ["get_kbit_device_map", "get_peft_config", "get_quantization_config"], + "trainer.utils": ["get_kbit_device_map", "get_peft_config", "get_quantization_config", "compute_token_accuracy"], } try: @@ -200,7 +200,7 @@ XPOTrainer, ) from .trainer.callbacks import RichProgressCallback, SyncRefModelCallback - from .trainer.utils import get_kbit_device_map, get_peft_config, get_quantization_config + from .trainer.utils import compute_token_accuracy, get_kbit_device_map, get_peft_config, get_quantization_config try: if not is_diffusers_available(): diff --git a/trl/trainer/__init__.py b/trl/trainer/__init__.py index 85968218cc..9ef887864a 100644 --- a/trl/trainer/__init__.py +++ b/trl/trainer/__init__.py @@ -76,6 +76,7 @@ "disable_dropout_in_model", "empty_cache", "peft_module_casting_to_bf16", + "compute_token_accuracy", ], "xpo_config": ["XPOConfig"], "xpo_trainer": ["XPOTrainer"], @@ -144,6 +145,7 @@ DataCollatorForCompletionOnlyLM, RunningMoments, compute_accuracy, + compute_token_accuracy, disable_dropout_in_model, empty_cache, peft_module_casting_to_bf16, diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index 9e2e5fe04f..086a5bac79 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -16,15 +16,18 @@ import inspect import os import warnings +from collections import defaultdict from typing import Callable, Optional, Union import datasets import torch import torch.nn as nn +import transformers from accelerate.state import PartialState from datasets import Dataset from datasets.arrow_writer import SchemaInferenceError from datasets.builder import DatasetGenerationError +from packaging import version from transformers import ( AutoModelForCausalLM, AutoTokenizer, @@ -48,6 +51,7 @@ from .utils import ( ConstantLengthDataset, DataCollatorForCompletionOnlyLM, + compute_token_accuracy, generate_model_card, get_comet_experiment_url, peft_module_casting_to_bf16, @@ -304,6 +308,9 @@ def make_inputs_require_grad(module, input, output): UserWarning, ) + # Initialize the metrics + self._metrics = defaultdict(list) + super().__init__( model=model, args=args, @@ -546,3 +553,42 @@ def create_model_card( ) model_card.save(os.path.join(self.args.output_dir, "README.md")) + + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): + """ + Compute training loss and additionally compute token accuracies + """ + (loss, outputs) = super().compute_loss( + model, inputs, return_outputs=True, num_items_in_batch=num_items_in_batch + ) + + # Compute token accuracy if we have labels + if "labels" in inputs: + shift_logits = outputs.logits[..., :-1, :].contiguous() + shift_labels = inputs["labels"][..., 1:].contiguous() + + # Gather logits and labels from all GPUs first + shift_logits = self.accelerator.gather_for_metrics(shift_logits) + shift_labels = self.accelerator.gather_for_metrics(shift_labels) + + # Then compute accuracy on the gathered tensors + if self.accelerator.is_main_process: + accuracy = compute_token_accuracy(shift_logits, shift_labels) + self._metrics["mean_token_accuracy"].append(accuracy) + + return (loss, outputs) if return_outputs else loss + + def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None: + metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()} # average the metrics + + # This method can be called both in training and evaluation. When called in evaluation, the keys in `logs` + # start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format. + if next(iter(logs.keys())).startswith("eval_"): + metrics = {f"eval_{key}": val for key, val in metrics.items()} + + logs = {**logs, **metrics} + if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"): + super().log(logs, start_time) + else: # transformers<=4.46 + super().log(logs) + self._metrics.clear() diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 719d952f1f..029e5639ab 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -1647,3 +1647,24 @@ def flush_left(mask: torch.Tensor, *tensors: torch.Tensor) -> tuple[torch.Tensor return mask else: return mask, *tensors + + +def compute_token_accuracy(logits: torch.Tensor, labels: torch.Tensor, ignore_index: int = -100) -> float: + """ + Compute the mean token accuracy. + """ + # Get predictions + predictions = logits.argmax(dim=-1) + + # Create mask for non-padding tokens (assuming pad_token_id is ignore_index) + mask = labels != ignore_index + + # Calculate accuracy only on non-padding tokens + correct_predictions = (predictions == labels) & mask + total_tokens = mask.sum() + correct_tokens = correct_predictions.sum() + + # Calculate accuracy + accuracy = correct_tokens.item() / total_tokens.item() if total_tokens > 0 else 0.0 + + return accuracy