Skip to content

Commit 02a9831

Browse files
committed
added note
1 parent 755f14b commit 02a9831

File tree

1 file changed

+1
-0
lines changed

1 file changed

+1
-0
lines changed

src/pytti/tensor_tools.py

+1
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def format_module(module, dest, *args, **kwargs) -> torch.tensor:
6565
return format_input(output, module, dest)
6666

6767

68+
# https://pytorch.org/docs/stable/autograd.html#function
6869
class ReplaceGrad(torch.autograd.Function):
6970
"""
7071
returns x_forward during forward pass, but evaluates derivates as though

0 commit comments

Comments
 (0)