Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,9 @@ jobs:
run: |

if [[ $OS == "macos-15" ]]; then
micromamba install --yes -q "python~=${PYTHON_VERSION}" numpy scipy "numba>=0.63" pip graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock pytest-sphinx libblas=*=*accelerate;
micromamba install --yes -q "python~=${PYTHON_VERSION}" numpy "scipy<1.17.0" "numba>=0.63" pip graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock pytest-sphinx libblas=*=*accelerate;
else
micromamba install --yes -q "python~=${PYTHON_VERSION}" numpy scipy "numba>=0.63" pip graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock pytest-sphinx mkl mkl-service;
micromamba install --yes -q "python~=${PYTHON_VERSION}" numpy "scipy<1.17.0" "numba>=0.63" pip graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock pytest-sphinx mkl mkl-service;
fi
if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" jax jaxlib numpyro equinox && pip install tfp-nightly; fi
if [[ $INSTALL_TORCH == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" pytorch pytorch-cuda=12.1 "mkl<=2024.0" -c pytorch -c nvidia; fi
Expand Down
6 changes: 3 additions & 3 deletions doc/extending/creating_an_op.rst
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,7 @@ If both outputs are disconnected PyTensor will not bother calling the :meth:`L_o

from pytensor.graph.op import Op
from pytensor.graph.basic import Apply
from pytensor.gradient import DisconnectedType
from pytensor.gradient import DisconnectedType, disconnected_type

class TransposeAndSumOp(Op):
__props__ = ()
Expand Down Expand Up @@ -539,13 +539,13 @@ If both outputs are disconnected PyTensor will not bother calling the :meth:`L_o
out1_grad, out2_grad = output_grads

if isinstance(out1_grad.type, DisconnectedType):
x_grad = DisconnectedType()()
x_grad = disconnected_type()
else:
# Transpose the last two dimensions of the output gradient
x_grad = pt.swapaxes(out1_grad, -1, -2)

if isinstance(out2_grad.type, DisconnectedType):
y_grad = DisconnectedType()()
y_grad = disconnected_type()
else:
# Broadcast the output gradient to the same shape as y
y_grad = pt.broadcast_to(pt.expand_dims(out2_grad, -1), y.shape)
Expand Down
4 changes: 2 additions & 2 deletions pytensor/breakpoint.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np

from pytensor.gradient import DisconnectedType
from pytensor.gradient import disconnected_type
from pytensor.graph.basic import Apply, Variable
from pytensor.graph.op import Op
from pytensor.tensor.basic import as_tensor_variable
Expand Down Expand Up @@ -142,7 +142,7 @@ def perform(self, node, inputs, output_storage):
output_storage[i][0] = inputs[i + 1]

def grad(self, inputs, output_gradients):
return [DisconnectedType()(), *output_gradients]
return [disconnected_type(), *output_gradients]

def infer_shape(self, fgraph, inputs, input_shapes):
# Return the shape of every input but the condition (first input)
Expand Down
7 changes: 5 additions & 2 deletions pytensor/raise_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from textwrap import indent

from pytensor.gradient import DisconnectedType
from pytensor.gradient import disconnected_type
from pytensor.graph.basic import Apply, Constant, Variable
from pytensor.graph.replace import _vectorize_node
from pytensor.link.c.op import COp
Expand Down Expand Up @@ -89,7 +89,10 @@ def perform(self, node, inputs, outputs):
raise self.exc_type(self.msg)

def grad(self, input, output_gradients):
return output_gradients + [DisconnectedType()()] * (len(input) - 1)
return [
*output_gradients,
*(disconnected_type() for _ in range(len(input) - 1)),
]

def connection_pattern(self, node):
return [[1]] + [[0]] * (len(node.inputs) - 1)
Expand Down
6 changes: 3 additions & 3 deletions pytensor/scalar/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import pytensor
from pytensor import printing
from pytensor.configdefaults import config
from pytensor.gradient import DisconnectedType, grad_undefined
from pytensor.gradient import disconnected_type, grad_undefined
from pytensor.graph.basic import Apply, Constant, Variable, clone
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import HasInnerGraph
Expand Down Expand Up @@ -2426,13 +2426,13 @@ def grad(self, inputs, gout):
(gz,) = gout
if y.type in continuous_types:
# x is disconnected because the elements of x are not used
return DisconnectedType()(), gz
return disconnected_type(), gz
else:
# when y is discrete, we assume the function can be extended
# to deal with real-valued inputs by rounding them to the
# nearest integer. f(x+eps) thus equals f(x) so the gradient
# is zero, not disconnected or undefined
return DisconnectedType()(), y.zeros_like(dtype=config.floatX)
return disconnected_type(), y.zeros_like(dtype=config.floatX)


second = Second(name="second")
Expand Down
23 changes: 15 additions & 8 deletions pytensor/scan/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,14 @@
from pytensor.compile.mode import Mode, get_mode
from pytensor.compile.profiling import register_profiler_printer
from pytensor.configdefaults import config
from pytensor.gradient import DisconnectedType, NullType, Rop, grad, grad_undefined
from pytensor.gradient import (
DisconnectedType,
NullType,
Rop,
disconnected_type,
grad,
grad_undefined,
)
from pytensor.graph.basic import (
Apply,
Variable,
Expand Down Expand Up @@ -3073,7 +3080,7 @@ def compute_all_gradients(known_grads):
)
outputs = local_op(*outer_inputs, return_list=True)
# Re-order the gradients correctly
gradients = [DisconnectedType()()]
gradients = [disconnected_type()] # n_steps is disconnected

offset = info.n_mit_mot + info.n_mit_sot + info.n_sit_sot + n_sitsot_outs
for p, (x, t) in enumerate(
Expand All @@ -3098,7 +3105,7 @@ def compute_all_gradients(known_grads):
else:
gradients.append(x[::-1])
elif t == "disconnected":
gradients.append(DisconnectedType()())
gradients.append(disconnected_type())
elif t == "through_untraced":
gradients.append(
grad_undefined(
Expand Down Expand Up @@ -3126,7 +3133,7 @@ def compute_all_gradients(known_grads):
else:
gradients.append(x[::-1])
elif t == "disconnected":
gradients.append(DisconnectedType()())
gradients.append(disconnected_type())
elif t == "through_untraced":
gradients.append(
grad_undefined(
Expand All @@ -3149,15 +3156,15 @@ def compute_all_gradients(known_grads):
if not isinstance(dC_dout.type, DisconnectedType) and connected:
disconnected = False
if disconnected:
gradients.append(DisconnectedType()())
gradients.append(disconnected_type())
else:
gradients.append(
grad_undefined(
self, idx, inputs[idx], "Shared Variable with update"
)
)

gradients += [DisconnectedType()() for _ in range(info.n_nit_sot)]
gradients.extend(disconnected_type() for _ in range(info.n_nit_sot))
begin = end

end = begin + n_sitsot_outs
Expand All @@ -3167,7 +3174,7 @@ def compute_all_gradients(known_grads):
if t == "connected":
gradients.append(x[-1])
elif t == "disconnected":
gradients.append(DisconnectedType()())
gradients.append(disconnected_type())
elif t == "through_untraced":
gradients.append(
grad_undefined(
Expand Down Expand Up @@ -3195,7 +3202,7 @@ def compute_all_gradients(known_grads):
):
disconnected = False
if disconnected:
gradients[idx] = DisconnectedType()()
gradients[idx] = disconnected_type()
return gradients

def R_op(self, inputs, eval_points):
Expand Down
10 changes: 5 additions & 5 deletions pytensor/sparse/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from pytensor import _as_symbolic, as_symbolic
from pytensor import scalar as ps
from pytensor.configdefaults import config
from pytensor.gradient import DisconnectedType, grad_undefined
from pytensor.gradient import DisconnectedType, disconnected_type, grad_undefined
from pytensor.graph.basic import Apply, Constant, Variable
from pytensor.graph.op import Op
from pytensor.link.c.type import generic
Expand Down Expand Up @@ -480,9 +480,9 @@ def grad(self, inputs, gout):
)
return [
g_data,
DisconnectedType()(),
DisconnectedType()(),
DisconnectedType()(),
disconnected_type(),
disconnected_type(),
disconnected_type(),
]

def infer_shape(self, fgraph, node, shapes):
Expand Down Expand Up @@ -1940,7 +1940,7 @@ def grad(self, inputs, grads):
gx = g_output
gy = pytensor.tensor.subtensor.advanced_subtensor1(g_output, *idx_list)

return [gx, gy] + [DisconnectedType()()] * len(idx_list)
return [gx, gy, *(disconnected_type() for _ in range(len(idx_list)))]


construct_sparse_from_list = ConstructSparseFromList()
12 changes: 6 additions & 6 deletions pytensor/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from pytensor import config, printing
from pytensor import scalar as ps
from pytensor.compile.builders import OpFromGraph
from pytensor.gradient import DisconnectedType, grad_undefined
from pytensor.gradient import DisconnectedType, disconnected_type, grad_undefined
from pytensor.graph import RewriteDatabaseQuery
from pytensor.graph.basic import Apply, Constant, Variable, equal_computations
from pytensor.graph.fg import FunctionGraph, Output
Expand Down Expand Up @@ -1738,7 +1738,7 @@ def grad(self, inputs, grads):
# the inputs that specify the shape. If you grow the
# shape by epsilon, the existing elements do not
# change.
return [gx] + [DisconnectedType()() for i in inputs[1:]]
return [gx, *(disconnected_type() for _ in range(len(inputs) - 1))]

def R_op(self, inputs, eval_points):
if eval_points[0] is None:
Expand Down Expand Up @@ -2277,7 +2277,7 @@ def L_op(self, inputs, outputs, g_outputs):
return [
join(axis, *new_g_outputs),
grad_undefined(self, 1, axis),
DisconnectedType()(),
disconnected_type(),
]

def R_op(self, inputs, eval_points):
Expand Down Expand Up @@ -3340,14 +3340,14 @@ def L_op(self, inputs, outputs, grads):
if self.dtype in discrete_dtypes:
return [
start.zeros_like(dtype=config.floatX),
DisconnectedType()(),
disconnected_type(),
step.zeros_like(dtype=config.floatX),
]
else:
num_steps_taken = outputs[0].shape[0]
return [
gz.sum(),
DisconnectedType()(),
disconnected_type(),
(gz * arange(num_steps_taken, dtype=self.dtype)).sum(),
]

Expand Down Expand Up @@ -4374,7 +4374,7 @@ def connection_pattern(self, node):
return [[False] for i in node.inputs]

def grad(self, inputs, grads):
return [DisconnectedType()() for i in inputs]
return [disconnected_type() for _ in range(len(inputs))]

def R_op(self, inputs, eval_points):
return [zeros(inputs, self.dtype)]
Expand Down
3 changes: 1 addition & 2 deletions pytensor/tensor/extra_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import pytensor
import pytensor.scalar.basic as ps
from pytensor.gradient import (
DisconnectedType,
_float_zeros_like,
disconnected_type,
grad_undefined,
Expand Down Expand Up @@ -716,7 +715,7 @@ def grad(self, inputs, gout):
gx_transpose = ptb.zeros_like(x_transpose)[repeated_arange].inc(gz_transpose)
gx = ptb.moveaxis(gx_transpose, 0, axis)

return [gx, DisconnectedType()()]
return [gx, disconnected_type()]

def infer_shape(self, fgraph, node, ins_shapes):
i0_shapes = ins_shapes[0]
Expand Down
6 changes: 3 additions & 3 deletions pytensor/tensor/fft.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np

from pytensor.gradient import DisconnectedType
from pytensor.gradient import disconnected_type
from pytensor.graph.basic import Apply
from pytensor.graph.op import Op
from pytensor.tensor.basic import as_tensor_variable
Expand Down Expand Up @@ -59,7 +59,7 @@ def grad(self, inputs, output_grads):
+ [slice(None)]
)
gout = set_subtensor(gout[idx], gout[idx] * 0.5)
return [irfft_op(gout, s), DisconnectedType()()]
return [irfft_op(gout, s), disconnected_type()]

def connection_pattern(self, node):
# Specify that shape input parameter has no connection to graph and gradients.
Expand Down Expand Up @@ -121,7 +121,7 @@ def grad(self, inputs, output_grads):
+ [slice(None)]
)
gf = set_subtensor(gf[idx], gf[idx] * 2)
return [gf, DisconnectedType()()]
return [gf, disconnected_type()]

def connection_pattern(self, node):
# Specify that shape input parameter has no connection to graph and gradients.
Expand Down
6 changes: 3 additions & 3 deletions pytensor/tensor/nlinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from pytensor import scalar as ps
from pytensor.compile.builders import OpFromGraph
from pytensor.gradient import DisconnectedType
from pytensor.gradient import DisconnectedType, disconnected_type
from pytensor.graph.basic import Apply
from pytensor.graph.op import Op
from pytensor.tensor import TensorLike
Expand Down Expand Up @@ -652,8 +652,8 @@ def s_grad_only(
]
if all(is_disconnected):
# This should never actually be reached by Pytensor -- the SVD Op should be pruned from the gradient
# graph if its fully disconnected. It is included for completeness.
return [DisconnectedType()()] # pragma: no cover
# graph if it's fully disconnected. It is included for completeness.
return [disconnected_type()] # pragma: no cover

elif is_disconnected == [True, False, True]:
# This is the same as the compute_uv = False, so we can drop back to that simpler computation, without
Expand Down
4 changes: 2 additions & 2 deletions pytensor/tensor/reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from numpy.lib._array_utils_impl import normalize_axis_index, normalize_axis_tuple

from pytensor import Variable
from pytensor.gradient import DisconnectedType
from pytensor.gradient import disconnected_type
from pytensor.graph import Apply
from pytensor.graph.op import Op
from pytensor.graph.replace import _vectorize_node
Expand Down Expand Up @@ -217,7 +217,7 @@ def L_op(self, inputs, outputs, output_grads):
n_axes = g_out.ndim - x.ndim + 1
axis_range = list(range(self.axis, self.axis + n_axes))

return [join_dims(g_out, axis=axis_range), DisconnectedType()()]
return [join_dims(g_out, axis=axis_range), disconnected_type()]


@_vectorize_node.register(SplitDims)
Expand Down
11 changes: 6 additions & 5 deletions pytensor/tensor/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from numpy.lib.array_utils import normalize_axis_tuple

import pytensor
from pytensor.gradient import DisconnectedType
from pytensor.gradient import disconnected_type
from pytensor.graph import Op
from pytensor.graph.basic import Apply, Variable
from pytensor.graph.replace import _vectorize_node
Expand Down Expand Up @@ -103,7 +103,7 @@ def grad(self, inp, grads):
# the elements of the tensor variable do not participate
# in the computation of the shape, so they are not really
# part of the graph
return [pytensor.gradient.DisconnectedType()()]
return [disconnected_type()]

def R_op(self, inputs, eval_points):
return [None]
Expand Down Expand Up @@ -474,8 +474,9 @@ def connection_pattern(self, node):
def grad(self, inp, grads):
_x, *shape = inp
(gz,) = grads
return [specify_shape(gz, shape)] + [
pytensor.gradient.DisconnectedType()() for _ in range(len(shape))
return [
specify_shape(gz, shape),
*(disconnected_type() for _ in range(len(shape))),
]

def R_op(self, inputs, eval_points):
Expand Down Expand Up @@ -725,7 +726,7 @@ def connection_pattern(self, node):
def grad(self, inp, grads):
x, _shp = inp
(g_out,) = grads
return [reshape(g_out, shape(x), ndim=x.ndim), DisconnectedType()()]
return [reshape(g_out, shape(x), ndim=x.ndim), disconnected_type()]

def R_op(self, inputs, eval_points):
if eval_points[0] is None:
Expand Down
Loading
Loading