|
3 | 3 | # This source code is licensed under the BSD license found in the |
4 | 4 | # LICENSE file in the root directory of this source tree. |
5 | 5 | import logging |
| 6 | +import operator |
6 | 7 | from collections import defaultdict |
7 | 8 | from dataclasses import dataclass |
8 | 9 | from typing import Optional, Union |
@@ -207,6 +208,26 @@ def mark_nodes_as_must_save_to_stage_recomputation( |
207 | 208 | if node.meta.get("recompute", None) is not None: |
208 | 209 | # do not mess with allgather nodes that have already been marked recompute! |
209 | 210 | continue |
| 211 | + if node.target is operator.getitem: |
| 212 | + # we need to be a bit careful: we are trying to manually emulate setting "precompute" tags |
| 213 | + # in the same way that compiel does when it encounters userland SAC. |
| 214 | + # |
| 215 | + # torch.compile does this by using TorchDispatchModes to intercept ops as they are traced, |
| 216 | + # and setting their "recompute" tag. |
| 217 | + # |
| 218 | + # However, TorchDispatchModes *only* intercept OpOverloads (and HOPs) |
| 219 | + # getitem is neither, and so in vanilla torch.compile usage, |
| 220 | + # getitem nodes recieve no tags. |
| 221 | + # |
| 222 | + # What happens if we blindly set all nodes to PREFER_RECOMPUTE? Example bad outcome: |
| 223 | + # - user is using attention, so we see this series of ops in the joint graph: |
| 224 | + # attention_fw -> getitem -> attention_bw (the getitem is an output used for the bw) |
| 225 | + # - user runs SAC, and marks attention_fw as MUST_SAVE |
| 226 | + # - if we mark getitem as PREFER_RECOMPUTE, and attention_fw as MUST_SAVE, |
| 227 | + # the partitioner ends up generating an invalid graph. |
| 228 | + # Today the partitioner relies on the fact that getitem's recompute behavior |
| 229 | + # is implicitly determined by the recompute behavior of the multi-output op preceding it. |
| 230 | + continue |
210 | 231 | node.meta["recompute"] = CheckpointPolicy.PREFER_RECOMPUTE |
211 | 232 | # add an arbitrarily large graph id. I'm assuming 100000 here, which should be fine |
212 | 233 | # and is the same we add for the all-gather nodes |
@@ -327,7 +348,7 @@ def ac_joint_pass(graph: torch.fx.Graph, ac_stage_size_in_GiB: float = 2.0): |
327 | 348 | # policy, but this is not working yet |
328 | 349 | save_list = { |
329 | 350 | torch.ops.aten.mm.default, |
330 | | - # torch.ops.aten._scaled_dot_product_efficient_attention.default, |
331 | | - # torch.ops.aten._scaled_dot_product_flash_attention.default, |
| 351 | + torch.ops.aten._scaled_dot_product_efficient_attention.default, |
| 352 | + torch.ops.aten._scaled_dot_product_flash_attention.default, |
332 | 353 | } |
333 | 354 | _apply_ac_policy(graph, save_list=save_list) |
0 commit comments