-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpreprocessing.py
More file actions
70 lines (65 loc) · 2.95 KB
/
preprocessing.py
File metadata and controls
70 lines (65 loc) · 2.95 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
"""
call preprocess module
"""
from patch_extraction.patch_extract import PatchExtract
from patch_extraction.aug_params import AUG_PARAMS
import argparse
from configs.config_vars import DS_ROOT_DIR
import os.path as osp
from datasets import CT_DATASETS
def param_parser():
parser = argparse.ArgumentParser(description='preprocess datasets')
parser.add_argument('--dataset', '-r',
dest="dataset",
help='name of the dataset, defined in CT_DATASETS',
default='StanfordRadiogenomicsDataset')
parser.add_argument('--aug_param', '-a',
dest="aug_param",
help='augmentation parameters (filename) for '
'aumgentation, should be in aug_param folder',
default='') # in `aug_params`` rotate, shift or version1
parser.add_argument('--save_dir', '-S',
dest="save_dir",
help="save directory of converted patches",
default='temp/') # NOTE: e.g. "LNDb/LNDb-path32-aug"
parser.add_argument('--vis_dir', '-V',
dest="vis_dir",
help="visualization directory of converted patches",
default='temp/') # NTE: e.g. StanfordRadiogenomics/patch-visualization-32-aug
parser.add_argument('--size', '-s',
dest="size",
help="size of the patches",
default='32')
parser.add_argument('--multi', '-M',
dest="multi",
action='store_true',
help="whether to use multiprocessing")
parser.add_argument('--overwrite',
dest="overwrite",
help="whether to overwrite previously processed set",
action='store_true',)
args = parser.parse_args()
args.size = tuple([int(args.size)] * 3)
if not args.save_dir.startswith('/'):
args.save_dir = osp.join(DS_ROOT_DIR, args.save_dir)
if not args.vis_dir.startswith('/'):
args.vis_dir = osp.join(DS_ROOT_DIR, args.vis_dir)
return args
if __name__ == "__main__":
args = param_parser()
# dataset registration
ds = CT_DATASETS[args.dataset]()
# augmentation parameters
aug_param = AUG_PARAMS[args.aug_param] # defined in aug_params
# extract patches
patch_extract = PatchExtract(patch_size=args.size,
dataset=ds,
augmentation_params=aug_param,
debug=bool(args.vis_dir))
patch_extract.load_extract_ds(save_dir=args.save_dir,
multi=args.multi,
overwrite=args.overwrite)
# visualization
if args.vis_dir:
patch_extract.vis_ds(dataset_dir=args.save_dir,
vis_dir=args.vis_dir)