-
Notifications
You must be signed in to change notification settings - Fork 156
Fix issues with split and split_dims #1828
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,23 +1,28 @@ | ||
| from collections.abc import Iterable, Sequence | ||
| from itertools import pairwise | ||
| from typing import TypeAlias | ||
|
|
||
| import numpy as np | ||
| from numpy.lib._array_utils_impl import normalize_axis_tuple | ||
| from numpy.lib._array_utils_impl import normalize_axis_index, normalize_axis_tuple | ||
|
|
||
| from pytensor import Variable | ||
| from pytensor.gradient import DisconnectedType | ||
| from pytensor.graph import Apply | ||
| from pytensor.graph.op import Op | ||
| from pytensor.graph.replace import _vectorize_node | ||
| from pytensor.scalar import ScalarVariable | ||
| from pytensor.tensor import TensorLike, as_tensor_variable | ||
| from pytensor.tensor.basic import expand_dims, infer_static_shape, join, split | ||
| from pytensor.tensor.extra_ops import squeeze | ||
| from pytensor.tensor.math import prod | ||
| from pytensor.tensor.shape import ShapeValueType, shape | ||
| from pytensor.tensor.type import tensor | ||
| from pytensor.tensor.variable import TensorVariable | ||
|
|
||
|
|
||
| ShapeValueType: TypeAlias = ( | ||
| int | np.integer | ScalarVariable | TensorVariable | np.ndarray | ||
| ) | ||
|
|
||
|
|
||
| class JoinDims(Op): | ||
| __props__ = ( | ||
| "start_axis", | ||
|
|
@@ -81,16 +86,11 @@ def perform(self, node, inputs, outputs): | |
|
|
||
| out[0] = x.reshape(output_shape) | ||
|
|
||
| def L_op( | ||
| self, | ||
| inputs: Sequence[Variable], | ||
| outputs: Sequence[Variable], | ||
| output_grads: Sequence[Variable], | ||
| ) -> list[Variable]: | ||
| def L_op(self, inputs, outputs, output_grads): | ||
| (x,) = inputs | ||
| (g_out,) = output_grads | ||
|
|
||
| x_shape = shape(x) | ||
| x_shape = x.shape | ||
| packed_shape = [x_shape[i] for i in self.axis_range] | ||
| return [split_dims(g_out, shape=packed_shape, axis=self.start_axis)] | ||
|
|
||
|
|
@@ -163,13 +163,18 @@ def __init__(self, axis: int): | |
| raise ValueError("SplitDims axis must be non-negative") | ||
| self.axis = axis | ||
|
|
||
| def make_node(self, x: Variable, shape: Variable) -> Apply: # type: ignore[override] | ||
| def make_node(self, x, shape): | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was wrong, as x, shape may be TensorLike |
||
| x = as_tensor_variable(x) | ||
| shape = as_tensor_variable(shape, dtype=int, ndim=1) | ||
| shape = as_tensor_variable(shape, dtype=int) | ||
|
|
||
| if shape.type.numpy_dtype.kind not in "iu": | ||
| raise TypeError("shape must be an integer tensor") | ||
|
|
||
| if shape.type.ndim != 1: | ||
| raise TypeError( | ||
| f"shape must be a 1-D tensor, got {shape} with {shape.type.ndim} dimensions" | ||
| ) | ||
|
|
||
| axis = self.axis | ||
| _, constant_shape = infer_static_shape(shape) | ||
|
|
||
|
|
@@ -205,16 +210,11 @@ def perform(self, node, inputs, outputs): | |
| def connection_pattern(self, node): | ||
| return [[True], [False]] | ||
|
|
||
| def L_op( | ||
| self, | ||
| inputs: Sequence[Variable], | ||
| outputs: Sequence[Variable], | ||
| output_grads: Sequence[Variable], | ||
| ) -> list[Variable]: | ||
| def L_op(self, inputs, outputs, output_grads): | ||
| (x, _) = inputs | ||
| (g_out,) = output_grads | ||
|
|
||
| n_axes = g_out.ndim - x.ndim + 1 # type: ignore[attr-defined] | ||
| n_axes = g_out.ndim - x.ndim + 1 | ||
| axis_range = list(range(self.axis, self.axis + n_axes)) | ||
|
|
||
| return [join_dims(g_out, axis=axis_range), DisconnectedType()()] | ||
|
|
@@ -266,25 +266,21 @@ def split_dims( | |
| x = as_tensor_variable(x) | ||
|
|
||
| if axis is None: | ||
| if x.ndim != 1: | ||
| if x.type.ndim != 1: | ||
| raise ValueError( | ||
| "split_dims can only be called with axis=None for 1d inputs" | ||
| ) | ||
| axis = 0 | ||
|
|
||
| if isinstance(shape, int): | ||
| shape = [shape] | ||
| else: | ||
| shape = list(shape) # type: ignore[arg-type] | ||
|
|
||
| if not shape: | ||
| # If we get an empty shape, there is potentially a dummy dimension at the requested axis. This happens for | ||
| # example when splitting a packed tensor that had its dims expanded before packing (e.g. when packing shapes | ||
| # (3, ) and (3, 3) to (3, 4) | ||
| return squeeze(x, axis=axis) # type: ignore[no-any-return] | ||
| axis = normalize_axis_index(axis, x.ndim) | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it can only be an index, not a tuple so be more pedantic |
||
|
|
||
| [axis] = normalize_axis_tuple(axis, x.ndim) # type: ignore[misc] | ||
| shape = as_tensor_variable(shape, dtype="int64", ndim=1) # type: ignore[arg-type] | ||
| # Convert scalar shape to 1d tuple (shape,) | ||
| if not isinstance(shape, Sequence): | ||
| if isinstance(shape, TensorVariable | np.ndarray): | ||
| if shape.ndim == 0: | ||
| shape = (shape,) | ||
| elif isinstance(shape, int | np.integer | ScalarVariable): | ||
| shape = (shape,) | ||
|
|
||
| return SplitDims(axis=axis)(x, shape) # type: ignore[return-value] | ||
|
|
||
|
|
@@ -372,7 +368,7 @@ def find_gaps(s): | |
|
|
||
| def pack( | ||
| *tensors: TensorLike, axes: Sequence[int] | int | None = None | ||
| ) -> tuple[TensorVariable, list[ShapeValueType]]: | ||
| ) -> tuple[TensorVariable, list[TensorVariable]]: | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We only return TensorVariable shapes, not the flexible input types |
||
| """ | ||
| Combine multiple tensors by preserving the specified axes and raveling the rest into a single axis. | ||
|
|
||
|
|
@@ -458,7 +454,7 @@ def pack( | |
| n_before, n_after, min_axes = _analyze_axes_list(axes) | ||
|
|
||
| reshaped_tensors: list[Variable] = [] | ||
| packed_shapes: list[ShapeValueType] = [] | ||
| packed_shapes: list[TensorVariable] = [] | ||
|
|
||
| for i, input_tensor in enumerate(tensor_list): | ||
| n_dim = input_tensor.ndim | ||
|
|
@@ -492,7 +488,7 @@ def pack( | |
| def unpack( | ||
| packed_input: TensorLike, | ||
| axes: int | Sequence[int] | None, | ||
| packed_shapes: list[ShapeValueType], | ||
| packed_shapes: Sequence[ShapeValueType], | ||
| ) -> list[TensorVariable]: | ||
| """ | ||
| Unpack a packed tensor into multiple tensors by splitting along the specified axes and reshaping. | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I strongly disagree with appeasing mypy here and pretend we don't know that we can only ever get and return TensorVariable