-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
44 lines (35 loc) · 1.28 KB
/
train.py
File metadata and controls
44 lines (35 loc) · 1.28 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
"""
Training Entrance
"""
import os
os.environ["PYTHONWARNINGS"] = "ignore::UserWarning:scipy"
os.environ["PYTHONWARNINGS"] = "ignore:pkg_resources is deprecated"
from deepode.nn import Trainer
from deepode.nn.utils import *
from deepode.nn.config import parser
def run_training_entry(args):
trainer = Trainer()
trainer.init_dataloaders(args)
trainer.build_model(args)
logging_args(args) ## print hyper-parameters in log file
trainer.run_training(args)
def main():
args = parser().parse_args()
args.model_root = "model"
args.modelname = f"DRM19-0D1DPert-ckv8-deepode"
args.input_path = "/home/yiyuxiao/data/AI4S/DeePCK/Data/DRM19/DRM19_2200wFlameMFPert_X.npy"
args.label_path = "/home/yiyuxiao/data/AI4S/DeePCK/Data/DRM19/DRM19_2200wFlameMFPert_Y.npy"
args.mech_path = "mechanism/DRM19.cti"
args.zero_input = ["ar"]
args.zero_gradient = ["p", "N2"]
setup_current_time(args)
setup_device(args)
create_model_path(args)
setup_logging(args)
run_training_entry(args)
if __name__ == '__main__':
"""
>>> python train.py -cuda 0,1,2,3,4,5,6,7 -ddp --delta_t 1e-6 -note "this is a test of DDP training"
>>> python train.py --device="cuda:5" --delta_t 1e-6 -note "a test of single-GPU training"
"""
main()