Skip to content

Commit 06c21b6

Browse files
fix: standardize the deepmd/npy/mixed format (#425)
This PR has concated two commits together: 1. Update the dpdata.MultiSystems() when from_deepmd_npy_mixed method is called; dpdata.MultiSystems().from_deepmd_npy_mixed only returned the results before but did not change itself, which is fixed in this commit, to be consistent with other from methods. (another bug is also fixed: not using .copy() in data["atom_names"] may cause error when manually changing type_map for this system. UTs are added in the next commit.) 2. Allow multiple sets in mixed-type format; Now for maximum 50000 frames in one sys and 2000 frames in one set. The reason I did not use 5000 frames per set, is that I think maximum set frames will be much more often used in mixed-type format than other format, and 2000 will be enough for large batch and more friendly for memory. Add UTs for type_map changing and mixed_type dir check. --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent aec7747 commit 06c21b6

File tree

5 files changed

+217
-131
lines changed

5 files changed

+217
-131
lines changed

dpdata/deepmd/mixed.py

+85-74
Original file line numberDiff line numberDiff line change
@@ -54,60 +54,80 @@ def to_system_data(folder, type_map=None, labels=True):
5454
if os.path.isfile(os.path.join(folder, "nopbc")):
5555
data["nopbc"] = True
5656
sets = sorted(glob.glob(os.path.join(folder, "set.*")))
57-
assert len(sets) == 1, "Mixed type must have only one set!"
58-
cells, coords, eners, forces, virs, real_atom_types = _load_set(
59-
sets[0], data.get("nopbc", False)
60-
)
61-
nframes = np.reshape(cells, [-1, 3, 3]).shape[0]
62-
cells = np.reshape(cells, [nframes, 3, 3])
63-
coords = np.reshape(coords, [nframes, -1, 3])
64-
real_atom_types = np.reshape(real_atom_types, [nframes, -1])
65-
natom = real_atom_types.shape[1]
66-
if labels:
67-
if eners is not None and eners.size > 0:
57+
all_cells = []
58+
all_coords = []
59+
all_eners = []
60+
all_forces = []
61+
all_virs = []
62+
all_real_atom_types = []
63+
for ii in sets:
64+
cells, coords, eners, forces, virs, real_atom_types = _load_set(
65+
ii, data.get("nopbc", False)
66+
)
67+
nframes = np.reshape(cells, [-1, 3, 3]).shape[0]
68+
all_cells.append(np.reshape(cells, [nframes, 3, 3]))
69+
all_coords.append(np.reshape(coords, [nframes, -1, 3]))
70+
all_real_atom_types.append(np.reshape(real_atom_types, [nframes, -1]))
71+
if eners is not None:
6872
eners = np.reshape(eners, [nframes])
69-
if forces is not None and forces.size > 0:
70-
forces = np.reshape(forces, [nframes, -1, 3])
71-
if virs is not None and virs.size > 0:
72-
virs = np.reshape(virs, [nframes, 3, 3])
73+
if labels:
74+
if eners is not None and eners.size > 0:
75+
all_eners.append(np.reshape(eners, [nframes]))
76+
if forces is not None and forces.size > 0:
77+
all_forces.append(np.reshape(forces, [nframes, -1, 3]))
78+
if virs is not None and virs.size > 0:
79+
all_virs.append(np.reshape(virs, [nframes, 3, 3]))
80+
all_cells_concat = np.concatenate(all_cells, axis=0)
81+
all_coords_concat = np.concatenate(all_coords, axis=0)
82+
all_real_atom_types_concat = np.concatenate(all_real_atom_types, axis=0)
83+
all_eners_concat = None
84+
all_forces_concat = None
85+
all_virs_concat = None
86+
if len(all_eners) > 0:
87+
all_eners_concat = np.concatenate(all_eners, axis=0)
88+
if len(all_forces) > 0:
89+
all_forces_concat = np.concatenate(all_forces, axis=0)
90+
if len(all_virs) > 0:
91+
all_virs_concat = np.concatenate(all_virs, axis=0)
7392
data_list = []
7493
while True:
75-
if real_atom_types.size == 0:
94+
if all_real_atom_types_concat.size == 0:
7695
break
7796
temp_atom_numbs = [
78-
np.count_nonzero(real_atom_types[0] == i)
97+
np.count_nonzero(all_real_atom_types_concat[0] == i)
7998
for i in range(len(data["atom_names"]))
8099
]
81100
# temp_formula = formula(data['atom_names'], temp_atom_numbs)
82-
temp_idx = np.arange(real_atom_types.shape[0])[
83-
(real_atom_types == real_atom_types[0]).all(-1)
101+
temp_idx = np.arange(all_real_atom_types_concat.shape[0])[
102+
(all_real_atom_types_concat == all_real_atom_types_concat[0]).all(-1)
84103
]
85-
rest_idx = np.arange(real_atom_types.shape[0])[
86-
(real_atom_types != real_atom_types[0]).any(-1)
104+
rest_idx = np.arange(all_real_atom_types_concat.shape[0])[
105+
(all_real_atom_types_concat != all_real_atom_types_concat[0]).any(-1)
87106
]
88107
temp_data = data.copy()
108+
temp_data["atom_names"] = data["atom_names"].copy()
89109
temp_data["atom_numbs"] = temp_atom_numbs
90-
temp_data["atom_types"] = real_atom_types[0]
91-
real_atom_types = real_atom_types[rest_idx]
92-
temp_data["cells"] = cells[temp_idx]
93-
cells = cells[rest_idx]
94-
temp_data["coords"] = coords[temp_idx]
95-
coords = coords[rest_idx]
110+
temp_data["atom_types"] = all_real_atom_types_concat[0]
111+
all_real_atom_types_concat = all_real_atom_types_concat[rest_idx]
112+
temp_data["cells"] = all_cells_concat[temp_idx]
113+
all_cells_concat = all_cells_concat[rest_idx]
114+
temp_data["coords"] = all_coords_concat[temp_idx]
115+
all_coords_concat = all_coords_concat[rest_idx]
96116
if labels:
97-
if eners is not None and eners.size > 0:
98-
temp_data["energies"] = eners[temp_idx]
99-
eners = eners[rest_idx]
100-
if forces is not None and forces.size > 0:
101-
temp_data["forces"] = forces[temp_idx]
102-
forces = forces[rest_idx]
103-
if virs is not None and virs.size > 0:
104-
temp_data["virials"] = virs[temp_idx]
105-
virs = virs[rest_idx]
117+
if all_eners_concat is not None and all_eners_concat.size > 0:
118+
temp_data["energies"] = all_eners_concat[temp_idx]
119+
all_eners_concat = all_eners_concat[rest_idx]
120+
if all_forces_concat is not None and all_forces_concat.size > 0:
121+
temp_data["forces"] = all_forces_concat[temp_idx]
122+
all_forces_concat = all_forces_concat[rest_idx]
123+
if all_virs_concat is not None and all_virs_concat.size > 0:
124+
temp_data["virials"] = all_virs_concat[temp_idx]
125+
all_virs_concat = all_virs_concat[rest_idx]
106126
data_list.append(temp_data)
107127
return data_list
108128

109129

110-
def dump(folder, data, comp_prec=np.float32, remove_sets=True):
130+
def dump(folder, data, set_size=2000, comp_prec=np.float32, remove_sets=True):
111131
os.makedirs(folder, exist_ok=True)
112132
sets = sorted(glob.glob(os.path.join(folder, "set.*")))
113133
if len(sets) > 0:
@@ -164,20 +184,29 @@ def dump(folder, data, comp_prec=np.float32, remove_sets=True):
164184
np.int64
165185
)
166186
# dump frame properties: cell, coord, energy, force and virial
167-
set_folder = os.path.join(folder, "set.%03d" % 0)
168-
os.makedirs(set_folder)
169-
np.save(os.path.join(set_folder, "box"), cells)
170-
np.save(os.path.join(set_folder, "coord"), coords)
171-
if eners is not None:
172-
np.save(os.path.join(set_folder, "energy"), eners)
173-
if forces is not None:
174-
np.save(os.path.join(set_folder, "force"), forces)
175-
if virials is not None:
176-
np.save(os.path.join(set_folder, "virial"), virials)
177-
if real_atom_types is not None:
178-
np.save(os.path.join(set_folder, "real_atom_types"), real_atom_types)
179-
if "atom_pref" in data:
180-
np.save(os.path.join(set_folder, "atom_pref"), atom_pref)
187+
nsets = nframes // set_size
188+
if set_size * nsets < nframes:
189+
nsets += 1
190+
for ii in range(nsets):
191+
set_stt = ii * set_size
192+
set_end = (ii + 1) * set_size
193+
set_folder = os.path.join(folder, "set.%06d" % ii)
194+
os.makedirs(set_folder)
195+
np.save(os.path.join(set_folder, "box"), cells[set_stt:set_end])
196+
np.save(os.path.join(set_folder, "coord"), coords[set_stt:set_end])
197+
if eners is not None:
198+
np.save(os.path.join(set_folder, "energy"), eners[set_stt:set_end])
199+
if forces is not None:
200+
np.save(os.path.join(set_folder, "force"), forces[set_stt:set_end])
201+
if virials is not None:
202+
np.save(os.path.join(set_folder, "virial"), virials[set_stt:set_end])
203+
if real_atom_types is not None:
204+
np.save(
205+
os.path.join(set_folder, "real_atom_types"),
206+
real_atom_types[set_stt:set_end],
207+
)
208+
if "atom_pref" in data:
209+
np.save(os.path.join(set_folder, "atom_pref"), atom_pref[set_stt:set_end])
181210
try:
182211
os.remove(os.path.join(folder, "nopbc"))
183212
except OSError:
@@ -187,61 +216,43 @@ def dump(folder, data, comp_prec=np.float32, remove_sets=True):
187216
pass
188217

189218

190-
def mix_system(*system, type_map, split_num=200, **kwargs):
191-
"""Mix the systems into mixed_type ones
219+
def mix_system(*system, type_map, **kwargs):
220+
"""Mix the systems into mixed_type ones according to the unified given type_map.
192221
193222
Parameters
194223
----------
195224
*system : System
196225
The systems to mix
197226
type_map : list of str
198227
Maps atom type to name
199-
split_num : int
200-
Number of frames in each system
201228
202229
Returns
203230
-------
204231
mixed_systems: dict
205-
dict of mixed system with key '{atom_numbs}/sys.xxx'
232+
dict of mixed system with key 'atom_numbs'
206233
"""
207234
mixed_systems = {}
208235
temp_systems = {}
209-
atom_numbs_sys_index = {} # index of sys
210236
atom_numbs_frame_index = {} # index of frames in cur sys
211237
for sys in system:
212238
tmp_sys = sys.copy()
213239
natom = tmp_sys.get_natoms()
214240
tmp_sys.convert_to_mixed_type(type_map=type_map)
215-
if str(natom) not in atom_numbs_sys_index:
216-
atom_numbs_sys_index[str(natom)] = 0
217241
if str(natom) not in atom_numbs_frame_index:
218242
atom_numbs_frame_index[str(natom)] = 0
219243
atom_numbs_frame_index[str(natom)] += tmp_sys.get_nframes()
220244
if str(natom) not in temp_systems or not temp_systems[str(natom)]:
221245
temp_systems[str(natom)] = tmp_sys
222246
else:
223247
temp_systems[str(natom)].append(tmp_sys)
224-
if atom_numbs_frame_index[str(natom)] >= split_num:
225-
while True:
226-
sys_split, temp_systems[str(natom)], rest_num = split_system(
227-
temp_systems[str(natom)], split_num=split_num
228-
)
229-
sys_name = (
230-
f"{str(natom)}/sys." + "%.6d" % atom_numbs_sys_index[str(natom)]
231-
)
232-
mixed_systems[sys_name] = sys_split
233-
atom_numbs_sys_index[str(natom)] += 1
234-
if rest_num < split_num:
235-
atom_numbs_frame_index[str(natom)] = rest_num
236-
break
237248
for natom in temp_systems:
238249
if atom_numbs_frame_index[natom] > 0:
239-
sys_name = f"{natom}/sys." + "%.6d" % atom_numbs_sys_index[natom]
250+
sys_name = f"{natom}"
240251
mixed_systems[sys_name] = temp_systems[natom]
241252
return mixed_systems
242253

243254

244-
def split_system(sys, split_num=100):
255+
def split_system(sys, split_num=10000):
245256
rest = sys.get_nframes() - split_num
246257
if rest <= 0:
247258
return sys, None, 0

dpdata/format.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def to_multi_systems(self, formulas, directory, **kwargs):
132132
"%s doesn't support MultiSystems.to" % (self.__class__.__name__)
133133
)
134134

135-
def mix_system(self, *system, type_map, split_num=200, **kwargs):
135+
def mix_system(self, *system, type_map, **kwargs):
136136
"""Mix the systems into mixed_type ones according to the unified given type_map.
137137
138138
Parameters
@@ -141,13 +141,11 @@ def mix_system(self, *system, type_map, split_num=200, **kwargs):
141141
The systems to mix
142142
type_map : list of str
143143
Maps atom type to name
144-
split_num : int
145-
Number of frames in each system
146144
147145
Returns
148146
-------
149147
mixed_systems: dict
150-
dict of mixed system with key '{atom_numbs}/sys.xxx'
148+
dict of mixed system with key 'atom_numbs'
151149
"""
152150
raise NotImplementedError(
153151
"%s doesn't support System.from" % (self.__class__.__name__)

dpdata/plugins/deepmd.py

+10-37
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def from_labeled_system_mix(self, file_name, type_map=None, **kwargs):
117117
file_name, type_map=type_map, labels=True
118118
)
119119

120-
def mix_system(self, *system, type_map, split_num=200, **kwargs):
120+
def mix_system(self, *system, type_map, **kwargs):
121121
"""Mix the systems into mixed_type ones according to the unified given type_map.
122122
123123
Parameters
@@ -126,49 +126,22 @@ def mix_system(self, *system, type_map, split_num=200, **kwargs):
126126
The systems to mix
127127
type_map : list of str
128128
Maps atom type to name
129-
split_num : int
130-
Number of frames in each system
131129
132130
Returns
133131
-------
134132
mixed_systems: dict
135-
dict of mixed system with key '{atom_numbs}/sys.xxx'
133+
dict of mixed system with key 'atom_numbs'
136134
"""
137-
return dpdata.deepmd.mixed.mix_system(
138-
*system, type_map=type_map, split_num=split_num, **kwargs
139-
)
135+
return dpdata.deepmd.mixed.mix_system(*system, type_map=type_map, **kwargs)
140136

141137
def from_multi_systems(self, directory, **kwargs):
142-
"""MultiSystems.from
143-
144-
Parameters
145-
----------
146-
directory : str
147-
directory of system
148-
149-
Returns
150-
-------
151-
filenames: list[str]
152-
list of filenames
153-
"""
154-
if self.MultiMode == self.MultiModes.Directory:
155-
level_1_dir = [
156-
os.path.join(directory, name)
157-
for name in os.listdir(directory)
158-
if os.path.isdir(os.path.join(directory, name))
159-
and os.path.isfile(os.path.join(directory, name, "type_map.raw"))
160-
]
161-
level_2_dir = [
162-
os.path.join(directory, name1, name2)
163-
for name1 in os.listdir(directory)
164-
for name2 in os.listdir(os.path.join(directory, name1))
165-
if os.path.isdir(os.path.join(directory, name1))
166-
and os.path.isdir(os.path.join(directory, name1, name2))
167-
and os.path.isfile(
168-
os.path.join(directory, name1, name2, "type_map.raw")
169-
)
170-
]
171-
return level_1_dir + level_2_dir
138+
sys_dir = []
139+
for root, dirs, files in os.walk(directory):
140+
if (
141+
"type_map.raw" in files
142+
): # mixed_type format systems must have type_map.raw
143+
sys_dir.append(root)
144+
return sys_dir
172145

173146
MultiMode = Format.MultiModes.Directory
174147

dpdata/system.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -1307,15 +1307,13 @@ def from_fmt_obj(self, fmtobj, directory, labeled=True, **kwargs):
13071307
if labeled:
13081308
data_list = fmtobj.from_labeled_system_mix(dd, **kwargs)
13091309
for data_item in data_list:
1310-
system_list.append(LabeledSystem(data=data_item))
1310+
system_list.append(LabeledSystem(data=data_item, **kwargs))
13111311
else:
13121312
data_list = fmtobj.from_system_mix(dd, **kwargs)
13131313
for data_item in data_list:
1314-
system_list.append(System(data=data_item))
1315-
return self.__class__(
1316-
*system_list,
1317-
type_map=kwargs["type_map"] if "type_map" in kwargs else None,
1318-
)
1314+
system_list.append(System(data=data_item, **kwargs))
1315+
self.append(*system_list)
1316+
return self
13191317

13201318
def to_fmt_obj(self, fmtobj, directory, *args, **kwargs):
13211319
if not isinstance(fmtobj, dpdata.plugins.deepmd.DeePMDMixedFormat):

0 commit comments

Comments
 (0)