Skip to content

Commit 2f00cdd

Browse files
authored
Concatenated dim (#336)
1 parent 0a4ce53 commit 2f00cdd

File tree

23 files changed

+471
-244
lines changed

23 files changed

+471
-244
lines changed

fast_llm/engine/config_utils/tensor_space.py

Lines changed: 246 additions & 62 deletions
Large diffs are not rendered by default.

fast_llm/engine/multi_stage/fsdp.py

Lines changed: 7 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -441,39 +441,21 @@ def _get_parameter_shard_indices_in_full_weight(
441441
where it is located in the shard if it exists, or -1 if it's not in the shard.
442442
Used to determine the location of each entry in a different distributed configuration.
443443
"""
444-
445-
# Create an empty index for the global parameter.
446-
index = torch.full(
447-
parameter_meta.global_shape,
448-
-1,
449-
dtype=torch.int64,
450-
device=device,
451-
)
452444
# Set the shard slice of the global parameter to corresponding indices of the parameter slice of the shard
453445
begin, end = self._get_parameter_range_in_shard(parameter_name)
454446

455-
buffer_index = parameter_meta.global_to_local(index, expand=True)
456-
# Copying directly into `buffer_index` requires a view of the tensor, which may not be feasible.
457-
# In that case, we work with a separate tensor to be copied back into `buffer_index`.
458-
try:
459-
buffer_index_flat = buffer_index.view(-1)
460-
is_view = True
461-
except RuntimeError:
462-
buffer_index_flat = buffer_index.new_full((buffer_index.numel(),), -1)
463-
is_view = False
464-
465-
# Copy the shard indices at their respective positions in the flat buffer index.
466-
buffer_index_flat[
447+
# Create an empty local index to hold the local shard indices.
448+
buffer_index = torch.full_like(parameter_meta, -1, dtype=torch.int64, device=device)
449+
450+
# Copy the shard indices at their respective positions in the buffer index.
451+
buffer_index.flatten()[
467452
self._index_buffer_to_param(
468453
self._fsdp_dim.rank * self._shard_size, parameter_name
469454
) : self._index_buffer_to_param((self._fsdp_dim.rank + 1) * self._shard_size, parameter_name)
470455
].copy_(torch.arange(begin, end, dtype=torch.int64, device=device))
471456

472-
# If needed, copy the flat buffer index back into the index.
473-
if not is_view:
474-
buffer_index.copy_(buffer_index_flat.view_as(buffer_index))
475-
476-
return index
457+
# Create a global index from the local one.
458+
return parameter_meta.local_to_global_partial(buffer_index, -1)
477459

478460
def copy_shard_overlaps(
479461
self,

fast_llm/engine/multi_stage/stage_base.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,8 +185,9 @@ def initialize_weights(self) -> None:
185185
# Multi-gpu init may be different because of TP or FSDP (different shape), or PP (not on device)
186186
global_shape = meta.global_shape
187187

188-
if self._distributed_config.reproducible_init and (
189-
global_shape.numel() != parameter.numel() or not self._mode.on_device
188+
if meta.requires_global_initialization or (
189+
self._distributed_config.reproducible_init
190+
and (global_shape.numel() != parameter.numel() or not self._mode.on_device)
190191
):
191192
# Initialize all global weights on every gpu, then select the appropriate slice if applicable.
192193
global_param = parameter.new_empty(global_shape, device=self._distributed.device)

fast_llm/layers/common/config.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ class LayerNormalizationBaseConfig(NormalizationConfig):
9999
)
100100

101101
def get_layer(self, hidden_dim: "TensorDim", lr_scale: float | None = None) -> "LayerNorm | RMSNorm":
102-
from fast_llm.tensor import init_uniform_
102+
from fast_llm.tensor import init_uniform_centered_
103103

104104
kwargs = {
105105
"hidden_dim": hidden_dim,
@@ -110,9 +110,7 @@ def get_layer(self, hidden_dim: "TensorDim", lr_scale: float | None = None) -> "
110110
}
111111
if self.initialization_range:
112112
mean = 0 if self.zero_centered else 1
113-
kwargs["weight_init_method"] = init_uniform_(
114-
mean - self.initialization_range, mean + self.initialization_range
115-
)
113+
kwargs["weight_init_method"] = init_uniform_centered_(self.initialization_range, mean=mean)
116114
return self.module_class(**kwargs)
117115

118116
@property

fast_llm/layers/common/linear.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,8 @@ def __init__(
9494
transposed_weight: bool = False,
9595
lr_scale: float | None | tuple[float | None, ...] = None,
9696
):
97-
assert in_dim.parallel_dim is None
98-
assert out_dim.parallel_dim is None
97+
assert not in_dim.is_parallel
98+
assert not out_dim.is_parallel
9999
super().__init__(
100100
in_dim,
101101
out_dim,
@@ -132,7 +132,7 @@ def __init__(
132132
sequence_parallel: bool = False,
133133
lr_scale: float | None | tuple[float | None, ...] = None,
134134
):
135-
assert in_dim.parallel_dim is None
135+
assert not in_dim.is_parallel
136136
self._group_size = 1 if out_dim.parallel_dim is None else out_dim.parallel_dim.size
137137
self._sequence_parallel = sequence_parallel and self._group_size > 1
138138
super().__init__(
@@ -176,7 +176,7 @@ def __init__(
176176
transposed_weight: bool = False,
177177
lr_scale: float | None | tuple[float | None, ...] = None,
178178
):
179-
assert out_dim.parallel_dim is None
179+
assert not out_dim.is_parallel
180180
self._group_size = 1 if in_dim.parallel_dim is None else in_dim.parallel_dim.size
181181
self._sequence_parallel = sequence_parallel and self._group_size > 1
182182
super().__init__(

fast_llm/layers/common/normalization.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def __init__(
158158
lr_scale: float | None = None,
159159
):
160160
super().__init__()
161-
assert hidden_dim.parallel_dim is None
161+
assert not hidden_dim.is_parallel
162162
self._eps = eps
163163
self._zero_centered = zero_centered
164164
if implementation == NormalizationImplementation.auto:
@@ -242,7 +242,7 @@ def __init__(
242242
lr_scale: float | None = None,
243243
):
244244
super().__init__()
245-
assert hidden_dim.parallel_dim is None
245+
assert not hidden_dim.is_parallel
246246
self._eps = eps
247247
self._zero_centered = zero_centered
248248
if implementation == NormalizationImplementation.auto:

fast_llm/layers/common/peft.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,12 @@ def lora_linear(
1919
):
2020
layer.weight.requires_grad = False
2121
in_dim = layer._in_dim
22+
assert not in_dim.is_parallel, "LoRA not supported with tensor parallelism."
2223
if in_dim.parallel_dim is not None:
23-
assert in_dim.parallel_dim.size == 1, "LoRA not supported with tensor parallelism."
2424
in_dim = TensorDim(in_dim.name, in_dim.global_size)
2525
out_dim = layer._out_dim
26+
assert not out_dim.is_parallel, "LoRA not supported with tensor parallelism."
2627
if out_dim.parallel_dim is not None:
27-
assert out_dim.parallel_dim.size == 1, "LoRA not supported with tensor parallelism."
2828
out_dim = TensorDim(out_dim.name, out_dim.global_size)
2929
if out_channel_begin is not None or out_channel_end is not None:
3030
if out_channel_begin is None:

fast_llm/layers/language_model/embedding.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,10 @@ def __init__(
4646
self._dropout_p = config.transformer.hidden_dropout
4747
self._use_absolute_position_embeddings = config.use_absolute_position_embeddings
4848

49-
hidden_dim = tensor_space.get_tensor_dim(TransformerDimNames.hidden)
50-
vocab_dim = tensor_space.get_tensor_dim(
49+
hidden_dim = tensor_space[TransformerDimNames.hidden]
50+
vocab_dim = tensor_space[
5151
LanguageModelDimNames.vocab_tp if self._parallel_embeddings else LanguageModelDimNames.vocab
52-
)
52+
]
5353

5454
if self._parallel_embeddings:
5555
self._vocab_start_index = self._distributed_config.tensor_rank * vocab_dim.size
@@ -66,7 +66,7 @@ def __init__(
6666
)
6767
if self._use_absolute_position_embeddings:
6868
self.position_embeddings_weight = ParameterMeta.from_dims(
69-
(tensor_space.get_tensor_dim(LanguageModelDimNames.position_embed), hidden_dim),
69+
(tensor_space[LanguageModelDimNames.position_embed], hidden_dim),
7070
init_method=init_normal_(
7171
std=config.init_method_std_embed,
7272
min_val=config.init_method_min_embed,

fast_llm/layers/language_model/head.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def __init__(
6161
if self._cross_entropy_splits is not None and self._sequence_parallel:
6262
assert not self._parallel_embeddings
6363

64-
hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden)
64+
hidden_dim = self._tensor_space[TransformerDimNames.hidden]
6565

6666
self._loss_coefficient = (
6767
config.prediction_loss_coefficient[prediction_distance] if config.prediction_loss_coefficient else 1.0
@@ -108,9 +108,9 @@ def _init_output_weights(self, hidden_dim: TensorDim, config) -> None:
108108
if self._tie_word_embeddings or self._prediction_distance > 0:
109109
return
110110
# untie embedding weights
111-
vocab_dim = self._tensor_space.get_tensor_dim(
111+
vocab_dim = self._tensor_space[
112112
LanguageModelDimNames.vocab_tp if self._parallel_embeddings else LanguageModelDimNames.vocab
113-
)
113+
]
114114
self.output_weights = ParameterMeta.from_dims(
115115
(vocab_dim, hidden_dim),
116116
init_method=init_normal_(
@@ -338,9 +338,9 @@ def _logits_cross_entropy_forward_backward(
338338
logits_scale_factor=self._logits_scale_factor,
339339
)
340340
if self._debug_transformer and self._cross_entropy_splits is None:
341-
vocab_dim = self._tensor_space.get_tensor_dim(
341+
vocab_dim = self._tensor_space[
342342
LanguageModelDimNames.vocab if self._sequence_parallel_logits else LanguageModelDimNames.vocab_tp
343-
)
343+
]
344344
dims = [*kwargs[TransformerKwargs.hidden_dims][:-1], vocab_dim]
345345
sequence_index = 1 - int(kwargs[TransformerKwargs.sequence_first])
346346
dims[sequence_index] = (

fast_llm/layers/language_model/preprocessing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def __init__(
2828
assert config.use_absolute_position_embeddings
2929
self._tensor_space = tensor_space
3030
self._distributed_config = self._tensor_space.distributed_config
31-
self._scalar_dim = self._tensor_space.get_tensor_dim(DefaultDimNames.scalar)
31+
self._scalar_dim = self._tensor_space[DefaultDimNames.scalar]
3232

3333
def _create_tensors(self, sequence_length: int) -> None:
3434
if sequence_length <= self._tensor_cache_max_sequence_length:
@@ -76,7 +76,7 @@ def __init__(self, config: LanguageModelBaseConfig, tensor_space: TensorSpace):
7676
self._config = config
7777
self._tensor_space = tensor_space
7878
self._distributed_config = self._tensor_space.distributed_config
79-
self._scalar_dim = self._tensor_space.get_tensor_dim(DefaultDimNames.scalar)
79+
self._scalar_dim = self._tensor_space[DefaultDimNames.scalar]
8080

8181
def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None:
8282
return

0 commit comments

Comments
 (0)