-
Notifications
You must be signed in to change notification settings - Fork 46
Open
Description
Working Items
Modifications on Mamba code (lchu6/mamba@fd4fa08)
-
nn.ModuleList -> nn.ModuleDict. Currently the Mamba blocks are stored inModuleListwhich has a dynamic FQN that will cause issue for checkpoint saving when using PP. For example, say we have 2 ranks for PP and 2 blocks for the model. we want PP rank0 has block 0 while PP rank1 has block 1, thus we "remove" block0 on pp rank1, in which case, the FQN of block 1 will change from "Block1" to "Block0" because that block now becomes the "0th element in the ModuleList", this is undesired as we want static FQN for all modules. when switching toModuleDictwe can guarantee the FQN as it is always attached to the "key" of the dict and "Block1" will always be "Block1". - Implement
Nonelayer forembedding,norm_fandlm_head. non-first-pp-stages will haveembeddingas None while non-last-pp-stages will have bothnorm_fandlm_headas None. current implementation does not deal with None modules so we need to modify the Mamba code to allow optional None values for these modules. - handle
hidden_states+residualproperly. Unlike other model blocks, Mamba block actually outputs bothhidden_statesandresidualseparately, which isn't necessary as these two can be summed before outputting, however, keeping them separately makesfused_add_normpossible and accelerate the run. So Mamba did this. We need to "pack" and "unpack" these two properly as we communicate them through different PP ranks. - Make all outputs Tensor. Following HF style, Mamba uses a CausalLMOutput object (NamedTuple) as model output. However, PP does shape validation across ranks by calling
.shapeon all inputs/outputs for shape equality assertion. So we need to change the output from NamedTuple to a regular Tensor.
Some notes:
- Unlike other model implementations where we have very clear top-level flat hierarchy (
embedding->[blocks]->norm->lm_head), Mamba hasbackbone->lm_headwherebackbonehasembeddingandnorminside. So we need to be a little careful here with this slightly nested hierarchy. And when making layers None-able we might need to modify in differentforward()functions and check carefully. - We might want to revisit and change the Mamba block output at a later stage. As mentioned previously, Mamba outputs
hidden_statesandresidualseparately in order to fuse theadd_norm, however, this means we will need to communicate 2 x [b, s, h] when doing communications between PP ranks, rather than [b, s, h]. So we might want to revisit this later to change the output to summed hidden_states and residual rather than separate, which will make the add_norm a tiny bit slower due to non-fused, but will halve the size of the tensor that needs to be communicated among PP ranks.
Modifications on Training code
TODO
Metadata
Metadata
Assignees
Labels
No labels