Skip to content

[PP + EP][Stage I] PP x Mamba #134

@lchu6

Description

@lchu6

Working Items

Modifications on Mamba code (lchu6/mamba@fd4fa08)

  • nn.ModuleList -> nn.ModuleDict. Currently the Mamba blocks are stored in ModuleList which has a dynamic FQN that will cause issue for checkpoint saving when using PP. For example, say we have 2 ranks for PP and 2 blocks for the model. we want PP rank0 has block 0 while PP rank1 has block 1, thus we "remove" block0 on pp rank1, in which case, the FQN of block 1 will change from "Block1" to "Block0" because that block now becomes the "0th element in the ModuleList", this is undesired as we want static FQN for all modules. when switching to ModuleDict we can guarantee the FQN as it is always attached to the "key" of the dict and "Block1" will always be "Block1".
  • Implement None layer for embedding, norm_f and lm_head. non-first-pp-stages will have embedding as None while non-last-pp-stages will have both norm_f and lm_head as None. current implementation does not deal with None modules so we need to modify the Mamba code to allow optional None values for these modules.
  • handle hidden_states + residual properly. Unlike other model blocks, Mamba block actually outputs both hidden_states and residual separately, which isn't necessary as these two can be summed before outputting, however, keeping them separately makes fused_add_norm possible and accelerate the run. So Mamba did this. We need to "pack" and "unpack" these two properly as we communicate them through different PP ranks.
  • Make all outputs Tensor. Following HF style, Mamba uses a CausalLMOutput object (NamedTuple) as model output. However, PP does shape validation across ranks by calling .shape on all inputs/outputs for shape equality assertion. So we need to change the output from NamedTuple to a regular Tensor.

Some notes:

  1. Unlike other model implementations where we have very clear top-level flat hierarchy (embedding -> [blocks] -> norm -> lm_head), Mamba has backbone -> lm_head where backbone has embedding and norm inside. So we need to be a little careful here with this slightly nested hierarchy. And when making layers None-able we might need to modify in different forward() functions and check carefully.
  2. We might want to revisit and change the Mamba block output at a later stage. As mentioned previously, Mamba outputs hidden_states and residual separately in order to fuse the add_norm, however, this means we will need to communicate 2 x [b, s, h] when doing communications between PP ranks, rather than [b, s, h]. So we might want to revisit this later to change the output to summed hidden_states and residual rather than separate, which will make the add_norm a tiny bit slower due to non-fused, but will halve the size of the tensor that needs to be communicated among PP ranks.

Modifications on Training code

TODO

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions