|
| 1 | +"""Moving MNIST. |
| 2 | +
|
| 3 | +This class produces a batch of moving MNIST experiments |
| 4 | +considered in the ICML2015 paper: |
| 5 | + Unsupervised Learning of Video Representations using LSTMs. |
| 6 | +
|
| 7 | +Author: Yuhuang Hu |
| 8 | + |
| 9 | +""" |
| 10 | +from __future__ import print_function |
| 11 | +from builtins import range |
| 12 | + |
| 13 | +import numpy as np |
| 14 | + |
| 15 | + |
| 16 | +def default_motion_fn(data_batch, data_pairs, config=None): |
| 17 | + """A default motion function for moving MNIST. |
| 18 | +
|
| 19 | + Parameters |
| 20 | + ---------- |
| 21 | + data_batch : numpy.ndarray |
| 22 | + the background data batch |
| 23 | + (batch_size x bg_size x bg_size x time_steps) |
| 24 | + data_pairs : numpy.ndarray |
| 25 | + the numpy array that takes the digits pairs |
| 26 | + (batch_size x num_digits x 28 x 28) |
| 27 | + config : dictionary |
| 28 | + configuration dictionary for extra configs |
| 29 | + """ |
| 30 | + # facts about the target data |
| 31 | + batch_size = data_batch.shape[0] |
| 32 | + bg_size = data_batch.shape[1] |
| 33 | + time_steps = data_batch.shape[3] |
| 34 | + num_digits = data_pairs.shape[1] |
| 35 | + digit_shape = data_pairs.shape[2] |
| 36 | + bound_size = bg_size-digit_shape+1 |
| 37 | + |
| 38 | + # load config |
| 39 | + # we use top left to coordinate the position |
| 40 | + x_vel = config["x_vel"] # velocity range for x direction |
| 41 | + y_vel = config["y_vel"] # velocity range for y direction |
| 42 | + # make sure you chose some reasonable range |
| 43 | + x_init = config["x_init"] # range for sample initial x position |
| 44 | + y_init = config["y_init"] # range for sample initial y position |
| 45 | + |
| 46 | + # the most intense way, looping everything |
| 47 | + for sample_idx in range(batch_size): |
| 48 | + # sample initial position (num_digits, 2) |
| 49 | + curr_pos = np.vstack((np.random.choice(y_init, num_digits), |
| 50 | + np.random.choice(x_init, num_digits))).T |
| 51 | + # sample initial velocity |
| 52 | + vels = np.vstack((np.random.choice(y_vel*2, num_digits)-y_vel, |
| 53 | + np.random.choice(x_vel*2, num_digits)-x_vel)).T |
| 54 | + |
| 55 | + # process every step |
| 56 | + for step in range(time_steps): |
| 57 | + # set the image at current position |
| 58 | + for num_idx in range(num_digits): |
| 59 | + data_batch[ |
| 60 | + sample_idx, |
| 61 | + curr_pos[num_idx, 0]:curr_pos[num_idx, 0]+digit_shape, |
| 62 | + curr_pos[num_idx, 1]:curr_pos[num_idx, 1]+digit_shape, |
| 63 | + step] += data_pairs[sample_idx, num_idx] |
| 64 | + data_batch[sample_idx] = np.clip(data_batch[sample_idx], 0, 255) |
| 65 | + |
| 66 | + # update current position based on the velocity |
| 67 | + # while consider boundary issues |
| 68 | + for num_idx in range(num_digits): |
| 69 | + # consider x direction |
| 70 | + x_pos = curr_pos[num_idx, 0]+vels[num_idx, 0] |
| 71 | + y_pos = curr_pos[num_idx, 1]+vels[num_idx, 1] |
| 72 | + # free space at x direction |
| 73 | + if x_pos >= 0 and x_pos <= bound_size-1: |
| 74 | + curr_pos[num_idx, 0] = x_pos |
| 75 | + elif x_pos < 0: |
| 76 | + curr_pos[num_idx, 0] = -x_pos |
| 77 | + vels[num_idx, 0] = -vels[num_idx, 0] |
| 78 | + elif x_pos > bound_size-1: |
| 79 | + curr_pos[num_idx, 0] = 2*bound_size-x_pos-2 |
| 80 | + vels[num_idx, 0] = -vels[num_idx, 0] |
| 81 | + # free space at y direction |
| 82 | + if y_pos >= 0 and y_pos <= bound_size-1: |
| 83 | + curr_pos[num_idx, 1] = y_pos |
| 84 | + elif y_pos < 0: |
| 85 | + curr_pos[num_idx, 1] = -y_pos |
| 86 | + vels[num_idx, 1] = -vels[num_idx, 1] |
| 87 | + elif y_pos > bound_size-1: |
| 88 | + curr_pos[num_idx, 1] = 2*bound_size-y_pos-2 |
| 89 | + vels[num_idx, 1] = -vels[num_idx, 1] |
| 90 | + |
| 91 | + return data_batch |
| 92 | + |
| 93 | + |
| 94 | +class MovingMNIST(object): |
| 95 | + "The moving MNIST generator.""" |
| 96 | + def __init__(self, data_path, batch_size=64, |
| 97 | + num_digits=2, time_steps=40, |
| 98 | + motion_fn=None, motion_fn_config=None, bg_size=64): |
| 99 | + """Generate a batch of moving MNIST samples. |
| 100 | +
|
| 101 | + The output of the class is a 4-D tensor: |
| 102 | + batch_size x bg_size x bg_size x num_steps |
| 103 | +
|
| 104 | + Parameters |
| 105 | + ---------- |
| 106 | + batch_size : int |
| 107 | + number of samples in this batch |
| 108 | + num_digits : int |
| 109 | + number of digits in the scene |
| 110 | + time_steps : int |
| 111 | + simulation steps |
| 112 | + motion_fn : function |
| 113 | + A function that controls a digit's movement. |
| 114 | + motion_fn_config : dictionary |
| 115 | + A dictionary that contains custom motion function configuration. |
| 116 | + bg_size : int |
| 117 | + the background size is bg_size x bg_size |
| 118 | + """ |
| 119 | + self.batch_size = batch_size |
| 120 | + self.num_digits = num_digits |
| 121 | + self.time_steps = time_steps |
| 122 | + self.motion_fn = motion_fn |
| 123 | + self.motion_fn_config = motion_fn_config |
| 124 | + self.bg_size = bg_size |
| 125 | + self.data_path = data_path |
| 126 | + |
| 127 | + # load the dataset |
| 128 | + self.train_data, self.test_data = self.load_mnist() |
| 129 | + self.num_train_data = self.train_data[0].shape[0] |
| 130 | + self.num_test_data = self.test_data[0].shape[0] |
| 131 | + |
| 132 | + def load_mnist(self): |
| 133 | + """Load MNIST. |
| 134 | +
|
| 135 | + We adopted Keras copy of MNSIT, the download address is here: |
| 136 | + https://s3.amazonaws.com/img-datasets/mnist.npz |
| 137 | +
|
| 138 | + Download the dataset and put it somewhere in your file system. |
| 139 | + """ |
| 140 | + f = np.load(self.data_path) |
| 141 | + x_train, y_train = f['x_train'], f['y_train'] |
| 142 | + x_test, y_test = f['x_test'], f['y_test'] |
| 143 | + f.close() |
| 144 | + |
| 145 | + return (x_train, y_train), (x_test, y_test) |
| 146 | + |
| 147 | + def get_batch(self): |
| 148 | + """Get a batch of moving MNIST data.""" |
| 149 | + # prepare the batch background |
| 150 | + data_batch = np.zeros( |
| 151 | + (self.batch_size, self.bg_size, self.bg_size, |
| 152 | + self.time_steps), dtype=np.uint8) |
| 153 | + |
| 154 | + # generate pairs of digits |
| 155 | + # make sure we have enough unique pairs |
| 156 | + assert self.num_digits*self.batch_size < self.num_train_data |
| 157 | + num_pairs = np.random.choice( |
| 158 | + self.num_train_data, |
| 159 | + self.num_digits*self.batch_size, |
| 160 | + replace=False) |
| 161 | + data_pairs = self.train_data[0][num_pairs].reshape( |
| 162 | + self.batch_size, self.num_digits, 28, 28) |
| 163 | + |
| 164 | + return self.motion_fn(data_batch, data_pairs, |
| 165 | + self.motion_fn_config) |
0 commit comments