3
3
import numpy as np
4
4
import matplotlib .pyplot as plt
5
5
import tensorflow_datasets as tfds
6
+ from tensorflow .data import Dataset
7
+ from tensorflow .keras .optimizers import Adam
6
8
from tensorflow .data .experimental import AUTOTUNE
9
+ from tensorflow .keras .callbacks import ModelCheckpoint
10
+ from tensorflow .config .experimental import set_memory_growth
11
+ from tensorflow .config import list_physical_devices
12
+
13
+ # Failed to get convolution algorithm
14
+ physical_devices = list_physical_devices ('GPU' )
15
+ set_memory_growth (physical_devices [0 ], True )
7
16
8
17
# Hyperparameters
9
18
from cycle .config import ORIG_IMG_SIZE
10
19
from cycle .config import INPUT_IMG_SIZE
11
20
from cycle .config import BUFFER_SIZE
12
21
from cycle .config import BATCH_SIZE
13
22
23
+ from cycle import CycleGAN
24
+ from cycle .callbacks import GANMonitor
25
+ from cycle .loss import generator_loss_fn
26
+ from cycle .loss import discriminator_loss_fn
27
+ from cycle .generator import get_resnet_generator
28
+ from cycle .discriminator import get_discriminator
29
+
14
30
from cycle .preprocessing import preprocess_train_image
15
31
from cycle .preprocessing import preprocess_test_image
16
32
64
80
zebra = (((samples [1 ][0 ] * 127.5 ) + 127.5 ).numpy ()).astype (np .uint8 )
65
81
ax_test [i , 0 ].imshow (horse )
66
82
ax_test [i , 1 ].imshow (zebra )
67
- plt .show ()
83
+ plt .show ()
84
+
85
+ # Putting all together
86
+ generator_G = get_resnet_generator (name = 'generator_G' )
87
+ generator_F = get_resnet_generator (name = 'generator_F' )
88
+
89
+ discriminator_X = get_discriminator (name = 'discriminator_X' )
90
+ discriminator_Y = get_discriminator (name = 'discriminator_Y' )
91
+
92
+ cycle_model = CycleGAN (
93
+ generator_G = generator_G ,
94
+ generator_F = generator_F ,
95
+ discriminator_X = discriminator_X ,
96
+ discriminator_Y = discriminator_Y )
97
+
98
+ cycle_model .compile (
99
+ generator_G_opt = Adam (learning_rate = 2e-4 , beta_1 = 0.5 ),
100
+ generator_F_opt = Adam (learning_rate = 2e-4 , beta_1 = 0.5 ),
101
+ discriminator_X_opt = Adam (learning_rate = 2e-4 , beta_1 = 0.5 ),
102
+ discriminator_Y_opt = Adam (learning_rate = 2e-4 , beta_1 = 0.5 ),
103
+ generator_loss_fn = generator_loss_fn ,
104
+ discriminator_loss_fn = discriminator_loss_fn )
105
+
106
+ plotter = GANMonitor (data = test_horses )
107
+ checkpoint_filepath = "./model_checkpoints/cyclegan_checkpoints.{epoch:03d}"
108
+ model_checkpoint_callback = ModelCheckpoint (
109
+ filepath = checkpoint_filepath
110
+ )
111
+
112
+ cycle_model .fit (
113
+ Dataset .zip ((train_horses , train_zebras )),
114
+ epochs = 90 ,
115
+ callbacks = [plotter , model_checkpoint_callback ],
116
+ )
0 commit comments