Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 20 additions & 53 deletions src/fairchem/core/models/uma/escn_md.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,48 +284,6 @@ def _get_rotmat_and_wigner(
)
return wigner_and_M_mapping, wigner_and_M_mapping_inv

def _get_displacement_and_cell(
self, data_dict: AtomicData
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
###############################################################
# gradient-based forces/stress
###############################################################
displacement = None
orig_cell = None
if self.regress_stress and not self.direct_forces:
displacement = torch.zeros(
(3, 3),
dtype=data_dict["pos"].dtype,
device=data_dict["pos"].device,
)
num_batch = len(data_dict["natoms"])
displacement = displacement.view(-1, 3, 3).expand(num_batch, 3, 3)
displacement.requires_grad = True
symmetric_displacement = 0.5 * (
displacement + displacement.transpose(-1, -2)
)
if data_dict["pos"].requires_grad is False:
data_dict["pos"].requires_grad = True
data_dict["pos_original"] = data_dict["pos"]
data_dict["pos"] = data_dict["pos"] + torch.bmm(
data_dict["pos"].unsqueeze(-2),
torch.index_select(symmetric_displacement, 0, data_dict["batch"]),
).squeeze(-2)

orig_cell = data_dict["cell"]
data_dict["cell"] = data_dict["cell"] + torch.bmm(
data_dict["cell"], symmetric_displacement
)

if (
not self.regress_stress
and self.regress_forces
and not self.direct_forces
and data_dict["pos"].requires_grad is False
):
data_dict["pos"].requires_grad = True
return displacement, orig_cell

def csd_embedding(self, charge, spin, dataset):
with record_function("charge spin dataset embeddings"):
# Add charge, spin, and dataset embeddings
Expand Down Expand Up @@ -366,9 +324,16 @@ def _generate_graph(self, data_dict):
assert (
"edge_index" in data_dict
), "otf_graph is false, need to provide edge_index as input!"
cell_per_edge = data_dict["cell"].repeat_interleave(
data_dict["nedges"], dim=0
)

if data_dict["cell"].shape[0] == 1:
cell_per_edge = data_dict["cell"].expand(
data_dict["edge_index"].shape[1], -1, -1
)
else:
cell_per_edge = data_dict["cell"].repeat_interleave(
data_dict["nedges"], dim=0
)

shifts = torch.einsum(
"ij,ijk->ik",
data_dict["cell_offsets"].to(cell_per_edge.dtype),
Expand Down Expand Up @@ -405,6 +370,10 @@ def _generate_graph(self, data_dict):

@conditional_grad(torch.enable_grad())
def forward(self, data_dict: AtomicData) -> dict[str, torch.Tensor]:
if not data_dict["pos"].requires_grad:
data_dict["pos"].requires_grad = True
if not data_dict["cell"].requires_grad:
data_dict["cell"].requires_grad = True
data_dict["atomic_numbers"] = data_dict["atomic_numbers"].long()
data_dict["atomic_numbers_full"] = data_dict["atomic_numbers"]
data_dict["batch_full"] = data_dict["batch"]
Expand All @@ -421,9 +390,6 @@ def forward(self, data_dict: AtomicData) -> dict[str, torch.Tensor]:
csd_mixed_emb=csd_mixed_emb,
)

with record_function("get_displacement_and_cell"):
displacement, orig_cell = self._get_displacement_and_cell(data_dict)

with record_function("generate_graph"):
graph_dict = self._generate_graph(data_dict)

Expand Down Expand Up @@ -513,8 +479,6 @@ def forward(self, data_dict: AtomicData) -> dict[str, torch.Tensor]:
x_message = self.norm(x_message)
out = {
"node_embedding": x_message,
"displacement": displacement,
"orig_cell": orig_cell,
"batch": data_dict["batch"],
}
return out
Expand Down Expand Up @@ -658,7 +622,7 @@ def forward(
if self.regress_stress:
grads = torch.autograd.grad(
[energy_part.sum()],
[data["pos_original"], emb["displacement"]],
[data["pos"], data["cell"]],
create_graph=self.training,
)
if gp_utils.initialized():
Expand All @@ -667,8 +631,12 @@ def forward(
gp_utils.reduce_from_model_parallel_region(grads[1]),
)

forces_prod_pos = grads[0].T @ data["pos"]
virial = (
grads[1] @ data["cell"] + (forces_prod_pos + forces_prod_pos.T) / 2
).view(-1, 3, 3)

forces = torch.neg(grads[0])
virial = grads[1].view(-1, 3, 3)
volume = torch.det(data["cell"]).abs().unsqueeze(-1)
stress = virial / volume.view(-1, 1, 1)
virial = torch.neg(virial)
Expand All @@ -677,7 +645,6 @@ def forward(
) # NOTE to work better with current Multi-task trainer
outputs[forces_key] = {"forces": forces} if self.wrap_property else forces
outputs[stress_key] = {"stress": stress} if self.wrap_property else stress
data["cell"] = emb["orig_cell"]
elif self.regress_forces:
forces = (
-1
Expand Down
2 changes: 1 addition & 1 deletion tests/core/units/mlip_unit/test_mlip_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ def test_conserve_train_from_cli_aselmdb(mode, fake_uma_dataset, torch_determini
"tests/core/units/mlip_unit/test_mlip_train_conserving.yaml",
"datasets=aselmdb_conserving",
f"datasets.data_root_dir={fake_uma_dataset}",
"+expected_loss=86.24614715576172",
"+expected_loss=86.25643157958984",
]
if mode == "gp":
sys_args += [
Expand Down