Skip to content

Commit 16cc64d

Browse files
committed
update a moving mnist implementation
1 parent d5f636e commit 16cc64d

File tree

4 files changed

+204
-0
lines changed

4 files changed

+204
-0
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@
33
.cache/
44
__pycache__/*
55
.#*
6+
mnist.npz

src/environments/moving_mnist.py

+165
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
"""Moving MNIST.
2+
3+
This class produces a batch of moving MNIST experiments
4+
considered in the ICML2015 paper:
5+
Unsupervised Learning of Video Representations using LSTMs.
6+
7+
Author: Yuhuang Hu
8+
9+
"""
10+
from __future__ import print_function
11+
from builtins import range
12+
13+
import numpy as np
14+
15+
16+
def default_motion_fn(data_batch, data_pairs, config=None):
17+
"""A default motion function for moving MNIST.
18+
19+
Parameters
20+
----------
21+
data_batch : numpy.ndarray
22+
the background data batch
23+
(batch_size x bg_size x bg_size x time_steps)
24+
data_pairs : numpy.ndarray
25+
the numpy array that takes the digits pairs
26+
(batch_size x num_digits x 28 x 28)
27+
config : dictionary
28+
configuration dictionary for extra configs
29+
"""
30+
# facts about the target data
31+
batch_size = data_batch.shape[0]
32+
bg_size = data_batch.shape[1]
33+
time_steps = data_batch.shape[3]
34+
num_digits = data_pairs.shape[1]
35+
digit_shape = data_pairs.shape[2]
36+
bound_size = bg_size-digit_shape+1
37+
38+
# load config
39+
# we use top left to coordinate the position
40+
x_vel = config["x_vel"] # velocity range for x direction
41+
y_vel = config["y_vel"] # velocity range for y direction
42+
# make sure you chose some reasonable range
43+
x_init = config["x_init"] # range for sample initial x position
44+
y_init = config["y_init"] # range for sample initial y position
45+
46+
# the most intense way, looping everything
47+
for sample_idx in range(batch_size):
48+
# sample initial position (num_digits, 2)
49+
curr_pos = np.vstack((np.random.choice(y_init, num_digits),
50+
np.random.choice(x_init, num_digits))).T
51+
# sample initial velocity
52+
vels = np.vstack((np.random.choice(y_vel*2, num_digits)-y_vel,
53+
np.random.choice(x_vel*2, num_digits)-x_vel)).T
54+
55+
# process every step
56+
for step in range(time_steps):
57+
# set the image at current position
58+
for num_idx in range(num_digits):
59+
data_batch[
60+
sample_idx,
61+
curr_pos[num_idx, 0]:curr_pos[num_idx, 0]+digit_shape,
62+
curr_pos[num_idx, 1]:curr_pos[num_idx, 1]+digit_shape,
63+
step] += data_pairs[sample_idx, num_idx]
64+
data_batch[sample_idx] = np.clip(data_batch[sample_idx], 0, 255)
65+
66+
# update current position based on the velocity
67+
# while consider boundary issues
68+
for num_idx in range(num_digits):
69+
# consider x direction
70+
x_pos = curr_pos[num_idx, 0]+vels[num_idx, 0]
71+
y_pos = curr_pos[num_idx, 1]+vels[num_idx, 1]
72+
# free space at x direction
73+
if x_pos >= 0 and x_pos <= bound_size-1:
74+
curr_pos[num_idx, 0] = x_pos
75+
elif x_pos < 0:
76+
curr_pos[num_idx, 0] = -x_pos
77+
vels[num_idx, 0] = -vels[num_idx, 0]
78+
elif x_pos > bound_size-1:
79+
curr_pos[num_idx, 0] = 2*bound_size-x_pos-2
80+
vels[num_idx, 0] = -vels[num_idx, 0]
81+
# free space at y direction
82+
if y_pos >= 0 and y_pos <= bound_size-1:
83+
curr_pos[num_idx, 1] = y_pos
84+
elif y_pos < 0:
85+
curr_pos[num_idx, 1] = -y_pos
86+
vels[num_idx, 1] = -vels[num_idx, 1]
87+
elif y_pos > bound_size-1:
88+
curr_pos[num_idx, 1] = 2*bound_size-y_pos-2
89+
vels[num_idx, 1] = -vels[num_idx, 1]
90+
91+
return data_batch
92+
93+
94+
class MovingMNIST(object):
95+
"The moving MNIST generator."""
96+
def __init__(self, data_path, batch_size=64,
97+
num_digits=2, time_steps=40,
98+
motion_fn=None, motion_fn_config=None, bg_size=64):
99+
"""Generate a batch of moving MNIST samples.
100+
101+
The output of the class is a 4-D tensor:
102+
batch_size x bg_size x bg_size x num_steps
103+
104+
Parameters
105+
----------
106+
batch_size : int
107+
number of samples in this batch
108+
num_digits : int
109+
number of digits in the scene
110+
time_steps : int
111+
simulation steps
112+
motion_fn : function
113+
A function that controls a digit's movement.
114+
motion_fn_config : dictionary
115+
A dictionary that contains custom motion function configuration.
116+
bg_size : int
117+
the background size is bg_size x bg_size
118+
"""
119+
self.batch_size = batch_size
120+
self.num_digits = num_digits
121+
self.time_steps = time_steps
122+
self.motion_fn = motion_fn
123+
self.motion_fn_config = motion_fn_config
124+
self.bg_size = bg_size
125+
self.data_path = data_path
126+
127+
# load the dataset
128+
self.train_data, self.test_data = self.load_mnist()
129+
self.num_train_data = self.train_data[0].shape[0]
130+
self.num_test_data = self.test_data[0].shape[0]
131+
132+
def load_mnist(self):
133+
"""Load MNIST.
134+
135+
We adopted Keras copy of MNSIT, the download address is here:
136+
https://s3.amazonaws.com/img-datasets/mnist.npz
137+
138+
Download the dataset and put it somewhere in your file system.
139+
"""
140+
f = np.load(self.data_path)
141+
x_train, y_train = f['x_train'], f['y_train']
142+
x_test, y_test = f['x_test'], f['y_test']
143+
f.close()
144+
145+
return (x_train, y_train), (x_test, y_test)
146+
147+
def get_batch(self):
148+
"""Get a batch of moving MNIST data."""
149+
# prepare the batch background
150+
data_batch = np.zeros(
151+
(self.batch_size, self.bg_size, self.bg_size,
152+
self.time_steps), dtype=np.uint8)
153+
154+
# generate pairs of digits
155+
# make sure we have enough unique pairs
156+
assert self.num_digits*self.batch_size < self.num_train_data
157+
num_pairs = np.random.choice(
158+
self.num_train_data,
159+
self.num_digits*self.batch_size,
160+
replace=False)
161+
data_pairs = self.train_data[0][num_pairs].reshape(
162+
self.batch_size, self.num_digits, 28, 28)
163+
164+
return self.motion_fn(data_batch, data_pairs,
165+
self.motion_fn_config)

src/moving_mnist.gif

806 KB
Loading

src/run_moving_mnist_example.py

+38
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
"""Moving MNIST Example.
2+
3+
Author: Yuhuang Hu
4+
5+
"""
6+
7+
from __future__ import print_function
8+
9+
import matplotlib.pyplot as plt
10+
from matplotlib.animation import FuncAnimation
11+
12+
from environments.moving_mnist import MovingMNIST
13+
from environments.moving_mnist import default_motion_fn
14+
15+
# create moving MNIST environment
16+
env = MovingMNIST(
17+
"./mnist.npz",
18+
motion_fn=default_motion_fn,
19+
motion_fn_config={"x_vel": 10, "y_vel": 10, "x_init": 100,
20+
"y_init": 100},
21+
batch_size=2,
22+
bg_size=128,
23+
num_digits=5,
24+
time_steps=100)
25+
26+
# get a batch of data
27+
data_batch = env.get_batch()
28+
29+
# visualise first sample
30+
fig = plt.figure()
31+
32+
33+
def update(t):
34+
plt.imshow(data_batch[0, :, :, t], cmap="gray")
35+
36+
37+
anim = FuncAnimation(fig, update, frames=100, interval=100)
38+
anim.save("moving_mnist.gif", writer="imagemagick")

0 commit comments

Comments
 (0)