Skip to content

Commit c3c1d94

Browse files
committed
Only allow measurable transpositions in univariate Elemwise chains or direct valued nodes
1 parent 37b0387 commit c3c1d94

File tree

3 files changed

+65
-20
lines changed

3 files changed

+65
-20
lines changed

pymc/logprob/tensor.py

+62-15
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,19 @@
4141
from pytensor.graph.rewriting.basic import node_rewriter
4242
from pytensor.tensor import TensorVariable
4343
from pytensor.tensor.basic import Join, MakeVector
44-
from pytensor.tensor.elemwise import DimShuffle
44+
from pytensor.tensor.elemwise import DimShuffle, Elemwise
4545
from pytensor.tensor.random.op import RandomVariable
4646
from pytensor.tensor.random.rewriting import (
4747
local_dimshuffle_rv_lift,
4848
)
4949

50-
from pymc.logprob.abstract import MeasurableOp, _logprob, _logprob_helper, promised_valued_rv
50+
from pymc.logprob.abstract import (
51+
MeasurableOp,
52+
ValuedRV,
53+
_logprob,
54+
_logprob_helper,
55+
promised_valued_rv,
56+
)
5157
from pymc.logprob.rewriting import (
5258
assume_valued_outputs,
5359
early_measurable_ir_rewrites_db,
@@ -57,6 +63,7 @@
5763
from pymc.logprob.utils import (
5864
check_potential_measurability,
5965
filter_measurable_variables,
66+
get_related_valued_nodes,
6067
replace_rvs_by_values,
6168
)
6269
from pymc.pytensorf import constant_fold
@@ -183,6 +190,9 @@ class MeasurableDimShuffle(MeasurableOp, DimShuffle):
183190
# find it locally and fails when a new `Op` is initialized
184191
c_func_file = str(DimShuffle.get_path(Path(DimShuffle.c_func_file))) # type: ignore[arg-type]
185192

193+
def __str__(self):
194+
return f"Measurable{super().__str__()}"
195+
186196

187197
@_logprob.register(MeasurableDimShuffle)
188198
def logprob_dimshuffle(op: MeasurableDimShuffle, values, base_var, **kwargs):
@@ -215,29 +225,66 @@ def logprob_dimshuffle(op: MeasurableDimShuffle, values, base_var, **kwargs):
215225
return raw_logp.dimshuffle(redo_ds)
216226

217227

228+
def _elemwise_univariate_chain(fgraph, node) -> bool:
229+
# Check whether only Elemwise operations connect a base univariate RV to the valued node through var.
230+
from pymc.distributions.distribution import SymbolicRandomVariable
231+
from pymc.logprob.transforms import MeasurableTransform
232+
233+
[inp] = node.inputs
234+
[out] = node.outputs
235+
236+
def elemwise_root(var: TensorVariable) -> TensorVariable | None:
237+
if isinstance(var.owner.op, RandomVariable | SymbolicRandomVariable):
238+
return var
239+
elif isinstance(var.owner.op, MeasurableTransform):
240+
return elemwise_root(var.owner.inputs[var.owner.op.measurable_input_idx])
241+
else:
242+
return None
243+
244+
# Check that the root is a univariate distribution linked by only elemwise operations
245+
root = elemwise_root(inp)
246+
if root is None:
247+
return False
248+
elif root.owner.op.ndim_supp != 0:
249+
# This is still fine if the variable is directly valued
250+
return any(get_related_valued_nodes(fgraph, node))
251+
252+
def elemwise_leaf(var: TensorVariable, clients=fgraph.clients) -> bool:
253+
var_clients = clients[var]
254+
if len(var_clients) != 1:
255+
return False
256+
[(client, _)] = var_clients
257+
if isinstance(client.op, ValuedRV):
258+
return True
259+
elif isinstance(client.op, Elemwise) and len(client.outputs) == 1:
260+
return elemwise_leaf(client.outputs[0])
261+
else:
262+
return False
263+
264+
# Check that the path to the valued node consists only of elemwise operations
265+
return elemwise_leaf(out)
266+
267+
218268
@node_rewriter([DimShuffle])
219269
def find_measurable_dimshuffles(fgraph, node) -> list[TensorVariable] | None:
220270
r"""Find `Dimshuffle`\s for which a `logprob` can be computed."""
221-
from pymc.distributions.distribution import SymbolicRandomVariable
222-
223271
if isinstance(node.op, MeasurableOp):
224272
return None
225273

226274
if not filter_measurable_variables(node.inputs):
227275
return None
228276

229-
base_var = node.inputs[0]
277+
# In cases where DimShuffle transposes dimensions, we only apply this rewrite when only Elemwise
278+
# operations separate it from the valued node. Further transformations likely need to know where
279+
# the support axes are for a correct implementation (and thus assume they are the rightmost axes).
280+
# TODO: When we include the support axis as meta information in each intermediate MeasurableVariable,
281+
# we can lift this restriction (see https://github.com/pymc-devs/pymc/issues/6360)
282+
if tuple(node.op.shuffle) != tuple(sorted(node.op.shuffle)) and not _elemwise_univariate_chain(
283+
fgraph, node
284+
):
285+
return None
230286

231-
# We can only apply this rewrite directly to `RandomVariable`s, as those are
232-
# the only `Op`s for which we always know the support axis. Other measurable
233-
# variables can have arbitrary support axes (e.g., if they contain separate
234-
# `MeasurableDimShuffle`s). Most measurable variables with `DimShuffle`s
235-
# should still be supported as long as the `DimShuffle`s can be merged/
236-
# lifted towards the base RandomVariable.
237-
# TODO: If we include the support axis as meta information in each
238-
# intermediate MeasurableVariable, we can lift this restriction.
239-
if not isinstance(base_var.owner.op, RandomVariable | SymbolicRandomVariable):
240-
return None # pragma: no cover
287+
base_var = node.inputs[0]
241288

242289
measurable_dimshuffle = MeasurableDimShuffle(node.op.input_broadcastable, node.op.new_order)(
243290
base_var

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ ignore = [
4848
"D101", # Missing docstring in public class
4949
"D102", # Missing docstring in public method
5050
"D103", # Missing docstring in public function
51+
"D105", # Missing docstring in magic method
5152
]
5253

5354
[tool.ruff.lint.pydocstyle]

tests/logprob/test_tensor.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -309,10 +309,7 @@ def test_join_mixed_ndim_supp():
309309
(1, 2, 0), # Swap
310310
(0, 1, 2, "x"), # Expand
311311
("x", 0, 1, 2), # Expand
312-
(
313-
0,
314-
2,
315-
), # Drop
312+
(0, 2), # Drop
316313
(2, 0), # Swap and drop
317314
(2, 1, "x", 0), # Swap and expand
318315
("x", 0, 2), # Expand and drop
@@ -338,7 +335,7 @@ def test_measurable_dimshuffle(ds_order, multivariate):
338335

339336
ref_logp = logp(base_rv, base_vv).dimshuffle(logp_ds_order)
340337

341-
# Disable local_dimshuffle_rv_lift to test fallback Aeppl rewrite
338+
# Disable local_dimshuffle_rv_lift to test fallback logprob rewrite
342339
ir_rewriter = logprob_rewrites_db.query(
343340
RewriteDatabaseQuery(include=["basic"]).excluding("dimshuffle_lift")
344341
)

0 commit comments

Comments
 (0)