-
Notifications
You must be signed in to change notification settings - Fork 155
Fix issues with split and split_dims #1828
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
4f38402 to
579566d
Compare
|
I reverted the changes to Something else I noticed was that we're passing |
|
No, better not to cast variables in node but raise like before. That's what shape ops always do. If a user passes a float as a shape argument it's likely a bug and this would mask it |
|
Someday I will merge a PR |
pytensor/tensor/reshape.py
Outdated
| ) | ||
|
|
||
| if not shape: | ||
| if empty_shape: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about just shape.type.shape == (0,), for the variable case? Also if you standardize as_tensor_variable you don't need the variable vs non-variable case
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But also do we need the special squeeze branch or would the Op do the right thing anyway?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tests pass without it (as long as I adjust the existing test_split_size_zero_shape test to pass dtype int to the shape argument), so I guess not.
|
I'm happy with the PR. I'll fix the git history and merge |
39f8dc4 to
deaf670
Compare
|
I made some further simplifications, and also cleaned type hints. The |
Remove cases where type-hints are better than bad type-hints
deaf670 to
b4b7d8f
Compare
b4b7d8f to
f0fbe9a
Compare
| outputs: Sequence[Variable], | ||
| output_grads: Sequence[Variable], | ||
| ) -> list[Variable]: | ||
| def L_op(self, inputs, outputs, output_grads): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I strongly disagree with appeasing mypy here and pretend we don't know that we can only ever get and return TensorVariable
| self.axis = axis | ||
|
|
||
| def make_node(self, x: Variable, shape: Variable) -> Apply: # type: ignore[override] | ||
| def make_node(self, x, shape): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was wrong, as x, shape may be TensorLike
| # example when splitting a packed tensor that had its dims expanded before packing (e.g. when packing shapes | ||
| # (3, ) and (3, 3) to (3, 4) | ||
| return squeeze(x, axis=axis) # type: ignore[no-any-return] | ||
| axis = normalize_axis_index(axis, x.ndim) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it can only be an index, not a tuple so be more pedantic
| def pack( | ||
| *tensors: TensorLike, axes: Sequence[int] | int | None = None | ||
| ) -> tuple[TensorVariable, list[ShapeValueType]]: | ||
| ) -> tuple[TensorVariable, list[TensorVariable]]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We only return TensorVariable shapes, not the flexible input types
Related to #1806 #1827
Fix bug when passing simple Tensor shape to split_dims
Change grad_undefined -> grad_disconnected for split_sizes in SplitOp (see #1827 for more context)