Skip to content

Commit 6410df5

Browse files
authored
batch norm, cleanup unused args (#539)
At least I assume they are unused. Likely a user would not provide a tf.Tensor in a config. Or the specified type was wrong and this was supposed to be a float. Anyway, I still don't see how this would have been used potentially. Maybe this was a relict from the Theano code conversion.
1 parent a41d02e commit 6410df5

File tree

3 files changed

+127
-142
lines changed

3 files changed

+127
-142
lines changed

returnn/tf/layers/base.py

Lines changed: 39 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1194,11 +1194,9 @@ def get_constraints_value(self):
11941194
def batch_norm(self, data,
11951195
use_shift=True, use_std=True, use_sample=0.0, force_sample=False,
11961196
momentum=0.99, epsilon=1e-3,
1197-
sample_mean=None, sample_variance=None,
11981197
update_sample_only_in_training=False,
11991198
delay_sample_update=False,
12001199
param_version=0,
1201-
gamma=None, beta=None,
12021200
gamma_init=1.0, beta_init=0.0,
12031201
masked_time=True):
12041202
"""
@@ -1212,10 +1210,6 @@ def batch_norm(self, data,
12121210
:param bool delay_sample_update:
12131211
:param int param_version: 0 or 1
12141212
:param float epsilon:
1215-
:param tf.Tensor sample_mean:
1216-
:param tf.Tensor sample_variance:
1217-
:param tf.Tensor gamma:
1218-
:param tf.Tensor beta:
12191213
:param str|float gamma_init: see :func:`TFUtil.get_initializer`, for the scale
12201214
:param str|float beta_init: see :func:`TFUtil.get_initializer`, for the mean
12211215
:param bool masked_time: flatten and mask input tensor
@@ -1253,31 +1247,29 @@ def batch_norm(self, data,
12531247
param_name_prefix = ""
12541248
else:
12551249
raise NotImplementedError("%s: batch_norm param_version %r" % (self, param_version))
1256-
if sample_mean is None:
1257-
with self.var_creation_scope():
1258-
sample_mean = self.add_param(tf_compat.v1.get_variable(
1259-
shape=data.get_bc_spatial_batch_shape(), initializer=tf_compat.v1.zeros_initializer(),
1260-
name="%smean" % param_name_prefix,
1261-
trainable=False))
1262-
# Use exponential moving average of batch mean.
1263-
# Note: We could also use cumulative moving average. Our Theano implementation does that for inference.
1264-
updated_sample_mean = tf_compat.v1.assign_add(sample_mean, (mean - sample_mean) * momentum)
1265-
if delay_sample_update:
1266-
delayed_ops.append(updated_sample_mean.op)
1267-
else:
1268-
sample_mean = updated_sample_mean
1269-
if sample_variance is None:
1270-
# Note: Our Theano implementation does not use a moving average for this.
1271-
with self.var_creation_scope():
1272-
sample_variance = self.add_param(tf_compat.v1.get_variable(
1273-
shape=data.get_bc_spatial_batch_shape(), initializer=tf_compat.v1.ones_initializer(),
1274-
name="%svariance" % param_name_prefix,
1275-
trainable=False))
1276-
updated_sample_variance = tf_compat.v1.assign_add(sample_variance, (variance - sample_variance) * momentum)
1277-
if delay_sample_update:
1278-
delayed_ops.append(updated_sample_variance.op)
1279-
else:
1280-
sample_variance = updated_sample_variance
1250+
with self.var_creation_scope():
1251+
sample_mean = self.add_param(tf_compat.v1.get_variable(
1252+
shape=data.get_bc_spatial_batch_shape(), initializer=tf_compat.v1.zeros_initializer(),
1253+
name="%smean" % param_name_prefix,
1254+
trainable=False))
1255+
# Use exponential moving average of batch mean.
1256+
# Note: We could also use cumulative moving average. Our Theano implementation does that for inference.
1257+
updated_sample_mean = tf_compat.v1.assign_add(sample_mean, (mean - sample_mean) * momentum)
1258+
if delay_sample_update:
1259+
delayed_ops.append(updated_sample_mean.op)
1260+
else:
1261+
sample_mean = updated_sample_mean
1262+
# Note: Our Theano implementation does not use a moving average for this.
1263+
with self.var_creation_scope():
1264+
sample_variance = self.add_param(tf_compat.v1.get_variable(
1265+
shape=data.get_bc_spatial_batch_shape(), initializer=tf_compat.v1.ones_initializer(),
1266+
name="%svariance" % param_name_prefix,
1267+
trainable=False))
1268+
updated_sample_variance = tf_compat.v1.assign_add(sample_variance, (variance - sample_variance) * momentum)
1269+
if delay_sample_update:
1270+
delayed_ops.append(updated_sample_variance.op)
1271+
else:
1272+
sample_variance = updated_sample_variance
12811273
# If train or if force_sample, use default use_sample=0.0, otherwise use_sample=1.0.
12821274
if self.network.train_flag is not False or force_sample:
12831275
if force_sample:
@@ -1295,26 +1287,24 @@ def batch_norm(self, data,
12951287
tf_util.add_control_input(op, control_input=bn.op)
12961288
self.network.register_post_control_dependencies(delayed_ops)
12971289
if use_std:
1298-
if gamma is None:
1299-
with self.var_creation_scope():
1300-
from returnn.tf.util.basic import get_initializer
1301-
gamma_initializer = get_initializer(
1302-
gamma_init, seed=self.network.random.randint(2 ** 31) if gamma_init else 0, eval_local_ns={"layer": self})
1303-
gamma = self.add_param(tf_compat.v1.get_variable(
1304-
shape=data.get_bc_spatial_batch_shape(), initializer=gamma_initializer,
1305-
name="%sgamma" % param_name_prefix,
1306-
trainable=True))
1290+
with self.var_creation_scope():
1291+
from returnn.tf.util.basic import get_initializer
1292+
gamma_initializer = get_initializer(
1293+
gamma_init, seed=self.network.random.randint(2 ** 31) if gamma_init else 0, eval_local_ns={"layer": self})
1294+
gamma = self.add_param(tf_compat.v1.get_variable(
1295+
shape=data.get_bc_spatial_batch_shape(), initializer=gamma_initializer,
1296+
name="%sgamma" % param_name_prefix,
1297+
trainable=True))
13071298
bn *= gamma
13081299
if use_shift:
1309-
if beta is None:
1310-
with self.var_creation_scope():
1311-
from returnn.tf.util.basic import get_initializer
1312-
beta_initializer = get_initializer(
1313-
beta_init, seed=self.network.random.randint(2 ** 31) if beta_init else 0, eval_local_ns={"layer": self})
1314-
beta = self.add_param(tf_compat.v1.get_variable(
1315-
shape=data.get_bc_spatial_batch_shape(), initializer=beta_initializer,
1316-
name="%sbeta" % param_name_prefix,
1317-
trainable=True))
1300+
with self.var_creation_scope():
1301+
from returnn.tf.util.basic import get_initializer
1302+
beta_initializer = get_initializer(
1303+
beta_init, seed=self.network.random.randint(2 ** 31) if beta_init else 0, eval_local_ns={"layer": self})
1304+
beta = self.add_param(tf_compat.v1.get_variable(
1305+
shape=data.get_bc_spatial_batch_shape(), initializer=beta_initializer,
1306+
name="%sbeta" % param_name_prefix,
1307+
trainable=True))
13181308
bn += beta
13191309
return bn
13201310

returnn/tf/layers/basic.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -575,11 +575,9 @@ class BatchNormLayer(CopyLayer):
575575

576576
def __init__(self, use_shift=NotSpecified, use_std=NotSpecified, use_sample=NotSpecified, force_sample=NotSpecified,
577577
momentum=NotSpecified, epsilon=NotSpecified,
578-
sample_mean=NotSpecified, sample_variance=NotSpecified,
579578
update_sample_only_in_training=NotSpecified,
580579
delay_sample_update=NotSpecified,
581580
param_version=NotSpecified,
582-
gamma=NotSpecified, beta=NotSpecified,
583581
gamma_init=NotSpecified, beta_init=NotSpecified,
584582
masked_time=NotSpecified, **kwargs):
585583
"""
@@ -592,10 +590,6 @@ def __init__(self, use_shift=NotSpecified, use_std=NotSpecified, use_sample=NotS
592590
:param bool delay_sample_update:
593591
:param int param_version: 0 or 1
594592
:param float epsilon:
595-
:param tf.Tensor sample_mean:
596-
:param tf.Tensor sample_variance:
597-
:param tf.Tensor gamma:
598-
:param tf.Tensor beta:
599593
:param str|float gamma_init: see :func:`TFUtil.get_initializer`, for the scale
600594
:param str|float beta_init: see :func:`TFUtil.get_initializer`, for the mean
601595
:param bool masked_time: flatten and mask input tensor

tests/test_TFNetworkLayer.py

Lines changed: 88 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -243,127 +243,128 @@ def test_batch_norm_vars():
243243

244244
def test_batch_norm():
245245
with make_scope() as session:
246-
import numpy as np
247-
net = TFNetwork(extern_data=ExternData())
248-
net.train_flag = True
246+
net = TFNetwork(extern_data=ExternData(), train_flag=True)
249247
with tf_compat.v1.variable_scope("src_nchw"):
250-
src_nhwc = InternalLayer(name="src_nchw", network=net, out_type={"dim": 16,
251-
"shape": (None, 16, 16),
252-
"batch_dim_axis": 0,
253-
"time_dim_axis": 1,
254-
"feature_dim_axis": 3,
255-
"sparse": False
256-
})
248+
src_nhwc = InternalLayer(
249+
name="src_nchw", network=net,
250+
out_type={
251+
"dim": 16,
252+
"shape": (None, 16, 16),
253+
"batch_dim_axis": 0,
254+
"time_dim_axis": 1,
255+
"feature_dim_axis": 3,
256+
"sparse": False})
257257
src_nhwc.output.placeholder = tf_compat.v1.placeholder(shape=(None, None, 16, 16), dtype=tf.float32)
258258
src_nhwc.output.size_placeholder = {0: tf_compat.v1.placeholder(shape=(None,), dtype=tf.int32)}
259259

260-
rnd = np.random.RandomState(42)
261-
mean = tf.constant(rnd.rand(1, 1, 1, 16), name="rand_mean", dtype=tf.float32)
262-
variance = tf.constant(rnd.rand(1, 1, 1, 16), name="rand_var", dtype=tf.float32)
260+
rnd = numpy.random.RandomState(42)
263261
input_data = rnd.rand(10, 11, 16, 16)
264-
seq_lens = np.array([11, 11, 11, 11, 11, 11, 11, 11, 11, 11])
262+
seq_lens = numpy.array([11] * 10)
265263

266264
with tf_compat.v1.variable_scope("batch_norm_masked_nchw"):
267-
batch_norm_1 = BatchNormLayer(name="batch_norm_masked_nchw", network=net, masked_time=True,
268-
sample_mean=mean, sample_variance=variance,
269-
sources=[src_nhwc],
270-
output=BatchNormLayer.get_out_data_from_opts(name="batch_norm_masked_nchw",
271-
sources=[src_nhwc],
272-
network=net))
265+
batch_norm_1 = BatchNormLayer(
266+
name="batch_norm_masked_nchw", network=net, masked_time=True,
267+
sources=[src_nhwc],
268+
output=BatchNormLayer.get_out_data_from_opts(
269+
name="batch_norm_masked_nchw",
270+
sources=[src_nhwc],
271+
network=net))
273272
batch_norm_1.post_init(layer_desc=None)
274273
with tf_compat.v1.variable_scope("batch_norm_nonmasked_nchw"):
275-
batch_norm_2 = BatchNormLayer(name="batch_norm_nonmasked_nchw", network=net, masked_time=False,
276-
sample_mean=mean, sample_variance=variance,
277-
sources=[src_nhwc],
278-
output=BatchNormLayer.get_out_data_from_opts(name="batch_norm_nonmasked_nchw",
279-
sources=[src_nhwc],
280-
network=net))
274+
batch_norm_2 = BatchNormLayer(
275+
name="batch_norm_nonmasked_nchw", network=net, masked_time=False,
276+
sources=[src_nhwc],
277+
output=BatchNormLayer.get_out_data_from_opts(
278+
name="batch_norm_nonmasked_nchw",
279+
sources=[src_nhwc],
280+
network=net))
281281
batch_norm_2.post_init(layer_desc=None)
282-
tf_compat.v1.global_variables_initializer().run()
283-
out_1, seq_lens_1 = session.run([batch_norm_1.output.placeholder,
284-
batch_norm_1.output.size_placeholder[0]],
285-
feed_dict={src_nhwc.output.placeholder: input_data,
286-
src_nhwc.output.size_placeholder[0]: seq_lens}
287-
)
288-
out_2, seq_lens_2 = session.run([batch_norm_2.output.placeholder,
289-
batch_norm_2.output.size_placeholder[0]],
290-
feed_dict={src_nhwc.output.placeholder: input_data,
291-
src_nhwc.output.size_placeholder[0]: seq_lens}
292-
)
293-
assert np.array_equal(out_1, out_2)
294-
print(np.sum(out_1 - out_2))
282+
tf_compat.v1.global_variables_initializer().run(session=session)
283+
out_1, seq_lens_1 = session.run(
284+
[batch_norm_1.output.placeholder, batch_norm_1.output.size_placeholder[0]],
285+
feed_dict={
286+
src_nhwc.output.placeholder: input_data,
287+
src_nhwc.output.size_placeholder[0]: seq_lens})
288+
out_2, seq_lens_2 = session.run(
289+
[batch_norm_2.output.placeholder, batch_norm_2.output.size_placeholder[0]],
290+
feed_dict={
291+
src_nhwc.output.placeholder: input_data,
292+
src_nhwc.output.size_placeholder[0]: seq_lens})
293+
assert numpy.array_equal(out_1, out_2)
294+
print(numpy.sum(out_1 - out_2))
295295

296296

297297
def test_batch_norm_unequal_seq_len():
298298
with make_scope() as session:
299-
import numpy as np
300-
import numpy.testing as npt
301-
net = TFNetwork(extern_data=ExternData())
302-
net.train_flag = True
299+
net = TFNetwork(extern_data=ExternData(), train_flag=True)
303300
with tf_compat.v1.variable_scope("src_nhwc"):
304-
src_nhwc = InternalLayer(name="src_nhwc", network=net, out_type={"dim": 16,
305-
"shape": (None, 16, 16),
306-
"batch_dim_axis": 0,
307-
"time_dim_axis": 1,
308-
"feature_dim_axis": 3,
309-
"sparse": False
310-
})
301+
src_nhwc = InternalLayer(
302+
name="src_nhwc", network=net,
303+
out_type={
304+
"dim": 16,
305+
"shape": (None, 16, 16),
306+
"batch_dim_axis": 0,
307+
"time_dim_axis": 1,
308+
"feature_dim_axis": 3,
309+
"sparse": False})
311310
src_nhwc.output.placeholder = tf_compat.v1.placeholder(shape=(None, None, 16, 16), dtype=tf.float32)
312311
src_nhwc.output.size_placeholder = {0: tf_compat.v1.placeholder(shape=(None,), dtype=tf.int32)}
313312

314-
rnd = np.random.RandomState(42)
315-
mean = tf.constant(rnd.rand(1, 1, 1, 16), name="rand_mean", dtype=tf.float32)
316-
variance = tf.constant(rnd.rand(1, 1, 1, 16), name="rand_var", dtype=tf.float32)
313+
rnd = numpy.random.RandomState(42)
317314
input_data = rnd.rand(10, 11, 16, 16).astype('f')
318315
input_data[2, 5:, :, :] = 0
319-
data_mean = np.mean(input_data, axis=(0, 1, 2), keepdims=True, dtype=np.float32)
320-
data_var = np.var(input_data, axis=(0, 1, 2), keepdims=True, dtype=np.float32)
321-
input_data_masked = np.copy(input_data)
322-
seq_lens = np.array([11, 11, 5, 11, 11, 11, 11, 11, 11, 11], dtype=np.float32)
316+
input_data_masked = numpy.copy(input_data)
317+
seq_lens = numpy.array([11, 11, 5, 11, 11, 11, 11, 11, 11, 11], dtype=numpy.float32)
323318
n1 = 9 * 11 * 16 + 5 * 16
324319
n2 = 10 * 11 * 16
325320

326321
with tf_compat.v1.variable_scope("batch_norm_masked_nchw"):
327-
batch_norm_1 = BatchNormLayer(name="batch_norm_masked_nchw", network=net, masked_time=True,
328-
sample_mean=mean, sample_variance=variance,
329-
use_shift=False, use_std=False, epsilon=0.0,
330-
sources=[src_nhwc],
331-
output=BatchNormLayer.get_out_data_from_opts(name="batch_norm_masked_nchw",
332-
sources=[src_nhwc],
333-
network=net))
322+
batch_norm_1 = BatchNormLayer(
323+
name="batch_norm_masked_nchw", network=net, masked_time=True,
324+
use_shift=False, use_std=False, epsilon=0.0,
325+
sources=[src_nhwc],
326+
output=BatchNormLayer.get_out_data_from_opts(
327+
name="batch_norm_masked_nchw",
328+
sources=[src_nhwc],
329+
network=net))
334330
batch_norm_1.post_init(layer_desc=None)
335331
with tf_compat.v1.variable_scope("batch_norm_nonmasked_nchw"):
336-
batch_norm_2 = BatchNormLayer(name="batch_norm_nonmasked_nchw", network=net, masked_time=False,
337-
sample_mean=mean, sample_variance=variance,
338-
use_shift=False, use_std=False, epsilon=0,
339-
sources=[src_nhwc],
340-
output=BatchNormLayer.get_out_data_from_opts(name="batch_norm_nonmasked_nchw",
341-
sources=[src_nhwc],
342-
network=net))
332+
batch_norm_2 = BatchNormLayer(
333+
name="batch_norm_nonmasked_nchw", network=net, masked_time=False,
334+
use_shift=False, use_std=False, epsilon=0,
335+
sources=[src_nhwc],
336+
output=BatchNormLayer.get_out_data_from_opts(
337+
name="batch_norm_nonmasked_nchw",
338+
sources=[src_nhwc],
339+
network=net))
343340
batch_norm_2.post_init(layer_desc=None)
344-
tf_compat.v1.global_variables_initializer().run()
345-
out_1, seq_lens_1 = session.run([batch_norm_1.output.placeholder,
346-
batch_norm_1.output.size_placeholder[0]],
347-
feed_dict={src_nhwc.output.placeholder: input_data,
348-
src_nhwc.output.size_placeholder[0]: seq_lens}
349-
)
350-
out_2, seq_lens_2 = session.run([batch_norm_2.output.placeholder,
351-
batch_norm_2.output.size_placeholder[0]],
352-
feed_dict={src_nhwc.output.placeholder: input_data_masked,
353-
src_nhwc.output.size_placeholder[0]: seq_lens}
354-
)
341+
tf_compat.v1.global_variables_initializer().run(session=session)
342+
out_1, seq_lens_1 = session.run(
343+
[batch_norm_1.output.placeholder, batch_norm_1.output.size_placeholder[0]],
344+
feed_dict={
345+
src_nhwc.output.placeholder: input_data,
346+
src_nhwc.output.size_placeholder[0]: seq_lens})
347+
out_2, seq_lens_2 = session.run(
348+
[batch_norm_2.output.placeholder, batch_norm_2.output.size_placeholder[0]],
349+
feed_dict={
350+
src_nhwc.output.placeholder: input_data_masked,
351+
src_nhwc.output.size_placeholder[0]: seq_lens})
352+
355353
# Manually calculating batch_norm and compare to the tf output
356-
np_bn2 = (input_data - data_mean) * (1.0 / np.sqrt(data_var))
357-
npt.assert_array_almost_equal(np_bn2, out_2, decimal=5)
354+
data_mean = numpy.mean(input_data, axis=(0, 1, 2), keepdims=True, dtype=numpy.float32)
355+
data_var = numpy.var(input_data, axis=(0, 1, 2), keepdims=True, dtype=numpy.float32)
356+
np_bn2 = (input_data - data_mean) * (1.0 / numpy.sqrt(data_var))
357+
numpy.testing.assert_array_almost_equal(np_bn2, out_2, decimal=5)
358358
# Manually calculating batch_norm with different seq_lens, having:
359359
# Mean_1 = n2 / n1 * Mean_2
360360
# Var_1 = n2 / n1 * (Var_2 + Mean_2 ^ 2 (1 - n2 / n1))
361361
# bn_1 = (x - Mean_1) * 1 / sqrt(Var_1)
362362
# Substituting Mean_1 and Var_1:
363-
np_bn1 = (input_data - n2 / n1 * data_mean) * \
364-
(1.0 / np.sqrt(n2 / n1 * (data_var + data_mean ** 2 * (1 - n2 / n1))))
363+
np_bn1 = (
364+
(input_data - n2 / n1 * data_mean) *
365+
(1.0 / numpy.sqrt(n2 / n1 * (data_var + data_mean ** 2 * (1 - n2 / n1)))))
365366
# Check with tf output.
366-
npt.assert_array_almost_equal(np_bn1, out_1, decimal=5)
367+
numpy.testing.assert_array_almost_equal(np_bn1, out_1, decimal=5)
367368

368369

369370
def test_activation_layer_net_construct():

0 commit comments

Comments
 (0)