From 33e934917111bfe0bc545bac3ff28fbd9c0313e8 Mon Sep 17 00:00:00 2001 From: Feynman Tsing-Yang Liang Date: Tue, 16 Aug 2022 12:02:53 -0700 Subject: [PATCH] Support parameters in World Summary: To support the use of `bm.param`s in the generative model, beanmachine's VI now: * Traces through `World` in addition to `VariationalWorld` so that `param`s used in `World` are added to the optimizer * Utilizes a `VariationalWorld` for both the generative and guide distributions in `gradient_estimators` so that `params` can be used in the generative model This is needed to support PPCA (N2358633) Differential Revision: D38754091 fbshipit-source-id: c110e22f9a4b7c9c86e37b9884d8e931466a0abf --- .../ppl/experimental/tests/vi/vi_test.py | 31 +++++++++++++++++++ .../ppl/experimental/vi/gradient_estimator.py | 3 +- .../ppl/experimental/vi/variational_infer.py | 3 +- 3 files changed, 35 insertions(+), 2 deletions(-) diff --git a/src/beanmachine/ppl/experimental/tests/vi/vi_test.py b/src/beanmachine/ppl/experimental/tests/vi/vi_test.py index 01c10a6bb2..55a82bd12f 100644 --- a/src/beanmachine/ppl/experimental/tests/vi/vi_test.py +++ b/src/beanmachine/ppl/experimental/tests/vi/vi_test.py @@ -287,6 +287,37 @@ def test_dirichlet(self, auto_guide_inference): vi_estimate = world.get_guide_distribution(alpha()).sample((100,)).mean(dim=0) assert vi_estimate.isclose(map_truth, atol=0.1).all().item() + @pytest.mark.parametrize("auto_guide_inference", [ADVI, MAP]) + def test_ppca(self, auto_guide_inference): + d, D = 2, 4 + n = 150 + + @bm.random_variable + def A(): + return dist.Normal(torch.zeros((D, d)), 2.0 * torch.ones((D, d))) + + @bm.random_variable + def mu(): + return dist.Normal(torch.zeros(D), 2.0 * torch.ones(D)) + + @bm.random_variable + def z(): + return dist.Normal(torch.zeros(n, d), 1.0) + + @bm.random_variable + def x(): + return dist.Normal(z() @ A().T + mu(), 1.0) + + vi = auto_guide_inference( + queries=[A(), mu()], + observations={x(): torch.random.randn(n, d) * torch.random.randn(d, D)}, + ) + losses = [] + for _ in range(30): + loss, _ = vi.step() + losses.append(loss) + assert losses[-1].item() < losses[0].item() + class TestStochasticVariationalInfer: @pytest.fixture(autouse=True) diff --git a/src/beanmachine/ppl/experimental/vi/gradient_estimator.py b/src/beanmachine/ppl/experimental/vi/gradient_estimator.py index d1ac144a46..7a1bf55a2b 100644 --- a/src/beanmachine/ppl/experimental/vi/gradient_estimator.py +++ b/src/beanmachine/ppl/experimental/vi/gradient_estimator.py @@ -41,7 +41,7 @@ def monte_carlo_approximate_reparam( params=params, queries_to_guides=queries_to_guides, ) - world = World.initialize_world( + world = VariationalWorld.initialize_world( queries=[], observations={ **{ @@ -50,6 +50,7 @@ def monte_carlo_approximate_reparam( }, **observations, }, + params=params, ) # form log density ratio logu = logp - logq diff --git a/src/beanmachine/ppl/experimental/vi/variational_infer.py b/src/beanmachine/ppl/experimental/vi/variational_infer.py index c993afe3a4..75c17579a7 100644 --- a/src/beanmachine/ppl/experimental/vi/variational_infer.py +++ b/src/beanmachine/ppl/experimental/vi/variational_infer.py @@ -56,7 +56,8 @@ def __init__( params={}, queries_to_guides=queries_to_guides, ) - for guide in queries_to_guides.values(): + for query, guide in queries_to_guides.items(): + world.call(query) world.call(guide) self.params = world._params self._optimizer = optimizer(self.params.values())