-
Notifications
You must be signed in to change notification settings - Fork 934
Open
Labels
PR welcomeContributions towards resolving this issue are welcomeContributions towards resolving this issue are welcomegood first issueThis is an ideal fix for someone who is new to the repository and looking to contribute!This is an ideal fix for someone who is new to the repository and looking to contribute!
Description
Currently tracing the gradient through take is not supported:
import numpy as np
import autograd as ag
import autograd.numpy as anp
rng = np.random.default_rng(42)
x = rng.uniform(size=(3, 4, 5))
idx = rng.integers(0, 4, size=(6,))
def foo(x, idx):
# # works:
# return x[:, idx, :].sum()
# # doesn't work:
return anp.take(x, idx, axis=1).sum()
gfoo = ag.grad(foo, argnum=0)
gfoo(x, rng.integers(0, 4, size=(6,))).shape
# NotImplementedError: VJP of take wrt argnums (0,) not definedtake is mostly just a convenient syntax/subset of getitem indexing, so I suppose there is nothing fundamentally blocking it.
Metadata
Metadata
Assignees
Labels
PR welcomeContributions towards resolving this issue are welcomeContributions towards resolving this issue are welcomegood first issueThis is an ideal fix for someone who is new to the repository and looking to contribute!This is an ideal fix for someone who is new to the repository and looking to contribute!