Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions recipes/full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from torchtune.datasets import ConcatDataset
from torchtune.modules.embedding_utils import resize_token_embeddings
from torchtune.modules.loss import SFTLoss
from recipes.validation import validate_custom_sharding_config
from torchtune.modules.moe import utils as moe_utils
from torchtune.recipe_interfaces import FTRecipeInterface
from torchtune.training import (
Expand Down Expand Up @@ -669,6 +670,12 @@ def _setup_model(
model, auto_wrap_policy={modules.TransformerSelfAttentionLayer}
)

# Validate custom_sharded_layers configuration
validate_custom_sharding_config(
self._loss_fn,
custom_sharded_layers,
parallelism_enabled=self.parallel_dims.dp_shard_enabled or self.parallel_dims.cp_enabled
)
# Apply Fully Sharded Data Parallelism to the model
if self.parallel_dims.dp_shard_enabled or self.parallel_dims.cp_enabled:
fsdp_shard_conditions = [
Expand Down
9 changes: 9 additions & 0 deletions recipes/lora_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from torchtune.data import padded_collate_packed
from torchtune.datasets import ConcatDataset
from torchtune.modules.loss import SFTLoss
from recipes.validation import validate_custom_sharding_config
from torchtune.modules.peft import (
AdapterModule,
get_adapter_params,
Expand Down Expand Up @@ -530,6 +531,14 @@ def _setup_model(
model, auto_wrap_policy={modules.TransformerSelfAttentionLayer}
)

# Validate custom_sharded_layers configuration
validate_custom_sharding_config(
self._loss_fn,
custom_sharded_layers,
parallelism_enabled=self.parallel_dims.dp_shard_enabled or self.parallel_dims.cp_enabled

)

# Apply Fully Sharded Data Parallelism to the model
if self.parallel_dims.dp_shard_enabled or self.parallel_dims.cp_enabled:
# For FSDP sharding
Expand Down
57 changes: 57 additions & 0 deletions recipes/validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""Shared validation utilities for recipes."""

from typing import Optional, Set
from torchtune.modules.loss import LinearCrossEntropyLoss


def validate_custom_sharding_config(
loss_fn,
custom_sharded_layers: Optional[list[str]],
required_layer: str = "output",
parallelism_enabled: Optional[bool] = None,
available_layers: Optional[Set[str]] = None,
) -> None:
"""
Validates custom_sharded_layers configuration for specific loss functions.

Args:
loss_fn: The loss function instance
custom_sharded_layers: List of layer names to shard, or None
required_layer: The layer name that must be included (default: "output")
parallelism_enabled: If False, skip validation (default: None)
available_layers: Optional set of valid layer names for typo checking

Raises:
ValueError: If validation fails
"""
# Skip when nothing to validate
if not custom_sharded_layers:
return

# Skip validation if parallelism is explicitly disabled
if parallelism_enabled is False:
return

# Only enforce when the loss needs the output projection
needs_output_proj = isinstance(loss_fn, LinearCrossEntropyLoss)

if needs_output_proj and required_layer not in custom_sharded_layers:
raise ValueError(
f"When using {type(loss_fn).__name__} with custom_sharded_layers, "
f"'{required_layer}' must be included to ensure tensor compatibility. "
f"Example: custom_sharded_layers = ['tok_embeddings', '{required_layer}']."
)

# Optional: catch typos early
if available_layers is not None:
unknown = set(custom_sharded_layers) - set(available_layers)
if unknown:
raise ValueError(
f"Unknown layer(s) in custom_sharded_layers: {sorted(unknown)}"
)
64 changes: 64 additions & 0 deletions tests/recipes/test_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import pytest
from unittest.mock import Mock
from torchtune.modules.loss import LinearCrossEntropyLoss
from recipes.validation import validate_custom_sharding_config


class TestValidateCustomShardingConfig:
"""Unit tests for validate_custom_sharding_config"""

def test_missing_output_raises_error(self):
"""Test that validation raises error when output is missing"""
loss_fn = LinearCrossEntropyLoss()
custom_sharded_layers = ['tok_embeddings']

with pytest.raises(ValueError, match="'output' must be included"):
validate_custom_sharding_config(loss_fn, custom_sharded_layers)

def test_with_output_passes(self):
"""Test that validation passes when output is included"""
loss_fn = LinearCrossEntropyLoss()
custom_sharded_layers = ['tok_embeddings', 'output']

# Should not raise
validate_custom_sharding_config(loss_fn, custom_sharded_layers)

def test_none_layers_passes(self):
"""Test that validation passes when custom_sharded_layers is None"""
loss_fn = LinearCrossEntropyLoss()

# Should not raise
validate_custom_sharding_config(loss_fn, None)

def test_empty_layers_passes(self):
"""Test that validation passes when custom_sharded_layers is empty"""
loss_fn = LinearCrossEntropyLoss()

# Should not raise
validate_custom_sharding_config(loss_fn, [])

def test_parallelism_disabled_skips_validation(self):
"""Test that validation is skipped when parallelism is disabled"""
loss_fn = LinearCrossEntropyLoss()
custom_sharded_layers = ['tok_embeddings'] # Missing output

# Should not raise because parallelism_enabled=False
validate_custom_sharding_config(
loss_fn,
custom_sharded_layers,
parallelism_enabled=False
)

def test_non_linear_ce_loss_passes(self):
"""Test that non-LinearCrossEntropyLoss doesn't require output"""
loss_fn = Mock() # Some other loss function
custom_sharded_layers = ['tok_embeddings'] # Missing output

# Should not raise
validate_custom_sharding_config(loss_fn, custom_sharded_layers)