Skip to content

[Prototype] Option to configure layers independently #168

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion Megatron-LM
2 changes: 1 addition & 1 deletion docs/developer_guide/conversion.md
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ Continuing our `AwesomeModel` handler example, we define:
def _create_weight_converters(self) -> list[WeightConverter]:
converters = []
# The set of converters may depend on the base model configuration, which is accessible through `self._model.base_model_config`.
num_layers = self._model.config.base_model.transformer.num_layers
num_layers = self._model.config.base_model.layers.default.num_layers

# A simple renaming example, for the word embeddings.
converters.append(WeightConverter("layers.0.word_embeddings_weight", "model.embed_tokens.weight"))
Expand Down
1 change: 1 addition & 0 deletions fast_llm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -888,6 +888,7 @@ def __init_subclass__(cls):
valid=value.pop("valid", base_class_field.valid),
default=value.pop("default", base_class_field.default),
default_factory=value.pop("default_factory", base_class_field.default_factory),
init=value.pop("init", base_class_field.init),
repr=value.pop("repr", base_class_field.repr),
hash=value.pop("hash", base_class_field.hash),
compare=value.pop("compare", base_class_field.compare),
Expand Down
18 changes: 16 additions & 2 deletions fast_llm/engine/config_utils/tensor_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,12 @@ class TensorSpace:
_is_setup: bool = False
_distributed: "Distributed"

def __init__(self, distributed_config: DistributedConfig):
def __init__(self, distributed_config: DistributedConfig, _parent: "TensorSpace|None" = None):
self._distributed_config = distributed_config
self._tensor_dims: dict[str, TensorDim] = {}
self.add_tensor_dim(TensorDim(DefaultDimNames.scalar, 1))
self._parent = _parent
self._sub_spaces: dict[str, TensorSpace] = {}

def setup(self, distributed: "Distributed") -> None:
assert distributed.config is self._distributed_config
Expand Down Expand Up @@ -146,5 +148,17 @@ def add_tensor_dim(self, dim: TensorDim) -> None:
Assert.eq(dim.parallel_dim, self._distributed_config.distributed_dims[dim.parallel_dim.name])
self._tensor_dims[dim.name] = dim

def add_sub_space(self, name: str) -> "TensorSpace":
self._sub_spaces[name] = TensorSpace(self._distributed_config, _parent=self)
return self._sub_spaces[name]

def get_sub_space(self, name: str) -> "TensorSpace":
return self._sub_spaces[name]

def get_tensor_dim(self, name: str) -> TensorDim:
return self._tensor_dims[name]
if name in self._tensor_dims:
return self._tensor_dims[name]
elif self._parent is not None:
return self._parent.get_tensor_dim(name)
else:
raise KeyError(name)
33 changes: 20 additions & 13 deletions fast_llm/layers/language_model/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,7 @@ class LanguageModelKwargs:

@config_class()
class LanguageModelArchitectureConfig(BaseModelArchitectureConfig):
transformer: TransformerArchitectureConfig = Field(
default_factory=TransformerArchitectureConfig,
desc="Configuration for the transformer architecture.",
hint=FieldHint.core,
)
layers: TransformerArchitectureConfig = Field(default_factory=TransformerArchitectureConfig)
max_position_embeddings: int = Field(
default=2048,
desc="Number of absolute position embeddings, if applicable.",
Expand All @@ -60,11 +56,11 @@ class LanguageModelArchitectureConfig(BaseModelArchitectureConfig):

def _validate(self) -> None:
if self.use_position_embeddings is None:
self.use_position_embeddings = not self.transformer.rotary.enabled
self.use_position_embeddings = not self.layers.default.rotary.enabled
super()._validate()

def setup_tensor_space(self, tensor_space: TensorSpace) -> None:
self.transformer.setup_tensor_space(tensor_space)
self.layers.setup_tensor_space(tensor_space)
tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor)

# Embedding dimensions
Expand Down Expand Up @@ -97,6 +93,17 @@ def from_flat_dict(
cls._handle_renamed_field(default, "zero_centered_normalization", "zero_centered")
return super().from_flat_dict(default, strict)

@classmethod
def _from_dict(
cls,
default: dict[str, typing.Any],
strict: bool = True,
flat: bool = False,
) -> typing.Self:
# TODO v0.x: Remove backward compatibility.
cls._handle_renamed_field(default, "transformer", ("layers", "default"))
return super()._from_dict(default, strict, flat)


@config_class()
class LanguageModelBaseConfig(LanguageModelArchitectureConfig, BaseModelConfig):
Expand All @@ -111,7 +118,7 @@ class LanguageModelBaseConfig(LanguageModelArchitectureConfig, BaseModelConfig):

architecture_class = LanguageModelArchitectureConfig

transformer: TransformerConfig = FieldUpdate(default_factory=TransformerConfig)
layers: TransformerConfig = FieldUpdate(default_factory=TransformerConfig)
init_method_std_embed: float = Field(
default=None,
desc="Initialization scale for the vocabulary embedding and output weights (logits).",
Expand Down Expand Up @@ -175,14 +182,14 @@ class LanguageModelBaseConfig(LanguageModelArchitectureConfig, BaseModelConfig):
)

def _validate(self) -> None:
if self.transformer.init_method_std is None:
self.transformer.init_method_std = self.transformer.hidden_size**-0.5
if self.layers.default.init_method_std is None:
self.layers.default.init_method_std = self.layers.default.hidden_size**-0.5
if self.init_method_std_embed is None:
self.init_method_std_embed = self.transformer.init_method_std
self.init_method_std_embed = self.layers.default.init_method_std
if self.init_method_max_embed is None:
self.init_method_max_embed = self.transformer.init_method_max
self.init_method_max_embed = self.layers.default.init_method_max
if self.init_method_min_embed is None:
self.init_method_min_embed = self.transformer.init_method_min
self.init_method_min_embed = self.layers.default.init_method_min
if self.init_method_max_embed is not None and self.init_method_min_embed is not None:
Assert.leq(self.init_method_min_embed, self.init_method_max_embed)
super()._validate()
4 changes: 2 additions & 2 deletions fast_llm/layers/language_model/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,13 @@ def __init__(
self._tensor_space = tensor_space
self._residual_dtype = (
self._distributed_config.optimization_dtype
if config.transformer.full_precision_residual
if config.layers.default.full_precision_residual
else self._distributed_config.training_dtype
).torch
self._group_size = self._distributed_config.tensor_parallel
self._sequence_parallel = self._distributed_config.sequence_tensor_parallel
self._parallel_embeddings = tensor_space.distributed_config.tensor_parallel > 1 and config.parallel_embeddings
self._dropout_p = config.transformer.hidden_dropout
self._dropout_p = config.layers.default.hidden_dropout
self._use_absolute_position_embeddings = config.use_absolute_position_embeddings

hidden_dim = tensor_space.get_tensor_dim(TransformerDimNames.hidden)
Expand Down
4 changes: 2 additions & 2 deletions fast_llm/layers/language_model/head.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(
tensor_space: TensorSpace,
):
super().__init__(config)
self._debug_transformer = config.transformer.debug_transformer
self._debug_transformer = config.layers.default.debug_transformer
self._tie_word_embeddings = config.tie_word_embeddings
self._tensor_space = tensor_space

Expand All @@ -56,7 +56,7 @@ def __init__(

hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden)

self.final_norm = config.transformer.normalization.get_layer(hidden_dim)
self.final_norm = config.layers.default.normalization.get_layer(hidden_dim)
self._logits_scale_factor = config.logits_scale_factor
self._z_loss_factor = config.logit_z_loss

Expand Down
14 changes: 5 additions & 9 deletions fast_llm/layers/transformer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,7 @@
from fast_llm.functional.rotary import apply_rotary_embeddings
from fast_llm.functional.triton.rotary import triton_rotary_autograd_
from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear
from fast_llm.layers.transformer.config import (
TransformerConfig,
TransformerDimNames,
TransformerKwargs,
)
from fast_llm.layers.transformer.config import TransformerDimNames, TransformerKwargs, TransformerLayerConfig
from fast_llm.logging import log_distributed_grad, log_distributed_tensor
from fast_llm.tensor import TensorMeta, init_normal_, init_zeros_
from fast_llm.utils import Assert
Expand Down Expand Up @@ -69,14 +65,14 @@ class Attention(torch.nn.Module):

def __init__(
self,
config: TransformerConfig,
config: TransformerLayerConfig,
tensor_space: TensorSpace,
layer_index,
):
super().__init__()
self._config = config
self._tensor_space = tensor_space
Assert.in_range_incl(layer_index, 1, self._config.num_layers)
Assert.in_range(layer_index, 0, self._config.num_layers)
self._layer_index = layer_index
self._sequence_parallel = self._tensor_space.distributed_config.sequence_tensor_parallel
self._debug_transformer = self._config.debug_transformer
Expand Down Expand Up @@ -161,10 +157,10 @@ def _attn_fused(
query,
key,
beta=0,
alpha=self._softmax_scale / self._layer_index,
alpha=self._softmax_scale / (self._layer_index + 1),
).view(b, self._local_head_groups, sq, self._local_heads_per_group, sk)

attn_weights = attn_weights.to(torch.float32) * self._layer_index
attn_weights = attn_weights.to(torch.float32) * (self._layer_index + 1)
attn_weights = torch.where(mask, attn_weights, mask_value)
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1).to(query.dtype)

Expand Down
144 changes: 139 additions & 5 deletions fast_llm/layers/transformer/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,16 @@
import typing
import warnings

from fast_llm.config import Field, FieldHint, FieldUpdate, check_field, config_class, skip_valid_if_none
from fast_llm.config import (
Config,
Field,
FieldHint,
FieldUpdate,
check_field,
config_class,
process_field,
skip_valid_if_none,
)
from fast_llm.engine.base_model.config import BaseModelArchitectureConfig, BaseModelConfig
from fast_llm.engine.config_utils.data_type import DataType
from fast_llm.engine.config_utils.tensor_space import CompositeTensorDim, TensorDim, TensorSpace
Expand Down Expand Up @@ -156,7 +165,7 @@ class AddLinearBiasChoices(str, enum.Enum):


@config_class()
class TransformerArchitectureConfig(BaseModelArchitectureConfig):
class TransformerLayerArchitectureConfig(BaseModelArchitectureConfig):
_abstract = False
normalization: NormalizationArchitectureConfig = Field(
default_factory=NormalizationArchitectureConfig,
Expand Down Expand Up @@ -367,7 +376,7 @@ def setup_tensor_space(self, tensor_space: TensorSpace) -> None:


@config_class()
class TransformerConfig(TransformerArchitectureConfig, BaseModelConfig):
class TransformerLayerConfig(TransformerLayerArchitectureConfig, BaseModelConfig):
normalization: NormalizationConfig = FieldUpdate(default_factory=NormalizationConfig)
rotary: RotaryConfig = FieldUpdate(default_factory=RotaryConfig)
# Default: hidden_size**-0.5
Expand Down Expand Up @@ -618,8 +627,133 @@ def do_use_flash_attention(self, distributed_config: DistributedConfig) -> bool:
DataType.bfloat16,
)

# Config parameter `window_size` only can be used with flash attention
if not use_flash_attention:
Assert.is_(self.window_size, None)
assert self.max_window_layers is None

return use_flash_attention


@config_class()
class RangeConfig(Config):
"""
A configuration that defines a range of values, to be used for example in python `slice` or `range`.
"""

# TODO: Not specific to transformers, move elsewhere?
begin: int = Field(
default=0,
desc="The beginning of the range.",
hint=FieldHint.optional,
)
end: int | None = Field(
default=None,
desc="The end of the range (excluded).",
hint=FieldHint.optional,
)
step: int = Field(
default=1,
desc="The step for the range.",
hint=FieldHint.optional,
)

def in_range(self, index) -> bool:
"""
Checks whether `index` is in `range(begin, end, step)`.
"""
return (
index >= self.begin and (self.end is None or index < self.end) and ((index - self.begin) % self.step == 0)
)


def process_config_updates(updates: dict[str | tuple[str, ...], typing.Any]) -> dict[tuple[str, ...], typing.Any]:
return {(tuple(key.split("/")) if isinstance(key, str) else key): value for (key, value) in updates.items()}


@config_class()
class TransformerLayerRangeArchitectureConfig(BaseModelArchitectureConfig):
_abstract = False
layer_ranges: list[RangeConfig] = Field(
default_factory=RangeConfig,
desc="Layer range.",
hint=FieldHint.core,
)
updates: dict[tuple[str, ...], typing.Any] = Field(
default_factory=dict, valid=process_field(process_config_updates)
)
config: TransformerLayerArchitectureConfig = Field(init=False)
_default: TransformerLayerArchitectureConfig = Field(init=False)

def setup(self, default: TransformerLayerArchitectureConfig) -> None:
assert not hasattr(self, "_default")
self._default = default

def _validate(self) -> None:
assert hasattr(self, "_default")
assert len(self.layer_ranges) > 0
super()._validate()
# Create the full config from the default and updates.
# We use `default.from_dict` so we also have the appropriate class in `TransformerLayerRangeConfig`.
# For the architecture class we need to set `strict=False` because of possible non-architecture parameters.
self.config = self._default.from_dict(self._default, self.updates, strict=isinstance(self, BaseModelConfig))
self.config.validate()

def in_range(self, index) -> bool:
return any(layer_range.in_range(index) for layer_range in self.layer_ranges)


@config_class()
class TransformerLayerRangeConfig(TransformerLayerRangeArchitectureConfig, BaseModelConfig):
config: TransformerLayerConfig = FieldUpdate()
_default: TransformerLayerConfig = FieldUpdate()


@config_class()
class TransformerArchitectureConfig(BaseModelArchitectureConfig):
_abstract = False
layers: list[TransformerLayerRangeArchitectureConfig] = Field(default_factory=list)
default: TransformerLayerArchitectureConfig = Field(default_factory=TransformerLayerArchitectureConfig)

def _validate(self) -> None:
for layer in self.layers:
layer.setup(self.default)
super()._validate()
for layer in self.layers:
# Hidden layers must match
Assert.eq(layer.config.hidden_size, self.default.hidden_size)
# TODO: Move elsewhere? Kept here because used in a few places like default initialization.
Assert.eq(layer.config.num_layers, self.default.num_layers)
# TODO: Rotary preprocessor doesn't support variations across layers.
Assert.eq(layer.config.rotary.to_serialized(), self.default.rotary.to_serialized())

def get_layer_config_and_tensor_space(
self, index: int, tensor_space: TensorSpace
) -> tuple[TransformerLayerArchitectureConfig, TensorSpace]:
for i, layer in enumerate(self.layers):
if layer.in_range(index):
return layer.config, tensor_space.get_sub_space(f"transformer_layers_{i}")
return self.default, tensor_space

def setup_tensor_space(self, tensor_space: TensorSpace) -> None:
assert self._validated
self.default.setup_tensor_space(tensor_space)
for i, layer in enumerate(self.layers):
layer.config.setup_tensor_space(tensor_space.add_sub_space(f"transformer_layers_{i}"))


@config_class()
class TransformerConfig(TransformerArchitectureConfig, BaseModelConfig):
layers: list[TransformerLayerRangeConfig] = FieldUpdate()
default: TransformerLayerConfig = FieldUpdate(default_factory=TransformerLayerConfig)

def _validate(self) -> None:
super()._validate()
for layer in self.layers:
# Hidden layers must match
Assert.eq(layer.config.full_precision_residual, self.default.full_precision_residual)
if self.layers:
warnings.warn("Variable layer configuration is experimental. Use with caution.")

def get_layer_config_and_tensor_space(
self, index: int, tensor_space: TensorSpace
) -> tuple[TransformerLayerConfig, TensorSpace]:
return super().get_layer_config_and_tensor_space(index, tensor_space)
4 changes: 2 additions & 2 deletions fast_llm/layers/transformer/mixture_of_experts.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
from fast_llm.layers.common.linear import Linear
from fast_llm.layers.transformer.config import (
RoutingType,
TransformerConfig,
TransformerDimNames,
TransformerKwargs,
TransformerLayerConfig,
TransformerLossNames,
)
from fast_llm.layers.transformer.mlp import MLPBase
Expand All @@ -40,7 +40,7 @@ class MixtureOfExpertMLP(MLPBase):

_group: ProcessGroup

def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp"):
def __init__(self, config: TransformerLayerConfig, tensor_space: TensorSpace, name: str = "mlp"):
Assert.gt(config.num_experts, 1)
# TODO: Implement?
assert not config.add_linear_biases, "Biases not supported for MoE."
Expand Down
Loading