Skip to content

support numpy.take #743

@jcmgray

Description

@jcmgray

Currently tracing the gradient through take is not supported:

import numpy as np
import autograd as ag
import autograd.numpy as anp

rng = np.random.default_rng(42)
x = rng.uniform(size=(3, 4, 5))
idx = rng.integers(0, 4, size=(6,))

def foo(x, idx):
    # # works:
    # return x[:, idx, :].sum()
    # # doesn't work:
    return anp.take(x, idx, axis=1).sum()

gfoo = ag.grad(foo, argnum=0)
gfoo(x, rng.integers(0, 4, size=(6,))).shape
# NotImplementedError: VJP of take wrt argnums (0,) not defined

take is mostly just a convenient syntax/subset of getitem indexing, so I suppose there is nothing fundamentally blocking it.

Metadata

Metadata

Assignees

No one assigned

    Labels

    PR welcomeContributions towards resolving this issue are welcomegood first issueThis is an ideal fix for someone who is new to the repository and looking to contribute!

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions