-
Notifications
You must be signed in to change notification settings - Fork 3.9k
[codex] Fix Mamba conv params under fine-grained FSDP gather #4467
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
ilml
wants to merge
3
commits into
NVIDIA:main
Choose a base branch
from
ilml:codex/fix-mamba-fsdp-direct-conv-gather
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -70,6 +70,7 @@ def __init__( | |
| A_log_cp1: torch.Tensor, | ||
| D_cp1: torch.Tensor, | ||
| D_has_hdim: bool, | ||
| mixer=None, | ||
| ) -> None: | ||
| if not HAVE_EINOPS: | ||
| raise ImportError("einops is required by the Mamba model but cannot be imported") | ||
|
|
@@ -84,6 +85,7 @@ def __init__( | |
| self.A_log_cp1 = A_log_cp1 | ||
| self.D_cp1 = D_cp1 | ||
| self.D_has_hdim = D_has_hdim | ||
| self._mixer = mixer | ||
|
|
||
| self.cp_size = self.cp_group.size() | ||
|
|
||
|
|
@@ -231,24 +233,29 @@ def conv1d_channels(self): | |
| def get_conv1d_weight(self) -> torch.Tensor: | ||
| """Returns a slice of the conv1d weight relevant to the current context parallel rank""" | ||
| # weight shape: [conv_dim, 1, d_conv] | ||
| return self._slice_conv_param(self.conv1d_cp1.weight) | ||
| conv1d = self._mixer.conv1d if self._mixer is not None else self.conv1d_cp1 | ||
| return self._slice_conv_param(conv1d.weight) | ||
|
Comment on lines
-234
to
+237
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Have to ask, where/when do we have "stale" references such that we can no longer directly retrieve the weight? |
||
|
|
||
| def get_conv1d_bias(self) -> torch.Tensor: | ||
| """Returns a slice of the conv1d bias relevant to the current context parallel rank""" | ||
| # bias shape: [conv_dim] | ||
| return self._slice_conv_param(self.conv1d_cp1.bias) | ||
| conv1d = self._mixer.conv1d if self._mixer is not None else self.conv1d_cp1 | ||
| return self._slice_conv_param(conv1d.bias) | ||
|
|
||
| def get_dt_bias(self) -> torch.Tensor: | ||
| """Returns a slice of dt_bias relevant to the current context parallel rank""" | ||
| return self._slice_vector_param(self.dt_bias_cp1) | ||
| param = self._mixer.dt_bias if self._mixer is not None else self.dt_bias_cp1 | ||
| return self._slice_vector_param(param) | ||
|
|
||
| def get_A_log(self) -> torch.Tensor: | ||
| """Returns a slice of A_log relevant to the current context parallel rank""" | ||
| return self._slice_vector_param(self.A_log_cp1) | ||
| param = self._mixer.A_log if self._mixer is not None else self.A_log_cp1 | ||
| return self._slice_vector_param(param) | ||
|
|
||
| def get_D(self) -> torch.Tensor: | ||
| """Returns a slice of D relevant to the current context parallel rank""" | ||
| return self._slice_vector_param(self.D_cp1, has_hdim=self.D_has_hdim) | ||
| param = self._mixer.D if self._mixer is not None else self.D_cp1 | ||
| return self._slice_vector_param(param, has_hdim=self.D_has_hdim) | ||
|
|
||
| def _slice_conv_param(self, param: torch.Tensor) -> torch.Tensor: | ||
| """ | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Early note, what about the pre-backward param unshard?
We need to ensure that any rogue weights are re-sharded, and un-sharded during the backward pass. Did you check that the Conv-1D weights are re-sharded?