-
Notifications
You must be signed in to change notification settings - Fork 7
Add DeepSeekV3 and figure out what's needed to support it #29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
fmassa
wants to merge
30
commits into
main
Choose a base branch
from
fmassa/deepseekv3
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+1,179
−24
Draft
Changes from 1 commit
Commits
Show all changes
30 commits
Select commit
Hold shift + click to select a range
52ea0c1
[WIP] Add basic DeepSeekV3
fmassa 0d3ae2d
Lint
fmassa 98d9dfd
Workarounds to make graph capture pass
fmassa 61a63c4
Add dummy propagation rules just to see what we need to implement
fmassa 67eb264
Cleanup
fmassa 86d53ff
prims.fma comes from softmax_backward
fmassa 7864f4d
Make _geenrate_dummy_strategy more generic
fmassa 60ccf1a
Add proper redistribute_cost to dummy strategies
fmassa dbbc205
Hack around missing dtypes in compute estimation and handle grouped_m…
fmassa d92f8c6
Add representative batch size
fmassa e25ff7b
Fix grouped_mm stride issue
wconstab 3b7e7fa
get DS3 running forward, OOM at backward
wconstab 3833a06
WIP factory_strategy
wconstab 3740b45
Start rebasing on top of main
fmassa 39fedfd
Merge branch 'main' of github.com:pytorch-labs/autoparallel into fmas…
fmassa 6bec5f5
Fixes so that it runs
fmassa ce1c0a5
[WIP] Plumb fake_mode to avoid materializing memory
fmassa 5d79bec
Use more representative values for DS3 example
fmassa daea5a2
Add approximate flop formula to grouped_mm
fmassa 6d350e0
Merge branch 'main' of github.com:pytorch-labs/autoparallel into fmas…
fmassa 418ad55
Glimpses of having DeepSeekV3 returning a reasonable solution
fmassa fce321f
Merge branch 'main' of github.com:pytorch-labs/autoparallel into fmas…
fmassa 6d5747a
Use with_implicit_strategies instead of my generate_dummy_strategy
fmassa e0ae8a2
[WIP] Convert view->mm->view into matmul
fmassa 1b83581
Merge branch 'main' of github.com:pytorch-labs/autoparallel into fmas…
fmassa cf1229d
Merge branch 'main' of github.com:pytorch-labs/autoparallel into fmas…
fmassa 4fe5a40
Merge branch 'main' of github.com:meta-pytorch/autoparallel into fmas…
fmassa 67542ad
Remove sharding rules that have been since moved to PyTorch
fmassa 779e808
Merge branch 'main' of github.com:meta-pytorch/autoparallel into fmas…
fmassa 124034e
Fixes after rebase
fmassa File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -668,6 +668,127 @@ def index_rule(mesh, op_schema): | |
| return out_strat | ||
|
|
||
|
|
||
| @register_opschema_rule(torch.ops.aten.sort.stable) | ||
| def sort_rule(mesh, op_schema): | ||
| op = torch.ops.aten.topk.default | ||
| out_strat = torch.distributed.tensor.DTensor._op_dispatcher.sharding_propagator.op_strategy_funcs[ | ||
| op | ||
| ]( | ||
| op_schema | ||
| ) | ||
| return out_strat | ||
|
|
||
|
|
||
| @register_opschema_rule(torch.ops.aten.gather.default) | ||
| def gather_strategy(mesh, op_schema): | ||
| from torch.distributed.tensor._op_schema import PlacementList | ||
| from torch.distributed.tensor._ops._embedding_ops import _MaskPartial | ||
| from torch.distributed.tensor._ops.utils import expand_to_full_mesh_op_strategy | ||
|
|
||
| input_strategy = op_schema.args_schema[0] | ||
| dim = op_schema.args_schema[1] | ||
| index_strategy = op_schema.args_schema[2] | ||
|
|
||
| input_shape = input_strategy.shape | ||
| index_shape = index_strategy.shape | ||
|
|
||
| single_mesh_dim_strategies = [] | ||
|
|
||
| # placement list stores placements of [output, input, index] | ||
| # first we always have replicate all for inputs and output | ||
| all_replicate: PlacementList = [Replicate()] * 3 | ||
| single_mesh_dim_strategies.append(all_replicate) | ||
|
|
||
| # input sharding, input sharded, index accepts mask partial, output follows index | ||
| # this only works when the input is sharded on the gather dimension, and | ||
| # index has size 1 on the gather dimension | ||
| if index_shape[dim] == 1: | ||
| index_partial_placement = _MaskPartial(offset_shape=input_shape, offset_dim=dim) | ||
| input_sharding: PlacementList = [ | ||
| index_partial_placement, | ||
| Shard(dim), | ||
| index_partial_placement, | ||
| ] | ||
| single_mesh_dim_strategies.append(input_sharding) | ||
|
|
||
| # index sharding, input replicated, index sharded, output follows index | ||
| # this only works when the sharding dimension is the gather dimension | ||
| index_sharding: PlacementList = [Shard(dim), Replicate(), Shard(dim)] | ||
| single_mesh_dim_strategies.append(index_sharding) | ||
|
|
||
| if len(input_shape) == len(index_shape): | ||
| for d in range(len(input_shape)): | ||
| if d != dim: | ||
| sharding = [Shard(d), Shard(d), Shard(d)] | ||
| single_mesh_dim_strategies.append(sharding) | ||
|
|
||
| return expand_to_full_mesh_op_strategy( | ||
| mesh, op_schema, single_mesh_dim_strategies, input_index=1 | ||
| ) | ||
|
|
||
|
|
||
| @register_opschema_rule(torch.ops.aten.scatter_add.default) | ||
| def scatter_add_strategy(mesh, op_schema): | ||
|
||
| from torch.distributed.tensor._op_schema import PlacementList | ||
|
|
||
| # from torch.distributed.tensor._ops._embedding_ops import _MaskPartial | ||
| from torch.distributed.tensor._ops.utils import expand_to_full_mesh_op_strategy | ||
|
|
||
| input_strategy = op_schema.args_schema[0] | ||
| dim = op_schema.args_schema[1] | ||
| index_strategy = op_schema.args_schema[2] | ||
| # src_strategy = op_schema.args_schema[3] | ||
|
|
||
| input_shape = input_strategy.shape | ||
| index_shape = index_strategy.shape | ||
|
|
||
| single_mesh_dim_strategies = [] | ||
|
|
||
| # placement list stores placements of [output, input, index] | ||
| # first we always have replicate all for inputs and output | ||
| all_replicate: PlacementList = [Replicate()] * 4 | ||
| single_mesh_dim_strategies.append(all_replicate) | ||
|
|
||
| """ | ||
| # input sharding, input sharded, index accepts mask partial, output follows index | ||
| # this only works when the input is sharded on the gather dimension, and | ||
| # index has size 1 on the gather dimension | ||
| if index_shape[dim] == 1: | ||
| index_partial_placement = _MaskPartial(offset_shape=input_shape, offset_dim=dim) | ||
| input_sharding: PlacementList = [ | ||
| index_partial_placement, | ||
| Shard(dim), | ||
| index_partial_placement, | ||
| ] | ||
| single_mesh_dim_strategies.append(input_sharding) | ||
| """ | ||
| # index sharding, input replicated, index sharded, output follows index | ||
| # this only works when the sharding dimension is the gather dimension | ||
| index_sharding: PlacementList = [Shard(dim), Replicate(), Shard(dim), Shard(dim)] | ||
| single_mesh_dim_strategies.append(index_sharding) | ||
|
|
||
| if len(input_shape) == len(index_shape): | ||
| for d in range(len(input_shape)): | ||
| if d != dim: | ||
| sharding = [Shard(d), Shard(d), Shard(d), Shard(d)] | ||
| single_mesh_dim_strategies.append(sharding) | ||
|
|
||
| return expand_to_full_mesh_op_strategy( | ||
| mesh, op_schema, single_mesh_dim_strategies, input_index=1 | ||
| ) | ||
|
|
||
|
|
||
| @register_opschema_rule(torch.ops.aten.slice_scatter.default) | ||
| def slice_scatter_rule(mesh, op_schema): | ||
| op = torch.ops.aten.slice_scatter.default | ||
| out_strat = torch.distributed.tensor.DTensor._op_dispatcher.sharding_propagator.op_strategy_funcs[ | ||
| op | ||
| ]( | ||
| op_schema | ||
| ) | ||
| return out_strat | ||
|
|
||
|
|
||
| def sdpa_rule(op, mesh, op_schema): | ||
| out_strat = torch.distributed.tensor.DTensor._op_dispatcher.sharding_propagator.op_strategy_funcs[ | ||
| op | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.