diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index 8071a086..86572dd7 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -10,7 +10,12 @@ from fast_llm.functional.rotary import apply_rotary_embeddings from fast_llm.functional.triton.rotary import triton_rotary_autograd_ from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear -from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs +from fast_llm.layers.transformer.config import ( + TransformerConfig, + TransformerDimNames, + TransformerKwargs, + TransformerSubLayerKeys, +) from fast_llm.logging import log_distributed_grad, log_distributed_tensor from fast_llm.tensor import TensorMeta, init_normal_, init_zeros_ from fast_llm.utils import Assert @@ -102,7 +107,7 @@ def __init__( self.query = OutputParallelLinear( hidden_dim, self._tensor_space.get_tensor_dim(TransformerDimNames.composite_query), - bias=self._config.add_linear_biases, + bias=self._config.should_add_linear_bias(self._layer_index, TransformerSubLayerKeys.attn_query), weight_init_method=init_method_qkv, bias_init_method=init_method_qkv if self._config.random_bias_init else init_zeros_, sequence_parallel=self._sequence_parallel, @@ -111,7 +116,7 @@ def __init__( self.key_value = OutputParallelLinear( hidden_dim, self._tensor_space.get_tensor_dim(TransformerDimNames.composite_key_value), - bias=self._config.add_linear_biases, + bias=self._config.should_add_linear_bias(self._layer_index, TransformerSubLayerKeys.attn_key_value), weight_init_method=init_method_qkv, bias_init_method=init_method_qkv if self._config.random_bias_init else init_zeros_, sequence_parallel=self._sequence_parallel, @@ -123,7 +128,7 @@ def __init__( self.dense = InputParallelLinear( self._tensor_space.get_tensor_dim(TransformerDimNames.composite_dense), hidden_dim, - bias=self._config.add_linear_biases, + bias=self._config.should_add_linear_bias(self._layer_index, TransformerSubLayerKeys.attn_dense), weight_init_method=init_method_std_attn_proj, bias_init_method=init_method_std_attn_proj if self._config.random_bias_init else init_zeros_, sequence_parallel=self._sequence_parallel, diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index 1b4e7749..de4dcaee 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -1,6 +1,8 @@ import enum +import itertools import logging import math +import re import typing import warnings @@ -149,6 +151,14 @@ class RotaryConfig(RotaryArchitectureConfig, BaseModelConfig): pass +class TransformerSubLayerKeys(str, enum.Enum): + attn_query = "layers.self_attn.query" + attn_key_value = "layers.self_attn.key_value" + attn_dense = "layers.self_attn.dense" + mlp_layer1 = "layers.mlp.layer_1" + mlp_layer2 = "layers.mlp.layer_2" + + @config_class() class TransformerArchitectureConfig(BaseModelArchitectureConfig): _abstract = False @@ -174,7 +184,11 @@ class TransformerArchitectureConfig(BaseModelArchitectureConfig): hint=FieldHint.core, valid=check_field(Assert.gt, 0), ) - add_linear_biases: bool = Field(default=True, desc="Add biases to all dense layers.", hint=FieldHint.core) + add_linear_biases: bool | dict[TransformerSubLayerKeys, str] = Field( + default=True, + desc="Add biases to all or selected dense layers. Accepted values: True, False, or a dict with keys from TransformerSubLayerKeys and values as '*' or index ranges.", + hint=FieldHint.core, + ) ffn_hidden_size: int = Field( default=None, desc="Hidden dimension of the MLP intermediate state. Default: 4 * hidden_size.", @@ -234,6 +248,10 @@ class TransformerArchitectureConfig(BaseModelArchitectureConfig): hint=FieldHint.feature, ) + _parsed_add_linear_biases: bool | dict[TransformerSubLayerKeys, set[int] | str] = Field( + default=None, init=False, repr=False + ) + def _validate(self) -> None: if self.ffn_hidden_size is None: self.ffn_hidden_size = 4 * self.hidden_size @@ -243,7 +261,13 @@ def _validate(self) -> None: self.activation_type = ActivationType.silu if self.gated else ActivationType.gelu self.projection_size = self.num_attention_heads * self.kv_channels self.num_unshared_experts = self.num_experts - self.num_shared_experts + + # Validate before parent validate to have assertion error on invalid key for TransformerSubLayerKeys + self._validate_add_linear_biases() + self._parse_add_linear_biases() + super()._validate() + if not TritonConfig.TRITON_ENABLED: warnings.warn("Triton is disabled, but triton rotary kernel will be used anyway.") @@ -251,6 +275,56 @@ def _validate(self) -> None: Assert.leq(self.num_shared_experts + self.num_experts_per_token, self.num_experts) Assert.multiple(self.num_attention_heads, self.head_groups) + def _validate_add_linear_biases(self) -> None: + """Validate the `add_linear_biases` parameter.""" + if isinstance(self.add_linear_biases, dict): + Assert.gt(len(self.add_linear_biases), 0) + for key, value in self.add_linear_biases.items(): + Assert.incl(key, TransformerSubLayerKeys) # Assert valid sublayer key + Assert.custom( + lambda val: val == "*" or re.match(r"^\d+(:\d+(:\d+)?)?(,\s*\d+(:\d+(:\d+)?)?)*$", val), + value, + ) # Assert valid range format + + def _parse_add_linear_biases(self) -> bool | dict[TransformerSubLayerKeys, set[int] | str]: + """Parse `add_linear_biases` and store the result for quick lookup.""" + if isinstance(self.add_linear_biases, bool): + self._parsed_add_linear_biases = self.add_linear_biases + return + + parsed = {} + for key, value in self.add_linear_biases.items(): + parsed[key] = self._parse_indices(value) + self._parsed_add_linear_biases = parsed + + def _parse_indices(self, indices_str: str) -> set[int]: + """Parse layer indices from a string like '1:10:2, 20, 30' or '*'.""" + indices = [] + # Layers are numbered from 1 as 0 layer is embedding layer in Fast-LLM + if indices_str == "*": + indices.extend(range(1, self.num_layers + 1)) + else: + for part in indices_str.split(","): + part = part.strip() + if ":" in part: + parts = list(map(int, part.split(":"))) + start, stop = parts[0] + 1, parts[1] + 1 + step = parts[2] if len(parts) == 3 else 1 + indices.extend(range(start, stop, step)) + else: + indices.append(int(part) + 1) + return set(itertools.chain(indices)) + + def should_add_linear_bias(self, layer_index: int, sublayer_key: TransformerSubLayerKeys) -> bool: + """Check if linear bias should be added for a given layer and sublayer.""" + if isinstance(self._parsed_add_linear_biases, bool): + return self._parsed_add_linear_biases + + if sublayer_key in self._parsed_add_linear_biases: + return layer_index in self._parsed_add_linear_biases[sublayer_key] + + return False + @classmethod def _from_dict( cls, diff --git a/fast_llm/layers/transformer/mixture_of_experts.py b/fast_llm/layers/transformer/mixture_of_experts.py index 85c6686f..e374f31f 100644 --- a/fast_llm/layers/transformer/mixture_of_experts.py +++ b/fast_llm/layers/transformer/mixture_of_experts.py @@ -40,11 +40,11 @@ class MixtureOfExpertMLP(MLPBase): _group: ProcessGroup - def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp"): + def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, layer_index: int, name: str = "mlp"): Assert.gt(config.num_experts, 1) # TODO: Implement? assert not config.add_linear_biases, "Biases not supported for MoE." - super().__init__(config, tensor_space, name) + super().__init__(config, tensor_space, layer_index, name) self._config = config self._tensor_space = tensor_space self._debug_mode = self._config.debug_transformer or self._config.debug_transformer_memory diff --git a/fast_llm/layers/transformer/mlp.py b/fast_llm/layers/transformer/mlp.py index 76ebfcc0..02b0b4ca 100644 --- a/fast_llm/layers/transformer/mlp.py +++ b/fast_llm/layers/transformer/mlp.py @@ -8,15 +8,16 @@ from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton.mlp import mlp_autograd, torch_mlp_activation, triton_mlp_activation_autograd from fast_llm.layers.common.linear import LinearBase -from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames +from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerSubLayerKeys from fast_llm.tensor import init_normal_, init_zeros_ from fast_llm.utils import Assert class MLPBase(Layer, ABC): - def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp"): + def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, layer_index: int, name: str = "mlp"): super().__init__() self._name = name + self._layer_index = layer_index init_method_1 = init_normal_( std=config.init_method_std_mlp_1, @@ -42,7 +43,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s self.layer_1 = LinearBase( hidden_dim, tensor_space.get_tensor_dim(TransformerDimNames.composite_gated_expert_mlp), - bias=config.add_linear_biases, + bias=config.should_add_linear_bias(self._layer_index, TransformerSubLayerKeys.mlp_layer1), weight_init_method=init_method_1, bias_init_method=init_method_1 if config.random_bias_init else init_zeros_, lr_scale=tuple(config.mlp_lr_scale), @@ -50,7 +51,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s self.layer_2 = LinearBase( self._intermediate_dim, hidden_dim, - bias=config.add_linear_biases, + bias=config.should_add_linear_bias(self._layer_index, TransformerSubLayerKeys.mlp_layer2), weight_init_method=init_method_2, bias_init_method=init_method_2 if config.random_bias_init else init_zeros_, auto_bias_grad_accumulation=tensor_space.distributed_config.tensor_parallel > 1, @@ -60,9 +61,9 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s class MLP(MLPBase): - def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp"): + def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, layer_index: int, name: str = "mlp"): Assert.eq(config.num_experts, 1) - super().__init__(config, tensor_space, name) + super().__init__(config, tensor_space, layer_index, name) def forward( self, diff --git a/fast_llm/layers/transformer/transformer.py b/fast_llm/layers/transformer/transformer.py index 4780dd3a..df326c04 100644 --- a/fast_llm/layers/transformer/transformer.py +++ b/fast_llm/layers/transformer/transformer.py @@ -43,7 +43,7 @@ def __init__( self.self_attn = Attention(self._config, self._tensor_space, layer_index) self.mlp = (MixtureOfExpertMLP if self._config.num_experts > 1 else MLP)( - self._config, self._tensor_space, f"{self.name} mlp" + self._config, self._tensor_space, self._layer_index, f"{self.name} mlp" ) @torch.compile diff --git a/tests/test_attention.py b/tests/test_attention.py index db856787..c8b91d76 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -1,6 +1,8 @@ import unittest.mock from fast_llm.layers.transformer.attention import Attention from fast_llm.layers.transformer.config import TransformerConfig +from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.engine.config_utils.tensor_space import TensorSpace def test_decide_window_size(): @@ -20,3 +22,17 @@ def test_decide_window_size(): # Arrange - Case 3: max_window_layers is None (always return window_size) attention._config = TransformerConfig(window_size=512, max_window_layers=None) assert attention._decide_window_size() == 512 + + +def test_attention_constructor(): + transformer_conf = TransformerConfig( + num_layers=2, + num_attention_heads=2, + hidden_size=16, + ) + distributed_config = DistributedConfig() + tensor_space = TensorSpace(distributed_config=distributed_config) + transformer_conf.setup_tensor_space(tensor_space) + + Attention(transformer_conf, tensor_space, 1) + diff --git a/tests/test_config.py b/tests/test_config.py index 86c99a23..2f28f423 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -5,7 +5,11 @@ import yaml -from fast_llm.layers.transformer.config import TransformerConfig +from fast_llm.layers.transformer.config import ( + TransformerConfig, + TransformerArchitectureConfig, + TransformerSubLayerKeys, +) from fast_llm.utils import Assert from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.engine.config_utils.data_type import DataType @@ -84,3 +88,96 @@ def test_do_use_flash_attention(): mock_distributed_config.training_dtype = DataType.float32 with pytest.raises(AssertionError): config.do_use_flash_attention(mock_distributed_config) + + +@pytest.fixture +def config_with_true_biases(): + """Fixture for TransformerArchitectureConfig with True add_linear_biases.""" + return TransformerArchitectureConfig(add_linear_biases=True) + + +@pytest.fixture +def config_with_false_biases(): + """Fixture for TransformerArchitectureConfig with False add_linear_biases.""" + return TransformerArchitectureConfig(add_linear_biases=False) + + +@pytest.fixture +def config_with_dict_biases(): + """Fixture for TransformerArchitectureConfig with dict add_linear_biases.""" + return TransformerArchitectureConfig( + num_layers = 10, + add_linear_biases={ + "layers.self_attn.query": "*", + "layers.mlp.layer_1": "1:10:3, 9", + "layers.mlp.layer_2": "5:7", + } + ) + + +def test_add_linear_biases_bool_true(config_with_true_biases): + """Test case for add_linear_biases set to True (default).""" + assert config_with_true_biases._parsed_add_linear_biases == True + + +def test_add_linear_biases_bool_false(config_with_false_biases): + """Test case for add_linear_biases set to False.""" + assert config_with_false_biases._parsed_add_linear_biases == False + + +def test_add_linear_biases_dict_valid(config_with_dict_biases): + """Test case for add_linear_biases with valid dictionary.""" + assert config_with_dict_biases._parsed_add_linear_biases == { + TransformerSubLayerKeys.attn_query: {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + TransformerSubLayerKeys.mlp_layer1: {2, 5, 8, 10}, + TransformerSubLayerKeys.mlp_layer2: {6, 7}, + } + + +def test_invalid_key_in_dict(): + """Test case where an invalid key is provided in add_linear_biases dictionary.""" + with pytest.raises(AssertionError): + # Using an invalid key in the dictionary. + TransformerArchitectureConfig(add_linear_biases={"invalid_key": "*"}) + + +def test_invalid_range_format(): + """Test case where invalid range format is provided.""" + with pytest.raises(AssertionError): + TransformerArchitectureConfig(add_linear_biases={TransformerSubLayerKeys.mlp_layer1: "1:10:3, abc"}) + + +def test_empty_add_linear_biases(): + """Test case for empty add_linear_biases dictionary.""" + with pytest.raises(AssertionError): # Expecting AssertionError for invalid empty dictionary + TransformerArchitectureConfig(add_linear_biases={}) + + +def test_should_add_linear_bias_for_layer_and_sublayer(config_with_dict_biases): + """Test case for should_add_linear_bias based on layer index and sublayer key.""" + + # Layer 1 and sublayer mlp_layer1 + assert config_with_dict_biases.should_add_linear_bias(1, TransformerSubLayerKeys.mlp_layer1) == False + + # Layer 2 and sublayer mlp_layer1 + assert config_with_dict_biases.should_add_linear_bias(2, TransformerSubLayerKeys.mlp_layer1) == True + + # Layer 9 and sublayer mlp_layer1 + assert config_with_dict_biases.should_add_linear_bias(9, TransformerSubLayerKeys.mlp_layer1) == False + + # Layer 6 and sublayer mlp_layer2 + assert config_with_dict_biases.should_add_linear_bias(6, TransformerSubLayerKeys.mlp_layer2) == True + + # Layer 5 and sublayer attn_query + assert config_with_dict_biases.should_add_linear_bias(5, TransformerSubLayerKeys.attn_query) == True + + +def test_should_add_linear_bias_for_bool_true(config_with_true_biases): + """Test case for add_linear_biases set to True (should always return True).""" + assert config_with_true_biases.should_add_linear_bias(10, TransformerSubLayerKeys.mlp_layer1) == True + + +def test_should_add_linear_bias_for_bool_false(config_with_false_biases): + """Test case for add_linear_biases set to False (should always return False).""" + assert config_with_false_biases.should_add_linear_bias(10, TransformerSubLayerKeys.mlp_layer1) == False + diff --git a/tests/test_mlp.py b/tests/test_mlp.py new file mode 100644 index 00000000..4d343ec0 --- /dev/null +++ b/tests/test_mlp.py @@ -0,0 +1,33 @@ +from fast_llm.layers.transformer.mlp import MLP +from fast_llm.layers.transformer.mixture_of_experts import MixtureOfExpertMLP +from fast_llm.layers.transformer.config import TransformerConfig +from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.engine.config_utils.tensor_space import TensorSpace + + +def test_mlp_constructor(): + transformer_conf = TransformerConfig( + num_layers=2, + num_attention_heads=2, + hidden_size=16, + ) + distributed_config = DistributedConfig() + tensor_space = TensorSpace(distributed_config=distributed_config) + transformer_conf.setup_tensor_space(tensor_space) + + MLP(transformer_conf, tensor_space, 1, "name") + + +def test_moe_mlp_constructor(): + transformer_conf = TransformerConfig( + num_layers=2, + num_attention_heads=2, + hidden_size=16, + num_experts=2, + add_linear_biases=False + ) + distributed_config = DistributedConfig() + tensor_space = TensorSpace(distributed_config=distributed_config) + transformer_conf.setup_tensor_space(tensor_space) + + MixtureOfExpertMLP(transformer_conf, tensor_space, 1, "name")