125
125
run_opt_file = os .path .join (ROOT_PATH , "generator/lib/calypso_run_opt.py" )
126
126
127
127
128
+ def _get_model_suffix (jdata ) -> str :
129
+ """Return the model suffix based on the backend."""
130
+ suffix_map = {"tensorflow" : ".pb" , "pytorch" : ".pth" }
131
+ backend = jdata .get ("train_backend" , "tensorflow" )
132
+ if backend in suffix_map :
133
+ suffix = suffix_map [backend ]
134
+ else :
135
+ raise ValueError (
136
+ f"The backend { backend } is not available. Supported backends are: 'tensorflow', 'pytorch'."
137
+ )
138
+ return suffix
139
+
140
+
128
141
def get_job_names (jdata ):
129
142
jobkeys = []
130
143
for ii in jdata .keys ():
@@ -172,7 +185,7 @@ def _check_empty_iter(iter_index, max_v=0):
172
185
return all (empty_sys )
173
186
174
187
175
- def copy_model (numb_model , prv_iter_index , cur_iter_index ):
188
+ def copy_model (numb_model , prv_iter_index , cur_iter_index , suffix = ".pb" ):
176
189
cwd = os .getcwd ()
177
190
prv_train_path = os .path .join (make_iter_name (prv_iter_index ), train_name )
178
191
cur_train_path = os .path .join (make_iter_name (cur_iter_index ), train_name )
@@ -184,7 +197,8 @@ def copy_model(numb_model, prv_iter_index, cur_iter_index):
184
197
os .chdir (cur_train_path )
185
198
os .symlink (os .path .relpath (prv_train_task ), train_task_fmt % ii )
186
199
os .symlink (
187
- os .path .join (train_task_fmt % ii , "frozen_model.pb" ), "graph.%03d.pb" % ii
200
+ os .path .join (train_task_fmt % ii , f"frozen_model{ suffix } " ),
201
+ "graph.%03d%s" % (ii , suffix ),
188
202
)
189
203
os .chdir (cwd )
190
204
with open (os .path .join (cur_train_path , "copied" ), "w" ) as fp :
@@ -315,18 +329,19 @@ def make_train(iter_index, jdata, mdata):
315
329
number_old_frames = 0
316
330
number_new_frames = 0
317
331
332
+ suffix = _get_model_suffix (jdata )
318
333
model_devi_engine = jdata .get ("model_devi_engine" , "lammps" )
319
334
if iter_index > 0 and _check_empty_iter (iter_index - 1 , fp_task_min ):
320
335
log_task ("prev data is empty, copy prev model" )
321
- copy_model (numb_models , iter_index - 1 , iter_index )
336
+ copy_model (numb_models , iter_index - 1 , iter_index , suffix )
322
337
return
323
338
elif (
324
339
model_devi_engine != "calypso"
325
340
and iter_index > 0
326
341
and _check_skip_train (model_devi_jobs [iter_index - 1 ])
327
342
):
328
343
log_task ("skip training at step %d " % (iter_index - 1 ))
329
- copy_model (numb_models , iter_index - 1 , iter_index )
344
+ copy_model (numb_models , iter_index - 1 , iter_index , suffix )
330
345
return
331
346
else :
332
347
iter_name = make_iter_name (iter_index )
@@ -647,7 +662,9 @@ def make_train(iter_index, jdata, mdata):
647
662
)
648
663
if copied_models is not None :
649
664
for ii in range (len (copied_models )):
650
- _link_old_models (work_path , [copied_models [ii ]], ii , basename = "init.pb" )
665
+ _link_old_models (
666
+ work_path , [copied_models [ii ]], ii , basename = f"init{ suffix } "
667
+ )
651
668
# Copy user defined forward files
652
669
symlink_user_forward_files (mdata = mdata , task_type = "train" , work_path = work_path )
653
670
# HDF5 format for training data
@@ -699,6 +716,7 @@ def run_train(iter_index, jdata, mdata):
699
716
# print("debug:run_train:mdata", mdata)
700
717
# load json param
701
718
numb_models = jdata ["numb_models" ]
719
+ suffix = _get_model_suffix (jdata )
702
720
# train_param = jdata['train_param']
703
721
train_input_file = default_train_input_file
704
722
training_reuse_iter = jdata .get ("training_reuse_iter" )
@@ -730,7 +748,11 @@ def run_train(iter_index, jdata, mdata):
730
748
"training_init_model, training_init_frozen_model, and training_finetune_model are mutually exclusive."
731
749
)
732
750
733
- train_command = mdata .get ("train_command" , "dp" )
751
+ train_command = mdata .get ("train_command" , "dp" ).strip ()
752
+ # assert train_command == "dp", "The 'train_command' should be 'dp'" # the tests should be updated to run this command
753
+ if suffix == ".pth" :
754
+ train_command += " --pt"
755
+
734
756
train_resources = mdata ["train_resources" ]
735
757
736
758
# paths
@@ -761,9 +783,9 @@ def run_train(iter_index, jdata, mdata):
761
783
if training_init_model :
762
784
init_flag = " --init-model old/model.ckpt"
763
785
elif training_init_frozen_model is not None :
764
- init_flag = " --init-frz-model old/init.pb "
786
+ init_flag = f " --init-frz-model old/init{ suffix } "
765
787
elif training_finetune_model is not None :
766
- init_flag = " --finetune old/init.pb "
788
+ init_flag = f " --finetune old/init{ suffix } "
767
789
command = f"{ train_command } train { train_input_file } { extra_flags } "
768
790
command = f"{{ if [ ! -f model.ckpt.index ]; then { command } { init_flag } ; else { command } --restart model.ckpt; fi }}"
769
791
command = f"/bin/sh -c { shlex .quote (command )} "
@@ -792,23 +814,35 @@ def run_train(iter_index, jdata, mdata):
792
814
if "srtab_file_path" in jdata .keys ():
793
815
forward_files .append (zbl_file )
794
816
if training_init_model :
795
- forward_files += [
796
- os .path .join ("old" , "model.ckpt.meta" ),
797
- os .path .join ("old" , "model.ckpt.index" ),
798
- os .path .join ("old" , "model.ckpt.data-00000-of-00001" ),
799
- ]
817
+ if suffix == ".pb" :
818
+ forward_files += [
819
+ os .path .join ("old" , "model.ckpt.meta" ),
820
+ os .path .join ("old" , "model.ckpt.index" ),
821
+ os .path .join ("old" , "model.ckpt.data-00000-of-00001" ),
822
+ ]
823
+ elif suffix == ".pth" :
824
+ forward_files += [os .path .join ("old" , "model.ckpt.pt" )]
800
825
elif training_init_frozen_model is not None or training_finetune_model is not None :
801
- forward_files .append (os .path .join ("old" , "init.pb " ))
826
+ forward_files .append (os .path .join ("old" , f "init{ suffix } " ))
802
827
803
- backward_files = ["frozen_model.pb" , "lcurve.out" , "train.log" ]
804
- backward_files += [
805
- "model.ckpt.meta" ,
806
- "model.ckpt.index" ,
807
- "model.ckpt.data-00000-of-00001" ,
828
+ backward_files = [
829
+ f"frozen_model{ suffix } " ,
830
+ "lcurve.out" ,
831
+ "train.log" ,
808
832
"checkpoint" ,
809
833
]
810
834
if jdata .get ("dp_compress" , False ):
811
- backward_files .append ("frozen_model_compressed.pb" )
835
+ backward_files .append (f"frozen_model_compressed{ suffix } " )
836
+
837
+ if suffix == ".pb" :
838
+ backward_files += [
839
+ "model.ckpt.meta" ,
840
+ "model.ckpt.index" ,
841
+ "model.ckpt.data-00000-of-00001" ,
842
+ ]
843
+ elif suffix == ".pth" :
844
+ backward_files += ["model.ckpt.pt" ]
845
+
812
846
if not jdata .get ("one_h5" , False ):
813
847
init_data_sys_ = jdata ["init_data_sys" ]
814
848
init_data_sys = []
@@ -879,13 +913,14 @@ def post_train(iter_index, jdata, mdata):
879
913
log_task ("copied model, do not post train" )
880
914
return
881
915
# symlink models
916
+ suffix = _get_model_suffix (jdata )
882
917
for ii in range (numb_models ):
883
- if not jdata .get ("dp_compress" , False ):
884
- model_name = "frozen_model.pb"
885
- else :
886
- model_name = "frozen_model_compressed.pb"
918
+ model_name = f"frozen_model{ suffix } "
919
+ if jdata .get ("dp_compress" , False ):
920
+ model_name = f"frozen_model_compressed{ suffix } "
921
+
922
+ ofile = os .path .join (work_path , "graph.%03d%s" % (ii , suffix ))
887
923
task_file = os .path .join (train_task_fmt % ii , model_name )
888
- ofile = os .path .join (work_path , "graph.%03d.pb" % ii )
889
924
if os .path .isfile (ofile ):
890
925
os .remove (ofile )
891
926
os .symlink (task_file , ofile )
@@ -1124,7 +1159,8 @@ def make_model_devi(iter_index, jdata, mdata):
1124
1159
iter_name = make_iter_name (iter_index )
1125
1160
train_path = os .path .join (iter_name , train_name )
1126
1161
train_path = os .path .abspath (train_path )
1127
- models = sorted (glob .glob (os .path .join (train_path , "graph*pb" )))
1162
+ suffix = _get_model_suffix (jdata )
1163
+ models = sorted (glob .glob (os .path .join (train_path , f"graph*{ suffix } " )))
1128
1164
work_path = os .path .join (iter_name , model_devi_name )
1129
1165
create_path (work_path )
1130
1166
if model_devi_engine == "calypso" :
@@ -1305,7 +1341,8 @@ def _make_model_devi_revmat(iter_index, jdata, mdata, conf_systems):
1305
1341
iter_name = make_iter_name (iter_index )
1306
1342
train_path = os .path .join (iter_name , train_name )
1307
1343
train_path = os .path .abspath (train_path )
1308
- models = sorted (glob .glob (os .path .join (train_path , "graph*pb" )))
1344
+ suffix = _get_model_suffix (jdata )
1345
+ models = sorted (glob .glob (os .path .join (train_path , f"graph*{ suffix } " )))
1309
1346
task_model_list = []
1310
1347
for ii in models :
1311
1348
task_model_list .append (os .path .join (".." , os .path .basename (ii )))
@@ -1502,7 +1539,8 @@ def _make_model_devi_native(iter_index, jdata, mdata, conf_systems):
1502
1539
iter_name = make_iter_name (iter_index )
1503
1540
train_path = os .path .join (iter_name , train_name )
1504
1541
train_path = os .path .abspath (train_path )
1505
- models = glob .glob (os .path .join (train_path , "graph*pb" ))
1542
+ suffix = _get_model_suffix (jdata )
1543
+ models = sorted (glob .glob (os .path .join (train_path , f"graph*{ suffix } " )))
1506
1544
task_model_list = []
1507
1545
for ii in models :
1508
1546
task_model_list .append (os .path .join (".." , os .path .basename (ii )))
@@ -1644,7 +1682,8 @@ def _make_model_devi_native_gromacs(iter_index, jdata, mdata, conf_systems):
1644
1682
iter_name = make_iter_name (iter_index )
1645
1683
train_path = os .path .join (iter_name , train_name )
1646
1684
train_path = os .path .abspath (train_path )
1647
- models = glob .glob (os .path .join (train_path , "graph*pb" ))
1685
+ suffix = _get_model_suffix (jdata )
1686
+ models = sorted (glob .glob (os .path .join (train_path , f"graph*{ suffix } " )))
1648
1687
task_model_list = []
1649
1688
for ii in models :
1650
1689
task_model_list .append (os .path .join (".." , os .path .basename (ii )))
@@ -1827,7 +1866,8 @@ def _make_model_devi_amber(
1827
1866
.replace ("@qm_theory@" , jdata ["low_level" ])
1828
1867
.replace ("@rcut@" , str (jdata ["cutoff" ]))
1829
1868
)
1830
- models = sorted (glob .glob (os .path .join (train_path , "graph.*.pb" )))
1869
+ suffix = _get_model_suffix (jdata )
1870
+ models = sorted (glob .glob (os .path .join (train_path , f"graph.*{ suffix } " )))
1831
1871
task_model_list = []
1832
1872
for ii in models :
1833
1873
task_model_list .append (os .path .join (".." , os .path .basename (ii )))
@@ -1935,7 +1975,9 @@ def run_md_model_devi(iter_index, jdata, mdata):
1935
1975
run_tasks = [os .path .basename (ii ) for ii in run_tasks_ ]
1936
1976
# dlog.info("all_task is ", all_task)
1937
1977
# dlog.info("run_tasks in run_model_deviation",run_tasks_)
1938
- all_models = glob .glob (os .path .join (work_path , "graph*pb" ))
1978
+
1979
+ suffix = _get_model_suffix (jdata )
1980
+ all_models = glob .glob (os .path .join (work_path , f"graph*{ suffix } " ))
1939
1981
model_names = [os .path .basename (ii ) for ii in all_models ]
1940
1982
1941
1983
model_devi_engine = jdata .get ("model_devi_engine" , "lammps" )
0 commit comments