We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 755f14b commit 02a9831Copy full SHA for 02a9831
src/pytti/tensor_tools.py
@@ -65,6 +65,7 @@ def format_module(module, dest, *args, **kwargs) -> torch.tensor:
65
return format_input(output, module, dest)
66
67
68
+# https://pytorch.org/docs/stable/autograd.html#function
69
class ReplaceGrad(torch.autograd.Function):
70
"""
71
returns x_forward during forward pass, but evaluates derivates as though
0 commit comments