Skip to content

Commit c90954f

Browse files
committed
Avoid runtime broadcast error due to dot_to_mul rewrite
1 parent 8206eb4 commit c90954f

File tree

2 files changed

+37
-1
lines changed

2 files changed

+37
-1
lines changed

pytensor/tensor/rewriting/math.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@
106106
from pytensor.tensor.rewriting.blockwise import blockwise_of
107107
from pytensor.tensor.rewriting.elemwise import apply_local_dimshuffle_lift
108108
from pytensor.tensor.rewriting.linalg import is_matrix_transpose
109-
from pytensor.tensor.shape import Shape, Shape_i
109+
from pytensor.tensor.shape import Shape, Shape_i, specify_shape
110110
from pytensor.tensor.slinalg import BlockDiagonal
111111
from pytensor.tensor.subtensor import Subtensor
112112
from pytensor.tensor.type import (
@@ -424,6 +424,13 @@ def local_dot_to_mul(fgraph, node):
424424
):
425425
return None
426426

427+
# Add specify_shape for unknown dimensions that must be 1
428+
# To avoid runtime broadcast error by multiply
429+
if a.type.shape[-1] != 1:
430+
a = specify_shape(a, (..., None, 1))
431+
if b.type.shape[-2] != 1:
432+
b = specify_shape(b, (..., 1, None))
433+
427434
new_out = mul(a, b)
428435
copy_stack_trace(node.out, new_out)
429436
return [new_out]

tests/tensor/rewriting/test_math.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4805,6 +4805,35 @@ def test_local_dot_to_mul(batched, a_shape, b_shape):
48054805
)
48064806

48074807

4808+
def test_local_dot_to_mul_unspecified_length_1():
4809+
# Regression test for https://github.com/pymc-devs/pytensor/issues/1782
4810+
x = matrix("x", shape=(5, 1), dtype="float64")
4811+
y = matrix("y", shape=(None, 1), dtype="float64")
4812+
out = x @ y
4813+
fn = function([x, y], out)
4814+
assert all(
4815+
isinstance(node.op, Elemwise | SpecifyShape)
4816+
for node in fn.maker.fgraph.apply_nodes
4817+
)
4818+
np.testing.assert_allclose(
4819+
fn(x=np.ones((5, 1)), y=np.ones((1, 1)) * 5),
4820+
np.ones((5, 1)) * 5,
4821+
)
4822+
4823+
x = matrix("x", shape=(1, None), dtype="float64")
4824+
y = matrix("y", shape=(1, 5), dtype="float64")
4825+
out = x @ y
4826+
fn = function([x, y], out)
4827+
assert all(
4828+
isinstance(node.op, Elemwise | SpecifyShape)
4829+
for node in fn.maker.fgraph.apply_nodes
4830+
)
4831+
np.testing.assert_allclose(
4832+
fn(x=np.ones((1, 1)) * 5, y=np.ones((1, 5))),
4833+
np.ones((1, 5)) * 5,
4834+
)
4835+
4836+
48084837
@pytest.mark.parametrize("left_multiply", [True, False], ids=["left", "right"])
48094838
@pytest.mark.parametrize(
48104839
"batch_blockdiag", [True, False], ids=["batch_blockdiag", "unbatched_blockdiag"]

0 commit comments

Comments
 (0)