|
| 1 | +#!/usr/bin/env python3 |
| 2 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 3 | +# All rights reserved. |
| 4 | +# |
| 5 | +# This source code is licensed under the BSD-style license found in the |
| 6 | +# LICENSE file in the root directory of this source tree. |
| 7 | + |
| 8 | +# pyre-strict |
| 9 | +from typing import Dict, List, Optional, Union |
| 10 | + |
| 11 | +import torch |
| 12 | + |
| 13 | +from torch import nn |
| 14 | +from torchrec.distributed.embedding import ShardedEmbeddingCollection |
| 15 | +from torchrec.distributed.embeddingbag import ShardedEmbeddingBagCollection |
| 16 | +from torchrec.distributed.model_tracker.types import ( |
| 17 | + DeltaRows, |
| 18 | + EmbdUpdateMode, |
| 19 | + TrackingMode, |
| 20 | +) |
| 21 | +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor |
| 22 | + |
| 23 | +UPDATE_MODE_MAP: Dict[TrackingMode, EmbdUpdateMode] = { |
| 24 | + # Only IDs are tracked, no additional state is stored. |
| 25 | + TrackingMode.ID_ONLY: EmbdUpdateMode.NONE, |
| 26 | + # TrackingMode.EMBEDDING utilizes EmbdUpdateMode.FIRST to ensure that |
| 27 | + # the earliest embedding values are stored since the last checkpoint or snapshot. |
| 28 | + # This mode is used for computing topk delta rows, which is currently achieved by running (new_emb - old_emb).norm().topk(). |
| 29 | + TrackingMode.EMBEDDING: EmbdUpdateMode.FIRST, |
| 30 | +} |
| 31 | + |
| 32 | +# Tracking is current only supported for ShardedEmbeddingCollection and ShardedEmbeddingBagCollection. |
| 33 | +SUPPORTED_MODULES = Union[ShardedEmbeddingCollection, ShardedEmbeddingBagCollection] |
| 34 | + |
| 35 | + |
| 36 | +class ModelDeltaTracker: |
| 37 | + r""" |
| 38 | +
|
| 39 | + ModelDeltaTracker provides a way to track and retrieve unique IDs for supported modules, along with optional support |
| 40 | + for tracking corresponding embeddings or states. This is useful for identifying and retrieving the latest delta or |
| 41 | + unique rows for a given model, which can help compute topk or to stream updated embeddings from predictors to trainers during |
| 42 | + online training. Unique IDs or states can be retrieved by calling the get_unique() method. |
| 43 | +
|
| 44 | + Args: |
| 45 | + model (nn.Module): the model to track. |
| 46 | + consumers (List[str], optional): list of consumers to track. Each consumer will |
| 47 | + have its own batch offset index. Every get_unique_ids invocation will |
| 48 | + only return the new ids for the given consumer since last get_unique_ids |
| 49 | + call. |
| 50 | + delete_on_read (bool, optional): whether to delete the tracked ids after all consumers have read them. |
| 51 | + mode (TrackingMode, optional): tracking mode to use from supported tracking modes. Default: TrackingMode.ID_ONLY. |
| 52 | + """ |
| 53 | + |
| 54 | + DEFAULT_CONSUMER: str = "default" |
| 55 | + |
| 56 | + def __init__( |
| 57 | + self, |
| 58 | + model: nn.Module, |
| 59 | + consumers: Optional[List[str]] = None, |
| 60 | + delete_on_read: bool = True, |
| 61 | + mode: TrackingMode = TrackingMode.ID_ONLY, |
| 62 | + ) -> None: |
| 63 | + self._model = model |
| 64 | + self._consumers: List[str] = consumers or [self.DEFAULT_CONSUMER] |
| 65 | + self._delete_on_read = delete_on_read |
| 66 | + self._mode = mode |
| 67 | + pass |
| 68 | + |
| 69 | + def record_lookup(self, kjt: KeyedJaggedTensor, states: torch.Tensor) -> None: |
| 70 | + """ |
| 71 | + Record Ids from a given KeyedJaggedTensor and embeddings/ parameter states. |
| 72 | +
|
| 73 | + Args: |
| 74 | + kjt (KeyedJaggedTensor): the KeyedJaggedTensor to record. |
| 75 | + states (torch.Tensor): the states to record. |
| 76 | + """ |
| 77 | + pass |
| 78 | + |
| 79 | + def get_delta(self, consumer: Optional[str] = None) -> Dict[str, DeltaRows]: |
| 80 | + """ |
| 81 | + Return a dictionary of hit local IDs for each sparse feature. The IDs are first keyed by submodule FQN. |
| 82 | +
|
| 83 | + Args: |
| 84 | + consumer (str, optional): The consumer to retrieve IDs for. If not specified, "default" is used as the default consumer. |
| 85 | + """ |
| 86 | + return {} |
| 87 | + |
| 88 | + def fqn_to_feature_names(self, module: nn.Module) -> Dict[str, List[str]]: |
| 89 | + """ |
| 90 | + Returns a mapping from FQN to feature names for a given module. |
| 91 | +
|
| 92 | + Args: |
| 93 | + module (nn.Module): the module to retrieve feature names for. |
| 94 | + """ |
| 95 | + return {} |
| 96 | + |
| 97 | + def clear(self, consumer: Optional[str] = None) -> None: |
| 98 | + """ |
| 99 | + Clear tracked IDs for a given consumer. |
| 100 | +
|
| 101 | + Args: |
| 102 | + consumer (str, optional): The consumer to clear IDs/States for. If not specified, "default" is used as the default consumer. |
| 103 | + """ |
| 104 | + pass |
| 105 | + |
| 106 | + def compact(self, start_idx: int, end_idx: int) -> None: |
| 107 | + """ |
| 108 | + Compact tracked IDs for a given range of indices. |
| 109 | +
|
| 110 | + Args: |
| 111 | + start_idx (int): Starting index for compaction. |
| 112 | + end_idx (int): Ending index for compaction. |
| 113 | + """ |
| 114 | + pass |
0 commit comments