Skip to content

Commit 4185b7b

Browse files
authored
Add warning if a Minibatched variable is used without total_size (#7742)
Add warning if a Minibatched variable is used but total_size is forgotten This should help catch a well-known Minibatch footgun
1 parent 6ef135b commit 4185b7b

File tree

2 files changed

+17
-1
lines changed

2 files changed

+17
-1
lines changed

pymc/model/core.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
from pytensor.tensor.variable import TensorConstant, TensorVariable
4141

4242
from pymc.blocking import DictToArrayBijection, RaveledVars
43-
from pymc.data import is_valid_observed
43+
from pymc.data import MinibatchOp, is_valid_observed
4444
from pymc.exceptions import (
4545
BlockModelAccessError,
4646
ImputationWarning,
@@ -1241,6 +1241,15 @@ def register_rv(
12411241
self.add_named_variable(rv_var, dims)
12421242
self.set_initval(rv_var, initval)
12431243
else:
1244+
if (
1245+
isinstance(observed, TensorVariable)
1246+
and observed.owner is not None
1247+
and isinstance(observed.owner.op, MinibatchOp)
1248+
and total_size is None
1249+
):
1250+
warnings.warn(
1251+
f"total_size not provided for observed variable `{name}` that uses pm.Minibatch"
1252+
)
12441253
if not is_valid_observed(observed):
12451254
raise TypeError(
12461255
"Variables that depend on other nodes cannot be used for observed data."

tests/variational/test_minibatch_rv.py

+7
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,13 @@ def test_random(self):
112112
assert mx is not x
113113
np.testing.assert_array_equal(draw(mx, random_seed=1), draw(x, random_seed=1))
114114

115+
def test_warning_on_missing_total_size(self):
116+
total_size = 1000
117+
with pytest.warns(match="total_size not provided"):
118+
with pm.Model() as m:
119+
MB = pm.Minibatch(np.arange(total_size, dtype="float64"), batch_size=100)
120+
pm.Normal("n", observed=MB)
121+
115122
@pytest.mark.filterwarnings("error")
116123
def test_minibatch_parameter_and_value(self):
117124
rng = np.random.default_rng(161)

0 commit comments

Comments
 (0)