Skip to content

Commit 9d29459

Browse files
thangcktpre-commit-ci[bot]njzjz
authored
add option to select backends TF/PT (#1545)
reopen PR #1541 due to branch is deleted add a new key in `param.json` file ``` "train_backend": "pytorch"/"tensorflow", ``` relate to this issue #1462 <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Improved model management by dynamically generating model suffixes based on the selected backend, enhancing compatibility. - **Enhancements** - Updated model-related functions to incorporate backend-specific model suffixes for accurate file handling during training processes. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: C. Thang Nguyen <[email protected]> Signed-off-by: Jinzhe Zeng <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jinzhe Zeng <[email protected]>
1 parent e13c186 commit 9d29459

File tree

3 files changed

+87
-32
lines changed

3 files changed

+87
-32
lines changed

dpgen/generator/arginfo.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,9 @@ def training_args() -> list[Argument]:
8787
list[dargs.Argument]
8888
List of training arguments.
8989
"""
90+
doc_train_backend = (
91+
"The backend of the training. Currently only support tensorflow and pytorch."
92+
)
9093
doc_numb_models = "Number of models to be trained in 00.train. 4 is recommend."
9194
doc_training_iter0_model_path = "The model used to init the first iter training. Number of element should be equal to numb_models."
9295
doc_training_init_model = "Iteration > 0, the model parameters will be initilized from the model trained at the previous iteration. Iteration == 0, the model parameters will be initialized from training_iter0_model_path."
@@ -123,6 +126,13 @@ def training_args() -> list[Argument]:
123126
doc_training_finetune_model = "At interation 0, finetune the model parameters from the given frozen models. Number of element should be equal to numb_models."
124127

125128
return [
129+
Argument(
130+
"train_backend",
131+
str,
132+
optional=True,
133+
default="tensorflow",
134+
doc=doc_train_backend,
135+
),
126136
Argument("numb_models", int, optional=False, doc=doc_numb_models),
127137
Argument(
128138
"training_iter0_model_path",

dpgen/generator/run.py

Lines changed: 73 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,19 @@
125125
run_opt_file = os.path.join(ROOT_PATH, "generator/lib/calypso_run_opt.py")
126126

127127

128+
def _get_model_suffix(jdata) -> str:
129+
"""Return the model suffix based on the backend."""
130+
suffix_map = {"tensorflow": ".pb", "pytorch": ".pth"}
131+
backend = jdata.get("train_backend", "tensorflow")
132+
if backend in suffix_map:
133+
suffix = suffix_map[backend]
134+
else:
135+
raise ValueError(
136+
f"The backend {backend} is not available. Supported backends are: 'tensorflow', 'pytorch'."
137+
)
138+
return suffix
139+
140+
128141
def get_job_names(jdata):
129142
jobkeys = []
130143
for ii in jdata.keys():
@@ -172,7 +185,7 @@ def _check_empty_iter(iter_index, max_v=0):
172185
return all(empty_sys)
173186

174187

175-
def copy_model(numb_model, prv_iter_index, cur_iter_index):
188+
def copy_model(numb_model, prv_iter_index, cur_iter_index, suffix=".pb"):
176189
cwd = os.getcwd()
177190
prv_train_path = os.path.join(make_iter_name(prv_iter_index), train_name)
178191
cur_train_path = os.path.join(make_iter_name(cur_iter_index), train_name)
@@ -184,7 +197,8 @@ def copy_model(numb_model, prv_iter_index, cur_iter_index):
184197
os.chdir(cur_train_path)
185198
os.symlink(os.path.relpath(prv_train_task), train_task_fmt % ii)
186199
os.symlink(
187-
os.path.join(train_task_fmt % ii, "frozen_model.pb"), "graph.%03d.pb" % ii
200+
os.path.join(train_task_fmt % ii, f"frozen_model{suffix}"),
201+
"graph.%03d%s" % (ii, suffix),
188202
)
189203
os.chdir(cwd)
190204
with open(os.path.join(cur_train_path, "copied"), "w") as fp:
@@ -315,18 +329,19 @@ def make_train(iter_index, jdata, mdata):
315329
number_old_frames = 0
316330
number_new_frames = 0
317331

332+
suffix = _get_model_suffix(jdata)
318333
model_devi_engine = jdata.get("model_devi_engine", "lammps")
319334
if iter_index > 0 and _check_empty_iter(iter_index - 1, fp_task_min):
320335
log_task("prev data is empty, copy prev model")
321-
copy_model(numb_models, iter_index - 1, iter_index)
336+
copy_model(numb_models, iter_index - 1, iter_index, suffix)
322337
return
323338
elif (
324339
model_devi_engine != "calypso"
325340
and iter_index > 0
326341
and _check_skip_train(model_devi_jobs[iter_index - 1])
327342
):
328343
log_task("skip training at step %d " % (iter_index - 1))
329-
copy_model(numb_models, iter_index - 1, iter_index)
344+
copy_model(numb_models, iter_index - 1, iter_index, suffix)
330345
return
331346
else:
332347
iter_name = make_iter_name(iter_index)
@@ -647,7 +662,9 @@ def make_train(iter_index, jdata, mdata):
647662
)
648663
if copied_models is not None:
649664
for ii in range(len(copied_models)):
650-
_link_old_models(work_path, [copied_models[ii]], ii, basename="init.pb")
665+
_link_old_models(
666+
work_path, [copied_models[ii]], ii, basename=f"init{suffix}"
667+
)
651668
# Copy user defined forward files
652669
symlink_user_forward_files(mdata=mdata, task_type="train", work_path=work_path)
653670
# HDF5 format for training data
@@ -699,6 +716,7 @@ def run_train(iter_index, jdata, mdata):
699716
# print("debug:run_train:mdata", mdata)
700717
# load json param
701718
numb_models = jdata["numb_models"]
719+
suffix = _get_model_suffix(jdata)
702720
# train_param = jdata['train_param']
703721
train_input_file = default_train_input_file
704722
training_reuse_iter = jdata.get("training_reuse_iter")
@@ -730,7 +748,11 @@ def run_train(iter_index, jdata, mdata):
730748
"training_init_model, training_init_frozen_model, and training_finetune_model are mutually exclusive."
731749
)
732750

733-
train_command = mdata.get("train_command", "dp")
751+
train_command = mdata.get("train_command", "dp").strip()
752+
# assert train_command == "dp", "The 'train_command' should be 'dp'" # the tests should be updated to run this command
753+
if suffix == ".pth":
754+
train_command += " --pt"
755+
734756
train_resources = mdata["train_resources"]
735757

736758
# paths
@@ -761,9 +783,9 @@ def run_train(iter_index, jdata, mdata):
761783
if training_init_model:
762784
init_flag = " --init-model old/model.ckpt"
763785
elif training_init_frozen_model is not None:
764-
init_flag = " --init-frz-model old/init.pb"
786+
init_flag = f" --init-frz-model old/init{suffix}"
765787
elif training_finetune_model is not None:
766-
init_flag = " --finetune old/init.pb"
788+
init_flag = f" --finetune old/init{suffix}"
767789
command = f"{train_command} train {train_input_file}{extra_flags}"
768790
command = f"{{ if [ ! -f model.ckpt.index ]; then {command}{init_flag}; else {command} --restart model.ckpt; fi }}"
769791
command = f"/bin/sh -c {shlex.quote(command)}"
@@ -792,23 +814,35 @@ def run_train(iter_index, jdata, mdata):
792814
if "srtab_file_path" in jdata.keys():
793815
forward_files.append(zbl_file)
794816
if training_init_model:
795-
forward_files += [
796-
os.path.join("old", "model.ckpt.meta"),
797-
os.path.join("old", "model.ckpt.index"),
798-
os.path.join("old", "model.ckpt.data-00000-of-00001"),
799-
]
817+
if suffix == ".pb":
818+
forward_files += [
819+
os.path.join("old", "model.ckpt.meta"),
820+
os.path.join("old", "model.ckpt.index"),
821+
os.path.join("old", "model.ckpt.data-00000-of-00001"),
822+
]
823+
elif suffix == ".pth":
824+
forward_files += [os.path.join("old", "model.ckpt.pt")]
800825
elif training_init_frozen_model is not None or training_finetune_model is not None:
801-
forward_files.append(os.path.join("old", "init.pb"))
826+
forward_files.append(os.path.join("old", f"init{suffix}"))
802827

803-
backward_files = ["frozen_model.pb", "lcurve.out", "train.log"]
804-
backward_files += [
805-
"model.ckpt.meta",
806-
"model.ckpt.index",
807-
"model.ckpt.data-00000-of-00001",
828+
backward_files = [
829+
f"frozen_model{suffix}",
830+
"lcurve.out",
831+
"train.log",
808832
"checkpoint",
809833
]
810834
if jdata.get("dp_compress", False):
811-
backward_files.append("frozen_model_compressed.pb")
835+
backward_files.append(f"frozen_model_compressed{suffix}")
836+
837+
if suffix == ".pb":
838+
backward_files += [
839+
"model.ckpt.meta",
840+
"model.ckpt.index",
841+
"model.ckpt.data-00000-of-00001",
842+
]
843+
elif suffix == ".pth":
844+
backward_files += ["model.ckpt.pt"]
845+
812846
if not jdata.get("one_h5", False):
813847
init_data_sys_ = jdata["init_data_sys"]
814848
init_data_sys = []
@@ -879,13 +913,14 @@ def post_train(iter_index, jdata, mdata):
879913
log_task("copied model, do not post train")
880914
return
881915
# symlink models
916+
suffix = _get_model_suffix(jdata)
882917
for ii in range(numb_models):
883-
if not jdata.get("dp_compress", False):
884-
model_name = "frozen_model.pb"
885-
else:
886-
model_name = "frozen_model_compressed.pb"
918+
model_name = f"frozen_model{suffix}"
919+
if jdata.get("dp_compress", False):
920+
model_name = f"frozen_model_compressed{suffix}"
921+
922+
ofile = os.path.join(work_path, "graph.%03d%s" % (ii, suffix))
887923
task_file = os.path.join(train_task_fmt % ii, model_name)
888-
ofile = os.path.join(work_path, "graph.%03d.pb" % ii)
889924
if os.path.isfile(ofile):
890925
os.remove(ofile)
891926
os.symlink(task_file, ofile)
@@ -1124,7 +1159,8 @@ def make_model_devi(iter_index, jdata, mdata):
11241159
iter_name = make_iter_name(iter_index)
11251160
train_path = os.path.join(iter_name, train_name)
11261161
train_path = os.path.abspath(train_path)
1127-
models = sorted(glob.glob(os.path.join(train_path, "graph*pb")))
1162+
suffix = _get_model_suffix(jdata)
1163+
models = sorted(glob.glob(os.path.join(train_path, f"graph*{suffix}")))
11281164
work_path = os.path.join(iter_name, model_devi_name)
11291165
create_path(work_path)
11301166
if model_devi_engine == "calypso":
@@ -1305,7 +1341,8 @@ def _make_model_devi_revmat(iter_index, jdata, mdata, conf_systems):
13051341
iter_name = make_iter_name(iter_index)
13061342
train_path = os.path.join(iter_name, train_name)
13071343
train_path = os.path.abspath(train_path)
1308-
models = sorted(glob.glob(os.path.join(train_path, "graph*pb")))
1344+
suffix = _get_model_suffix(jdata)
1345+
models = sorted(glob.glob(os.path.join(train_path, f"graph*{suffix}")))
13091346
task_model_list = []
13101347
for ii in models:
13111348
task_model_list.append(os.path.join("..", os.path.basename(ii)))
@@ -1502,7 +1539,8 @@ def _make_model_devi_native(iter_index, jdata, mdata, conf_systems):
15021539
iter_name = make_iter_name(iter_index)
15031540
train_path = os.path.join(iter_name, train_name)
15041541
train_path = os.path.abspath(train_path)
1505-
models = glob.glob(os.path.join(train_path, "graph*pb"))
1542+
suffix = _get_model_suffix(jdata)
1543+
models = sorted(glob.glob(os.path.join(train_path, f"graph*{suffix}")))
15061544
task_model_list = []
15071545
for ii in models:
15081546
task_model_list.append(os.path.join("..", os.path.basename(ii)))
@@ -1644,7 +1682,8 @@ def _make_model_devi_native_gromacs(iter_index, jdata, mdata, conf_systems):
16441682
iter_name = make_iter_name(iter_index)
16451683
train_path = os.path.join(iter_name, train_name)
16461684
train_path = os.path.abspath(train_path)
1647-
models = glob.glob(os.path.join(train_path, "graph*pb"))
1685+
suffix = _get_model_suffix(jdata)
1686+
models = sorted(glob.glob(os.path.join(train_path, f"graph*{suffix}")))
16481687
task_model_list = []
16491688
for ii in models:
16501689
task_model_list.append(os.path.join("..", os.path.basename(ii)))
@@ -1827,7 +1866,8 @@ def _make_model_devi_amber(
18271866
.replace("@qm_theory@", jdata["low_level"])
18281867
.replace("@rcut@", str(jdata["cutoff"]))
18291868
)
1830-
models = sorted(glob.glob(os.path.join(train_path, "graph.*.pb")))
1869+
suffix = _get_model_suffix(jdata)
1870+
models = sorted(glob.glob(os.path.join(train_path, f"graph.*{suffix}")))
18311871
task_model_list = []
18321872
for ii in models:
18331873
task_model_list.append(os.path.join("..", os.path.basename(ii)))
@@ -1935,7 +1975,9 @@ def run_md_model_devi(iter_index, jdata, mdata):
19351975
run_tasks = [os.path.basename(ii) for ii in run_tasks_]
19361976
# dlog.info("all_task is ", all_task)
19371977
# dlog.info("run_tasks in run_model_deviation",run_tasks_)
1938-
all_models = glob.glob(os.path.join(work_path, "graph*pb"))
1978+
1979+
suffix = _get_model_suffix(jdata)
1980+
all_models = glob.glob(os.path.join(work_path, f"graph*{suffix}"))
19391981
model_names = [os.path.basename(ii) for ii in all_models]
19401982

19411983
model_devi_engine = jdata.get("model_devi_engine", "lammps")

dpgen/simplify/simplify.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
record_iter,
3232
)
3333
from dpgen.generator.run import (
34+
_get_model_suffix,
3435
data_system_fmt,
3536
fp_name,
3637
fp_task_fmt,
@@ -186,7 +187,9 @@ def make_model_devi(iter_index, jdata, mdata):
186187
# link the model
187188
train_path = os.path.join(iter_name, train_name)
188189
train_path = os.path.abspath(train_path)
189-
models = glob.glob(os.path.join(train_path, "graph*pb"))
190+
suffix = _get_model_suffix(jdata)
191+
models = glob.glob(os.path.join(train_path, f"graph*{suffix}"))
192+
190193
for mm in models:
191194
model_name = os.path.basename(mm)
192195
os.symlink(mm, os.path.join(work_path, model_name))

0 commit comments

Comments
 (0)