@@ -54,60 +54,80 @@ def to_system_data(folder, type_map=None, labels=True):
54
54
if os .path .isfile (os .path .join (folder , "nopbc" )):
55
55
data ["nopbc" ] = True
56
56
sets = sorted (glob .glob (os .path .join (folder , "set.*" )))
57
- assert len (sets ) == 1 , "Mixed type must have only one set!"
58
- cells , coords , eners , forces , virs , real_atom_types = _load_set (
59
- sets [0 ], data .get ("nopbc" , False )
60
- )
61
- nframes = np .reshape (cells , [- 1 , 3 , 3 ]).shape [0 ]
62
- cells = np .reshape (cells , [nframes , 3 , 3 ])
63
- coords = np .reshape (coords , [nframes , - 1 , 3 ])
64
- real_atom_types = np .reshape (real_atom_types , [nframes , - 1 ])
65
- natom = real_atom_types .shape [1 ]
66
- if labels :
67
- if eners is not None and eners .size > 0 :
57
+ all_cells = []
58
+ all_coords = []
59
+ all_eners = []
60
+ all_forces = []
61
+ all_virs = []
62
+ all_real_atom_types = []
63
+ for ii in sets :
64
+ cells , coords , eners , forces , virs , real_atom_types = _load_set (
65
+ ii , data .get ("nopbc" , False )
66
+ )
67
+ nframes = np .reshape (cells , [- 1 , 3 , 3 ]).shape [0 ]
68
+ all_cells .append (np .reshape (cells , [nframes , 3 , 3 ]))
69
+ all_coords .append (np .reshape (coords , [nframes , - 1 , 3 ]))
70
+ all_real_atom_types .append (np .reshape (real_atom_types , [nframes , - 1 ]))
71
+ if eners is not None :
68
72
eners = np .reshape (eners , [nframes ])
69
- if forces is not None and forces .size > 0 :
70
- forces = np .reshape (forces , [nframes , - 1 , 3 ])
71
- if virs is not None and virs .size > 0 :
72
- virs = np .reshape (virs , [nframes , 3 , 3 ])
73
+ if labels :
74
+ if eners is not None and eners .size > 0 :
75
+ all_eners .append (np .reshape (eners , [nframes ]))
76
+ if forces is not None and forces .size > 0 :
77
+ all_forces .append (np .reshape (forces , [nframes , - 1 , 3 ]))
78
+ if virs is not None and virs .size > 0 :
79
+ all_virs .append (np .reshape (virs , [nframes , 3 , 3 ]))
80
+ all_cells_concat = np .concatenate (all_cells , axis = 0 )
81
+ all_coords_concat = np .concatenate (all_coords , axis = 0 )
82
+ all_real_atom_types_concat = np .concatenate (all_real_atom_types , axis = 0 )
83
+ all_eners_concat = None
84
+ all_forces_concat = None
85
+ all_virs_concat = None
86
+ if len (all_eners ) > 0 :
87
+ all_eners_concat = np .concatenate (all_eners , axis = 0 )
88
+ if len (all_forces ) > 0 :
89
+ all_forces_concat = np .concatenate (all_forces , axis = 0 )
90
+ if len (all_virs ) > 0 :
91
+ all_virs_concat = np .concatenate (all_virs , axis = 0 )
73
92
data_list = []
74
93
while True :
75
- if real_atom_types .size == 0 :
94
+ if all_real_atom_types_concat .size == 0 :
76
95
break
77
96
temp_atom_numbs = [
78
- np .count_nonzero (real_atom_types [0 ] == i )
97
+ np .count_nonzero (all_real_atom_types_concat [0 ] == i )
79
98
for i in range (len (data ["atom_names" ]))
80
99
]
81
100
# temp_formula = formula(data['atom_names'], temp_atom_numbs)
82
- temp_idx = np .arange (real_atom_types .shape [0 ])[
83
- (real_atom_types == real_atom_types [0 ]).all (- 1 )
101
+ temp_idx = np .arange (all_real_atom_types_concat .shape [0 ])[
102
+ (all_real_atom_types_concat == all_real_atom_types_concat [0 ]).all (- 1 )
84
103
]
85
- rest_idx = np .arange (real_atom_types .shape [0 ])[
86
- (real_atom_types != real_atom_types [0 ]).any (- 1 )
104
+ rest_idx = np .arange (all_real_atom_types_concat .shape [0 ])[
105
+ (all_real_atom_types_concat != all_real_atom_types_concat [0 ]).any (- 1 )
87
106
]
88
107
temp_data = data .copy ()
108
+ temp_data ["atom_names" ] = data ["atom_names" ].copy ()
89
109
temp_data ["atom_numbs" ] = temp_atom_numbs
90
- temp_data ["atom_types" ] = real_atom_types [0 ]
91
- real_atom_types = real_atom_types [rest_idx ]
92
- temp_data ["cells" ] = cells [temp_idx ]
93
- cells = cells [rest_idx ]
94
- temp_data ["coords" ] = coords [temp_idx ]
95
- coords = coords [rest_idx ]
110
+ temp_data ["atom_types" ] = all_real_atom_types_concat [0 ]
111
+ all_real_atom_types_concat = all_real_atom_types_concat [rest_idx ]
112
+ temp_data ["cells" ] = all_cells_concat [temp_idx ]
113
+ all_cells_concat = all_cells_concat [rest_idx ]
114
+ temp_data ["coords" ] = all_coords_concat [temp_idx ]
115
+ all_coords_concat = all_coords_concat [rest_idx ]
96
116
if labels :
97
- if eners is not None and eners .size > 0 :
98
- temp_data ["energies" ] = eners [temp_idx ]
99
- eners = eners [rest_idx ]
100
- if forces is not None and forces .size > 0 :
101
- temp_data ["forces" ] = forces [temp_idx ]
102
- forces = forces [rest_idx ]
103
- if virs is not None and virs .size > 0 :
104
- temp_data ["virials" ] = virs [temp_idx ]
105
- virs = virs [rest_idx ]
117
+ if all_eners_concat is not None and all_eners_concat .size > 0 :
118
+ temp_data ["energies" ] = all_eners_concat [temp_idx ]
119
+ all_eners_concat = all_eners_concat [rest_idx ]
120
+ if all_forces_concat is not None and all_forces_concat .size > 0 :
121
+ temp_data ["forces" ] = all_forces_concat [temp_idx ]
122
+ all_forces_concat = all_forces_concat [rest_idx ]
123
+ if all_virs_concat is not None and all_virs_concat .size > 0 :
124
+ temp_data ["virials" ] = all_virs_concat [temp_idx ]
125
+ all_virs_concat = all_virs_concat [rest_idx ]
106
126
data_list .append (temp_data )
107
127
return data_list
108
128
109
129
110
- def dump (folder , data , comp_prec = np .float32 , remove_sets = True ):
130
+ def dump (folder , data , set_size = 2000 , comp_prec = np .float32 , remove_sets = True ):
111
131
os .makedirs (folder , exist_ok = True )
112
132
sets = sorted (glob .glob (os .path .join (folder , "set.*" )))
113
133
if len (sets ) > 0 :
@@ -164,20 +184,29 @@ def dump(folder, data, comp_prec=np.float32, remove_sets=True):
164
184
np .int64
165
185
)
166
186
# dump frame properties: cell, coord, energy, force and virial
167
- set_folder = os .path .join (folder , "set.%03d" % 0 )
168
- os .makedirs (set_folder )
169
- np .save (os .path .join (set_folder , "box" ), cells )
170
- np .save (os .path .join (set_folder , "coord" ), coords )
171
- if eners is not None :
172
- np .save (os .path .join (set_folder , "energy" ), eners )
173
- if forces is not None :
174
- np .save (os .path .join (set_folder , "force" ), forces )
175
- if virials is not None :
176
- np .save (os .path .join (set_folder , "virial" ), virials )
177
- if real_atom_types is not None :
178
- np .save (os .path .join (set_folder , "real_atom_types" ), real_atom_types )
179
- if "atom_pref" in data :
180
- np .save (os .path .join (set_folder , "atom_pref" ), atom_pref )
187
+ nsets = nframes // set_size
188
+ if set_size * nsets < nframes :
189
+ nsets += 1
190
+ for ii in range (nsets ):
191
+ set_stt = ii * set_size
192
+ set_end = (ii + 1 ) * set_size
193
+ set_folder = os .path .join (folder , "set.%06d" % ii )
194
+ os .makedirs (set_folder )
195
+ np .save (os .path .join (set_folder , "box" ), cells [set_stt :set_end ])
196
+ np .save (os .path .join (set_folder , "coord" ), coords [set_stt :set_end ])
197
+ if eners is not None :
198
+ np .save (os .path .join (set_folder , "energy" ), eners [set_stt :set_end ])
199
+ if forces is not None :
200
+ np .save (os .path .join (set_folder , "force" ), forces [set_stt :set_end ])
201
+ if virials is not None :
202
+ np .save (os .path .join (set_folder , "virial" ), virials [set_stt :set_end ])
203
+ if real_atom_types is not None :
204
+ np .save (
205
+ os .path .join (set_folder , "real_atom_types" ),
206
+ real_atom_types [set_stt :set_end ],
207
+ )
208
+ if "atom_pref" in data :
209
+ np .save (os .path .join (set_folder , "atom_pref" ), atom_pref [set_stt :set_end ])
181
210
try :
182
211
os .remove (os .path .join (folder , "nopbc" ))
183
212
except OSError :
@@ -187,61 +216,43 @@ def dump(folder, data, comp_prec=np.float32, remove_sets=True):
187
216
pass
188
217
189
218
190
- def mix_system (* system , type_map , split_num = 200 , ** kwargs ):
191
- """Mix the systems into mixed_type ones
219
+ def mix_system (* system , type_map , ** kwargs ):
220
+ """Mix the systems into mixed_type ones according to the unified given type_map.
192
221
193
222
Parameters
194
223
----------
195
224
*system : System
196
225
The systems to mix
197
226
type_map : list of str
198
227
Maps atom type to name
199
- split_num : int
200
- Number of frames in each system
201
228
202
229
Returns
203
230
-------
204
231
mixed_systems: dict
205
- dict of mixed system with key '{ atom_numbs}/sys.xxx '
232
+ dict of mixed system with key 'atom_numbs'
206
233
"""
207
234
mixed_systems = {}
208
235
temp_systems = {}
209
- atom_numbs_sys_index = {} # index of sys
210
236
atom_numbs_frame_index = {} # index of frames in cur sys
211
237
for sys in system :
212
238
tmp_sys = sys .copy ()
213
239
natom = tmp_sys .get_natoms ()
214
240
tmp_sys .convert_to_mixed_type (type_map = type_map )
215
- if str (natom ) not in atom_numbs_sys_index :
216
- atom_numbs_sys_index [str (natom )] = 0
217
241
if str (natom ) not in atom_numbs_frame_index :
218
242
atom_numbs_frame_index [str (natom )] = 0
219
243
atom_numbs_frame_index [str (natom )] += tmp_sys .get_nframes ()
220
244
if str (natom ) not in temp_systems or not temp_systems [str (natom )]:
221
245
temp_systems [str (natom )] = tmp_sys
222
246
else :
223
247
temp_systems [str (natom )].append (tmp_sys )
224
- if atom_numbs_frame_index [str (natom )] >= split_num :
225
- while True :
226
- sys_split , temp_systems [str (natom )], rest_num = split_system (
227
- temp_systems [str (natom )], split_num = split_num
228
- )
229
- sys_name = (
230
- f"{ str (natom )} /sys." + "%.6d" % atom_numbs_sys_index [str (natom )]
231
- )
232
- mixed_systems [sys_name ] = sys_split
233
- atom_numbs_sys_index [str (natom )] += 1
234
- if rest_num < split_num :
235
- atom_numbs_frame_index [str (natom )] = rest_num
236
- break
237
248
for natom in temp_systems :
238
249
if atom_numbs_frame_index [natom ] > 0 :
239
- sys_name = f"{ natom } /sys." + "%.6d" % atom_numbs_sys_index [ natom ]
250
+ sys_name = f"{ natom } "
240
251
mixed_systems [sys_name ] = temp_systems [natom ]
241
252
return mixed_systems
242
253
243
254
244
- def split_system (sys , split_num = 100 ):
255
+ def split_system (sys , split_num = 10000 ):
245
256
rest = sys .get_nframes () - split_num
246
257
if rest <= 0 :
247
258
return sys , None , 0
0 commit comments