Skip to content
This repository was archived by the owner on Dec 18, 2023. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions src/beanmachine/ppl/experimental/tests/vi/vi_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,37 @@ def test_dirichlet(self, auto_guide_inference):
vi_estimate = world.get_guide_distribution(alpha()).sample((100,)).mean(dim=0)
assert vi_estimate.isclose(map_truth, atol=0.1).all().item()

@pytest.mark.parametrize("auto_guide_inference", [ADVI, MAP])
def test_ppca(self, auto_guide_inference):
d, D = 2, 4
n = 150

@bm.random_variable
def A():
return dist.Normal(torch.zeros((D, d)), 2.0 * torch.ones((D, d)))

@bm.random_variable
def mu():
return dist.Normal(torch.zeros(D), 2.0 * torch.ones(D))

@bm.random_variable
def z():
return dist.Normal(torch.zeros(n, d), 1.0)

@bm.random_variable
def x():
return dist.Normal(z() @ A().T + mu(), 1.0)

vi = auto_guide_inference(
queries=[A(), mu()],
observations={x(): torch.random.randn(n, d) * torch.random.randn(d, D)},
)
losses = []
for _ in range(30):
loss, _ = vi.step()
losses.append(loss)
assert losses[-1].item() < losses[0].item()


class TestStochasticVariationalInfer:
@pytest.fixture(autouse=True)
Expand Down
3 changes: 2 additions & 1 deletion src/beanmachine/ppl/experimental/vi/gradient_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def monte_carlo_approximate_reparam(
params=params,
queries_to_guides=queries_to_guides,
)
world = World.initialize_world(
world = VariationalWorld.initialize_world(
queries=[],
observations={
**{
Expand All @@ -50,6 +50,7 @@ def monte_carlo_approximate_reparam(
},
**observations,
},
params=params,
)

# form log density ratio logu = logp - logq
Expand Down
3 changes: 2 additions & 1 deletion src/beanmachine/ppl/experimental/vi/variational_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ def __init__(
params={},
queries_to_guides=queries_to_guides,
)
for guide in queries_to_guides.values():
for query, guide in queries_to_guides.items():
world.call(query)
world.call(guide)
self.params = world._params
self._optimizer = optimizer(self.params.values())
Expand Down