From 2c5b95f3df9f36fdae30d93025e852636b378146 Mon Sep 17 00:00:00 2001 From: SchrodingersCattt Date: Tue, 28 Oct 2025 09:40:06 +0000 Subject: [PATCH 1/4] feat: support zero-count elements in type_map for sort_atom_names --- dpdata/utils.py | 76 +++++++++++++++++++++++++++++++++++-------------- 1 file changed, 55 insertions(+), 21 deletions(-) diff --git a/dpdata/utils.py b/dpdata/utils.py index 58a908cc7..4c72bec94 100644 --- a/dpdata/utils.py +++ b/dpdata/utils.py @@ -62,9 +62,10 @@ def add_atom_names(data, atom_names): def sort_atom_names(data, type_map=None): - """Sort atom_names of the system and reorder atom_numbs and atom_types accoarding + """Sort atom_names of the system and reorder atom_numbs and atom_types according to atom_names. If type_map is not given, atom_names will be sorted by - alphabetical order. If type_map is given, atom_names will be type_map. + alphabetical order. If type_map is given, atom_names will be set to type_map, + and zero-count elements are kept. Parameters ---------- @@ -74,28 +75,61 @@ def sort_atom_names(data, type_map=None): type_map """ if type_map is not None: - # assign atom_names index to the specify order - # atom_names must be a subset of type_map - assert set(data["atom_names"]).issubset(set(type_map)) - # for the condition that type_map is a proper superset of atom_names - # new_atoms = set(type_map) - set(data["atom_names"]) - new_atoms = [e for e in type_map if e not in data["atom_names"]] - if new_atoms: - data = add_atom_names(data, new_atoms) - # index that will sort an array by type_map - # a[as[a]] == b[as[b]] as == argsort - # as[as[b]] == as^{-1}[b] - # a[as[a][as[as[b]]]] = b[as[b][as^{-1}[b]]] = b[id] - idx = np.argsort(data["atom_names"], kind="stable")[ - np.argsort(np.argsort(type_map, kind="stable"), kind="stable") - ] + # assign atom_names index to the specified order + # only active (numb > 0) atom names must be in type_map + orig_names = data["atom_names"] + orig_numbs = data["atom_numbs"] + active_names = {name for name, numb in zip(orig_names, orig_numbs) if numb > 0} + type_map_set = set(type_map) + if not active_names.issubset(type_map_set): + missing = active_names - type_map_set + raise ValueError(f"Active atom types {missing} not in provided type_map.") + + # for the condition that type_map is a proper superset of atom_names, + # we allow new elements with atom_numb = 0 + new_names = list(type_map) + new_numbs = [] + name_to_old_idx = {name: i for i, name in enumerate(orig_names)} + + for name in new_names: + if name in name_to_old_idx: + new_numbs.append(orig_numbs[name_to_old_idx[name]]) + else: + new_numbs.append(0) + + # build mapping from old atom type index to new one + # old_types[i] = j --> new_types[i] = type_map.index(atom_names[j]) + old_to_new_index = {} + for old_idx, name in enumerate(orig_names): + if name in type_map_set: + new_idx = type_map.index(name) + old_to_new_index[old_idx] = new_idx + + # remap atom_types using the index mapping + old_types = np.array(data["atom_types"]) + new_types = np.empty_like(old_types) + for old_idx, new_idx in old_to_new_index.items(): + new_types[old_types == old_idx] = new_idx + + # update data in-place + data["atom_names"] = new_names + data["atom_numbs"] = new_numbs + data["atom_types"] = new_types + else: # index that will sort an array by alphabetical order + # idx = argsort(atom_names) --> atom_names[idx] is sorted idx = np.argsort(data["atom_names"], kind="stable") - # sort atom_names, atom_numbs, atom_types by idx - data["atom_names"] = list(np.array(data["atom_names"])[idx]) - data["atom_numbs"] = list(np.array(data["atom_numbs"])[idx]) - data["atom_types"] = np.argsort(idx, kind="stable")[data["atom_types"]] + # sort atom_names and atom_numbs by idx + data["atom_names"] = list(np.array(data["atom_names"])[idx]) + data["atom_numbs"] = list(np.array(data["atom_numbs"])[idx]) + # to update atom_types: we need the inverse permutation of idx + # because if old_type = i, and atom_names[i] moves to position j, + # then the new type should be j. + # inv_idx = argsort(idx) satisfies: inv_idx[idx[i]] = i + inv_idx = np.argsort(idx, kind="stable") + data["atom_types"] = inv_idx[data["atom_types"]] + return data From f162c552cb760a2b0a162f25bc8a4cae9cb69dd9 Mon Sep 17 00:00:00 2001 From: SchrodingersCattt Date: Fri, 31 Oct 2025 00:18:53 +0000 Subject: [PATCH 2/4] test: add unit tests for atom type remapping --- tests/test_type_map_utils.py | 86 ++++++++++++++++++++++++++++++++++++ 1 file changed, 86 insertions(+) create mode 100644 tests/test_type_map_utils.py diff --git a/tests/test_type_map_utils.py b/tests/test_type_map_utils.py new file mode 100644 index 000000000..9e1022cfc --- /dev/null +++ b/tests/test_type_map_utils.py @@ -0,0 +1,86 @@ +from __future__ import annotations + +import unittest + +import numpy as np +from context import dpdata + +from dpdata.utils import sort_atom_names + + +class TestSortAtomNames(unittest.TestCase): + def test_sort_atom_names_type_map(self): + # Test basic functionality with type_map + data = { + "atom_names": ["H", "O"], + "atom_numbs": [2, 1], + "atom_types": np.array([1, 0, 0]), + } + type_map = ["O", "H"] + result = sort_atom_names(data, type_map=type_map) + + self.assertEqual(result["atom_names"], ["O", "H"]) + self.assertEqual(result["atom_numbs"], [1, 2]) + np.testing.assert_array_equal(result["atom_types"], np.array([0, 1, 1])) + + def test_sort_atom_names_type_map_with_zero_atoms(self): + # Test with type_map that includes elements with zero atoms + data = { + "atom_names": ["H", "O"], + "atom_numbs": [2, 1], + "atom_types": np.array([1, 0, 0]), + } + type_map = ["O", "H", "C"] # C is not in atom_names but in type_map + result = sort_atom_names(data, type_map=type_map) + + self.assertEqual(result["atom_names"], ["O", "H", "C"]) + self.assertEqual(result["atom_numbs"], [1, 2, 0]) + np.testing.assert_array_equal(result["atom_types"], np.array([0, 1, 1])) + + def test_sort_atom_names_type_map_missing_active_types(self): + # Test that ValueError is raised when active atom types are missing from type_map + data = { + "atom_names": ["H", "O"], + "atom_numbs": [2, 1], # Both H and O are active (numb > 0) + "atom_types": np.array([1, 0, 0]), + } + type_map = ["H"] # O is active but missing from type_map + + with self.assertRaises(ValueError) as cm: + sort_atom_names(data, type_map=type_map) + + self.assertIn("Active atom types", str(cm.exception)) + self.assertIn("not in provided type_map", str(cm.exception)) + self.assertIn("O", str(cm.exception)) + + def test_sort_atom_names_without_type_map(self): + # Test sorting without type_map (alphabetical order) + data = { + "atom_names": ["Zn", "O", "H"], + "atom_numbs": [1, 1, 2], + "atom_types": np.array([0, 1, 2, 2]), + } + result = sort_atom_names(data) + + self.assertEqual(result["atom_names"], ["H", "O", "Zn"]) + self.assertEqual(result["atom_numbs"], [2, 1, 1]) + np.testing.assert_array_equal(result["atom_types"], np.array([2, 1, 0, 0])) + + def test_sort_atom_names_with_zero_count_elements_removed(self): + # Test the case where original elements are A B C, but counts are 0 1 2, + # which should be able to map to B C (removing A which has count 0) + data = { + "atom_names": ["Cl", "O", "C"], + "atom_numbs": [0, 1, 2], + "atom_types": np.array([1, 2, 2]), + } + type_map = ["O", "C"] # A is omitted because it has 0 atoms + result = sort_atom_names(data, type_map=type_map) + + self.assertEqual(result["atom_names"], ["O", "C"]) + self.assertEqual(result["atom_numbs"], [1, 2]) + np.testing.assert_array_equal(result["atom_types"], np.array([0, 1, 1])) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file From cfde1f49e5bbb82dd0fe2d7c03e5472e0d52c6ad Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 31 Oct 2025 00:19:21 +0000 Subject: [PATCH 3/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_type_map_utils.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/tests/test_type_map_utils.py b/tests/test_type_map_utils.py index 9e1022cfc..6f5108f07 100644 --- a/tests/test_type_map_utils.py +++ b/tests/test_type_map_utils.py @@ -3,7 +3,6 @@ import unittest import numpy as np -from context import dpdata from dpdata.utils import sort_atom_names @@ -18,11 +17,11 @@ def test_sort_atom_names_type_map(self): } type_map = ["O", "H"] result = sort_atom_names(data, type_map=type_map) - + self.assertEqual(result["atom_names"], ["O", "H"]) self.assertEqual(result["atom_numbs"], [1, 2]) np.testing.assert_array_equal(result["atom_types"], np.array([0, 1, 1])) - + def test_sort_atom_names_type_map_with_zero_atoms(self): # Test with type_map that includes elements with zero atoms data = { @@ -32,11 +31,11 @@ def test_sort_atom_names_type_map_with_zero_atoms(self): } type_map = ["O", "H", "C"] # C is not in atom_names but in type_map result = sort_atom_names(data, type_map=type_map) - + self.assertEqual(result["atom_names"], ["O", "H", "C"]) self.assertEqual(result["atom_numbs"], [1, 2, 0]) np.testing.assert_array_equal(result["atom_types"], np.array([0, 1, 1])) - + def test_sort_atom_names_type_map_missing_active_types(self): # Test that ValueError is raised when active atom types are missing from type_map data = { @@ -45,14 +44,14 @@ def test_sort_atom_names_type_map_missing_active_types(self): "atom_types": np.array([1, 0, 0]), } type_map = ["H"] # O is active but missing from type_map - + with self.assertRaises(ValueError) as cm: sort_atom_names(data, type_map=type_map) - + self.assertIn("Active atom types", str(cm.exception)) self.assertIn("not in provided type_map", str(cm.exception)) self.assertIn("O", str(cm.exception)) - + def test_sort_atom_names_without_type_map(self): # Test sorting without type_map (alphabetical order) data = { @@ -61,7 +60,7 @@ def test_sort_atom_names_without_type_map(self): "atom_types": np.array([0, 1, 2, 2]), } result = sort_atom_names(data) - + self.assertEqual(result["atom_names"], ["H", "O", "Zn"]) self.assertEqual(result["atom_numbs"], [2, 1, 1]) np.testing.assert_array_equal(result["atom_types"], np.array([2, 1, 0, 0])) @@ -76,11 +75,11 @@ def test_sort_atom_names_with_zero_count_elements_removed(self): } type_map = ["O", "C"] # A is omitted because it has 0 atoms result = sort_atom_names(data, type_map=type_map) - + self.assertEqual(result["atom_names"], ["O", "C"]) self.assertEqual(result["atom_numbs"], [1, 2]) np.testing.assert_array_equal(result["atom_types"], np.array([0, 1, 1])) if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() From 3ec40471b8dfaa2c857d9912e1f1f5a716757bd1 Mon Sep 17 00:00:00 2001 From: SchrodingersCattt Date: Fri, 31 Oct 2025 00:40:17 +0000 Subject: [PATCH 4/4] style: enhance comments --- tests/test_type_map_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_type_map_utils.py b/tests/test_type_map_utils.py index 6f5108f07..4ae9567d5 100644 --- a/tests/test_type_map_utils.py +++ b/tests/test_type_map_utils.py @@ -68,12 +68,13 @@ def test_sort_atom_names_without_type_map(self): def test_sort_atom_names_with_zero_count_elements_removed(self): # Test the case where original elements are A B C, but counts are 0 1 2, # which should be able to map to B C (removing A which has count 0) + # Example: A, B, C = Cl, O, C data = { "atom_names": ["Cl", "O", "C"], "atom_numbs": [0, 1, 2], "atom_types": np.array([1, 2, 2]), } - type_map = ["O", "C"] # A is omitted because it has 0 atoms + type_map = ["O", "C"] # Cl is omitted because it has 0 atoms result = sort_atom_names(data, type_map=type_map) self.assertEqual(result["atom_names"], ["O", "C"])