diff --git a/autograd/numpy/numpy_boxes.py b/autograd/numpy/numpy_boxes.py index b9c73963..a5fab2a7 100644 --- a/autograd/numpy/numpy_boxes.py +++ b/autograd/numpy/numpy_boxes.py @@ -4,6 +4,7 @@ from autograd.builtins import SequenceBox from autograd.extend import Box, primitive +from autograd.tracer import trace_primitives_map from . import numpy_wrapper as anp @@ -18,16 +19,40 @@ class ArrayBox(Box): def __getitem__(A, idx): return A[idx] - # Constants w.r.t float data just pass though - shape = property(lambda self: self._value.shape) - ndim = property(lambda self: self._value.ndim) - size = property(lambda self: self._value.size) - dtype = property(lambda self: self._value.dtype) + # Basic array attributes just pass through + # Single wrapped scalars are presented as 0-dim, 1-size arrays. + shape = property(lambda self: anp.shape(self._value)) + ndim = property(lambda self: anp.ndim(self._value)) + size = property(lambda self: anp.size(self._value)) + dtype = property(lambda self: anp.result_type(self._value)) + T = property(lambda self: anp.transpose(self)) def __array_namespace__(self, *, api_version: Union[str, None] = None): return anp + # Calls to wrapped ufuncs first forward further handling to the ufunc + # dispatching mechanism, which allows any other operands to also try + # handling the ufunc call. See also tracer.primitive. + # + # In addition, implementing __array_ufunc__ allows ufunc calls to propagate + # through non-differentiable array-like objects (e.g. xarray.DataArray) into + # ArrayBoxes which might be contained within, upon which __array_ufunc__ + # below would call autograd's wrapper for the ufunc. For example, given a + # DataArray `a` containing an ArrayBox, this lets us write `np.abs(a)` + # instead of requiring the xarray-specific `xr.apply_func(np.abs, a)`. + def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): + if method != "__call__": + return NotImplemented + if "out" in kwargs: + return NotImplemented + if ufunc_wrapper := trace_primitives_map.get(ufunc): + try: + return ufunc_wrapper(*inputs, called_by_autograd_dispatcher=True, **kwargs) + except NotImplementedError: + return NotImplemented + return NotImplemented + def __len__(self): return len(self._value) diff --git a/autograd/numpy/numpy_wrapper.py b/autograd/numpy/numpy_wrapper.py index ab97f214..1f7fe49c 100644 --- a/autograd/numpy/numpy_wrapper.py +++ b/autograd/numpy/numpy_wrapper.py @@ -25,15 +25,29 @@ class IntdtypeSubclass(cls): def wrap_namespace(old, new): unchanged_types = {float, int, type(None), type} int_types = {_np.int8, _np.int16, _np.int32, _np.int64, _np.integer} + obj_to_wrapped = [] for name, obj in old.items(): - if obj in notrace_functions: - new[name] = notrace_primitive(obj) - elif callable(obj) and type(obj) is not type: - new[name] = primitive(obj) - elif type(obj) is type and obj in int_types: - new[name] = wrap_intdtype(obj) - elif type(obj) in unchanged_types: - new[name] = obj + # Map multiple names of the same object (e.g. conj/conjugate) + # to the same wrapped object + for mapped_obj, wrapped in obj_to_wrapped: + if mapped_obj is obj: + new[name] = wrapped + break + else: + if obj in notrace_functions: + wrapped = notrace_primitive(obj) + new[name] = wrapped + obj_to_wrapped.append((obj, wrapped)) + elif callable(obj) and type(obj) is not type: + wrapped = primitive(obj) + new[name] = wrapped + obj_to_wrapped.append((obj, wrapped)) + elif type(obj) is type and obj in int_types: + wrapped = wrap_intdtype(obj) + new[name] = wrapped + obj_to_wrapped.append((obj, wrapped)) + elif type(obj) in unchanged_types: + new[name] = obj wrap_namespace(_np.__dict__, globals()) diff --git a/autograd/tracer.py b/autograd/tracer.py index 30fa72a6..a42d66bc 100644 --- a/autograd/tracer.py +++ b/autograd/tracer.py @@ -2,6 +2,10 @@ from collections import defaultdict from contextlib import contextmanager +import numpy as np + +import autograd + from .util import subvals, toposort from .wrap_util import wraps @@ -33,28 +37,61 @@ def new_root(cls, *args, **kwargs): return root +trace_primitives_map = {} + + def primitive(f_raw): """ Wraps a function so that its gradient can be specified and its invocation can be recorded. For examples, see the docs.""" @wraps(f_raw) - def f_wrapped(*args, **kwargs): + def f_wrapped(*args, called_by_autograd_dispatcher=False, **kwargs): boxed_args, trace, node_constructor = find_top_boxed_args(args) if boxed_args: + # If we are a wrapper around a ufunc, first forward further handling to + # the ufunc dispatching mechanism (if we aren't already running inside it) + # by calling the ufunc. This allows other operands to also try to handle + # the call (it's still possible our handling attempt below will get the + # first shot; the handlers order is determined by the dispatch mechanism). + # + # For example, consider multiplying an ndarray wrapped inside an ArrayBox + # by an xarray.DataArray. The handling below will fail: The ndarray will + # be unboxed and multiplied by the DataArray resulting in a DataArray, + # for which `new_box` will raise an exception. In contrast, the DataArray's + # handling of the call might succeed: it might contain an ndarray, either + # plain or boxed in an ArrayBox, in which case it will be multiplied by + # the other ArrayBox yielding a new ArrayBox, which will be stored in a new + # DataArray. + if ( + isinstance(f_raw, np.ufunc) + and not called_by_autograd_dispatcher + and any(isinstance(arg, autograd.numpy.numpy_boxes.ArrayBox) for arg in args) + ): + return f_raw(*args, **kwargs) + argvals = subvals(args, [(argnum, box._value) for argnum, box in boxed_args]) if f_wrapped in notrace_primitives[node_constructor]: - return f_wrapped(*argvals, **kwargs) + return f_wrapped( + *argvals, called_by_autograd_dispatcher=called_by_autograd_dispatcher, **kwargs + ) parents = tuple(box._node for _, box in boxed_args) argnums = tuple(argnum for argnum, _ in boxed_args) - ans = f_wrapped(*argvals, **kwargs) + ans = f_wrapped(*argvals, called_by_autograd_dispatcher=called_by_autograd_dispatcher, **kwargs) node = node_constructor(ans, f_wrapped, argvals, kwargs, argnums, parents) - return new_box(ans, trace, node) + try: + box = new_box(ans, trace, node) + return box + except: + if called_by_autograd_dispatcher: + raise NotImplementedError + raise else: return f_raw(*args, **kwargs) f_wrapped.fun = f_raw f_wrapped._is_autograd_primitive = True + trace_primitives_map[f_raw] = f_wrapped return f_wrapped