-
Notifications
You must be signed in to change notification settings - Fork 934
Open
Labels
bugThis item is a bug that needs to be investigated or fixedThis item is a bug that needs to be investigated or fixedhelp wantedHelp on this item is much appreciatedHelp on this item is much appreciated
Milestone
Description
The traceback is as follows:
_____________________________________________________________________________ test_dtypes _____________________________________________________________________________
[gw10] darwin -- Python 3.12.7 /Users/agriyakhetarpal/Desktop/autograd/.nox/nightly-tests/bin/python3
def test_dtypes():
def f(x):
return np.real(np.sum(x**2))
# Array y with dtype np.float32
y = np.random.randn(10, 10).astype(np.float32)
> assert grad(f)(y).dtype.type is np.float32
E AssertionError: assert <class 'numpy.float64'> is <class 'numpy.float32'>
E + where <class 'numpy.float64'> = dtype('float64').type
E + where dtype('float64') = array([[ 0.99342829, -0.2765286 , 1.29537714, 3.04605961, -0.46830675,\n -0.46827391, 3.15842557, 1.53486943...8995, -1.40410614, -0.65532428, -0.78421628,\n -2.92702985, 0.59224057, 0.52211052, 0.01022691, -0.46917427]]).dtype
E + where array([[ 0.99342829, -0.2765286 , 1.29537714, 3.04605961, -0.46830675,\n -0.46827391, 3.15842557, 1.53486943...8995, -1.40410614, -0.65532428, -0.78421628,\n -2.92702985, 0.59224057, 0.52211052, 0.01022691, -0.46917427]]) = <function unary_to_nary.<locals>.nary_operator.<locals>.nary_f at 0x131754680>(array([[ 0.49671414, -0.1382643 , 0.64768857, 1.5230298 , -0.23415338,\n -0.23413695, 1.5792128 , 0.7674347 ....32766214, -0.39210814,\n -1.4635149 , 0.2961203 , 0.26105526, 0.00511346, -0.23458713]],\n dtype=float32))
E + where <function unary_to_nary.<locals>.nary_operator.<locals>.nary_f at 0x131754680> = grad(<function test_dtypes.<locals>.f at 0x1317540e0>)
E + and <class 'numpy.float32'> = np.float32
tests/test_wrappers.py:272: AssertionErrorAlso reported in CI: https://github.com/HIPS/autograd/actions/runs/12663833691/job/35291033097
What I understand is that NumPy changed a type promotion rule recently, and I tried to check in case any of our VJP safeguards around it are missing, but I haven't been able to find the cause so far.
All types are being converted to float64, rather than being preserved. Yesterday's CI passed, so it could be numpy/numpy#27901 or numpy/numpy#28118, but I don't see how, yet.
Metadata
Metadata
Assignees
Labels
bugThis item is a bug that needs to be investigated or fixedThis item is a bug that needs to be investigated or fixedhelp wantedHelp on this item is much appreciatedHelp on this item is much appreciated