-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathdb_config.py
executable file
·73 lines (63 loc) · 2.62 KB
/
db_config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import os
from easydict import EasyDict as edict
cfg = edict()
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~inference~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
cfg.MEANS = [123.68, 116.78, 103.94]
cfg.INPUT_MAX_SIZE = 640
cfg.K = 10
cfg.EPSILON_RATIO = 0.001
cfg.SHRINK_RATIO = 0.4
cfg.THRESH_MIN = 0.3
cfg.THRESH_MAX = 0.7
cfg.FILTER_MIN_AREA = 1e-4
# ['resnet_v1_50', 'resnet_v1_18', 'resnet_v2_50', 'resnet_v2_18', 'mobilenet_v2', 'mobilenet_v3']
cfg.BACKBONE = 'resnet_v1_50'
cfg.ASPP_LAYER = False
# ~~~~~~~~~~~~~~~~~~z~~~~~~~~~~~~~train config~~~~~~~~~~~~~~~~~~~~~~~~~~~~
cfg.TRAIN = edict()
cfg.TRAIN.VERSION = 'aspp'
# 多gpu训练
cfg.TRAIN.VIS_GPU = '3,4'
cfg.TRAIN.BATCH_SIZE_PER_GPU = 2
cfg.TRAIN.LOSS_ALPHA = 1.0
cfg.TRAIN.LOSS_BETA = 10.0
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~dataload & aug~~~~~~~~~~~~~~~~~~~~~~~~~~
cfg.TRAIN.IMG_DIR = '/hostpersistent/zzh/dataset/open_data/ctw1500/train/text_image'
cfg.TRAIN.LABEL_DIR = '/hostpersistent/zzh/dataset/open_data/ctw1500/train/text_label_curve'
cfg.TRAIN.IMG_SIZE = 640
cfg.TRAIN.MIN_TEXT_SIZE = 1
cfg.TRAIN.MIN_AREA = 1
cfg.TRAIN.IMG_SCALE = [0.5, 1, 1, 1, 1.5, 2.0]
cfg.TRAIN.CROP_PROB = 0.9
cfg.TRAIN.MIN_CROP_SIDE_RATIO = 0.001
cfg.TRAIN.NUM_READERS = 20
cfg.TRAIN.DATA_AUG_PROB = 0.0
cfg.TRAIN.AUG_TOOL = ['GaussianBlur',
'AverageBlur',
'MedianBlur',
'BilateralBlur',
'MotionBlur',
#'ElasticTransformation',
#'PerspectiveTransform',
]
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~save ckpt and log~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
cfg.TRAIN.MAX_STEPS = 10000000
cfg.TRAIN.SAVE_CHECKPOINT_STEPS = 2000
cfg.TRAIN.SAVE_SUMMARY_STEPS = 100
cfg.TRAIN.SAVE_MAX = 20
cfg.TRAIN.TRAIN_LOGS = os.path.join('/hostpersistent/zzh/lab/DB-tf/', 'tf_logs')
cfg.TRAIN.CHECKPOINTS_OUTPUT_DIR = os.path.join('/hostpersistent/zzh/lab/DB-tf/', 'ckpt')
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~restore and pretrain~~~~~~~~~~~~~~~~~~~~~
cfg.TRAIN.RESTORE = None
cfg.TRAIN.RESTORE_CKPT_PATH = os.path.join('/hostpersistent/zzh/lab/DB-tf/', 'ckpt')
cfg.TRAIN.PRETRAINED_MODEL_PATH = '/hostpersistent/zzh/lab/DB-tf/ckpt/DB_resnet_v1_50_1223_model.ckpt-121201'
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~super em~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
cfg.TRAIN.LEARNING_RATE = 0.0001
cfg.TRAIN.OPT = 'adam'#'momentum'#
cfg.TRAIN.MOVING_AVERAGE_DECAY = 0.997
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ eval ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
cfg.EVAL = edict()
cfg.EVAL.IMG_DIR = '/hostpersistent/zzh/dataset/open_data/ctw1500/test/text_image'
cfg.EVAL.LABEL_DIR = '/hostpersistent/zzh/dataset/open_data/ctw1500/test/text_label_circum'
cfg.EVAL.NUM_READERS = 1
cfg.EVAL.TEST_STEP = 5000