-
Notifications
You must be signed in to change notification settings - Fork 147
/
Copy pathmain.py
executable file
·42 lines (34 loc) · 1.14 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from utils.config import parse_args
from utils.data_loader import get_data_loader
from models.gan import GAN
from models.dcgan import DCGAN_MODEL
from models.wgan_clipping import WGAN_CP
from models.wgan_gradient_penalty import WGAN_GP
def main(args):
model = None
if args.model == 'GAN':
model = GAN(args)
elif args.model == 'DCGAN':
model = DCGAN_MODEL(args)
elif args.model == 'WGAN-CP':
model = WGAN_CP(args)
elif args.model == 'WGAN-GP':
model = WGAN_GP(args)
else:
print("Model type non-existing. Try again.")
exit(-1)
# Load datasets to train and test loaders
train_loader, test_loader = get_data_loader(args)
#feature_extraction = FeatureExtractionTest(train_loader, test_loader, args.cuda, args.batch_size)
# Start model training
if args.is_train == 'True':
model.train(train_loader)
# start evaluating on test data
else:
model.evaluate(test_loader, args.load_D, args.load_G)
for i in range(50):
model.generate_latent_walk(i)
if __name__ == '__main__':
args = parse_args()
print(args.cuda)
main(args)