Skip to content

Primitive called with ArrayBox object on forward pass when argument passed by keyword #681

@andre-al

Description

@andre-al

For context, I am working on a project where I need to use functions outside autograd which do not support autograd's ArrayBox arguments, but have easy enough derivatives that I can implement them directly and extend autograd.

When doing so, I was running into intermittent issues (the worst kind of issue) where these functions were sometimes still being called with ArrayBox arguments despite being defined as primitives, and when that happened an error was raised.

After some tinkering, I finally traced the core issue: when a primitive is called with keyword arguments, the forward pass of the function is done with the ArrayBox version of the argument regardless.

I include a minimal working example below

from autograd import grad
from autograd.extend import primitive, defvjp

@primitive
def f(x):
  print(f'x={x}')
  return x**2
def f_vjp(ans, x):
  return lambda g: 2*g*x
defvjp(f, f_vjp)

def g(y):
  return f(y)

def h(y):
  return f(x=y)

dg = grad(g)
print(f'dg={dg(1.)}')

dh = grad(h)
print(f'dg={dh(1.)}')

This code outputs

x=1.0
dg=2.0
x=Autograd ArrayBox with value 1.0
dh=2.0

In words, f(x) is declared a primitive. All it does is print(x) and return x**2.
Both g and h are direct wrappers around f, but while g calls f with a positional x argument, h does so by keyword.

When evaluating the grad(g) at a point, the forward pass prints the argument when it evaluates the primitive f, correctly outputing x=1.0, "unboxed" with a simple float type.

When evaluating grad(h) instead, the forward pass prints that x is an ArrayBox with value 1.0. While this is not be problematic here, it would be if the reason it was made a primitive was incompatibility with Boxed arguments.

I haven't checked the source code to see if there would be a simple way to fix this behavior, but even if not at the very least a well-described warning would be very useful for others not to spend as much time as I did puzzled by what could be causing errors like the ones I was having.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions