From 1a6e8b40d60e85d4372cf2d425096e7d501082ab Mon Sep 17 00:00:00 2001 From: Jscaldwell55 Date: Mon, 4 Aug 2025 06:31:14 -0500 Subject: [PATCH] Add validation for LinearCrossEntropyLoss with custom_sharded_layers When using LinearCrossEntropyLoss with custom_sharded_layers in FSDP, 'output' must be included in the layer list to ensure tensor type compatibility. - Added shared validation module in recipes/validation.py - Integrated validation into full_finetune_distributed and lora_finetune_distributed - Added comprehensive unit tests - Provides clear error message to guide users to correct configuration Fixes #2856 --- recipes/full_finetune_distributed.py | 7 +++ recipes/lora_finetune_distributed.py | 9 ++++ recipes/validation.py | 57 +++++++++++++++++++++++++ tests/recipes/test_validation.py | 64 ++++++++++++++++++++++++++++ 4 files changed, 137 insertions(+) create mode 100644 recipes/validation.py create mode 100644 tests/recipes/test_validation.py diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index e31224142f..a24a0d768b 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -29,6 +29,7 @@ from torchtune.datasets import ConcatDataset from torchtune.modules.embedding_utils import resize_token_embeddings from torchtune.modules.loss import SFTLoss +from recipes.validation import validate_custom_sharding_config from torchtune.modules.moe import utils as moe_utils from torchtune.recipe_interfaces import FTRecipeInterface from torchtune.training import ( @@ -669,6 +670,12 @@ def _setup_model( model, auto_wrap_policy={modules.TransformerSelfAttentionLayer} ) + # Validate custom_sharded_layers configuration + validate_custom_sharding_config( + self._loss_fn, + custom_sharded_layers, + parallelism_enabled=self.parallel_dims.dp_shard_enabled or self.parallel_dims.cp_enabled + ) # Apply Fully Sharded Data Parallelism to the model if self.parallel_dims.dp_shard_enabled or self.parallel_dims.cp_enabled: fsdp_shard_conditions = [ diff --git a/recipes/lora_finetune_distributed.py b/recipes/lora_finetune_distributed.py index 5e1766ed59..95f3932979 100644 --- a/recipes/lora_finetune_distributed.py +++ b/recipes/lora_finetune_distributed.py @@ -26,6 +26,7 @@ from torchtune.data import padded_collate_packed from torchtune.datasets import ConcatDataset from torchtune.modules.loss import SFTLoss +from recipes.validation import validate_custom_sharding_config from torchtune.modules.peft import ( AdapterModule, get_adapter_params, @@ -530,6 +531,14 @@ def _setup_model( model, auto_wrap_policy={modules.TransformerSelfAttentionLayer} ) + # Validate custom_sharded_layers configuration + validate_custom_sharding_config( + self._loss_fn, + custom_sharded_layers, + parallelism_enabled=self.parallel_dims.dp_shard_enabled or self.parallel_dims.cp_enabled + + ) + # Apply Fully Sharded Data Parallelism to the model if self.parallel_dims.dp_shard_enabled or self.parallel_dims.cp_enabled: # For FSDP sharding diff --git a/recipes/validation.py b/recipes/validation.py new file mode 100644 index 0000000000..b0700cb9c9 --- /dev/null +++ b/recipes/validation.py @@ -0,0 +1,57 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Shared validation utilities for recipes.""" + +from typing import Optional, Set +from torchtune.modules.loss import LinearCrossEntropyLoss + + +def validate_custom_sharding_config( + loss_fn, + custom_sharded_layers: Optional[list[str]], + required_layer: str = "output", + parallelism_enabled: Optional[bool] = None, + available_layers: Optional[Set[str]] = None, +) -> None: + """ + Validates custom_sharded_layers configuration for specific loss functions. + + Args: + loss_fn: The loss function instance + custom_sharded_layers: List of layer names to shard, or None + required_layer: The layer name that must be included (default: "output") + parallelism_enabled: If False, skip validation (default: None) + available_layers: Optional set of valid layer names for typo checking + + Raises: + ValueError: If validation fails + """ + # Skip when nothing to validate + if not custom_sharded_layers: + return + + # Skip validation if parallelism is explicitly disabled + if parallelism_enabled is False: + return + + # Only enforce when the loss needs the output projection + needs_output_proj = isinstance(loss_fn, LinearCrossEntropyLoss) + + if needs_output_proj and required_layer not in custom_sharded_layers: + raise ValueError( + f"When using {type(loss_fn).__name__} with custom_sharded_layers, " + f"'{required_layer}' must be included to ensure tensor compatibility. " + f"Example: custom_sharded_layers = ['tok_embeddings', '{required_layer}']." + ) + + # Optional: catch typos early + if available_layers is not None: + unknown = set(custom_sharded_layers) - set(available_layers) + if unknown: + raise ValueError( + f"Unknown layer(s) in custom_sharded_layers: {sorted(unknown)}" + ) diff --git a/tests/recipes/test_validation.py b/tests/recipes/test_validation.py new file mode 100644 index 0000000000..37bf3f86bd --- /dev/null +++ b/tests/recipes/test_validation.py @@ -0,0 +1,64 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +from unittest.mock import Mock +from torchtune.modules.loss import LinearCrossEntropyLoss +from recipes.validation import validate_custom_sharding_config + + +class TestValidateCustomShardingConfig: + """Unit tests for validate_custom_sharding_config""" + + def test_missing_output_raises_error(self): + """Test that validation raises error when output is missing""" + loss_fn = LinearCrossEntropyLoss() + custom_sharded_layers = ['tok_embeddings'] + + with pytest.raises(ValueError, match="'output' must be included"): + validate_custom_sharding_config(loss_fn, custom_sharded_layers) + + def test_with_output_passes(self): + """Test that validation passes when output is included""" + loss_fn = LinearCrossEntropyLoss() + custom_sharded_layers = ['tok_embeddings', 'output'] + + # Should not raise + validate_custom_sharding_config(loss_fn, custom_sharded_layers) + + def test_none_layers_passes(self): + """Test that validation passes when custom_sharded_layers is None""" + loss_fn = LinearCrossEntropyLoss() + + # Should not raise + validate_custom_sharding_config(loss_fn, None) + + def test_empty_layers_passes(self): + """Test that validation passes when custom_sharded_layers is empty""" + loss_fn = LinearCrossEntropyLoss() + + # Should not raise + validate_custom_sharding_config(loss_fn, []) + + def test_parallelism_disabled_skips_validation(self): + """Test that validation is skipped when parallelism is disabled""" + loss_fn = LinearCrossEntropyLoss() + custom_sharded_layers = ['tok_embeddings'] # Missing output + + # Should not raise because parallelism_enabled=False + validate_custom_sharding_config( + loss_fn, + custom_sharded_layers, + parallelism_enabled=False + ) + + def test_non_linear_ce_loss_passes(self): + """Test that non-LinearCrossEntropyLoss doesn't require output""" + loss_fn = Mock() # Some other loss function + custom_sharded_layers = ['tok_embeddings'] # Missing output + + # Should not raise + validate_custom_sharding_config(loss_fn, custom_sharded_layers)