Skip to content

Commit d34ed95

Browse files
authored
Transform to remove Minibatch from model (#7746)
1 parent 2705a5e commit d34ed95

File tree

2 files changed

+45
-2
lines changed

2 files changed

+45
-2
lines changed

pymc/model/transform/basic.py

+25-1
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,11 @@
1313
# limitations under the License.
1414
from collections.abc import Sequence
1515

16-
from pytensor import Variable
16+
from pytensor import Variable, clone_replace
1717
from pytensor.graph import ancestors
18+
from pytensor.graph.fg import FunctionGraph
1819

20+
from pymc.data import MinibatchOp
1921
from pymc.model.core import Model
2022
from pymc.model.fgraph import (
2123
ModelObservedRV,
@@ -58,3 +60,25 @@ def parse_vars(model: Model, vars: ModelVariable | Sequence[ModelVariable]) -> l
5860
else:
5961
vars_seq = (vars,)
6062
return [model[var] if isinstance(var, str) else var for var in vars_seq]
63+
64+
65+
def remove_minibatched_nodes(model: Model) -> Model:
66+
"""Remove all uses of pm.Minibatch in the Model."""
67+
fgraph, _ = fgraph_from_model(model)
68+
69+
replacements = {}
70+
for var in fgraph.apply_nodes:
71+
if isinstance(var.op, MinibatchOp):
72+
for inp, out in zip(var.inputs, var.outputs):
73+
replacements[out] = inp
74+
75+
old_outs, old_coords, old_dim_lengths = fgraph.outputs, fgraph._coords, fgraph._dim_lengths # type: ignore[attr-defined]
76+
# Using `rebuild_strict=False` means all coords, names, and dim information is lost
77+
# So we need to restore it from the old fgraph
78+
new_outs = clone_replace(old_outs, replacements, rebuild_strict=False) # type: ignore[arg-type]
79+
for old_out, new_out in zip(old_outs, new_outs):
80+
new_out.name = old_out.name
81+
fgraph = FunctionGraph(outputs=new_outs, clone=False)
82+
fgraph._coords = old_coords # type: ignore[attr-defined]
83+
fgraph._dim_lengths = old_dim_lengths # type: ignore[attr-defined]
84+
return model_from_fgraph(fgraph, mutate_fgraph=True)

tests/model/transform/test_basic.py

+20-1
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,11 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import numpy as np
15+
1416
import pymc as pm
1517

16-
from pymc.model.transform.basic import prune_vars_detached_from_observed
18+
from pymc.model.transform.basic import prune_vars_detached_from_observed, remove_minibatched_nodes
1719

1820

1921
def test_prune_vars_detached_from_observed():
@@ -30,3 +32,20 @@ def test_prune_vars_detached_from_observed():
3032
assert set(m.named_vars.keys()) == {"obs_data", "a0", "a1", "a2", "obs", "d0", "d1"}
3133
pruned_m = prune_vars_detached_from_observed(m)
3234
assert set(pruned_m.named_vars.keys()) == {"obs_data", "a0", "a1", "a2", "obs"}
35+
36+
37+
def test_remove_minibatches():
38+
data_size = 100
39+
data = np.zeros((data_size,))
40+
batch_size = 10
41+
with pm.Model(coords={"d": range(5)}) as m1:
42+
mb = pm.Minibatch(data, batch_size=batch_size)
43+
mu = pm.Normal("mu", dims="d")
44+
x = pm.Normal("x")
45+
y = pm.Normal("y", x, observed=mb, total_size=100)
46+
47+
m2 = remove_minibatched_nodes(m1)
48+
assert m1.y.shape[0].eval() == batch_size
49+
assert m2.y.shape[0].eval() == data_size
50+
assert m1.coords == m2.coords
51+
assert m1.dim_lengths["d"].eval() == m2.dim_lengths["d"].eval()

0 commit comments

Comments
 (0)