Skip to content

Conversation

@ricardoV94
Copy link
Member

Related to #1806 #1827

Fix bug when passing simple Tensor shape to split_dims
Change grad_undefined -> grad_disconnected for split_sizes in SplitOp (see #1827 for more context)

@jessegrabowski
Copy link
Member

I reverted the changes to as_tensor_variable. At minimum it's out of scope for this PR. Implementing more careful checks of the shape argument (based on the analysis in the comment above) was sufficient to clear the test failures. We can revisit the ndims argument later.

Something else I noticed was that we're passing dtype to as_tensor_variable. This doesn't do anything in the Variable case, so I changed it to an explicit cast (inside the Op make_node, I left it in the wrapper to handle the Sequence case)

@ricardoV94
Copy link
Member Author

No, better not to cast variables in node but raise like before. That's what shape ops always do. If a user passes a float as a shape argument it's likely a bug and this would mask it

@jessegrabowski
Copy link
Member

Someday I will merge a PR

)

if not shape:
if empty_shape:
Copy link
Member Author

@ricardoV94 ricardoV94 Jan 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about just shape.type.shape == (0,), for the variable case? Also if you standardize as_tensor_variable you don't need the variable vs non-variable case

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But also do we need the special squeeze branch or would the Op do the right thing anyway?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests pass without it (as long as I adjust the existing test_split_size_zero_shape test to pass dtype int to the shape argument), so I guess not.

@ricardoV94
Copy link
Member Author

I'm happy with the PR. I'll fix the git history and merge

@ricardoV94 ricardoV94 force-pushed the split_dims_tweak branch 2 times, most recently from 39f8dc4 to deaf670 Compare January 11, 2026 12:21
@ricardoV94
Copy link
Member Author

I made some further simplifications, and also cleaned type hints. The ShapeValueType from shape.py is not the right thing because it allows ellpsis and none as well, which split_dims and unpack do not.

Remove cases where type-hints are better than bad type-hints
outputs: Sequence[Variable],
output_grads: Sequence[Variable],
) -> list[Variable]:
def L_op(self, inputs, outputs, output_grads):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I strongly disagree with appeasing mypy here and pretend we don't know that we can only ever get and return TensorVariable

self.axis = axis

def make_node(self, x: Variable, shape: Variable) -> Apply: # type: ignore[override]
def make_node(self, x, shape):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was wrong, as x, shape may be TensorLike

# example when splitting a packed tensor that had its dims expanded before packing (e.g. when packing shapes
# (3, ) and (3, 3) to (3, 4)
return squeeze(x, axis=axis) # type: ignore[no-any-return]
axis = normalize_axis_index(axis, x.ndim)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it can only be an index, not a tuple so be more pedantic

def pack(
*tensors: TensorLike, axes: Sequence[int] | int | None = None
) -> tuple[TensorVariable, list[ShapeValueType]]:
) -> tuple[TensorVariable, list[TensorVariable]]:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We only return TensorVariable shapes, not the flexible input types

@ricardoV94 ricardoV94 merged commit d8b51df into pymc-devs:main Jan 11, 2026
66 checks passed
@ricardoV94 ricardoV94 deleted the split_dims_tweak branch January 11, 2026 19:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants