diff --git a/src/fairchem/core/models/uma/escn_md.py b/src/fairchem/core/models/uma/escn_md.py index 0555d6daee..858e6ce1ea 100644 --- a/src/fairchem/core/models/uma/escn_md.py +++ b/src/fairchem/core/models/uma/escn_md.py @@ -284,48 +284,6 @@ def _get_rotmat_and_wigner( ) return wigner_and_M_mapping, wigner_and_M_mapping_inv - def _get_displacement_and_cell( - self, data_dict: AtomicData - ) -> tuple[torch.Tensor | None, torch.Tensor | None]: - ############################################################### - # gradient-based forces/stress - ############################################################### - displacement = None - orig_cell = None - if self.regress_stress and not self.direct_forces: - displacement = torch.zeros( - (3, 3), - dtype=data_dict["pos"].dtype, - device=data_dict["pos"].device, - ) - num_batch = len(data_dict["natoms"]) - displacement = displacement.view(-1, 3, 3).expand(num_batch, 3, 3) - displacement.requires_grad = True - symmetric_displacement = 0.5 * ( - displacement + displacement.transpose(-1, -2) - ) - if data_dict["pos"].requires_grad is False: - data_dict["pos"].requires_grad = True - data_dict["pos_original"] = data_dict["pos"] - data_dict["pos"] = data_dict["pos"] + torch.bmm( - data_dict["pos"].unsqueeze(-2), - torch.index_select(symmetric_displacement, 0, data_dict["batch"]), - ).squeeze(-2) - - orig_cell = data_dict["cell"] - data_dict["cell"] = data_dict["cell"] + torch.bmm( - data_dict["cell"], symmetric_displacement - ) - - if ( - not self.regress_stress - and self.regress_forces - and not self.direct_forces - and data_dict["pos"].requires_grad is False - ): - data_dict["pos"].requires_grad = True - return displacement, orig_cell - def csd_embedding(self, charge, spin, dataset): with record_function("charge spin dataset embeddings"): # Add charge, spin, and dataset embeddings @@ -366,9 +324,16 @@ def _generate_graph(self, data_dict): assert ( "edge_index" in data_dict ), "otf_graph is false, need to provide edge_index as input!" - cell_per_edge = data_dict["cell"].repeat_interleave( - data_dict["nedges"], dim=0 - ) + + if data_dict["cell"].shape[0] == 1: + cell_per_edge = data_dict["cell"].expand( + data_dict["edge_index"].shape[1], -1, -1 + ) + else: + cell_per_edge = data_dict["cell"].repeat_interleave( + data_dict["nedges"], dim=0 + ) + shifts = torch.einsum( "ij,ijk->ik", data_dict["cell_offsets"].to(cell_per_edge.dtype), @@ -405,6 +370,10 @@ def _generate_graph(self, data_dict): @conditional_grad(torch.enable_grad()) def forward(self, data_dict: AtomicData) -> dict[str, torch.Tensor]: + if not data_dict["pos"].requires_grad: + data_dict["pos"].requires_grad = True + if not data_dict["cell"].requires_grad: + data_dict["cell"].requires_grad = True data_dict["atomic_numbers"] = data_dict["atomic_numbers"].long() data_dict["atomic_numbers_full"] = data_dict["atomic_numbers"] data_dict["batch_full"] = data_dict["batch"] @@ -421,9 +390,6 @@ def forward(self, data_dict: AtomicData) -> dict[str, torch.Tensor]: csd_mixed_emb=csd_mixed_emb, ) - with record_function("get_displacement_and_cell"): - displacement, orig_cell = self._get_displacement_and_cell(data_dict) - with record_function("generate_graph"): graph_dict = self._generate_graph(data_dict) @@ -513,8 +479,6 @@ def forward(self, data_dict: AtomicData) -> dict[str, torch.Tensor]: x_message = self.norm(x_message) out = { "node_embedding": x_message, - "displacement": displacement, - "orig_cell": orig_cell, "batch": data_dict["batch"], } return out @@ -658,7 +622,7 @@ def forward( if self.regress_stress: grads = torch.autograd.grad( [energy_part.sum()], - [data["pos_original"], emb["displacement"]], + [data["pos"], data["cell"]], create_graph=self.training, ) if gp_utils.initialized(): @@ -667,8 +631,12 @@ def forward( gp_utils.reduce_from_model_parallel_region(grads[1]), ) + forces_prod_pos = grads[0].T @ data["pos"] + virial = ( + grads[1] @ data["cell"] + (forces_prod_pos + forces_prod_pos.T) / 2 + ).view(-1, 3, 3) + forces = torch.neg(grads[0]) - virial = grads[1].view(-1, 3, 3) volume = torch.det(data["cell"]).abs().unsqueeze(-1) stress = virial / volume.view(-1, 1, 1) virial = torch.neg(virial) @@ -677,7 +645,6 @@ def forward( ) # NOTE to work better with current Multi-task trainer outputs[forces_key] = {"forces": forces} if self.wrap_property else forces outputs[stress_key] = {"stress": stress} if self.wrap_property else stress - data["cell"] = emb["orig_cell"] elif self.regress_forces: forces = ( -1 diff --git a/tests/core/units/mlip_unit/test_mlip_unit.py b/tests/core/units/mlip_unit/test_mlip_unit.py index 8c3861c3e2..b52f0d334a 100644 --- a/tests/core/units/mlip_unit/test_mlip_unit.py +++ b/tests/core/units/mlip_unit/test_mlip_unit.py @@ -454,7 +454,7 @@ def test_conserve_train_from_cli_aselmdb(mode, fake_uma_dataset, torch_determini "tests/core/units/mlip_unit/test_mlip_train_conserving.yaml", "datasets=aselmdb_conserving", f"datasets.data_root_dir={fake_uma_dataset}", - "+expected_loss=86.24614715576172", + "+expected_loss=86.25643157958984", ] if mode == "gp": sys_args += [