Skip to content

Commit

Permalink
fix langevin corrector alpha to general num timesteps (#81)
Browse files Browse the repository at this point in the history
Co-authored-by: Daniel Zuegner <[email protected]>
  • Loading branch information
danielzuegner and Daniel Zuegner authored Feb 24, 2025
1 parent 37f557b commit 87ccbce
Show file tree
Hide file tree
Showing 7 changed files with 23 additions and 21 deletions.
3 changes: 2 additions & 1 deletion mattergen/common/diffusion/predictors_correctors.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,10 @@ def step_given_score(
batch_idx: torch.LongTensor | None,
score: torch.Tensor,
t: torch.Tensor,
dt: torch.Tensor,
) -> SampleAndMean:
assert isinstance(self.corruption, sde_lib.LatticeVPSDE)
alpha = self.get_alpha(t)
alpha = self.get_alpha(t, dt=dt)
snr = self.snr
noise = torch.randn_like(x)
noise = sde_lib.make_noise_symmetric_preserve_variance(noise)
Expand Down
6 changes: 3 additions & 3 deletions mattergen/common/gemnet/gemnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,8 +381,8 @@ def get_triplets(

value = torch.arange(idx_s.size(0), device=idx_s.device, dtype=idx_s.dtype)
# Possibly contains multiple copies of the same edge (for periodic interactions)
pyg_device = get_pyg_device()
torch_device = get_device()
pyg_device = get_pyg_device() if idx_s.device != torch.device("cpu") else idx_s.device
torch_device = get_device() if idx_s.device != torch.device("cpu") else idx_s.device
adj = SparseTensor(
row=idx_t.to(pyg_device),
col=idx_s.to(pyg_device),
Expand Down Expand Up @@ -775,4 +775,4 @@ def forward(

@property
def num_params(self):
return sum(p.numel() for p in self.parameters())
return sum(p.numel() for p in self.parameters())
7 changes: 4 additions & 3 deletions mattergen/conf/data_module/mp_20.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,9 @@ num_workers:
test: 0

batch_size:
train: 128
val: 128
test: 128
# total batch size of 512, adjust for number of devices, nodes, and gradient accumulation
train: ${eval:'(512 // ${trainer.accumulate_grad_batches}) // (${trainer.devices} * ${trainer.num_nodes})'}
val: ${eval:'(64 // (${trainer.devices} * ${trainer.num_nodes})'}
test: ${eval:'(64 // (${trainer.devices} * ${trainer.num_nodes})'}

max_epochs: 900
2 changes: 1 addition & 1 deletion mattergen/diffusion/sampling/pc_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def _denoise(
}
samples_means: dict[str, Tuple[torch.Tensor, torch.Tensor]] = apply(
fns=fns,
broadcast={"t": t},
broadcast={"t": t, "dt": dt},
x=batch,
score=score,
batch_idx=self._multi_corruption._get_batch_indices(batch),
Expand Down
21 changes: 10 additions & 11 deletions mattergen/diffusion/sampling/predictors_correctors.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,30 +72,29 @@ def is_compatible(cls, corruption: Corruption):
and not isinstance(corruption, WrappedSDEMixin)
)

def update_fn(self, *, x, t, batch_idx) -> SampleAndMean:
def update_fn(self, *, x, t, batch_idx, dt: torch.Tensor) -> SampleAndMean:
assert self.score_fn is not None, "Did you mean to use step_given_score?"
for _ in range(self.n_steps):
score = self.score_fn(x, t, batch_idx)
x, x_mean = self.step_given_score(
x=x,
batch_idx=batch_idx,
score=score,
t=t,
)
x, x_mean = self.step_given_score(x=x, batch_idx=batch_idx, score=score, t=t, dt=dt)

return x, x_mean

def get_alpha(self, t: torch.FloatTensor) -> torch.Tensor:
def get_alpha(self, t: torch.FloatTensor, dt: torch.FloatTensor) -> torch.Tensor:
sde = self.corruption

if isinstance(sde, VPSDE):
alpha = 1 - sde.beta(t) * sde.T / 1000
alpha_bar = sde._marginal_mean_coeff(t) ** 2
alpha_bar_before = sde._marginal_mean_coeff(t + dt) ** 2
alpha = alpha_bar / alpha_bar_before
else:
alpha = torch.ones_like(t)
return alpha

def step_given_score(self, *, x, batch_idx: torch.LongTensor | None, score, t) -> SampleAndMean:
alpha = self.get_alpha(t)
def step_given_score(
self, *, x, batch_idx: torch.LongTensor | None, score, t: torch.Tensor, dt: torch.Tensor
) -> SampleAndMean:
alpha = self.get_alpha(t, dt=dt)
snr = self.snr
noise = torch.randn_like(score)
grad_norm_square = torch.square(score).reshape(score.shape[0], -1).sum(dim=1)
Expand Down
3 changes: 2 additions & 1 deletion mattergen/diffusion/wrapped/wrapped_predictors_correctors.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def step_given_score(
batch_idx: torch.LongTensor,
score: torch.Tensor,
t: torch.Tensor,
dt: torch.Tensor,
) -> SampleAndMean:
# mypy
assert isinstance(self, pc.LangevinCorrector)
Expand All @@ -66,7 +67,7 @@ def step_given_score(
raise IncompatibleSampler(
f"{self.__class__.__name__} is not compatible with {self.corruption}."
)
sample, mean = _super.step_given_score(x=x, score=score, t=t, batch_idx=batch_idx)
sample, mean = _super.step_given_score(x=x, score=score, t=t, batch_idx=batch_idx, dt=dt)
return self.corruption.wrap(sample), self.corruption.wrap(mean)


Expand Down
2 changes: 1 addition & 1 deletion mattergen/tests/test_diffusion_instantiation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import pytest

from mattergen.common.utils.globals import MODELS_PROJECT_ROOT
from scripts.run import mattergen_main
from mattergen.scripts.run import mattergen_main

CONFIG_DIR = os.path.join(MODELS_PROJECT_ROOT, "conf")

Expand Down

0 comments on commit 87ccbce

Please sign in to comment.