|
38 | 38 | #include <torch/csrc/THP.h>
|
39 | 39 | #include <torch/csrc/TypeInfo.h>
|
40 | 40 | #include <torch/csrc/api/include/torch/python/init.h>
|
| 41 | +#include <torch/csrc/autograd/python_cpp_function.h> |
41 | 42 | #include <torch/csrc/autograd/python_enum_tag.h>
|
42 | 43 | #include <torch/csrc/autograd/python_fft_functions.h>
|
| 44 | +#include <torch/csrc/autograd/python_function.h> |
43 | 45 | #include <torch/csrc/autograd/python_legacy_variable.h>
|
44 | 46 | #include <torch/csrc/autograd/python_linalg_functions.h>
|
45 | 47 | #include <torch/csrc/autograd/python_nested_functions.h>
|
@@ -708,6 +710,54 @@ PyObject* THPModule_isEnabledXNNPACK(PyObject* _unused, PyObject* noargs) {
|
708 | 710 | Py_RETURN_FALSE;
|
709 | 711 | }
|
710 | 712 |
|
| 713 | +PyObject* THPModule_willEngineExecuteNode(PyObject* _unused, PyObject* arg) { |
| 714 | + HANDLE_TH_ERRORS |
| 715 | + bool isTHPFunction = THPFunction_Check(arg); |
| 716 | + bool isTHPCppFunction = torch::autograd::THPCppFunction_Check(arg); |
| 717 | + THPUtils_assert( |
| 718 | + isTHPFunction || isTHPCppFunction, |
| 719 | + "_will_engine_execute_node expects an grad_fn, " |
| 720 | + "but got %s", |
| 721 | + THPUtils_typename(arg)); |
| 722 | + const auto exec_info = torch::autograd::get_current_graph_task_exec_info(); |
| 723 | + THPUtils_assert( |
| 724 | + exec_info, |
| 725 | + "_get_should_execute_nodes should only be called during the backward pass"); |
| 726 | + torch::autograd::Node* node; |
| 727 | + std::shared_ptr<torch::autograd::Node> node_sp; |
| 728 | + if (isTHPFunction) { |
| 729 | + node_sp = ((THPFunction*)arg)->cdata.lock(); |
| 730 | + node = node_sp.get(); |
| 731 | + } else { |
| 732 | + node = ((torch::autograd::THPCppFunction*)arg)->cdata.get(); |
| 733 | + } |
| 734 | + if (exec_info->empty()) { |
| 735 | + // .backward() without inputs= arg |
| 736 | + const auto nodes_in_graph = |
| 737 | + torch::autograd::get_current_graph_task_nodes_in_graph(); |
| 738 | + auto it = nodes_in_graph->find(node); |
| 739 | + if (it == nodes_in_graph->end()) { |
| 740 | + Py_RETURN_FALSE; |
| 741 | + } else { |
| 742 | + Py_RETURN_TRUE; |
| 743 | + } |
| 744 | + } else { |
| 745 | + // .grad or .backward when inputs= is passed |
| 746 | + auto it = exec_info->find(node); |
| 747 | + if (it == exec_info->end() || !it->second.should_execute()) { |
| 748 | + Py_RETURN_FALSE; |
| 749 | + } else { |
| 750 | + THPUtils_assert( |
| 751 | + !(node->topological_nr() == 0 && it->second.captures_), |
| 752 | + "A leaf node was passed to _will_engine_execute_node but we are " |
| 753 | + "currently running autograd.grad(). This is currently not supported."); |
| 754 | + Py_RETURN_TRUE; |
| 755 | + } |
| 756 | + } |
| 757 | + |
| 758 | + END_HANDLE_TH_ERRORS |
| 759 | +} |
| 760 | + |
711 | 761 | PyObject* THPModule_setDefaultMobileCPUAllocator(
|
712 | 762 | PyObject* _unused,
|
713 | 763 | PyObject* noargs) {
|
@@ -888,6 +938,10 @@ static PyMethodDef TorchMethods[] = {
|
888 | 938 | {"_set_qengine", THPModule_setQEngine, METH_O, nullptr},
|
889 | 939 | {"_supported_qengines", THPModule_supportedQEngines, METH_NOARGS, nullptr},
|
890 | 940 | {"_is_xnnpack_enabled", THPModule_isEnabledXNNPACK, METH_NOARGS, nullptr},
|
| 941 | + {"_will_engine_execute_node", |
| 942 | + THPModule_willEngineExecuteNode, |
| 943 | + METH_O, |
| 944 | + nullptr}, |
891 | 945 | {"_set_default_mobile_cpu_allocator",
|
892 | 946 | THPModule_setDefaultMobileCPUAllocator,
|
893 | 947 | METH_NOARGS,
|
|
0 commit comments