-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathrun.py
34 lines (29 loc) · 1020 Bytes
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
# -*- coding: utf-8 -*-
"""
@Author : Fei Wang
@Contact : [email protected]
@Time : 2020/12/10 0:30
@Description:
"""
import json
import argparse
from src.train.run_cross_validation import cross_validation
from src.train.run_train_test import train_and_test
from src.train.run_pretraining import pretrain
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--exp', type=str, required=True, help='experiment to run')
parser.add_argument('--config', type=str, required=True, help='path to config file')
args = parser.parse_args()
# cross validation on wikitables
if args.exp == 'cross_validation':
config = json.load(open(args.config))
cross_validation(config)
# train and test on webquerytable
elif args.exp == 'train_test':
config = json.load(open(args.config))
train_and_test(config)
# pretrain
elif args.exp == 'pretrain':
config = json.load(open(args.config))
pretrain(config)