Skip to content
Merged
Show file tree
Hide file tree
Changes from 39 commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
82eed2b
TP mamba
jlamypoirier Jul 21, 2025
4e310c7
TP mamba
jlamypoirier Jul 22, 2025
3cc4118
fix
jlamypoirier Jul 22, 2025
9f7f75c
fix
jlamypoirier Jul 22, 2025
4054e04
fixes
jlamypoirier Jul 23, 2025
0014cc6
fix
jlamypoirier Jul 23, 2025
47ad548
fixes
jlamypoirier Jul 23, 2025
6a074fa
fixes
jlamypoirier Jul 23, 2025
d66651f
Update external
jlamypoirier Jul 23, 2025
50083ba
SSM debugging
jlamypoirier Jul 24, 2025
5006328
Merge branch 'main' into tp_mamba
jlamypoirier Jul 24, 2025
13176bd
Merge branch 'debug_mamba' into tp_mamba
jlamypoirier Jul 24, 2025
7b32699
stuff
jlamypoirier Jul 24, 2025
73f591f
Merge branch 'debug_mamba' into tp_mamba
jlamypoirier Jul 24, 2025
1feccc8
stuff
jlamypoirier Jul 24, 2025
e528b50
misc
jlamypoirier Jul 24, 2025
b49c42f
misc
jlamypoirier Jul 24, 2025
bb4dcd9
Merge branch 'debug_mamba' into tp_mamba
jlamypoirier Jul 24, 2025
c1b7f44
misc
jlamypoirier Jul 24, 2025
31f5d41
misc
jlamypoirier Jul 24, 2025
051bb07
Merge branch 'debug_mamba' into tp_mamba
jlamypoirier Jul 24, 2025
0a9ff25
misc
jlamypoirier Jul 24, 2025
e7d9636
Parallel discrete mamba 2
jlamypoirier Jul 24, 2025
c14b764
Mamba 2, misc
jlamypoirier Jul 25, 2025
b605bd2
doc
jlamypoirier Jul 25, 2025
5eea938
fix
jlamypoirier Jul 28, 2025
0a3e2a7
Merge branch 'debug_mamba' into tp_mamba
jlamypoirier Jul 28, 2025
2e6d082
fixes
jlamypoirier Jul 28, 2025
b6c8613
misc
jlamypoirier Jul 28, 2025
f0c04cf
Merge remote-tracking branch 'origin/main' into debug_mamba
jlamypoirier Jul 28, 2025
acdfab1
Merge branch 'debug_mamba' into tp_mamba
jlamypoirier Jul 28, 2025
e536af9
Concatenated dim
jlamypoirier Jul 28, 2025
017f5cc
fixes
jlamypoirier Jul 28, 2025
93e4c94
Merge branch 'concatenated_dim' into tp_mamba
jlamypoirier Jul 28, 2025
c41efc2
doc
jlamypoirier Jul 28, 2025
0b8bd5d
cleanup
jlamypoirier Jul 28, 2025
6bf06d6
fix
jlamypoirier Jul 29, 2025
2ddc3a7
fix
jlamypoirier Jul 29, 2025
c0f1597
Merge branch 'concatenated_dim' into tp_mamba
jlamypoirier Jul 29, 2025
cef7c15
fix
jlamypoirier Jul 30, 2025
5a0eabc
Merge remote-tracking branch 'origin/main' into debug_mamba
jlamypoirier Aug 8, 2025
dd288df
Merge branch 'debug_mamba' into concatenated_dim
jlamypoirier Aug 8, 2025
defd6e0
Merge branch 'concatenated_dim' into tp_mamba
jlamypoirier Aug 8, 2025
8abf258
fixes
jlamypoirier Aug 8, 2025
be99372
Merge branch 'main' into debug_mamba
jlamypoirier Aug 12, 2025
a505f3a
Merge branch 'debug_mamba' into concatenated_dim
jlamypoirier Aug 12, 2025
0cc859a
Merge remote-tracking branch 'origin/main' into concatenated_dim
jlamypoirier Aug 12, 2025
bd4ff0d
doc
jlamypoirier Aug 12, 2025
fd3307d
Merge branch 'concatenated_dim' into tp_mamba
jlamypoirier Aug 12, 2025
0e2e124
stuff
jlamypoirier Aug 12, 2025
9a2a7a2
Pr comments
jlamypoirier Aug 21, 2025
8c382a9
Cleanup
jlamypoirier Aug 21, 2025
019e43d
Cleanup
jlamypoirier Aug 21, 2025
3e0f3e5
Cleanup
jlamypoirier Aug 21, 2025
1abdd19
fixes
jlamypoirier Aug 21, 2025
7c24292
fixes
jlamypoirier Aug 21, 2025
af2964b
fixes
jlamypoirier Aug 21, 2025
188587e
Merge branch 'main' into concatenated_dim
jlamypoirier Sep 17, 2025
e111509
Merge branch 'concatenated_dim' into tp_mamba
jlamypoirier Sep 17, 2025
b29d657
Merge remote-tracking branch 'origin/main' into tp_mamba
jlamypoirier Sep 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Megatron-LM
224 changes: 162 additions & 62 deletions fast_llm/engine/config_utils/tensor_space.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
import logging
import math
import typing

from fast_llm.engine.distributed.config import DistributedConfig, DistributedDim
from fast_llm.utils import Assert, div

if typing.TYPE_CHECKING:
import torch

from fast_llm.core.distributed import ProcessGroup
from fast_llm.engine.distributed.distributed import Distributed

logger = logging.getLogger(__name__)


class TensorDim:
def __init__(self, name: str, global_size: int | None, parallel_dim: DistributedDim | None = None):
Expand All @@ -19,11 +24,11 @@ def __init__(self, name: str, global_size: int | None, parallel_dim: Distributed

def __repr__(self) -> str:
return (
f"TensorDim("
f"{type(self).__name__}("
f"name={self._name},"
f" size={self._size},"
f" global_size={self._global_size},"
f" parallel_dim={None if self.parallel_dim is None else self._parallel_dim}"
f" parallel_dim={self._parallel_dim}"
f")"
)

Expand All @@ -38,83 +43,180 @@ def name(self) -> str:
def size(self) -> int:
return self._size

@property
def expanded_shape(self) -> tuple[int, ...]:
return (self._size,)

@property
def ndim(self) -> int:
return 1

@property
def global_size(self) -> int:
return self._global_size

@property
def global_expanded_shape(self) -> tuple[int, ...]:
return (self._size if self._parallel_dim is None else self._size * self._parallel_dim.size,)
def is_parallel(self) -> bool:
return self._parallel_dim is not None and self._parallel_dim.size > 1

@property
def parallel_dim(self) -> DistributedDim | None:
# TODO: Make more flexible for derived classes?
return self._parallel_dim

@property
def parallel_dim_index(self) -> int | None:
return None if self._parallel_dim is None else 0

@property
def parallel_group(self) -> "ProcessGroup|None":
# TODO: Make more flexible for derived classes?
return None if self._parallel_dim is None else self._parallel_dim.group

def replace_parallel_dim(self, distributed_dim: DistributedDim) -> typing.Self:
assert self.parallel_dim is not None
assert self.is_parallel
return TensorDim(self.name, self.size * distributed_dim.size, distributed_dim)

def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor":
if self.is_parallel:
from fast_llm.core.ops import gather_op

return gather_op(tensor, self.parallel_group, dim)
else:
return tensor

def local_to_global_partial(
self, tensor: "torch.Tensor", dim: int = 0, fill_value: float | int = -1
) -> "torch.Tensor":
if self.is_parallel:
output = tensor.new_full((*tensor.shape[:dim], self.parallel_dim.size, *tensor.shape[dim:]), fill_value)
output.narrow(dim, self.parallel_dim.rank, 1).copy_(tensor.unsqueeze(dim)).squeeze(dim)
return output.flatten(dim, dim + 1)
else:
return tensor

def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> "torch.Tensor":
return (
tensor.chunk(self.parallel_dim.size, dim)[self.parallel_dim.rank]
if self.parallel_dim is not None and self.parallel_dim.size > 1
else tensor
)


class CompositeTensorDim(TensorDim):
def __init__(self, name: str, dims: tuple[TensorDim, ...]):
# TODO: Recursive composition??
parallel_dims = [(i, dim.parallel_dim) for i, dim in enumerate(dims) if dim.parallel_dim]
Assert.leq(len(parallel_dims), 1)
def __init__(self, name: str, tensor_dims: tuple[TensorDim, ...]):
parallel_dim = None
for dim, tensor_dim in enumerate(tensor_dims):
if tensor_dim.parallel_dim is not None:
# TODO: Allow more than one parallel subdim?
assert parallel_dim is None
parallel_dim = tensor_dim.parallel_dim
self._parallel_dim_index = dim

super().__init__(
name=name,
global_size=math.prod(dim.global_size for dim in dims),
parallel_dim=parallel_dims[0][1] if parallel_dims else None,
)
self._dims = dims
self._parallel_dim_index = (
sum(dim.ndim for dim in self._dims[: parallel_dims[0][0]])
+ self._dims[parallel_dims[0][0]].parallel_dim_index
if parallel_dims
else None
global_size=math.prod(dim.global_size for dim in tensor_dims),
parallel_dim=parallel_dim,
)
self._tensor_dims = tensor_dims

@property
def dims(self) -> tuple[TensorDim, ...]:
return self._dims
def replace_parallel_dim(self, distributed_dim: DistributedDim) -> typing.Self:
assert self._parallel_dim_index is not None
dims = list(self._tensor_dims)
dims[self._parallel_dim_index] = dims[self._parallel_dim_index].replace_parallel_dim(distributed_dim)
return CompositeTensorDim(self.name, tuple(dims))

@property
def ndim(self) -> int:
return sum(dim.ndim for dim in self._dims)
def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor":
tensor = tensor.unflatten(dim, [tensor_dim.size for tensor_dim in self._tensor_dims])
for i, tensor_dim in enumerate(self._tensor_dims):
tensor = tensor_dim.local_to_global(tensor, dim + i)

@property
def expanded_shape(self) -> tuple[int, ...]:
return sum((dim.expanded_shape for dim in self._dims), ())
return tensor.flatten(dim, dim + len(self._tensor_dims) - 1)

@property
def global_expanded_shape(self) -> tuple[int, ...]:
return sum((dim.global_expanded_shape for dim in self._dims), ())
def local_to_global_partial(
self, tensor: "torch.Tensor", dim: int = 0, fill_value: float | int = -1
) -> "torch.Tensor":
tensor = tensor.unflatten(dim, [tensor_dim.size for tensor_dim in self._tensor_dims])
for i, tensor_dim in enumerate(self._tensor_dims):
tensor = tensor_dim.local_to_global_partial(tensor, dim + i)

return tensor.flatten(dim, dim + len(self._tensor_dims) - 1)

def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> "torch.Tensor":
tensor = tensor.unflatten(dim, [tensor_dim.global_size for tensor_dim in self._tensor_dims])
for i, tensor_dim in reversed(list(enumerate(self._tensor_dims))):
tensor = tensor_dim.global_to_local(tensor, dim + i)
return tensor if expand else tensor.flatten(dim, dim + len(self._tensor_dims) - 1)

@property
def parallel_dim_index(self) -> int | None:
return self._parallel_dim_index

class ConcatenatedTensorDim(TensorDim):
def __init__(self, name: str, tensor_dims: tuple[TensorDim, ...]):
parallel_dim = tensor_dims[0].parallel_dim
for dim, tensor_dim in enumerate(tensor_dims[1:]):
# TODO: Allow more flexibility?
Assert.is_(tensor_dim.parallel_dim, parallel_dim)

super().__init__(
name=name,
global_size=sum(dim.global_size for dim in tensor_dims),
parallel_dim=parallel_dim,
)
self._tensor_dims = tensor_dims

def replace_parallel_dim(self, distributed_dim: DistributedDim) -> typing.Self:
assert self.parallel_dim_index is not None
dims = list(self.dims)
dims[self.parallel_dim_index] = dims[self.parallel_dim_index].replace_parallel_dim(distributed_dim)
return CompositeTensorDim(self.name, tuple(dims))
assert self.is_parallel
return ConcatenatedTensorDim(
self.name, tuple(tensor_dim.replace_parallel_dim(distributed_dim) for tensor_dim in self._tensor_dims)
)

def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor":
import torch

return (
torch.concatenate(
[
tensor_dim.local_to_global(tensor_, dim)
for tensor_, tensor_dim in zip(
tensor.split([tensor_dim.size for tensor_dim in self._tensor_dims], dim),
self._tensor_dims,
strict=True,
)
],
dim,
)
if self.is_parallel
else tensor
)

def local_to_global_partial(
self, tensor: "torch.Tensor", dim: int = 0, fill_value: float | int = -1
) -> "torch.Tensor":
import torch

return (
torch.concatenate(
[
tensor_dim.local_to_global_partial(tensor_, dim)
for tensor_, tensor_dim in zip(
tensor.split([tensor_dim.size for tensor_dim in self._tensor_dims], dim),
self._tensor_dims,
strict=True,
)
],
dim,
)
if self.is_parallel
else tensor
)

def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> "torch.Tensor":
if self.is_parallel and expand:
raise NotImplementedError()
import torch

return (
torch.concatenate(
[
tensor_dim.global_to_local(tensor_, dim)
for tensor_, tensor_dim in zip(
tensor.split([tensor_dim.global_size for tensor_dim in self._tensor_dims], dim),
self._tensor_dims,
strict=True,
)
],
dim,
)
if self.is_parallel
else tensor
)


class DefaultDimNames:
Expand Down Expand Up @@ -147,21 +249,19 @@ def distributed(self) -> "Distributed":
assert self._is_setup
return self._distributed

def add_tensor_dim(self, dim: TensorDim) -> None:
if isinstance(dim, CompositeTensorDim):
for dim_ in dim.dims:
Assert.incl(dim_.name, self._tensor_dims)
Assert.eq(dim_, self._tensor_dims[dim_.name])
if dim.name in self._tensor_dims:
Assert.eq(dim, self._tensor_dims[dim.name])
def add_tensor_dim(self, tensor_dim: TensorDim) -> None:
if tensor_dim.name in self._tensor_dims:
Assert.eq(tensor_dim, self._tensor_dims[tensor_dim.name])
else:
if dim.parallel_dim is not None:
assert dim.parallel_dim.name in self._distributed_config.distributed_dims, dim.parallel_dim.name
if tensor_dim.parallel_dim is not None:
assert (
tensor_dim.parallel_dim.name in self._distributed_config.distributed_dims
), tensor_dim.parallel_dim.name
Assert.eq(
dim.parallel_dim.__dict__,
self._distributed_config.distributed_dims[dim.parallel_dim.name].__dict__,
tensor_dim.parallel_dim.__dict__,
self._distributed_config.distributed_dims[tensor_dim.parallel_dim.name].__dict__,
)
self._tensor_dims[dim.name] = dim
self._tensor_dims[tensor_dim.name] = tensor_dim

def get_tensor_dim(self, name: str) -> TensorDim:
def __getitem__(self, name: str) -> TensorDim:
return self._tensor_dims[name]
32 changes: 7 additions & 25 deletions fast_llm/engine/multi_stage/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,39 +441,21 @@ def _get_parameter_shard_indices_in_full_weight(
where it is located in the shard if it exists, or -1 if it's not in the shard.
Used to determine the location of each entry in a different distributed configuration.
"""

# Create an empty index for the global parameter.
index = torch.full(
parameter_meta.global_shape,
-1,
dtype=torch.int64,
device=device,
)
# Set the shard slice of the global parameter to corresponding indices of the parameter slice of the shard
begin, end = self._get_parameter_range_in_shard(parameter_name)

buffer_index = parameter_meta.global_to_local(index, expand=True)
# Copying directly into `buffer_index` requires a view of the tensor, which may not be feasible.
# In that case, we work with a separate tensor to be copied back into `buffer_index`.
try:
buffer_index_flat = buffer_index.view(-1)
is_view = True
except RuntimeError:
buffer_index_flat = buffer_index.new_full((buffer_index.numel(),), -1)
is_view = False

# Copy the shard indices at their respective positions in the flat buffer index.
buffer_index_flat[
# Create an empty local index to hold the local shard indices.
buffer_index = torch.full_like(parameter_meta, -1, dtype=torch.int64, device=device)

# Copy the shard indices at their respective positions in the buffer index.
buffer_index.flatten()[
self._index_buffer_to_param(
self._fsdp_dim.rank * self._shard_size, parameter_name
) : self._index_buffer_to_param((self._fsdp_dim.rank + 1) * self._shard_size, parameter_name)
].copy_(torch.arange(begin, end, dtype=torch.int64, device=device))

# If needed, copy the flat buffer index back into the index.
if not is_view:
buffer_index.copy_(buffer_index_flat.view_as(buffer_index))

return index
# Create a global index from the local one.
return parameter_meta.local_to_global_partial(buffer_index, -1)

def copy_shard_overlaps(
self,
Expand Down
7 changes: 5 additions & 2 deletions fast_llm/engine/multi_stage/stage_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,12 +185,15 @@ def initialize_weights(self) -> None:
# Multi-gpu init may be different because of TP or FSDP (different shape), or PP (not on device)
global_shape = meta.global_shape

if self._distributed_config.reproducible_init and (
global_shape.numel() != parameter.numel() or not self._mode.on_device
if meta.requires_global_initialization or (
self._distributed_config.reproducible_init
and (global_shape.numel() != parameter.numel() or not self._mode.on_device)
):
# Initialize all global weights on every gpu, then select the appropriate slice if applicable.
global_param = parameter.new_empty(global_shape, device=self._distributed.device)
meta.init_parameter(global_param, distributed=self._distributed)
# It happens.
Assert.eq(global_param.shape, global_shape)
if self._mode.on_device:
parameter.copy_(fsdp.parameter_global_to_shard(global_param, meta.tensor_name))
elif self._mode.on_device:
Expand Down
6 changes: 2 additions & 4 deletions fast_llm/layers/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ class LayerNormalizationBaseConfig(NormalizationConfig):
)

def get_layer(self, hidden_dim: "TensorDim", lr_scale: float | None = None) -> "LayerNorm | RMSNorm":
from fast_llm.tensor import init_uniform_
from fast_llm.tensor import init_uniform_centered_

kwargs = {
"hidden_dim": hidden_dim,
Expand All @@ -110,9 +110,7 @@ def get_layer(self, hidden_dim: "TensorDim", lr_scale: float | None = None) -> "
}
if self.initialization_range:
mean = 0 if self.zero_centered else 1
kwargs["weight_init_method"] = init_uniform_(
mean - self.initialization_range, mean + self.initialization_range
)
kwargs["weight_init_method"] = init_uniform_centered_(self.initialization_range, mean=mean)
return self.module_class(**kwargs)

@property
Expand Down
Loading