|
41 | 41 | from pytensor.graph.rewriting.basic import node_rewriter
|
42 | 42 | from pytensor.tensor import TensorVariable
|
43 | 43 | from pytensor.tensor.basic import Join, MakeVector
|
44 |
| -from pytensor.tensor.elemwise import DimShuffle |
| 44 | +from pytensor.tensor.elemwise import DimShuffle, Elemwise |
45 | 45 | from pytensor.tensor.random.op import RandomVariable
|
46 | 46 | from pytensor.tensor.random.rewriting import (
|
47 | 47 | local_dimshuffle_rv_lift,
|
48 | 48 | )
|
49 | 49 |
|
50 |
| -from pymc.logprob.abstract import MeasurableOp, _logprob, _logprob_helper, promised_valued_rv |
| 50 | +from pymc.logprob.abstract import ( |
| 51 | + MeasurableOp, |
| 52 | + ValuedRV, |
| 53 | + _logprob, |
| 54 | + _logprob_helper, |
| 55 | + promised_valued_rv, |
| 56 | +) |
51 | 57 | from pymc.logprob.rewriting import (
|
52 | 58 | assume_valued_outputs,
|
53 | 59 | early_measurable_ir_rewrites_db,
|
|
57 | 63 | from pymc.logprob.utils import (
|
58 | 64 | check_potential_measurability,
|
59 | 65 | filter_measurable_variables,
|
| 66 | + get_related_valued_nodes, |
60 | 67 | replace_rvs_by_values,
|
61 | 68 | )
|
62 | 69 | from pymc.pytensorf import constant_fold
|
@@ -183,6 +190,9 @@ class MeasurableDimShuffle(MeasurableOp, DimShuffle):
|
183 | 190 | # find it locally and fails when a new `Op` is initialized
|
184 | 191 | c_func_file = str(DimShuffle.get_path(Path(DimShuffle.c_func_file))) # type: ignore[arg-type]
|
185 | 192 |
|
| 193 | + def __str__(self): |
| 194 | + return f"Measurable{super().__str__()}" |
| 195 | + |
186 | 196 |
|
187 | 197 | @_logprob.register(MeasurableDimShuffle)
|
188 | 198 | def logprob_dimshuffle(op: MeasurableDimShuffle, values, base_var, **kwargs):
|
@@ -215,29 +225,66 @@ def logprob_dimshuffle(op: MeasurableDimShuffle, values, base_var, **kwargs):
|
215 | 225 | return raw_logp.dimshuffle(redo_ds)
|
216 | 226 |
|
217 | 227 |
|
| 228 | +def _elemwise_univariate_chain(fgraph, node) -> bool: |
| 229 | + # Check whether only Elemwise operations connect a base univariate RV to the valued node through var. |
| 230 | + from pymc.distributions.distribution import SymbolicRandomVariable |
| 231 | + from pymc.logprob.transforms import MeasurableTransform |
| 232 | + |
| 233 | + [inp] = node.inputs |
| 234 | + [out] = node.outputs |
| 235 | + |
| 236 | + def elemwise_root(var: TensorVariable) -> TensorVariable | None: |
| 237 | + if isinstance(var.owner.op, RandomVariable | SymbolicRandomVariable): |
| 238 | + return var |
| 239 | + elif isinstance(var.owner.op, MeasurableTransform): |
| 240 | + return elemwise_root(var.owner.inputs[var.owner.op.measurable_input_idx]) |
| 241 | + else: |
| 242 | + return None |
| 243 | + |
| 244 | + # Check that the root is a univariate distribution linked by only elemwise operations |
| 245 | + root = elemwise_root(inp) |
| 246 | + if root is None: |
| 247 | + return False |
| 248 | + elif root.owner.op.ndim_supp != 0: |
| 249 | + # This is still fine if the variable is directly valued |
| 250 | + return any(get_related_valued_nodes(fgraph, node)) |
| 251 | + |
| 252 | + def elemwise_leaf(var: TensorVariable, clients=fgraph.clients) -> bool: |
| 253 | + var_clients = clients[var] |
| 254 | + if len(var_clients) != 1: |
| 255 | + return False |
| 256 | + [(client, _)] = var_clients |
| 257 | + if isinstance(client.op, ValuedRV): |
| 258 | + return True |
| 259 | + elif isinstance(client.op, Elemwise) and len(client.outputs) == 1: |
| 260 | + return elemwise_leaf(client.outputs[0]) |
| 261 | + else: |
| 262 | + return False |
| 263 | + |
| 264 | + # Check that the path to the valued node consists only of elemwise operations |
| 265 | + return elemwise_leaf(out) |
| 266 | + |
| 267 | + |
218 | 268 | @node_rewriter([DimShuffle])
|
219 | 269 | def find_measurable_dimshuffles(fgraph, node) -> list[TensorVariable] | None:
|
220 | 270 | r"""Find `Dimshuffle`\s for which a `logprob` can be computed."""
|
221 |
| - from pymc.distributions.distribution import SymbolicRandomVariable |
222 |
| - |
223 | 271 | if isinstance(node.op, MeasurableOp):
|
224 | 272 | return None
|
225 | 273 |
|
226 | 274 | if not filter_measurable_variables(node.inputs):
|
227 | 275 | return None
|
228 | 276 |
|
229 |
| - base_var = node.inputs[0] |
| 277 | + # In cases where DimShuffle transposes dimensions, we only apply this rewrite when only Elemwise |
| 278 | + # operations separate it from the valued node. Further transformations likely need to know where |
| 279 | + # the support axes are for a correct implementation (and thus assume they are the rightmost axes). |
| 280 | + # TODO: When we include the support axis as meta information in each intermediate MeasurableVariable, |
| 281 | + # we can lift this restriction (see https://github.com/pymc-devs/pymc/issues/6360) |
| 282 | + if tuple(node.op.shuffle) != tuple(sorted(node.op.shuffle)) and not _elemwise_univariate_chain( |
| 283 | + fgraph, node |
| 284 | + ): |
| 285 | + return None |
230 | 286 |
|
231 |
| - # We can only apply this rewrite directly to `RandomVariable`s, as those are |
232 |
| - # the only `Op`s for which we always know the support axis. Other measurable |
233 |
| - # variables can have arbitrary support axes (e.g., if they contain separate |
234 |
| - # `MeasurableDimShuffle`s). Most measurable variables with `DimShuffle`s |
235 |
| - # should still be supported as long as the `DimShuffle`s can be merged/ |
236 |
| - # lifted towards the base RandomVariable. |
237 |
| - # TODO: If we include the support axis as meta information in each |
238 |
| - # intermediate MeasurableVariable, we can lift this restriction. |
239 |
| - if not isinstance(base_var.owner.op, RandomVariable | SymbolicRandomVariable): |
240 |
| - return None # pragma: no cover |
| 287 | + base_var = node.inputs[0] |
241 | 288 |
|
242 | 289 | measurable_dimshuffle = MeasurableDimShuffle(node.op.input_broadcastable, node.op.new_order)(
|
243 | 290 | base_var
|
|
0 commit comments