From f505ea83e7b7a529e9538214ea23ab9cbe46fcf4 Mon Sep 17 00:00:00 2001 From: Dan Raviv Date: Thu, 10 Jul 2025 16:23:55 -0700 Subject: [PATCH] Map multiple names of the same object to the same wrapped object Avoid creating two different wrappers for the same object in wrap_namespace. Generally cleaner, but more importantly towards better ufuncs support, which will require mapping from ufuncs to their wrapped version. Another minor pro is that this theoretically avoids needing to define grads for two names of the same object, e.g., we could define the VJP for anp.conj but not for anp.conjugate since they are actually the same wrapped object. However, it's minor and theoretical because whether or not two "equivalent" numpy functions/ufuncs are the same object is apparently an implementation detail, e.g., np.amax is documented as an alias of np.max, but they actually aren't the same object, and the equivalent ufuncs np.degrees and np.rad2deg are two different ones (tested on numpy 2.3.1). Because this might change between different numpy version, it's more robust to keep the grad definitions for both names. --- autograd/numpy/numpy_wrapper.py | 30 ++++++++++++++++++++++-------- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/autograd/numpy/numpy_wrapper.py b/autograd/numpy/numpy_wrapper.py index ab97f214..1f7fe49c 100644 --- a/autograd/numpy/numpy_wrapper.py +++ b/autograd/numpy/numpy_wrapper.py @@ -25,15 +25,29 @@ class IntdtypeSubclass(cls): def wrap_namespace(old, new): unchanged_types = {float, int, type(None), type} int_types = {_np.int8, _np.int16, _np.int32, _np.int64, _np.integer} + obj_to_wrapped = [] for name, obj in old.items(): - if obj in notrace_functions: - new[name] = notrace_primitive(obj) - elif callable(obj) and type(obj) is not type: - new[name] = primitive(obj) - elif type(obj) is type and obj in int_types: - new[name] = wrap_intdtype(obj) - elif type(obj) in unchanged_types: - new[name] = obj + # Map multiple names of the same object (e.g. conj/conjugate) + # to the same wrapped object + for mapped_obj, wrapped in obj_to_wrapped: + if mapped_obj is obj: + new[name] = wrapped + break + else: + if obj in notrace_functions: + wrapped = notrace_primitive(obj) + new[name] = wrapped + obj_to_wrapped.append((obj, wrapped)) + elif callable(obj) and type(obj) is not type: + wrapped = primitive(obj) + new[name] = wrapped + obj_to_wrapped.append((obj, wrapped)) + elif type(obj) is type and obj in int_types: + wrapped = wrap_intdtype(obj) + new[name] = wrapped + obj_to_wrapped.append((obj, wrapped)) + elif type(obj) in unchanged_types: + new[name] = obj wrap_namespace(_np.__dict__, globals())