Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions infogan/algos/infogan_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def init_opt(self):
self.generator_trainer = pt.apply_optimizer(generator_optimizer, losses=[generator_loss], var_list=g_vars)

for k, v in self.log_vars:
tf.scalar_summary(k, v)
tf.summary.scalar(k, v)

with pt.defaults_scope(phase=pt.Phase.test):
with tf.variable_scope("model", reuse=True) as scope:
Expand Down Expand Up @@ -199,23 +199,23 @@ def visualize_all_factors(self):
row_img = []
for col in xrange(rows):
row_img.append(imgs[row, col, :, :, :])
stacked_img.append(tf.concat(1, row_img))
imgs = tf.concat(0, stacked_img)
stacked_img.append(tf.concat(axis=1, values=row_img))
imgs = tf.concat(axis=0, values=stacked_img)
imgs = tf.expand_dims(imgs, 0)
tf.image_summary("image_%d_%s" % (dist_idx, dist.__class__.__name__), imgs)
tf.summary.image("image_%d_%s" % (dist_idx, dist.__class__.__name__), imgs)


def train(self):

self.init_opt()

init = tf.initialize_all_variables()
init = tf.global_variables_initializer()

with tf.Session() as sess:
sess.run(init)

summary_op = tf.merge_all_summaries()
summary_writer = tf.train.SummaryWriter(self.log_dir, sess.graph)
summary_op = tf.summary.merge_all()
summary_writer = tf.summary.FileWriter(self.log_dir, sess.graph)

saver = tf.train.Saver()

Expand Down
38 changes: 13 additions & 25 deletions infogan/misc/custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,31 +9,19 @@ class conv_batch_norm(pt.VarStoreMethod):

def __call__(self, input_layer, epsilon=1e-5, momentum=0.1, name="batch_norm",
in_dim=None, phase=Phase.train):
self.ema = tf.train.ExponentialMovingAverage(decay=0.9)

shape = input_layer.shape
shp = in_dim or shape[-1]
with tf.variable_scope(name) as scope:
self.gamma = self.variable("gamma", [shp], init=tf.random_normal_initializer(1., 0.02))
self.beta = self.variable("beta", [shp], init=tf.constant_initializer(0.))

self.mean, self.variance = tf.nn.moments(input_layer.tensor, [0, 1, 2])
# sigh...tf's shape system is so..
self.mean.set_shape((shp,))
self.variance.set_shape((shp,))
self.ema_apply_op = self.ema.apply([self.mean, self.variance])

if phase == Phase.train:
with tf.control_dependencies([self.ema_apply_op]):
normalized_x = tf.nn.batch_norm_with_global_normalization(
input_layer.tensor, self.mean, self.variance, self.beta, self.gamma, epsilon,
scale_after_normalization=True)
else:
normalized_x = tf.nn.batch_norm_with_global_normalization(
x, self.ema.average(self.mean), self.ema.average(self.variance), self.beta,
self.gamma, epsilon,
scale_after_normalization=True)
return input_layer.with_tensor(normalized_x, parameters=self.vars)
self.gamma = self.variable("gamma", [shp], init=tf.random_normal_initializer(1., 0.02))
self.beta = self.variable("beta", [shp], init=tf.constant_initializer(0.))

self.mean, self.variance = tf.nn.moments(input_layer, [0, 1, 2])
# sigh...tf's shape system is so..
self.mean.set_shape((shp,))
self.variance.set_shape((shp,))

normalized_x = tf.nn.batch_normalization(input_layer, self.mean,
self.variance, None, None, epsilon)
return input_layer.with_tensor(normalized_x, parameters=self.vars)


pt.Register(assign_defaults=('phase'))(conv_batch_norm)
Expand Down Expand Up @@ -79,7 +67,7 @@ def __call__(self, input_layer, output_shape,
k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02,
name="deconv2d"):
output_shape[0] = input_layer.shape[0]
ts_output_shape = tf.pack(output_shape)
ts_output_shape = tf.stack(output_shape)
with tf.variable_scope(name):
# filter : [height, width, output_channels, in_channels]
w = self.variable('w', [k_h, k_w, output_shape[-1], input_layer.shape[-1]],
Expand Down Expand Up @@ -108,7 +96,7 @@ def __call__(self, input_layer, output_size, scope=None, in_dim=None, stddev=0.0
input_ = input_layer.tensor
try:
if len(shape) == 4:
input_ = tf.reshape(input_, tf.pack([tf.shape(input_)[0], np.prod(shape[1:])]))
input_ = tf.reshape(input_, tf.stack([tf.shape(input_)[0], np.prod(shape[1:])]))
input_.set_shape([None, np.prod(shape[1:])])
shape = input_.get_shape().as_list()

Expand Down
2 changes: 1 addition & 1 deletion infogan/misc/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,4 +84,4 @@ def transform(self, data):
return data

def inverse_transform(self, data):
return data
return data
26 changes: 13 additions & 13 deletions infogan/misc/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def effective_dim(self):

def logli(self, x_var, dist_info):
prob = dist_info["prob"]
return tf.reduce_sum(tf.log(prob + TINY) * x_var, reduction_indices=1)
return tf.reduce_sum(tf.log(prob + TINY) * x_var, axis=1)

def prior_dist_info(self, batch_size):
prob = tf.ones([batch_size, self.dim]) * floatX(1.0 / self.dim)
Expand All @@ -131,8 +131,8 @@ def prior_dist_info(self, batch_size):
def marginal_logli(self, x_var, dist_info):
prob = dist_info["prob"]
avg_prob = tf.tile(
tf.reduce_mean(prob, reduction_indices=0, keep_dims=True),
tf.pack([tf.shape(prob)[0], 1])
tf.reduce_mean(prob, axis=0, keep_dims=True),
tf.stack([tf.shape(prob)[0], 1])
)
return self.logli(x_var, dict(prob=avg_prob))

Expand All @@ -149,7 +149,7 @@ def kl(self, p, q):
q_prob = q["prob"]
return tf.reduce_sum(
p_prob * (tf.log(p_prob + TINY) - tf.log(q_prob + TINY)),
reduction_indices=1
axis=1
)

def sample(self, dist_info):
Expand All @@ -163,13 +163,13 @@ def activate_dist(self, flat_dist):

def entropy(self, dist_info):
prob = dist_info["prob"]
return -tf.reduce_sum(prob * tf.log(prob + TINY), reduction_indices=1)
return -tf.reduce_sum(prob * tf.log(prob + TINY), axis=1)

def marginal_entropy(self, dist_info):
prob = dist_info["prob"]
avg_prob = tf.tile(
tf.reduce_mean(prob, reduction_indices=0, keep_dims=True),
tf.pack([tf.shape(prob)[0], 1])
tf.reduce_mean(prob, axis=0, keep_dims=True),
tf.stack([tf.shape(prob)[0], 1])
)
return self.entropy(dict(prob=avg_prob))

Expand Down Expand Up @@ -201,7 +201,7 @@ def logli(self, x_var, dist_info):
epsilon = (x_var - mean) / (stddev + TINY)
return tf.reduce_sum(
- 0.5 * np.log(2 * np.pi) - tf.log(stddev + TINY) - 0.5 * tf.square(epsilon),
reduction_indices=1,
axis=1,
)

def prior_dist_info(self, batch_size):
Expand All @@ -225,7 +225,7 @@ def kl(self, p, q):
denominator = 2. * tf.square(q_stddev)
return tf.reduce_sum(
numerator / (denominator + TINY) + tf.log(q_stddev + TINY) - tf.log(p_stddev + TINY),
reduction_indices=1
axis=1
)

def sample(self, dist_info):
Expand Down Expand Up @@ -291,7 +291,7 @@ def logli(self, x_var, dist_info):
p = dist_info["p"]
return tf.reduce_sum(
x_var * tf.log(p + TINY) + (1.0 - x_var) * tf.log(1.0 - p + TINY),
reduction_indices=1
axis=1
)

def nonreparam_logli(self, x_var, dist_info):
Expand Down Expand Up @@ -397,7 +397,7 @@ def join_vars(self, xs):
"""
Join the per component tensor variables into a whole tensor
"""
return tf.concat(1, xs)
return tf.concat(axis=1, values=xs)

def split_dist_flat(self, dist_flat):
"""
Expand Down Expand Up @@ -434,13 +434,13 @@ def sample(self, dist_info):
ret = []
for dist_info_i, dist_i in zip(self.split_dist_info(dist_info), self.dists):
ret.append(tf.cast(dist_i.sample(dist_info_i), tf.float32))
return tf.concat(1, ret)
return tf.concat(axis=1, values=ret)

def sample_prior(self, batch_size):
ret = []
for dist_i in self.dists:
ret.append(tf.cast(dist_i.sample_prior(batch_size), tf.float32))
return tf.concat(1, ret)
return tf.concat(axis=1, values=ret)

def logli(self, x_var, dist_info):
ret = tf.constant(0.)
Expand Down
2 changes: 1 addition & 1 deletion infogan/misc/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ def mkdir_p(path):
if exc.errno == errno.EEXIST and os.path.isdir(path):
pass
else:
raise
raise
6 changes: 4 additions & 2 deletions infogan/models/regularized_gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def __init__(self, output_dist, latent_spec, batch_size, image_shape, network_ty
:type batch_size: int
:type network_type: string
"""
self.reuse = False
self.output_dist = output_dist
self.latent_spec = latent_spec
self.latent_dist = Product([x for x, _ in latent_spec])
Expand All @@ -28,7 +29,7 @@ def __init__(self, output_dist, latent_spec, batch_size, image_shape, network_ty

image_size = image_shape[0]
if network_type == "mnist":
with tf.variable_scope("d_net"):
with tf.variable_scope("d_net", reuse=self.reuse):
shared_template = \
(pt.template("input").
reshape([-1] + list(image_shape)).
Expand All @@ -48,7 +49,7 @@ def __init__(self, output_dist, latent_spec, batch_size, image_shape, network_ty
apply(leaky_rectify).
custom_fully_connected(self.reg_latent_dist.dist_flat_dim))

with tf.variable_scope("g_net"):
with tf.variable_scope("g_net", reuse=self.reuse):
self.generator_template = \
(pt.template("input").
custom_fully_connected(1024).
Expand All @@ -63,6 +64,7 @@ def __init__(self, output_dist, latent_spec, batch_size, image_shape, network_ty
apply(tf.nn.relu).
custom_deconv2d([0] + list(image_shape), k_h=4, k_w=4).
flatten())
self.reuse = True
else:
raise NotImplementedError

Expand Down
Loading