Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions pytensor/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2254,18 +2254,19 @@ def infer_shape(self, fgraph, node, in_shapes):
out_shapes.append(temp)
return out_shapes

def connection_pattern(self, node):
n_out = len(node.outputs)
return [
[True] * n_out,
[True] * n_out,
[False] * n_out,
]

def L_op(self, inputs, outputs, g_outputs):
"""Join the gradients along the axis that was used to split x."""
_x, axis, n = inputs
_x, axis, _n = inputs

# If all the output gradients are disconnected, then so are the inputs
if builtins.all(isinstance(g.type, DisconnectedType) for g in g_outputs):
return [
DisconnectedType()(),
grad_undefined(self, 1, axis),
grad_undefined(self, 2, n),
]
# Else, we have to make them zeros before joining them
# We have to convert disconnected outputs to zeros before joining them
new_g_outputs = []
for o, g in zip(outputs, g_outputs, strict=True):
if isinstance(g.type, DisconnectedType):
Expand All @@ -2276,7 +2277,7 @@ def L_op(self, inputs, outputs, g_outputs):
return [
join(axis, *new_g_outputs),
grad_undefined(self, 1, axis),
grad_undefined(self, 2, n),
DisconnectedType()(),
]

def R_op(self, inputs, eval_points):
Expand Down
66 changes: 31 additions & 35 deletions pytensor/tensor/reshape.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,28 @@
from collections.abc import Iterable, Sequence
from itertools import pairwise
from typing import TypeAlias

import numpy as np
from numpy.lib._array_utils_impl import normalize_axis_tuple
from numpy.lib._array_utils_impl import normalize_axis_index, normalize_axis_tuple

from pytensor import Variable
from pytensor.gradient import DisconnectedType
from pytensor.graph import Apply
from pytensor.graph.op import Op
from pytensor.graph.replace import _vectorize_node
from pytensor.scalar import ScalarVariable
from pytensor.tensor import TensorLike, as_tensor_variable
from pytensor.tensor.basic import expand_dims, infer_static_shape, join, split
from pytensor.tensor.extra_ops import squeeze
from pytensor.tensor.math import prod
from pytensor.tensor.shape import ShapeValueType, shape
from pytensor.tensor.type import tensor
from pytensor.tensor.variable import TensorVariable


ShapeValueType: TypeAlias = (
int | np.integer | ScalarVariable | TensorVariable | np.ndarray
)


class JoinDims(Op):
__props__ = (
"start_axis",
Expand Down Expand Up @@ -81,16 +86,11 @@ def perform(self, node, inputs, outputs):

out[0] = x.reshape(output_shape)

def L_op(
self,
inputs: Sequence[Variable],
outputs: Sequence[Variable],
output_grads: Sequence[Variable],
) -> list[Variable]:
def L_op(self, inputs, outputs, output_grads):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I strongly disagree with appeasing mypy here and pretend we don't know that we can only ever get and return TensorVariable

(x,) = inputs
(g_out,) = output_grads

x_shape = shape(x)
x_shape = x.shape
packed_shape = [x_shape[i] for i in self.axis_range]
return [split_dims(g_out, shape=packed_shape, axis=self.start_axis)]

Expand Down Expand Up @@ -163,13 +163,18 @@ def __init__(self, axis: int):
raise ValueError("SplitDims axis must be non-negative")
self.axis = axis

def make_node(self, x: Variable, shape: Variable) -> Apply: # type: ignore[override]
def make_node(self, x, shape):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was wrong, as x, shape may be TensorLike

x = as_tensor_variable(x)
shape = as_tensor_variable(shape, dtype=int, ndim=1)
shape = as_tensor_variable(shape, dtype=int)

if shape.type.numpy_dtype.kind not in "iu":
raise TypeError("shape must be an integer tensor")

if shape.type.ndim != 1:
raise TypeError(
f"shape must be a 1-D tensor, got {shape} with {shape.type.ndim} dimensions"
)

axis = self.axis
_, constant_shape = infer_static_shape(shape)

Expand Down Expand Up @@ -205,16 +210,11 @@ def perform(self, node, inputs, outputs):
def connection_pattern(self, node):
return [[True], [False]]

def L_op(
self,
inputs: Sequence[Variable],
outputs: Sequence[Variable],
output_grads: Sequence[Variable],
) -> list[Variable]:
def L_op(self, inputs, outputs, output_grads):
(x, _) = inputs
(g_out,) = output_grads

n_axes = g_out.ndim - x.ndim + 1 # type: ignore[attr-defined]
n_axes = g_out.ndim - x.ndim + 1
axis_range = list(range(self.axis, self.axis + n_axes))

return [join_dims(g_out, axis=axis_range), DisconnectedType()()]
Expand Down Expand Up @@ -266,25 +266,21 @@ def split_dims(
x = as_tensor_variable(x)

if axis is None:
if x.ndim != 1:
if x.type.ndim != 1:
raise ValueError(
"split_dims can only be called with axis=None for 1d inputs"
)
axis = 0

if isinstance(shape, int):
shape = [shape]
else:
shape = list(shape) # type: ignore[arg-type]

if not shape:
# If we get an empty shape, there is potentially a dummy dimension at the requested axis. This happens for
# example when splitting a packed tensor that had its dims expanded before packing (e.g. when packing shapes
# (3, ) and (3, 3) to (3, 4)
return squeeze(x, axis=axis) # type: ignore[no-any-return]
axis = normalize_axis_index(axis, x.ndim)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it can only be an index, not a tuple so be more pedantic


[axis] = normalize_axis_tuple(axis, x.ndim) # type: ignore[misc]
shape = as_tensor_variable(shape, dtype="int64", ndim=1) # type: ignore[arg-type]
# Convert scalar shape to 1d tuple (shape,)
if not isinstance(shape, Sequence):
if isinstance(shape, TensorVariable | np.ndarray):
if shape.ndim == 0:
shape = (shape,)
elif isinstance(shape, int | np.integer | ScalarVariable):
shape = (shape,)

return SplitDims(axis=axis)(x, shape) # type: ignore[return-value]

Expand Down Expand Up @@ -372,7 +368,7 @@ def find_gaps(s):

def pack(
*tensors: TensorLike, axes: Sequence[int] | int | None = None
) -> tuple[TensorVariable, list[ShapeValueType]]:
) -> tuple[TensorVariable, list[TensorVariable]]:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We only return TensorVariable shapes, not the flexible input types

"""
Combine multiple tensors by preserving the specified axes and raveling the rest into a single axis.

Expand Down Expand Up @@ -458,7 +454,7 @@ def pack(
n_before, n_after, min_axes = _analyze_axes_list(axes)

reshaped_tensors: list[Variable] = []
packed_shapes: list[ShapeValueType] = []
packed_shapes: list[TensorVariable] = []

for i, input_tensor in enumerate(tensor_list):
n_dim = input_tensor.ndim
Expand Down Expand Up @@ -492,7 +488,7 @@ def pack(
def unpack(
packed_input: TensorLike,
axes: int | Sequence[int] | None,
packed_shapes: list[ShapeValueType],
packed_shapes: Sequence[ShapeValueType],
) -> list[TensorVariable]:
"""
Unpack a packed tensor into multiple tensors by splitting along the specified axes and reshaping.
Expand Down
15 changes: 13 additions & 2 deletions tests/tensor/test_reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pytensor import config, function
from pytensor import tensor as pt
from pytensor.graph import rewrite_graph, vectorize_graph
from pytensor.graph.op import io_connection_pattern
from pytensor.tensor.reshape import (
_analyze_axes_list,
join_dims,
Expand Down Expand Up @@ -61,9 +62,10 @@ def test_join_dims():
[
(0, pt.as_tensor([2, 3]), (2, 3, 4, 6)),
(2, [2, 3], (6, 4, 2, 3)),
(-1, pt.as_tensor(6), (6, 4, 6)),
(-1, 6, (6, 4, 6)),
],
ids=["tensor", "list", "integer"],
ids=["tensor list", "integer list", "tensor", "integer"],
)
def test_split_dims(axis, shape, expected_shape):
rng = np.random.default_rng()
Expand Down Expand Up @@ -95,7 +97,7 @@ def test_split_dims(axis, shape, expected_shape):

def test_split_size_zero_shape():
x = pt.tensor("x", shape=(1, 4, 6))
x_split = split_dims(x, axis=0, shape=pt.as_tensor(np.zeros((0,))))
x_split = split_dims(x, axis=0, shape=pt.as_tensor(np.zeros((0,), dtype="int32")))
assert x_split.type.shape == (4, 6)

x_value = np.empty((1, 4, 6), dtype=config.floatX)
Expand Down Expand Up @@ -288,3 +290,12 @@ def test_pack_unpack_round_trip(self, axes):

for input_val, output_val in zip(input_dict.values(), output_vals, strict=True):
np.testing.assert_allclose(input_val, output_val)


def test_unpack_connection():
x = pt.vector("x")
d0 = pt.scalar("d0", dtype=int)
d1 = pt.scalar("d1", dtype=int)
x0, x1 = pt.unpack(x, axes=None, packed_shapes=[d0, d1])
out = x0.sum() + x1.sum()
assert io_connection_pattern([x, d0, d1], [out]) == [[True], [False], [False]]
Loading