-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathtrain.py
59 lines (46 loc) · 1.76 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
"""Script for training the model."""
import os
import numpy as np
import tensorflow as tf
from bgan.mnist.models import BinaryGAN, GAN
from config import CONFIG
def load_data():
"""Load and return the training data."""
print('[*] Loading data...')
# Load data from SharedArray
if CONFIG['data']['training_data_location'] == 'sa':
import SharedArray as sa
x_train = sa.attach(CONFIG['data']['training_data'])
# Load data from hard disk
elif CONFIG['data']['training_data_location'] == 'hd':
if os.path.isabs(CONFIG['data']['training_data']):
x_train = np.load(CONFIG['data']['training_data'])
else:
filepath = os.path.abspath(os.path.join(
os.path.realpath(__file__), 'training_data',
CONFIG['data']['training_data']))
x_train = np.load(filepath)
return x_train
def main():
"""Main function."""
if CONFIG['exp']['model'] not in ('binarygan', 'gan'):
raise ValueError("Unrecognizable model name")
print("Start experiment: {}".format(CONFIG['exp']['exp_name']))
# Load training data
x_train = load_data()
# Open TensorFlow session
with tf.Session(config=CONFIG['tensorflow']) as sess:
# Create model
if CONFIG['exp']['model'] == 'gan':
gan = GAN(sess, CONFIG['model'])
elif CONFIG['exp']['model'] == 'binarygan':
gan = BinaryGAN(sess, CONFIG['model'])
# Initialize all variables
gan.init_all()
# Load pretrained model if given
if CONFIG['exp']['pretrained_dir'] is not None:
gan.load_latest(CONFIG['exp']['pretrained_dir'])
# Train the model
gan.train(x_train, CONFIG['train'])
if __name__ == '__main__':
main()