Skip to content

Commit 98266c8

Browse files
Intermediate changes again
1 parent f0460f4 commit 98266c8

File tree

4 files changed

+18
-16
lines changed

4 files changed

+18
-16
lines changed

Diff for: pytensor/compile/function/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,7 @@ def opt_log1p(node):
312312
else:
313313
# note: pfunc will also call orig_function -- orig_function is
314314
# a choke point that all compilation must pass through
315+
315316
fn = pfunc(
316317
params=inputs,
317318
outputs=outputs,

Diff for: pytensor/compile/function/types.py

+1
Original file line numberDiff line numberDiff line change
@@ -1758,6 +1758,7 @@ def orig_function(
17581758
name=name,
17591759
fgraph=fgraph,
17601760
)
1761+
print(m)
17611762
with config.change_flags(compute_test_value="off"):
17621763
fn = m.create(defaults)
17631764
finally:

Diff for: tests/tensor/test_math.py

+15-16
Original file line numberDiff line numberDiff line change
@@ -764,10 +764,10 @@ def setup_method(self):
764764
Max.debug = 0
765765
Argmax.debug = 0
766766

767-
def test_basic(self):
767+
def test_basic_0(self):
768768
# dbt: for some reason, Argmax does not work when I pass: n = as_tensor_variable(5.0)
769-
n = as_tensor_variable(5.0)
770-
v, i = eval_outputs(max_and_argmax(n))
769+
n = as_tensor_variable(5)
770+
v, i = eval_outputs(max_and_argmax(n, axis=()))
771771
assert v == 5.0
772772
assert i == 0
773773
assert i.dtype == "int64"
@@ -809,11 +809,7 @@ def test_basic_2(self, axis, np_axis):
809809
v_shape, i_shape = eval_outputs([vt.shape, it.shape])
810810
assert tuple(v_shape) == vt.type.shape
811811
assert tuple(i_shape) == it.type.shape
812-
# Test values
813-
v, i = eval_outputs([vt, it])
814-
assert i.dtype == "int64"
815-
assert np.all(v == np_max)
816-
assert np.all(i == np_argm)
812+
# Test valuesgi
817813

818814
@pytest.mark.parametrize(
819815
"axis,np_axis",
@@ -1372,27 +1368,30 @@ def _grad_list(self):
13721368
data = random(2, 3)
13731369
for fct in [max_and_argmax, max, min]:
13741370
utt.verify_grad(lambda v: fct(v, axis=[0, 1]), [data])
1375-
# n = as_tensor_variable(data)
1376-
# check_grad_max(data, eval_outputs(grad(max_and_argmax(n,
1377-
# axis=1)[0], n)),axis=1)
1371+
n = as_tensor_variable(data)
1372+
check_grad_max(
1373+
data, eval_outputs(grad(max_and_argmax(n, axis=1)[0], n)), axis=1
1374+
)
13781375

13791376
def test_uint(self):
13801377
for dtype in ("uint8", "uint16", "uint32", "uint64"):
13811378
itype = np.iinfo(dtype)
13821379
data = np.array([itype.min + 3, itype.min, itype.max - 5, itype.max], dtype)
13831380
n = as_tensor_variable(data)
13841381
assert min(n).dtype == dtype
1382+
# print(min(n).owner.inputs[1].acc_dtype)
13851383
i = eval_outputs(min(n))
13861384
# pytensor.dprint(n)
1387-
for x in n:
1388-
print(x.eval())
1385+
# for x in n:
1386+
# print(x.eval())
13891387
print(i)
13901388
print(itype.min)
13911389
print()
13921390
assert i == itype.min
1393-
assert max(n).dtype == dtype
1394-
i = eval_outputs(max(n))
1395-
assert i == itype.max
1391+
# assert max(n).dtype == dtype
1392+
# i = eval_outputs(max(n))
1393+
# assert i == itype.max
1394+
# assert 0
13961395

13971396
def test_bool(self):
13981397
data = np.array([True, False], "bool")

Diff for: tests/tensor/utils.py

+1
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ def inplace_func(
107107
def eval_outputs(outputs, ops=(), mode=None):
108108
f = inplace_func([], outputs, mode=mode)
109109
variables = f()
110+
110111
if ops:
111112
assert any(isinstance(node.op, ops) for node in f.maker.fgraph.apply_nodes)
112113
if isinstance(variables, tuple | list) and len(variables) == 1:

0 commit comments

Comments
 (0)