Skip to content

Commit 0002f19

Browse files
aliafzalfacebook-github-bot
authored andcommitted
ModelDeltaTracker initial checkin (#3057)
Summary: Pull Request resolved: #3057 # Summary: This PR is an initial checkin which introduces ```ModelDeltaTracker```. ModelDeltaTracker is a utility for tracking and retrieving unique IDs and their corresponding embeddings or states from embedding modules in model using Torchrec. It's particularly useful for: 1. Identifying which embedding rows were accessed during model execution 2. Retrieving the latest delta or unique rows for a model 3. Computing top-k changed embeddings 4. Supporting streaming updated embeddings between systems during online training The tracker works with ```ShardedEmbeddingCollection``` and ```ShardedEmbeddingBagCollection``` modules and supports different tracking modes (Adding support for optimizer modes in follow up diffs): * ID\_ONLY: Only tracks which IDs were accessed * EMBEDDING: Tracks both IDs and their embedding values ## Key features: * Multiple consumer support (each consumer can track its own state) * Configurable deletion policy for tracked data * Ability to retrieve delta information for specific consumers This utility helps optimize training workflows by enabling systems to focus on the most recently changed embeddings rather than processing the entire embedding table. Reviewed By: chouxi Differential Revision: D75853147 fbshipit-source-id: 8ce3960af7819c7d13b073605350570fb3af18b6
1 parent 18c1380 commit 0002f19

File tree

1 file changed

+114
-0
lines changed

1 file changed

+114
-0
lines changed
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-strict
9+
from typing import Dict, List, Optional, Union
10+
11+
import torch
12+
13+
from torch import nn
14+
from torchrec.distributed.embedding import ShardedEmbeddingCollection
15+
from torchrec.distributed.embeddingbag import ShardedEmbeddingBagCollection
16+
from torchrec.distributed.model_tracker.types import (
17+
DeltaRows,
18+
EmbdUpdateMode,
19+
TrackingMode,
20+
)
21+
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
22+
23+
UPDATE_MODE_MAP: Dict[TrackingMode, EmbdUpdateMode] = {
24+
# Only IDs are tracked, no additional state is stored.
25+
TrackingMode.ID_ONLY: EmbdUpdateMode.NONE,
26+
# TrackingMode.EMBEDDING utilizes EmbdUpdateMode.FIRST to ensure that
27+
# the earliest embedding values are stored since the last checkpoint or snapshot.
28+
# This mode is used for computing topk delta rows, which is currently achieved by running (new_emb - old_emb).norm().topk().
29+
TrackingMode.EMBEDDING: EmbdUpdateMode.FIRST,
30+
}
31+
32+
# Tracking is current only supported for ShardedEmbeddingCollection and ShardedEmbeddingBagCollection.
33+
SUPPORTED_MODULES = Union[ShardedEmbeddingCollection, ShardedEmbeddingBagCollection]
34+
35+
36+
class ModelDeltaTracker:
37+
r"""
38+
39+
ModelDeltaTracker provides a way to track and retrieve unique IDs for supported modules, along with optional support
40+
for tracking corresponding embeddings or states. This is useful for identifying and retrieving the latest delta or
41+
unique rows for a given model, which can help compute topk or to stream updated embeddings from predictors to trainers during
42+
online training. Unique IDs or states can be retrieved by calling the get_unique() method.
43+
44+
Args:
45+
model (nn.Module): the model to track.
46+
consumers (List[str], optional): list of consumers to track. Each consumer will
47+
have its own batch offset index. Every get_unique_ids invocation will
48+
only return the new ids for the given consumer since last get_unique_ids
49+
call.
50+
delete_on_read (bool, optional): whether to delete the tracked ids after all consumers have read them.
51+
mode (TrackingMode, optional): tracking mode to use from supported tracking modes. Default: TrackingMode.ID_ONLY.
52+
"""
53+
54+
DEFAULT_CONSUMER: str = "default"
55+
56+
def __init__(
57+
self,
58+
model: nn.Module,
59+
consumers: Optional[List[str]] = None,
60+
delete_on_read: bool = True,
61+
mode: TrackingMode = TrackingMode.ID_ONLY,
62+
) -> None:
63+
self._model = model
64+
self._consumers: List[str] = consumers or [self.DEFAULT_CONSUMER]
65+
self._delete_on_read = delete_on_read
66+
self._mode = mode
67+
pass
68+
69+
def record_lookup(self, kjt: KeyedJaggedTensor, states: torch.Tensor) -> None:
70+
"""
71+
Record Ids from a given KeyedJaggedTensor and embeddings/ parameter states.
72+
73+
Args:
74+
kjt (KeyedJaggedTensor): the KeyedJaggedTensor to record.
75+
states (torch.Tensor): the states to record.
76+
"""
77+
pass
78+
79+
def get_delta(self, consumer: Optional[str] = None) -> Dict[str, DeltaRows]:
80+
"""
81+
Return a dictionary of hit local IDs for each sparse feature. The IDs are first keyed by submodule FQN.
82+
83+
Args:
84+
consumer (str, optional): The consumer to retrieve IDs for. If not specified, "default" is used as the default consumer.
85+
"""
86+
return {}
87+
88+
def fqn_to_feature_names(self, module: nn.Module) -> Dict[str, List[str]]:
89+
"""
90+
Returns a mapping from FQN to feature names for a given module.
91+
92+
Args:
93+
module (nn.Module): the module to retrieve feature names for.
94+
"""
95+
return {}
96+
97+
def clear(self, consumer: Optional[str] = None) -> None:
98+
"""
99+
Clear tracked IDs for a given consumer.
100+
101+
Args:
102+
consumer (str, optional): The consumer to clear IDs/States for. If not specified, "default" is used as the default consumer.
103+
"""
104+
pass
105+
106+
def compact(self, start_idx: int, end_idx: int) -> None:
107+
"""
108+
Compact tracked IDs for a given range of indices.
109+
110+
Args:
111+
start_idx (int): Starting index for compaction.
112+
end_idx (int): Ending index for compaction.
113+
"""
114+
pass

0 commit comments

Comments
 (0)