Skip to content

Commit a5b309e

Browse files
committed
Initial v1
1 parent de1a4ec commit a5b309e

File tree

6 files changed

+299
-0
lines changed

6 files changed

+299
-0
lines changed

infer.ipynb

+181
Large diffs are not rendered by default.

models/vae-v1/saved_model.pb

364 KB
Binary file not shown.
Binary file not shown.
3.64 KB
Binary file not shown.

train.py

+35
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import tensorflow as tf
2+
from tensorflow import keras
3+
import tensorflow_datasets as tfds
4+
from vae import VarAutoEncoder
5+
6+
7+
# Load Data
8+
def preprocess(image, _):
9+
"""Return normalized image for both input and output label"""
10+
image = tf.cast(image, tf.float32)/255.
11+
image = tf.where(image < 0.5, 0., 1.)
12+
return image, image
13+
14+
15+
ds_train, ds_test = tfds.load(
16+
'mnist',
17+
split=['train', 'test'],
18+
as_supervised=True
19+
)
20+
ds_train = ds_train.map(preprocess).shuffle(1024).batch(64)
21+
ds_test = ds_test.map(preprocess).shuffle(1024).batch(64)
22+
23+
24+
# Train model
25+
model = VarAutoEncoder(latent_dim=2)
26+
model.compile(optimizer=keras.optimizers.Adam(1e-4))
27+
history = model.fit(
28+
ds_train,
29+
validation_data=ds_test,
30+
epochs=50
31+
)
32+
33+
34+
# Save model
35+
model.save("models/vae-v1")

vae.py

+83
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import tensorflow as tf
2+
from tensorflow import keras
3+
from tensorflow.keras.layers import Conv2DTranspose, Conv2D, Dense, Flatten, Reshape
4+
import numpy as np
5+
6+
7+
class Sampling(keras.layers.Layer):
8+
"""Sample *z* from the *z_mean* and *z_logvar* from encoder to input in decoder"""
9+
10+
def call(self, inputs, **kwargs):
11+
z_mean, z_logvar = inputs
12+
epsilon = tf.random.normal(shape=tf.shape(z_mean))
13+
return z_mean + tf.exp(0.5 * z_logvar) * epsilon
14+
15+
16+
class Encoder(keras.layers.Layer):
17+
"""Maps MNIST digits to triplet (z_mean, z_logvar, z)"""
18+
19+
def __init__(self, latent_dim, **kwargs):
20+
super(Encoder, self).__init__(**kwargs)
21+
self.conv1 = Conv2D(32, 3, (2,2), activation='relu')
22+
self.conv2 = Conv2D(64, 3, (2,2), activation='relu')
23+
self.flatten = Flatten()
24+
self.dense3_1 = Dense(latent_dim)
25+
self.dense3_2 = Dense(latent_dim)
26+
self.sampling = Sampling()
27+
28+
def call(self, inputs, **kwargs):
29+
x = self.conv1(inputs)
30+
x = self.conv2(x)
31+
x = self.flatten(x)
32+
z_mean = self.dense3_1(x)
33+
z_logvar = self.dense3_2(x)
34+
z = self.sampling((z_mean, z_logvar))
35+
return z_mean, z_logvar, z
36+
37+
38+
class Decoder(keras.layers.Layer):
39+
"""Reconstructs the image from latent variable *z*"""
40+
41+
def __init__(self, **kwargs):
42+
super(Decoder, self).__init__(**kwargs)
43+
self.dense1 = Dense(7*7*32, activation='relu')
44+
self.reshape = Reshape((7, 7, 32))
45+
self.deconv1 = Conv2DTranspose(64, 3, 2, padding='same', activation='relu')
46+
self.deconv2 = Conv2DTranspose(32, 3, 2, padding='same', activation='relu')
47+
self.out = Conv2DTranspose(1, 3, 1, padding='same')
48+
49+
def call(self, inputs, **kwargs):
50+
x = self.dense1(inputs)
51+
x = self.reshape(x)
52+
x = self.deconv1(x)
53+
x = self.deconv2(x)
54+
return self.out(x)
55+
56+
57+
class VarAutoEncoder(keras.Model):
58+
"""Convolutional Variational AutoEncoder Model for MNIST"""
59+
60+
def __init__(self, latent_dim, **kwargs):
61+
super(VarAutoEncoder, self).__init__(**kwargs)
62+
self.encoder = Encoder(latent_dim)
63+
self.decoder = Decoder()
64+
65+
# TODO : Use TFP library functions
66+
@tf.function
67+
def log_normal_pdf(self, z, mean, logvar):
68+
log2pi = tf.math.log(2. * np.pi)
69+
return tf.reduce_sum(
70+
-.5 * ((z - mean) ** 2. * tf.exp(-logvar) + logvar + log2pi),
71+
axis=1)
72+
73+
def call(self, inputs, **kwargs):
74+
z_mean, z_logvar, z = self.encoder(inputs)
75+
reconstructed = self.decoder(z)
76+
77+
# Compute loss
78+
cross_entropy_loss = -tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(inputs, reconstructed), axis=[1, 2, 3])
79+
kl_loss = self.log_normal_pdf(z, z_mean, z_logvar) - self.log_normal_pdf(z, 0., 0.)
80+
total_loss = -tf.reduce_mean(cross_entropy_loss - kl_loss)
81+
82+
self.add_loss(total_loss)
83+
return reconstructed

0 commit comments

Comments
 (0)