diff --git a/torchrec_dlrm/dlrm_main.py b/torchrec_dlrm/dlrm_main.py index 7025b450..ea188204 100644 --- a/torchrec_dlrm/dlrm_main.py +++ b/torchrec_dlrm/dlrm_main.py @@ -24,10 +24,27 @@ TOTAL_TRAINING_SAMPLES, ) from torchrec.datasets.utils import Batch +from torchrec.distributed.comm import get_local_size from torchrec.distributed import TrainPipelineSparseDist from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder +from torchrec.distributed.embedding_types import EmbeddingComputeKernel from torchrec.distributed.model_parallel import DistributedModelParallel -from torchrec.distributed.types import ModuleSharder +from torchrec.distributed.planner import ( + EmbeddingShardingPlanner, + ParameterConstraints, + Topology, +) +from torchrec.distributed.planner.constants import ( + INTRA_NODE_BANDWIDTH, + CROSS_NODE_BANDWIDTH, + HBM_CAP, + DDR_CAP, +) +from torchrec.distributed.types import ( + ModuleSharder, + ShardingEnv, + ShardingType, +) from torchrec.modules.embedding_configs import EmbeddingBagConfig from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizerWrapper from tqdm import tqdm @@ -197,18 +214,50 @@ def parse_args(argv: List[str]) -> argparse.Namespace: default=0.20, help="Learning rate after change point in first epoch.", ) - parser.set_defaults( - pin_memory=None, - mmap_mode=None, - shuffle_batches=None, - change_lr=None, - ) parser.add_argument( "--adagrad", dest="adagrad", action="store_true", help="Flag to determine if adagrad optimizer should be used.", ) + parser.add_argument( + "--sharding_type", + type=str, + choices=[st.value for st in ShardingType], + help="ShardingType constraint for all embedding tables" + ) + parser.add_argument( + "--compute_kernel", + type=str, + choices=[ck.value for ck in EmbeddingComputeKernel], + help="ComputeKernel constraint for all embedding tables" + ) + parser.add_argument( + "--intra_host_bw", + type=float, + default=INTRA_NODE_BANDWIDTH, + ) + parser.add_argument( + "--inter_host_bw", + type=float, + default=CROSS_NODE_BANDWIDTH, + ) + parser.add_argument( + "--hbm_cap", + type=int, + default=HBM_CAP, + ) + parser.add_argument( + "--ddr_cap", + type=int, + default=DDR_CAP, + ) + parser.set_defaults( + pin_memory=None, + mmap_mode=None, + shuffle_batches=None, + change_lr=None, + ) return parser.parse_args(argv) @@ -534,10 +583,38 @@ def main(argv: List[str]) -> None: EmbeddingBagCollectionSharder(fused_params=fused_params), ] + pg = dist.GroupMember.WORLD + assert pg is not None, "Process group is not initialized" + env = ShardingEnv.from_process_group(pg) + if any(a is not None for a in [args.sharding_type, args.compute_kernel]): + sharding_types = [args.sharding_type] if args.sharding_type else None + compute_kernels = [args.compute_kernel] if args.compute_kernel else None + constraints = { + f"t_{feature_name}": ParameterConstraints(sharding_types=sharding_types, compute_kernels=compute_kernels) + for feature_name in DEFAULT_CAT_NAMES + } + else: + constraints = None + planner = EmbeddingShardingPlanner( + topology=Topology( + world_size=env.world_size, + local_world_size=get_local_size(env.world_size), + compute_device=device.type, + hbm_cap=args.hbm_cap, + ddr_cap=args.ddr_cap, + intra_host_bw=args.intra_host_bw, + inter_host_bw=args.inter_host_bw, + batch_size=args.batch_size, + ), + constraints=constraints, + ) + plan = planner.collective_plan(train_model, sharders, pg) + model = DistributedModelParallel( module=train_model, device=device, sharders=cast(List[ModuleSharder[nn.Module]], sharders), + plan=plan, ) def optimizer_with_params():