|
| 1 | +import tensorflow as tf |
| 2 | +from tensorflow import keras |
| 3 | +from tensorflow.keras.layers import Conv2DTranspose, Conv2D, Dense, Flatten, Reshape |
| 4 | +import numpy as np |
| 5 | + |
| 6 | + |
| 7 | +class Sampling(keras.layers.Layer): |
| 8 | + """Sample *z* from the *z_mean* and *z_logvar* from encoder to input in decoder""" |
| 9 | + |
| 10 | + def call(self, inputs, **kwargs): |
| 11 | + z_mean, z_logvar = inputs |
| 12 | + epsilon = tf.random.normal(shape=tf.shape(z_mean)) |
| 13 | + return z_mean + tf.exp(0.5 * z_logvar) * epsilon |
| 14 | + |
| 15 | + |
| 16 | +class Encoder(keras.layers.Layer): |
| 17 | + """Maps MNIST digits to triplet (z_mean, z_logvar, z)""" |
| 18 | + |
| 19 | + def __init__(self, latent_dim, **kwargs): |
| 20 | + super(Encoder, self).__init__(**kwargs) |
| 21 | + self.conv1 = Conv2D(32, 3, (2,2), activation='relu') |
| 22 | + self.conv2 = Conv2D(64, 3, (2,2), activation='relu') |
| 23 | + self.flatten = Flatten() |
| 24 | + self.dense3_1 = Dense(latent_dim) |
| 25 | + self.dense3_2 = Dense(latent_dim) |
| 26 | + self.sampling = Sampling() |
| 27 | + |
| 28 | + def call(self, inputs, **kwargs): |
| 29 | + x = self.conv1(inputs) |
| 30 | + x = self.conv2(x) |
| 31 | + x = self.flatten(x) |
| 32 | + z_mean = self.dense3_1(x) |
| 33 | + z_logvar = self.dense3_2(x) |
| 34 | + z = self.sampling((z_mean, z_logvar)) |
| 35 | + return z_mean, z_logvar, z |
| 36 | + |
| 37 | + |
| 38 | +class Decoder(keras.layers.Layer): |
| 39 | + """Reconstructs the image from latent variable *z*""" |
| 40 | + |
| 41 | + def __init__(self, **kwargs): |
| 42 | + super(Decoder, self).__init__(**kwargs) |
| 43 | + self.dense1 = Dense(7*7*32, activation='relu') |
| 44 | + self.reshape = Reshape((7, 7, 32)) |
| 45 | + self.deconv1 = Conv2DTranspose(64, 3, 2, padding='same', activation='relu') |
| 46 | + self.deconv2 = Conv2DTranspose(32, 3, 2, padding='same', activation='relu') |
| 47 | + self.out = Conv2DTranspose(1, 3, 1, padding='same') |
| 48 | + |
| 49 | + def call(self, inputs, **kwargs): |
| 50 | + x = self.dense1(inputs) |
| 51 | + x = self.reshape(x) |
| 52 | + x = self.deconv1(x) |
| 53 | + x = self.deconv2(x) |
| 54 | + return self.out(x) |
| 55 | + |
| 56 | + |
| 57 | +class VarAutoEncoder(keras.Model): |
| 58 | + """Convolutional Variational AutoEncoder Model for MNIST""" |
| 59 | + |
| 60 | + def __init__(self, latent_dim, **kwargs): |
| 61 | + super(VarAutoEncoder, self).__init__(**kwargs) |
| 62 | + self.encoder = Encoder(latent_dim) |
| 63 | + self.decoder = Decoder() |
| 64 | + |
| 65 | + # TODO : Use TFP library functions |
| 66 | + @tf.function |
| 67 | + def log_normal_pdf(self, z, mean, logvar): |
| 68 | + log2pi = tf.math.log(2. * np.pi) |
| 69 | + return tf.reduce_sum( |
| 70 | + -.5 * ((z - mean) ** 2. * tf.exp(-logvar) + logvar + log2pi), |
| 71 | + axis=1) |
| 72 | + |
| 73 | + def call(self, inputs, **kwargs): |
| 74 | + z_mean, z_logvar, z = self.encoder(inputs) |
| 75 | + reconstructed = self.decoder(z) |
| 76 | + |
| 77 | + # Compute loss |
| 78 | + cross_entropy_loss = -tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(inputs, reconstructed), axis=[1, 2, 3]) |
| 79 | + kl_loss = self.log_normal_pdf(z, z_mean, z_logvar) - self.log_normal_pdf(z, 0., 0.) |
| 80 | + total_loss = -tf.reduce_mean(cross_entropy_loss - kl_loss) |
| 81 | + |
| 82 | + self.add_loss(total_loss) |
| 83 | + return reconstructed |
0 commit comments