Skip to content

Commit 1172ef7

Browse files
committed
compute likelihood in sampling
1 parent f28558a commit 1172ef7

File tree

2 files changed

+93
-6
lines changed

2 files changed

+93
-6
lines changed

flow_matching/solver/ode_solver.py

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ def sample(
3737
time_grid: Tensor = torch.tensor([0.0, 1.0]),
3838
return_intermediates: bool = False,
3939
enable_grad: bool = False,
40+
log_p0: Optional[Callable[[Tensor], Tensor]] = None,
41+
exact_divergence: bool = False,
4042
**model_extras,
4143
) -> Union[Tensor, Sequence[Tensor]]:
4244
r"""Solve the ODE with the velocity field.
@@ -73,6 +75,8 @@ def forward(self, x: torch.Tensor, t: torch.Tensor, **extras) -> torch.Tensor:
7375
time_grid (Tensor): The process is solved in the interval [min(time_grid, max(time_grid)] and if step_size is None then time discretization is set by the time grid. May specify a descending time_grid to solve in the reverse direction. Defaults to torch.tensor([0.0, 1.0]).
7476
return_intermediates (bool, optional): If True then return intermediate time steps according to time_grid. Defaults to False.
7577
enable_grad (bool, optional): Whether to compute gradients during sampling. Defaults to False.
78+
log_p0 (Optional[Callable[[Tensor], Tensor]]): If provided, the function computes the log likelihood of the source distribution at :math:`t=0`. The velocity model must be differentiable with respect to x.
79+
exact_divergence (bool): Whether to compute the exact divergence or use the Hutchinson estimator.
7680
**model_extras: Additional input for the model.
7781
7882
Returns:
@@ -81,27 +85,61 @@ def forward(self, x: torch.Tensor, t: torch.Tensor, **extras) -> torch.Tensor:
8185

8286
time_grid = time_grid.to(x_init.device)
8387

88+
# Fix the random projection for the Hutchinson divergence estimator
89+
if not exact_divergence:
90+
z = (torch.randn_like(x_init).to(x_init.device) < 0) * 2.0 - 1.0
91+
8492
def ode_func(t, x):
8593
return self.velocity_model(x=x, t=t, **model_extras)
8694

95+
def dynamics_func(t, states):
96+
xt = states[0]
97+
with torch.set_grad_enabled(True):
98+
xt.requires_grad_()
99+
ut = ode_func(t, xt)
100+
101+
# Compute exact divergence
102+
if exact_divergence:
103+
div = 0
104+
for i in range(ut.flatten(1).shape[1]):
105+
div += gradient(ut[:, i], xt, create_graph=True)[:, i].detach()
106+
else:
107+
# Compute Hutchinson divergence estimator E[z^T D_x(ut) z]
108+
ut_dot_z = torch.einsum(
109+
"ij,ij->i", ut.flatten(start_dim=1), z.flatten(start_dim=1)
110+
)
111+
grad_ut_dot_z = gradient(ut_dot_z, xt)
112+
div = torch.einsum(
113+
"ij,ij->i",
114+
grad_ut_dot_z.flatten(start_dim=1),
115+
z.flatten(start_dim=1),
116+
)
117+
118+
return ut.detach(), div.detach()
119+
87120
ode_opts = {"step_size": step_size} if step_size is not None else {}
88121

89122
with torch.set_grad_enabled(enable_grad):
90123
# Approximate ODE solution with numerical ODE solver
91124
sol = odeint(
92-
ode_func,
93-
x_init,
125+
ode_func if log_p0 is None else dynamics_func,
126+
(
127+
x_init
128+
if log_p0 is None
129+
else (x_init, torch.zeros(x_init.shape[0], device=x_init.device))
130+
),
94131
time_grid,
95132
method=method,
96133
options=ode_opts,
97134
atol=atol,
98135
rtol=rtol,
99136
)
100137

101-
if return_intermediates:
102-
return sol
103-
else:
104-
return sol[-1]
138+
if log_p0 is not None:
139+
sol, log_det = sol
140+
tmp = log_p0(x_init) - log_det[-1]
141+
return (sol, tmp) if return_intermediates else (sol[-1], tmp)
142+
return sol if return_intermediates else sol[-1]
105143

106144
def compute_likelihood(
107145
self,

tests/solver/test_ode_solver.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,55 @@ def dummy_log_p(x: Tensor) -> Tensor:
185185
torch.allclose(x_1.grad, torch.tensor([1.0, 1.0]), atol=1e-2),
186186
)
187187

188+
def test_sample_with_likelihoods(self):
189+
x_1 = torch.tensor([[0.0, 0.0]], requires_grad=True)
190+
step_size = 0.1
191+
192+
# Define a dummy log probability function
193+
def dummy_log_p(x: Tensor) -> Tensor:
194+
return -0.5 * torch.sum(x**2, dim=1)
195+
196+
result, log_likelihood = self.dummy_solver.sample(
197+
x_init=x_1,
198+
step_size=step_size,
199+
log_p0=dummy_log_p,
200+
exact_divergence=True,
201+
)
202+
self.assertIsInstance(log_likelihood, Tensor)
203+
self.assertEqual(x_1.shape[0], log_likelihood.shape[0])
204+
205+
def test_sample_with_likelihoods_to_posthoc_likelihoods(self):
206+
x_0 = torch.tensor([[1.0, 0.0]], requires_grad=True)
207+
step_size = 0.001
208+
209+
# Define a dummy log probability function
210+
def dummy_log_p(x: Tensor) -> Tensor:
211+
return -0.5 * torch.sum(x**2, dim=1)
212+
213+
x1, log_likelihood = self.dummy_solver.sample(
214+
x_init=x_0,
215+
step_size=step_size,
216+
log_p0=dummy_log_p,
217+
exact_divergence=True,
218+
)
219+
print("x1: ", x1)
220+
self.assertIsInstance(log_likelihood, Tensor)
221+
self.assertEqual(x_0.shape[0], log_likelihood.shape[0])
222+
223+
# Check if the post-hoc likelihoods match the original log likelihoods
224+
_, posthoc_log_likelihood = self.dummy_solver.compute_likelihood(
225+
x_1=x1,
226+
log_p0=dummy_log_p,
227+
step_size=step_size,
228+
exact_divergence=True,
229+
)
230+
231+
print("Log likelihood:", log_likelihood)
232+
print("Post-hoc log likelihood:", posthoc_log_likelihood)
233+
self.assertTrue(
234+
torch.allclose(log_likelihood, posthoc_log_likelihood, atol=1e-2),
235+
)
236+
188237

189238
if __name__ == "__main__":
190239
unittest.main()

0 commit comments

Comments
 (0)