Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 82 additions & 96 deletions pytensor/tensor/reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from pytensor.graph.replace import _vectorize_node
from pytensor.scalar import ScalarVariable
from pytensor.tensor import TensorLike, as_tensor_variable
from pytensor.tensor.basic import expand_dims, infer_static_shape, join, split
from pytensor.tensor.basic import infer_static_shape, join, split
from pytensor.tensor.math import prod
from pytensor.tensor.type import tensor
from pytensor.tensor.variable import TensorVariable
Expand All @@ -24,10 +24,7 @@


class JoinDims(Op):
__props__ = (
"start_axis",
"n_axes",
)
__props__ = ("start_axis", "n_axes")
view_map = {0: [0]}

def __init__(self, start_axis: int, n_axes: int):
Expand Down Expand Up @@ -55,6 +52,11 @@ def make_node(self, x: Variable) -> Apply: # type: ignore[override]

static_shapes = x.type.shape
axis_range = self.axis_range
if (self.start_axis + self.n_axes) > x.type.ndim:
raise ValueError(
f"JoinDims was asked to join dimensions {self.start_axis} to {self.n_axes}, "
f"but input {x} has only {x.type.ndim} dimensions."
)

joined_shape = (
int(np.prod([static_shapes[i] for i in axis_range]))
Expand All @@ -69,9 +71,7 @@ def make_node(self, x: Variable) -> Apply: # type: ignore[override]

def infer_shape(self, fgraph, node, shapes):
[input_shape] = shapes
axis_range = self.axis_range

joined_shape = prod([input_shape[i] for i in axis_range])
joined_shape = prod([input_shape[i] for i in self.axis_range], dtype=int)
return [self.output_shapes(input_shape, joined_shape)]

def perform(self, node, inputs, outputs):
Expand All @@ -98,23 +98,24 @@ def L_op(self, inputs, outputs, output_grads):
@_vectorize_node.register(JoinDims)
def _vectorize_joindims(op, node, x):
[old_x] = node.inputs

batched_ndims = x.type.ndim - old_x.type.ndim
start_axis = op.start_axis
n_axes = op.n_axes

return JoinDims(start_axis + batched_ndims, n_axes).make_node(x)
return JoinDims(op.start_axis + batched_ndims, op.n_axes).make_node(x)


def join_dims(x: TensorLike, axis: Sequence[int] | int | None = None) -> TensorVariable:
def join_dims(
x: TensorLike, start_axis: int = 0, n_axes: int | None = None
) -> TensorVariable:
"""Join consecutive dimensions of a tensor into a single dimension.

Parameters
----------
x : TensorLike
The input tensor.
axis : int or sequence of int, optional
The dimensions to join. If None, all dimensions are joined.
start_axis : int, default 0
The axis from which to start joining dimensions
n_axes: int, optional.
The number of axis to join after `axis`. If `None` joins all remaining axis.
If 0, it inserts a new dimension of length 1.

Returns
-------
Expand All @@ -125,33 +126,32 @@ def join_dims(x: TensorLike, axis: Sequence[int] | int | None = None) -> TensorV
--------
>>> import pytensor.tensor as pt
>>> x = pt.tensor("x", shape=(2, 3, 4, 5))
>>> y = pt.join_dims(x, axis=(1, 2))
>>> y = pt.join_dims(x)
>>> y.type.shape
(120,)
>>> y = pt.join_dims(x, start_axis=1)
>>> y.type.shape
(2, 60)
>>> y = pt.join_dims(x, start_axis=1, n_axes=2)
>>> y.type.shape
(2, 12, 5)
"""
x = as_tensor_variable(x)
ndim = x.type.ndim

if axis is None:
axis = list(range(x.ndim))
elif isinstance(axis, int):
axis = [axis]
elif not isinstance(axis, list | tuple):
raise TypeError("axis must be an int, a list/tuple of ints, or None")

axis = normalize_axis_tuple(axis, x.ndim)
if start_axis < 0:
# We treat scalars as if they had a single axis
start_axis += max(1, ndim)

if len(axis) <= 1:
return x # type: ignore[unreachable]

if np.diff(axis).max() > 1:
raise ValueError(
f"join_dims axis must be consecutive, got normalized axis: {axis}"
if not 0 <= start_axis <= ndim:
raise IndexError(
f"Axis {start_axis} is out of bounds for array of dimension {ndim}"
)

start_axis = min(axis)
n_axes = len(axis)
if n_axes is None:
n_axes = ndim - start_axis

return JoinDims(start_axis=start_axis, n_axes=n_axes)(x) # type: ignore[return-value]
return JoinDims(start_axis, n_axes)(x) # type: ignore[return-value]


class SplitDims(Op):
Expand Down Expand Up @@ -213,11 +213,11 @@ def connection_pattern(self, node):
def L_op(self, inputs, outputs, output_grads):
(x, _) = inputs
(g_out,) = output_grads

n_axes = g_out.ndim - x.ndim + 1
axis_range = list(range(self.axis, self.axis + n_axes))

return [join_dims(g_out, axis=axis_range), disconnected_type()]
return [
join_dims(g_out, start_axis=self.axis, n_axes=n_axes),
disconnected_type(),
]


@_vectorize_node.register(SplitDims)
Expand All @@ -230,14 +230,13 @@ def _vectorize_splitdims(op, node, x, shape):
if as_tensor_variable(shape).type.ndim != 1:
return vectorize_node_fallback(op, node, x, shape)

axis = op.axis
return SplitDims(axis=axis + batched_ndims).make_node(x, shape)
return SplitDims(axis=op.axis + batched_ndims).make_node(x, shape)


def split_dims(
x: TensorLike,
shape: ShapeValueType | Sequence[ShapeValueType],
axis: int | None = None,
axis: int = 0,
) -> TensorVariable:
"""Split a dimension of a tensor into multiple dimensions.

Expand All @@ -247,8 +246,8 @@ def split_dims(
The input tensor.
shape : int or sequence of int
The new shape to split the specified dimension into.
axis : int, optional
The dimension to split. If None, the input is assumed to be 1D and axis 0 is used.
axis : int, default 0
The dimension to split.

Returns
-------
Expand All @@ -259,22 +258,18 @@ def split_dims(
--------
>>> import pytensor.tensor as pt
>>> x = pt.tensor("x", shape=(6, 4, 6))
>>> y = pt.split_dims(x, shape=(2, 3), axis=0)
>>> y = pt.split_dims(x, shape=(2, 3))
>>> y.type.shape
(2, 3, 4, 6)
>>> y = pt.split_dims(x, shape=(2, 3), axis=-1)
>>> y.type.shape
(6, 4, 2, 3)
"""
x = as_tensor_variable(x)

if axis is None:
if x.type.ndim != 1:
raise ValueError(
"split_dims can only be called with axis=None for 1d inputs"
)
axis = 0
else:
axis = normalize_axis_index(axis, x.ndim)
axis = normalize_axis_index(axis, x.ndim)

# Convert scalar shape to 1d tuple (shape,)
# Which is basically a specify_shape
if not isinstance(shape, Sequence):
if isinstance(shape, TensorVariable | np.ndarray):
if shape.ndim == 0:
Expand Down Expand Up @@ -313,8 +308,6 @@ def _analyze_axes_list(axes) -> tuple[int, int, int]:
elif not isinstance(axes, Iterable):
raise TypeError("axes must be an int, an iterable of ints, or None")

axes = tuple(axes)

if len(axes) == 0:
raise ValueError("axes=[] is ambiguous; use None to ravel all")

Expand Down Expand Up @@ -367,7 +360,7 @@ def find_gaps(s):


def pack(
*tensors: TensorLike, axes: Sequence[int] | int | None = None
*tensors: TensorLike, keep_axes: Sequence[int] | int | None = None
) -> tuple[TensorVariable, list[TensorVariable]]:
"""
Combine multiple tensors by preserving the specified axes and raveling the rest into a single axis.
Expand Down Expand Up @@ -401,28 +394,29 @@ def pack(

Examples
--------
The easiest way to understand pack is through examples. The simplest case is using axes=None, which is equivalent
to ``join(0, *[t.ravel() for t in tensors])``:
The easiest way to understand pack is through examples.
The simplest case is using the default keep_axes=None, which is equivalent to ``concatenate([t.ravel() for t in tensors])``:

.. code-block:: python
import pytensor.tensor as pt

x = pt.tensor("x", shape=(2, 3))
y = pt.tensor("y", shape=(4, 5, 6))

packed_tensor, packed_shapes = pt.pack(x, y, axes=None)
packed_tensor, packed_shapes = pt.pack(x, y)
# packed_tensor has shape (6 + 120,) == (126,)
# packed_shapes is [(2, 3), (4, 5, 6)]

If we want to preserve a single axis, we can use either positive or negative indexing. Notice that all tensors
must have the same size along the preserved axis. For example, using axes=0:
If we want to preserve a single axis, we can use either positive or negative indexing.
Notice that all tensors must have the same size along the preserved axis.
For example, using keep_axes=0:

.. code-block:: python
import pytensor.tensor as pt

x = pt.tensor("x", shape=(2, 3))
y = pt.tensor("y", shape=(2, 5, 6))
packed_tensor, packed_shapes = pt.pack(x, y, axes=0)
packed_tensor, packed_shapes = pt.pack(x, y, keep_axes=0)
# packed_tensor has shape (2, 3 + 30) == (2, 33)
# packed_shapes is [(3,), (5, 6)]

Expand All @@ -434,7 +428,7 @@ def pack(

x = pt.tensor("x", shape=(4, 2, 3))
y = pt.tensor("y", shape=(5, 2, 3))
packed_tensor, packed_shapes = pt.pack(x, y, axes=(-2, -1))
packed_tensor, packed_shapes = pt.pack(x, y, keep_axes=(-2, -1))
# packed_tensor has shape (4 + 5, 2, 3) == (9, 2, 3)
# packed_shapes is [(4,), (5,

Expand All @@ -445,13 +439,13 @@ def pack(

x = pt.tensor("x", shape=(2, 4, 3))
y = pt.tensor("y", shape=(2, 5, 3))
packed_tensor, packed_shapes = pt.pack(x, y, axes=(0, -1))
packed_tensor, packed_shapes = pt.pack(x, y, keep_axes=(0, -1))
# packed_tensor has shape (2, 4 + 5, 3) == (2, 9, 3)
# packed_shapes is [(4,), (5,)]
"""
tensor_list = [as_tensor_variable(t) for t in tensors]

n_before, n_after, min_axes = _analyze_axes_list(axes)
n_before, n_after, min_axes = _analyze_axes_list(keep_axes)

reshaped_tensors: list[Variable] = []
packed_shapes: list[TensorVariable] = []
Expand All @@ -462,33 +456,21 @@ def pack(
if n_dim < min_axes:
raise ValueError(
f"Input {i} (zero indexed) to pack has {n_dim} dimensions, "
f"but axes={axes} assumes at least {min_axes} dimension{'s' if min_axes != 1 else ''}."
f"but {keep_axes=} assumes at least {min_axes} dimension{'s' if min_axes != 1 else ''}."
)
n_after_packed = n_dim - n_after
packed_shapes.append(input_tensor.shape[n_before:n_after_packed])

if n_dim == min_axes:
# If an input has the minimum number of axes, pack implicitly inserts a new axis based on the pattern
# implied by the axes.
input_tensor = expand_dims(input_tensor, axis=n_before)
reshaped_tensors.append(input_tensor)
continue

# The reshape we want is (shape[:before], -1, shape[n_after_packed:]). join_dims does (shape[:min(axes)], -1,
# shape[max(axes)+1:]). So this will work if we choose axes=(n_before, n_after_packed - 1). Because of the
# rules on the axes input, we will always have n_before <= n_after_packed - 1. A set is used here to cover the
# corner case when n_before == n_after_packed - 1 (i.e., when there is only one axis to ravel --> do nothing).
join_axes = range(n_before, n_after_packed)
joined = join_dims(input_tensor, tuple(join_axes))

n_packed = n_dim - n_after - n_before
packed_shapes.append(input_tensor.shape[n_before : n_before + n_packed])
joined = join_dims(input_tensor, n_before, n_packed)
reshaped_tensors.append(joined)

return join(n_before, *reshaped_tensors), packed_shapes


def unpack(
packed_input: TensorLike,
axes: int | Sequence[int] | None,
packed_shapes: Sequence[ShapeValueType],
keep_axes: int | Sequence[int] | None = None,
) -> list[TensorVariable]:
"""
Unpack a packed tensor into multiple tensors by splitting along the specified axes and reshaping.
Expand All @@ -504,39 +486,43 @@ def unpack(
----------
packed_input : TensorLike
The packed tensor to be unpacked.
axes : int, sequence of int, or None
Axes that were preserved during packing. If None, the input is assumed to be 1D and axis 0 is used.
packed_shapes : list of ShapeValueType
A list containing the shapes of the raveled dimensions for each output tensor.
keep_axes : int, sequence of int, optional
Axes that were preserved during packing. Default is None

Returns
-------
unpacked_tensors : list of TensorVariable
A list of unpacked tensors with their original shapes restored.
"""
packed_input = as_tensor_variable(packed_input)

if axes is None:
if keep_axes is None:
if packed_input.ndim != 1:
raise ValueError(
"unpack can only be called with keep_axis=None for 1d inputs"
)
split_axis = 0
else:
axes = normalize_axis_tuple(axes, ndim=packed_input.ndim)
keep_axes = normalize_axis_tuple(keep_axes, ndim=packed_input.ndim)
try:
[split_axis] = (i for i in range(packed_input.ndim) if i not in axes)
[split_axis] = (i for i in range(packed_input.ndim) if i not in keep_axes)
except ValueError as err:
raise ValueError(
"Unpack must have exactly one more dimension that implied by axes"
f"unpack input must have exactly one more dimension that implied by keep_axes. "
f"{packed_input} has {packed_input.type.ndim} dimensions, expected {len(keep_axes) + 1}"
) from err

split_inputs = split(
packed_input,
splits_size=[prod(shape, dtype=int) for shape in packed_shapes],
n_splits=len(packed_shapes),
axis=split_axis,
)
n_splits = len(packed_shapes)
if n_splits == 1:
# If there is only one tensor to unpack, no need to split
split_inputs = [packed_input]
else:
split_inputs = split(
packed_input,
splits_size=[prod(shape, dtype=int) for shape in packed_shapes],
axis=split_axis,
)

return [
split_dims(inp, shape, split_axis)
Expand Down
5 changes: 3 additions & 2 deletions pytensor/tensor/rewriting/reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,9 @@ def local_join_dims_to_reshape(fgraph, node):
"""

(x,) = node.inputs
start_axis = node.op.start_axis
n_axes = node.op.n_axes
op = node.op
start_axis = op.start_axis
n_axes = op.n_axes

output_shape = [
*x.shape[:start_axis],
Expand Down
2 changes: 1 addition & 1 deletion tests/tensor/rewriting/test_reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def test_local_split_dims_to_reshape():

def test_local_join_dims_to_reshape():
x = tensor("x", shape=(2, 2, 5, 1, 3))
x_join = join_dims(x, axis=(1, 2, 3))
x_join = join_dims(x, start_axis=1, n_axes=3)

fg = FunctionGraph(inputs=[x], outputs=[x_join])

Expand Down
Loading
Loading