Skip to content

Commit 8cda1a4

Browse files
jd7-trfacebook-github-bot
authored andcommitted
Add new argument for sharding type (#2969)
Summary: Pull Request resolved: #2969 Add new arg to `shard_quant_model` to allow users to specify sharding_type, which is currently hardcoded to `ShardingType.TABLE_WISE`. Still use `ShardingType.TABLE_WISE` as default value so there is no impact. Reviewed By: aporialiao Differential Revision: D73540534 fbshipit-source-id: bebefd06f38b98df23120e3a1217b5d5e92a357b
1 parent cc7f1d0 commit 8cda1a4

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

torchrec/inference/modules.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -499,6 +499,7 @@ def shard_quant_model(
499499
device_memory_size: Optional[int] = None,
500500
constraints: Optional[Dict[str, ParameterConstraints]] = None,
501501
ddr_cap: Optional[int] = None,
502+
sharding_type: ShardingType = ShardingType.TABLE_WISE,
502503
) -> Tuple[torch.nn.Module, ShardingPlan]:
503504
"""
504505
Shard a quantized TorchRec model, used for generating the most optimal model for inference and
@@ -534,6 +535,10 @@ def shard_quant_model(
534535
quant_model = quantize_inference_model(module)
535536
sharded_model, _ = shard_quant_model(quant_model)
536537
"""
538+
# TODO(T220572301): remove after new sharding types are validated.
539+
assert (
540+
sharding_type == ShardingType.TABLE_WISE
541+
), "Only table-wise sharding is supported now."
537542

538543
if constraints is None:
539544
table_fqns = []
@@ -552,7 +557,7 @@ def shard_quant_model(
552557
constraints = {}
553558
for name in table_fqns:
554559
constraints[name] = ParameterConstraints(
555-
sharding_types=[ShardingType.TABLE_WISE.value],
560+
sharding_types=[sharding_type.value],
556561
compute_kernels=[EmbeddingComputeKernel.QUANT.value],
557562
)
558563

0 commit comments

Comments
 (0)