diff --git a/Time2Vec/layers.py b/Time2Vec/layers.py index c000547..b8c265c 100644 --- a/Time2Vec/layers.py +++ b/Time2Vec/layers.py @@ -3,7 +3,7 @@ class Time2Vec(Layer): - def __init__(self, kernel_size, periodic_activation='sin'): + def __init__(self, kernel_size, periodic_activation='sin', **kwargs): ''' :param kernel_size: The length of time vector representation. @@ -11,7 +11,8 @@ def __init__(self, kernel_size, periodic_activation='sin'): ''' super(Time2Vec, self).__init__( trainable=True, - name='Time2VecLayer_'+periodic_activation.upper() + name='Time2VecLayer_'+periodic_activation.upper(), + **kwargs ) self.k = kernel_size @@ -22,26 +23,30 @@ def build(self, input_shape): self.wb = self.add_weight( shape=(1, 1), initializer='uniform', - trainable=True + trainable=True, + name='wb_weight' ) self.bb = self.add_weight( shape=(1, 1), initializer='uniform', - trainable=True + trainable=True, + name='bb_weight' ) # Else needs to pass the periodic activation self.wa = self.add_weight( shape=(1, self.k), initializer='uniform', - trainable=True + trainable=True, + name='wa_weight' ) self.ba = self.add_weight( shape=(1, self.k), initializer='uniform', - trainable=True + trainable=True, + name='ba_weight' ) super(Time2Vec, self).build(input_shape) @@ -63,4 +68,10 @@ def call(self, inputs, **kwargs): return K.concatenate([bias, wgts], -1) def compute_output_shape(self, input_shape): - return (input_shape[0], input_shape[1], self.k + 1) \ No newline at end of file + return (input_shape[0], input_shape[1], self.k + 1) + + def get_config(self): + config = super(Time2Vec, self).get_config() + config.update({"kernel_size": self.k}) + config.update({"periodic_activation": self.p_activation}) + return config