Skip to content

Commit 46d0e78

Browse files
committed
Pulling all toguether to training
1 parent 030a33e commit 46d0e78

File tree

5 files changed

+72
-15
lines changed

5 files changed

+72
-15
lines changed

.gitignore

+4
Original file line numberDiff line numberDiff line change
@@ -250,4 +250,8 @@ $RECYCLE.BIN/
250250
# Windows shortcuts
251251
*.lnk
252252

253+
254+
examples
255+
model_checkpoints
256+
253257
# End of https://www.toptal.com/developers/gitignore/api/visualstudiocode,macos,windows,linux,git,virtualenv,python

cycle/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .model import CycleGAN

cycle/callbacks.py

+10-9
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,24 @@
11
import numpy as np
22
import matplotlib.pyplot as plt
33
from tensorflow.data import Dataset
4-
from tensorflow.keras.callbacks import CallBack
5-
from tensorflow.keras.preprocessing import array_to_img
4+
from tensorflow.keras.callbacks import Callback
5+
from tensorflow.keras.preprocessing.image import array_to_img
66

7-
class GANMonitor(CallBack):
8-
def __init__(self, img_num: int):
7+
class GANMonitor(Callback):
8+
def __init__(self, data: Dataset, img_num: int = 4):
99
super(GANMonitor, self).__init__()
1010
self.img_num: int = img_num
11+
self.data = data
1112

12-
def on_epoch_end(self, epoch: int, test_horses: Dataset, logs=None):
13+
def on_epoch_end(self, epoch: int, logs=None):
1314
_, ax = plt.subplots(self.img_num, 2, figsize=(12, 12))
14-
for i, img in enumerate(test_horses.take(self.img_num)):
15-
output = self.model.gen_G(img)[0].numpy()
15+
for i, img in enumerate(self.data.take(self.img_num)):
16+
output = self.model.generator_G(img)[0]
1617
output = (output * 127.5 + 127.5).numpy().astype(np.uint8)
1718
img = (img[0] * 127.5 + 127.5).numpy().astype(np.uint8)
1819

1920
output = array_to_img(output)
2021
img = array_to_img(img)
2122

22-
output.save(f'generated_img_{i}_{epoch}.png')
23-
img.save('original_img_{i}_{epoch}.png')
23+
output.save(f'./examples/generated_img_{i}_{epoch}.png')
24+
img.save('./examples/original_img_{i}_{epoch}.png')

cycle/model_helper.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Tuple
2+
13
from tensorflow.keras import layers
24
from tensorflow_addons.layers import InstanceNormalization
35
from tensorflow.keras.initializers import Initializer
@@ -8,14 +10,14 @@
810
def residual_block(
911
x: layers.Layer,
1012
activation: layers.Activation,
11-
kernel_size: tuple[int] = (3, 3),
12-
strides: tuple[int] = (1, 1),
13+
kernel_size: Tuple[int] = (3, 3),
14+
strides: Tuple[int] = (1, 1),
1315
padding: str = 'valid',
1416
kernel_initializer: Initializer = None,
1517
gamma_initializer: Initializer = None,
1618
use_bias: bool = False) -> layers.Layer:
1719

18-
dim: int =- x.shape[-1]
20+
dim: int = x.shape[-1]
1921
input_tensor: layers.Layer = x
2022

2123
x = ReflectionPadding2D()(input_tensor)
@@ -69,8 +71,8 @@ def upsample(
6971
x: layers.Layer,
7072
filters: int,
7173
activation: layers.Activation,
72-
kernel_size: tuple[int] = (3, 3),
73-
strides: tuple[int] = (2, 2),
74+
kernel_size: Tuple[int] = (3, 3),
75+
strides: Tuple[int] = (2, 2),
7476
padding: str = 'same',
7577
kernel_initializer: Initializer = None,
7678
gamma_initializer: Initializer = None,

train.py

+50-1
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,30 @@
33
import numpy as np
44
import matplotlib.pyplot as plt
55
import tensorflow_datasets as tfds
6+
from tensorflow.data import Dataset
7+
from tensorflow.keras.optimizers import Adam
68
from tensorflow.data.experimental import AUTOTUNE
9+
from tensorflow.keras.callbacks import ModelCheckpoint
10+
from tensorflow.config.experimental import set_memory_growth
11+
from tensorflow.config import list_physical_devices
12+
13+
# Failed to get convolution algorithm
14+
physical_devices = list_physical_devices('GPU')
15+
set_memory_growth(physical_devices[0], True)
716

817
# Hyperparameters
918
from cycle.config import ORIG_IMG_SIZE
1019
from cycle.config import INPUT_IMG_SIZE
1120
from cycle.config import BUFFER_SIZE
1221
from cycle.config import BATCH_SIZE
1322

23+
from cycle import CycleGAN
24+
from cycle.callbacks import GANMonitor
25+
from cycle.loss import generator_loss_fn
26+
from cycle.loss import discriminator_loss_fn
27+
from cycle.generator import get_resnet_generator
28+
from cycle.discriminator import get_discriminator
29+
1430
from cycle.preprocessing import preprocess_train_image
1531
from cycle.preprocessing import preprocess_test_image
1632

@@ -64,4 +80,37 @@
6480
zebra = (((samples[1][0] * 127.5) + 127.5).numpy()).astype(np.uint8)
6581
ax_test[i, 0].imshow(horse)
6682
ax_test[i, 1].imshow(zebra)
67-
plt.show()
83+
plt.show()
84+
85+
# Putting all together
86+
generator_G = get_resnet_generator(name='generator_G')
87+
generator_F = get_resnet_generator(name='generator_F')
88+
89+
discriminator_X = get_discriminator(name='discriminator_X')
90+
discriminator_Y = get_discriminator(name='discriminator_Y')
91+
92+
cycle_model = CycleGAN(
93+
generator_G=generator_G,
94+
generator_F=generator_F,
95+
discriminator_X=discriminator_X,
96+
discriminator_Y=discriminator_Y)
97+
98+
cycle_model.compile(
99+
generator_G_opt=Adam(learning_rate=2e-4, beta_1=0.5),
100+
generator_F_opt=Adam(learning_rate=2e-4, beta_1=0.5),
101+
discriminator_X_opt=Adam(learning_rate=2e-4, beta_1=0.5),
102+
discriminator_Y_opt=Adam(learning_rate=2e-4, beta_1=0.5),
103+
generator_loss_fn=generator_loss_fn,
104+
discriminator_loss_fn=discriminator_loss_fn)
105+
106+
plotter = GANMonitor(data=test_horses)
107+
checkpoint_filepath = "./model_checkpoints/cyclegan_checkpoints.{epoch:03d}"
108+
model_checkpoint_callback = ModelCheckpoint(
109+
filepath=checkpoint_filepath
110+
)
111+
112+
cycle_model.fit(
113+
Dataset.zip((train_horses, train_zebras)),
114+
epochs=90,
115+
callbacks=[plotter, model_checkpoint_callback],
116+
)

0 commit comments

Comments
 (0)