Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 55 additions & 21 deletions dpdata/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,10 @@ def add_atom_names(data, atom_names):


def sort_atom_names(data, type_map=None):
"""Sort atom_names of the system and reorder atom_numbs and atom_types accoarding
"""Sort atom_names of the system and reorder atom_numbs and atom_types according
to atom_names. If type_map is not given, atom_names will be sorted by
alphabetical order. If type_map is given, atom_names will be type_map.
alphabetical order. If type_map is given, atom_names will be set to type_map,
and zero-count elements are kept.

Parameters
----------
Expand All @@ -74,28 +75,61 @@ def sort_atom_names(data, type_map=None):
type_map
"""
if type_map is not None:
# assign atom_names index to the specify order
# atom_names must be a subset of type_map
assert set(data["atom_names"]).issubset(set(type_map))
# for the condition that type_map is a proper superset of atom_names
# new_atoms = set(type_map) - set(data["atom_names"])
new_atoms = [e for e in type_map if e not in data["atom_names"]]
if new_atoms:
data = add_atom_names(data, new_atoms)
# index that will sort an array by type_map
# a[as[a]] == b[as[b]] as == argsort
# as[as[b]] == as^{-1}[b]
# a[as[a][as[as[b]]]] = b[as[b][as^{-1}[b]]] = b[id]
idx = np.argsort(data["atom_names"], kind="stable")[
np.argsort(np.argsort(type_map, kind="stable"), kind="stable")
]
# assign atom_names index to the specified order
# only active (numb > 0) atom names must be in type_map
orig_names = data["atom_names"]
orig_numbs = data["atom_numbs"]
active_names = {name for name, numb in zip(orig_names, orig_numbs) if numb > 0}
type_map_set = set(type_map)
if not active_names.issubset(type_map_set):
missing = active_names - type_map_set
raise ValueError(f"Active atom types {missing} not in provided type_map.")

# for the condition that type_map is a proper superset of atom_names,
# we allow new elements with atom_numb = 0
new_names = list(type_map)
new_numbs = []
name_to_old_idx = {name: i for i, name in enumerate(orig_names)}

for name in new_names:
if name in name_to_old_idx:
new_numbs.append(orig_numbs[name_to_old_idx[name]])
else:
new_numbs.append(0)

# build mapping from old atom type index to new one
# old_types[i] = j --> new_types[i] = type_map.index(atom_names[j])
old_to_new_index = {}
for old_idx, name in enumerate(orig_names):
if name in type_map_set:
new_idx = type_map.index(name)
old_to_new_index[old_idx] = new_idx

# remap atom_types using the index mapping
old_types = np.array(data["atom_types"])
new_types = np.empty_like(old_types)
for old_idx, new_idx in old_to_new_index.items():
new_types[old_types == old_idx] = new_idx

# update data in-place
data["atom_names"] = new_names
data["atom_numbs"] = new_numbs
data["atom_types"] = new_types

else:
# index that will sort an array by alphabetical order
# idx = argsort(atom_names) --> atom_names[idx] is sorted
idx = np.argsort(data["atom_names"], kind="stable")
# sort atom_names, atom_numbs, atom_types by idx
data["atom_names"] = list(np.array(data["atom_names"])[idx])
data["atom_numbs"] = list(np.array(data["atom_numbs"])[idx])
data["atom_types"] = np.argsort(idx, kind="stable")[data["atom_types"]]
# sort atom_names and atom_numbs by idx
data["atom_names"] = list(np.array(data["atom_names"])[idx])
data["atom_numbs"] = list(np.array(data["atom_numbs"])[idx])
# to update atom_types: we need the inverse permutation of idx
# because if old_type = i, and atom_names[i] moves to position j,
# then the new type should be j.
# inv_idx = argsort(idx) satisfies: inv_idx[idx[i]] = i
inv_idx = np.argsort(idx, kind="stable")
data["atom_types"] = inv_idx[data["atom_types"]]

return data


Expand Down