Skip to content

Commit 2c573ce

Browse files
authored
fix getitem handling in existing SAC tag pass, add attention back to example SAC run (#123)
* fix getitem handling in existing SAC tag pass, turn on attention SAC in example * cleanup
1 parent 52d7a17 commit 2c573ce

File tree

1 file changed

+23
-2
lines changed

1 file changed

+23
-2
lines changed

autoparallel/activation_checkpointing.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# This source code is licensed under the BSD license found in the
44
# LICENSE file in the root directory of this source tree.
55
import logging
6+
import operator
67
from collections import defaultdict
78
from dataclasses import dataclass
89
from typing import Optional, Union
@@ -207,6 +208,26 @@ def mark_nodes_as_must_save_to_stage_recomputation(
207208
if node.meta.get("recompute", None) is not None:
208209
# do not mess with allgather nodes that have already been marked recompute!
209210
continue
211+
if node.target is operator.getitem:
212+
# we need to be a bit careful: we are trying to manually emulate setting "precompute" tags
213+
# in the same way that compiel does when it encounters userland SAC.
214+
#
215+
# torch.compile does this by using TorchDispatchModes to intercept ops as they are traced,
216+
# and setting their "recompute" tag.
217+
#
218+
# However, TorchDispatchModes *only* intercept OpOverloads (and HOPs)
219+
# getitem is neither, and so in vanilla torch.compile usage,
220+
# getitem nodes recieve no tags.
221+
#
222+
# What happens if we blindly set all nodes to PREFER_RECOMPUTE? Example bad outcome:
223+
# - user is using attention, so we see this series of ops in the joint graph:
224+
# attention_fw -> getitem -> attention_bw (the getitem is an output used for the bw)
225+
# - user runs SAC, and marks attention_fw as MUST_SAVE
226+
# - if we mark getitem as PREFER_RECOMPUTE, and attention_fw as MUST_SAVE,
227+
# the partitioner ends up generating an invalid graph.
228+
# Today the partitioner relies on the fact that getitem's recompute behavior
229+
# is implicitly determined by the recompute behavior of the multi-output op preceding it.
230+
continue
210231
node.meta["recompute"] = CheckpointPolicy.PREFER_RECOMPUTE
211232
# add an arbitrarily large graph id. I'm assuming 100000 here, which should be fine
212233
# and is the same we add for the all-gather nodes
@@ -327,7 +348,7 @@ def ac_joint_pass(graph: torch.fx.Graph, ac_stage_size_in_GiB: float = 2.0):
327348
# policy, but this is not working yet
328349
save_list = {
329350
torch.ops.aten.mm.default,
330-
# torch.ops.aten._scaled_dot_product_efficient_attention.default,
331-
# torch.ops.aten._scaled_dot_product_flash_attention.default,
351+
torch.ops.aten._scaled_dot_product_efficient_attention.default,
352+
torch.ops.aten._scaled_dot_product_flash_attention.default,
332353
}
333354
_apply_ac_policy(graph, save_list=save_list)

0 commit comments

Comments
 (0)