Skip to content

Commit a876432

Browse files
soulitzerpytorchmergebot
authored andcommitted
Expose torch._will_engine_execute_node (pytorch#84773)
Addresses: pytorch#83617 This PR a way to query the TLS graph task's exec_info which is a map mapping the Node to a bool indicating whether it will be executed in the current backward pass (as determined by the inputs= argument for .grad of .backward). - this works with both custom Function nodes and normal codegened nodes - to be able to verify whether the pyobject passed is an actual node, we now store pointers to PyTypeObjects into a set on registration. - error out when .backward without inputs= to avoid silently returning True Alternatives: - not sure if it is possible to bind to Python from a raw pointer to Node. At least we wouldn't be able to use existing logic, and the Python object should only hold a weak reference to the Node. - other solutions to the motivating issue seem to require more extensive modification to the engine See the issue linked for an example of usage Pull Request resolved: pytorch#84773 Approved by: https://github.com/albanD
1 parent 8dd4542 commit a876432

File tree

6 files changed

+155
-6
lines changed

6 files changed

+155
-6
lines changed

test/test_autograd.py

+73
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,79 @@ def test_not_implemented_fwad(self):
394394
# if forward AD ends up being implemented for torch.igamma, choose a different op
395395
torch.igamma(dual_x, dual_x)
396396

397+
def test_will_engine_execute_node(self):
398+
counter = [0]
399+
400+
class MyFunction(Function):
401+
@staticmethod
402+
def forward(ctx, x):
403+
return x * 2
404+
405+
@staticmethod
406+
def backward(ctx, gO):
407+
return gO * 2
408+
409+
def get_grad_fn(t):
410+
if t.requires_grad and t.grad_fn is None:
411+
return t.clone().grad_fn.next_functions[0][0]
412+
else:
413+
return t.grad_fn
414+
415+
a = torch.randn(2, 3, 4, requires_grad=True)
416+
a2 = torch.randn(2, 3, 4, requires_grad=True)
417+
b = a * a2
418+
b2 = b.cos()
419+
c = MyFunction.apply(b)
420+
421+
should_execute = list(map(get_grad_fn, (a, b, c)))
422+
should_not_execute = list(map(get_grad_fn, (a2, b2)))
423+
424+
def fn(x):
425+
counter[0] += 1
426+
427+
for g in should_execute:
428+
self.assertTrue(torch._C._will_engine_execute_node(g))
429+
430+
for g in should_not_execute:
431+
self.assertFalse(torch._C._will_engine_execute_node(g))
432+
433+
b.register_hook(fn)
434+
c.register_hook(fn)
435+
436+
# .backward(inputs=) is OK
437+
out = c.sum()
438+
torch.autograd.backward(out, inputs=(a,), retain_graph=True)
439+
self.assertEqual(counter[0], 2)
440+
441+
# .backward() is OK
442+
should_execute = list(map(get_grad_fn, (a, a2, b, c)))
443+
should_not_execute = list(map(get_grad_fn, (b2,)))
444+
torch.autograd.backward(out, retain_graph=True)
445+
446+
# .grad is NOT OK when leaf is passed (this is the current state, subject to change)
447+
with self.assertRaisesRegex(RuntimeError, "are currently running autograd.grad()"):
448+
torch.autograd.grad(out, (a,))
449+
450+
# .grad is OK when non-leaf is passed
451+
a = torch.randn(1, 2, 3, requires_grad=True) * 2
452+
b = a * 2
453+
454+
def fn(x):
455+
# Check a non-leaf
456+
counter[0] += 1
457+
self.assertTrue(torch._C._will_engine_execute_node(b.grad_fn))
458+
b.register_hook(fn)
459+
counter[0] = 0
460+
torch.autograd.grad(b.sum(), (a,))
461+
self.assertEqual(counter[0], 1)
462+
463+
# Verify other errors are raised
464+
with self.assertRaisesRegex(RuntimeError, "during the backward pass"):
465+
torch._C._will_engine_execute_node(out.grad_fn)
466+
467+
with self.assertRaisesRegex(RuntimeError, "expects an grad_fn"):
468+
torch._C._will_engine_execute_node(out)
469+
397470
def test_accumulate_grad(self):
398471
grad_output = torch.ones(5, 5)
399472

torch/csrc/Module.cpp

+54
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,10 @@
3838
#include <torch/csrc/THP.h>
3939
#include <torch/csrc/TypeInfo.h>
4040
#include <torch/csrc/api/include/torch/python/init.h>
41+
#include <torch/csrc/autograd/python_cpp_function.h>
4142
#include <torch/csrc/autograd/python_enum_tag.h>
4243
#include <torch/csrc/autograd/python_fft_functions.h>
44+
#include <torch/csrc/autograd/python_function.h>
4345
#include <torch/csrc/autograd/python_legacy_variable.h>
4446
#include <torch/csrc/autograd/python_linalg_functions.h>
4547
#include <torch/csrc/autograd/python_nested_functions.h>
@@ -708,6 +710,54 @@ PyObject* THPModule_isEnabledXNNPACK(PyObject* _unused, PyObject* noargs) {
708710
Py_RETURN_FALSE;
709711
}
710712

713+
PyObject* THPModule_willEngineExecuteNode(PyObject* _unused, PyObject* arg) {
714+
HANDLE_TH_ERRORS
715+
bool isTHPFunction = THPFunction_Check(arg);
716+
bool isTHPCppFunction = torch::autograd::THPCppFunction_Check(arg);
717+
THPUtils_assert(
718+
isTHPFunction || isTHPCppFunction,
719+
"_will_engine_execute_node expects an grad_fn, "
720+
"but got %s",
721+
THPUtils_typename(arg));
722+
const auto exec_info = torch::autograd::get_current_graph_task_exec_info();
723+
THPUtils_assert(
724+
exec_info,
725+
"_get_should_execute_nodes should only be called during the backward pass");
726+
torch::autograd::Node* node;
727+
std::shared_ptr<torch::autograd::Node> node_sp;
728+
if (isTHPFunction) {
729+
node_sp = ((THPFunction*)arg)->cdata.lock();
730+
node = node_sp.get();
731+
} else {
732+
node = ((torch::autograd::THPCppFunction*)arg)->cdata.get();
733+
}
734+
if (exec_info->empty()) {
735+
// .backward() without inputs= arg
736+
const auto nodes_in_graph =
737+
torch::autograd::get_current_graph_task_nodes_in_graph();
738+
auto it = nodes_in_graph->find(node);
739+
if (it == nodes_in_graph->end()) {
740+
Py_RETURN_FALSE;
741+
} else {
742+
Py_RETURN_TRUE;
743+
}
744+
} else {
745+
// .grad or .backward when inputs= is passed
746+
auto it = exec_info->find(node);
747+
if (it == exec_info->end() || !it->second.should_execute()) {
748+
Py_RETURN_FALSE;
749+
} else {
750+
THPUtils_assert(
751+
!(node->topological_nr() == 0 && it->second.captures_),
752+
"A leaf node was passed to _will_engine_execute_node but we are "
753+
"currently running autograd.grad(). This is currently not supported.");
754+
Py_RETURN_TRUE;
755+
}
756+
}
757+
758+
END_HANDLE_TH_ERRORS
759+
}
760+
711761
PyObject* THPModule_setDefaultMobileCPUAllocator(
712762
PyObject* _unused,
713763
PyObject* noargs) {
@@ -888,6 +938,10 @@ static PyMethodDef TorchMethods[] = {
888938
{"_set_qengine", THPModule_setQEngine, METH_O, nullptr},
889939
{"_supported_qengines", THPModule_supportedQEngines, METH_NOARGS, nullptr},
890940
{"_is_xnnpack_enabled", THPModule_isEnabledXNNPACK, METH_NOARGS, nullptr},
941+
{"_will_engine_execute_node",
942+
THPModule_willEngineExecuteNode,
943+
METH_O,
944+
nullptr},
891945
{"_set_default_mobile_cpu_allocator",
892946
THPModule_setDefaultMobileCPUAllocator,
893947
METH_NOARGS,

torch/csrc/autograd/engine.cpp

+5-2
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,10 @@ get_current_graph_task_exec_info() {
381381
return current_graph_task ? &current_graph_task->exec_info_ : nullptr;
382382
}
383383

384+
const std::unordered_set<Node*>* get_current_graph_task_nodes_in_graph() {
385+
return current_graph_task ? &current_graph_task->nodes_in_graph_ : nullptr;
386+
}
387+
384388
bool get_current_graph_task_keep_graph() {
385389
return current_graph_task ? current_graph_task->keep_graph_ : true;
386390
}
@@ -1006,7 +1010,6 @@ auto Engine::compute_dependencies(
10061010
GraphTask& task,
10071011
uint64_t min_topo_nr) -> void {
10081012
// Computes the number of dependencies for each function which requires grad
1009-
std::unordered_set<Node*> seen;
10101013
std::vector<Node*> queue{root};
10111014
bool might_use_cuda = at::globalContext().hasCUDA();
10121015
bool will_use_cuda = false;
@@ -1026,7 +1029,7 @@ auto Engine::compute_dependencies(
10261029
for (const auto& edge : fn->next_edges()) {
10271030
if (auto next_ptr = edge.function.get()) {
10281031
dependencies[next_ptr] += 1;
1029-
const bool was_inserted = seen.insert(next_ptr).second;
1032+
const bool was_inserted = task.nodes_in_graph_.insert(next_ptr).second;
10301033
if (was_inserted)
10311034
queue.push_back(next_ptr);
10321035
}

torch/csrc/autograd/graph_task.h

+4
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ struct GraphTask : std::enable_shared_from_this<GraphTask> {
3131
std::unordered_map<Node*, InputBuffer> not_ready_;
3232
std::unordered_map<Node*, int> dependencies_;
3333

34+
// Records the nodes that are in the graph
35+
std::unordered_set<Node*> nodes_in_graph_;
3436
// Note [Exec info]
3537
// Exec info is created for each GraphTask, which allows filtering paths on
3638
// the graph that are not needed. It has a bit complicated semantics. If it's
@@ -186,6 +188,8 @@ class GraphTaskGuard {
186188

187189
TORCH_API const std::unordered_map<Node*, GraphTask::ExecInfo>*
188190
get_current_graph_task_exec_info();
191+
TORCH_API const std::unordered_set<Node*>*
192+
get_current_graph_task_nodes_in_graph();
189193
TORCH_API bool get_current_graph_task_keep_graph();
190194
void add_node_to_current_graph_task_exec_info(Node* fn);
191195

torch/csrc/autograd/python_cpp_function.cpp

+17-4
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,8 @@ PyTypeObject* _initFunctionPyTypeObject(
203203
return &type;
204204
}
205205

206-
static std::unordered_map<std::type_index, THPObjectPtr> cpp_function_types;
206+
static std::unordered_map<std::type_index, THPObjectPtr> cpp_function_types_map;
207+
static std::unordered_set<PyTypeObject*> cpp_function_types_set;
207208

208209
struct DefaultFunctionType {
209210
DefaultFunctionType() : type() {
@@ -231,10 +232,10 @@ PyObject* functionToPyObject(const std::shared_ptr<Node>& cdata) {
231232
Py_INCREF(cdata->pyobj());
232233
} else {
233234
auto& fn = *cdata;
234-
auto it = cpp_function_types.find(std::type_index(typeid(fn)));
235+
auto it = cpp_function_types_map.find(std::type_index(typeid(fn)));
235236
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
236237
PyTypeObject* type;
237-
if (it == cpp_function_types.end()) {
238+
if (it == cpp_function_types_map.end()) {
238239
type = &default_type.type;
239240
} else {
240241
type = (PyTypeObject*)it->second.get();
@@ -255,7 +256,9 @@ PyObject* functionToPyObject(const std::shared_ptr<Node>& cdata) {
255256

256257
void registerCppFunction(const std::type_info& type, PyTypeObject* pytype) {
257258
Py_INCREF((PyObject*)pytype);
258-
cpp_function_types[std::type_index(type)] = THPObjectPtr((PyObject*)pytype);
259+
cpp_function_types_map[std::type_index(type)] =
260+
THPObjectPtr((PyObject*)pytype);
261+
cpp_function_types_set.insert(pytype);
259262
}
260263

261264
PyObject* registerFunctionHook(Node& fn, PyObject* hook) {
@@ -317,5 +320,15 @@ PyObject* registerFunctionPreHook(Node& fn, PyObject* hook) {
317320
return handle;
318321
}
319322

323+
bool THPCppFunction_Check(PyObject* obj) {
324+
THPObjectPtr type = THPObjectPtr(PyObject_Type(obj));
325+
if (cpp_function_types_set.find((PyTypeObject*)type.get()) ==
326+
cpp_function_types_set.end()) {
327+
return false;
328+
} else {
329+
return true;
330+
}
331+
}
332+
320333
} // namespace autograd
321334
} // namespace torch

torch/csrc/autograd/python_cpp_function.h

+2
Original file line numberDiff line numberDiff line change
@@ -96,5 +96,7 @@ PyTypeObject* createForwardFunctionPyTypeObject(
9696
void registerCppFunction(const std::type_info& type, PyTypeObject* pytype);
9797
PyObject* functionToPyObject(const std::shared_ptr<Node>& cdata);
9898

99+
bool THPCppFunction_Check(PyObject* obj);
100+
99101
} // namespace autograd
100102
} // namespace torch

0 commit comments

Comments
 (0)