From 51ee4460db44668496d567aac348db1b56cd3212 Mon Sep 17 00:00:00 2001 From: Aishik Ghosh Date: Thu, 21 Mar 2019 16:38:53 +0100 Subject: [PATCH] Swish activation saves beta as weight even if it is not trainable --- .../layers/advanced_activations/swish.py | 37 ++++++++++--------- 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/keras_contrib/layers/advanced_activations/swish.py b/keras_contrib/layers/advanced_activations/swish.py index f2cac1ae6..0d7648fec 100644 --- a/keras_contrib/layers/advanced_activations/swish.py +++ b/keras_contrib/layers/advanced_activations/swish.py @@ -1,5 +1,7 @@ -from keras import backend as K from keras.layers import Layer +from keras import backend as K +from keras.layers import InputSpec +from keras.initializers import Constant class Swish(Layer): @@ -14,7 +16,7 @@ class Swish(Layer): Same shape as the input. # Arguments - beta: float >= 0. Scaling factor + initial_beta: float >= 0. Scaling factor if set to 1 and trainable set to False (default), Swish equals the SiLU activation (Elfwing et al., 2017) trainable: whether to learn the scaling factor during training or not @@ -24,29 +26,28 @@ class Swish(Layer): - [Sigmoid-weighted linear units for neural network function approximation in reinforcement learning](https://arxiv.org/abs/1702.03118) """ - - def __init__(self, beta=1.0, trainable=False, **kwargs): + """ + Swish activation function with a trainable parameter referred to as 'beta' in https://arxiv.org/abs/1710.05941""" + def __init__(self, trainable = True, initial_beta = 1., **kwargs): super(Swish, self).__init__(**kwargs) self.supports_masking = True - self.beta = beta self.trainable = trainable + self.initial_beta = initial_beta + self.beta_initializer = Constant(value=self.initial_beta) + self.__name__ = 'swish' def build(self, input_shape): - self.scaling_factor = K.variable(self.beta, - dtype=K.floatx(), - name='scaling_factor') - if self.trainable: - self._trainable_weights.append(self.scaling_factor) - super(Swish, self).build(input_shape) + self.beta = self.add_weight(shape=[1], name='beta', + initializer=self.beta_initializer, + trainable=trainable) + self.input_spec = InputSpec(ndim=len(input_shape)) + self.built = True - def call(self, inputs, mask=None): - return inputs * K.sigmoid(self.scaling_factor * inputs) + def call(self, inputs): + return inputs * K.sigmoid(self.beta * inputs) def get_config(self): - config = {'beta': self.get_weights()[0] if self.trainable else self.beta, - 'trainable': self.trainable} + config = {'trainable': self.trainable, + 'initial_beta': self.initial_beta)} base_config = super(Swish, self).get_config() return dict(list(base_config.items()) + list(config.items())) - - def compute_output_shape(self, input_shape): - return input_shape