|
13 | 13 | # limitations under the License.
|
14 | 14 | from collections.abc import Sequence
|
15 | 15 |
|
16 |
| -from pytensor import Variable |
| 16 | +from pytensor import Variable, clone_replace |
17 | 17 | from pytensor.graph import ancestors
|
| 18 | +from pytensor.graph.fg import FunctionGraph |
18 | 19 |
|
| 20 | +from pymc.data import MinibatchOp |
19 | 21 | from pymc.model.core import Model
|
20 | 22 | from pymc.model.fgraph import (
|
21 | 23 | ModelObservedRV,
|
@@ -58,3 +60,25 @@ def parse_vars(model: Model, vars: ModelVariable | Sequence[ModelVariable]) -> l
|
58 | 60 | else:
|
59 | 61 | vars_seq = (vars,)
|
60 | 62 | return [model[var] if isinstance(var, str) else var for var in vars_seq]
|
| 63 | + |
| 64 | + |
| 65 | +def remove_minibatched_nodes(model: Model) -> Model: |
| 66 | + """Remove all uses of pm.Minibatch in the Model.""" |
| 67 | + fgraph, _ = fgraph_from_model(model) |
| 68 | + |
| 69 | + replacements = {} |
| 70 | + for var in fgraph.apply_nodes: |
| 71 | + if isinstance(var.op, MinibatchOp): |
| 72 | + for inp, out in zip(var.inputs, var.outputs): |
| 73 | + replacements[out] = inp |
| 74 | + |
| 75 | + old_outs, old_coords, old_dim_lengths = fgraph.outputs, fgraph._coords, fgraph._dim_lengths # type: ignore[attr-defined] |
| 76 | + # Using `rebuild_strict=False` means all coords, names, and dim information is lost |
| 77 | + # So we need to restore it from the old fgraph |
| 78 | + new_outs = clone_replace(old_outs, replacements, rebuild_strict=False) # type: ignore[arg-type] |
| 79 | + for old_out, new_out in zip(old_outs, new_outs): |
| 80 | + new_out.name = old_out.name |
| 81 | + fgraph = FunctionGraph(outputs=new_outs, clone=False) |
| 82 | + fgraph._coords = old_coords # type: ignore[attr-defined] |
| 83 | + fgraph._dim_lengths = old_dim_lengths # type: ignore[attr-defined] |
| 84 | + return model_from_fgraph(fgraph, mutate_fgraph=True) |
0 commit comments