diff --git a/deeptables/models/layers.py b/deeptables/models/layers.py index eef76ef..285f946 100644 --- a/deeptables/models/layers.py +++ b/deeptables/models/layers.py @@ -923,9 +923,17 @@ def get_config(self): class VarLenColumnEmbedding(Layer): - def __init__(self, emb_vocab_size, emb_output_dim, dropout_rate=0. , **kwargs): + def __init__(self, emb_vocab_size, emb_output_dim, + embeddings_initializer, + embeddings_regularizer, + activity_regularizer, + dropout_rate=0., + **kwargs): self.emb_vocab_size = emb_vocab_size self.emb_output_dim = emb_output_dim + self.embeddings_initializer = embeddings_initializer + self.embeddings_regularizer = embeddings_regularizer + self.activity_regularizer = activity_regularizer self.dropout_rate = dropout_rate super(VarLenColumnEmbedding, self).__init__(**kwargs) self.dropout = None @@ -937,7 +945,11 @@ def compute_output_shape(self, input_shape): def build(self, input_shape=None): super(VarLenColumnEmbedding, self).build(input_shape) - self.emb_layer = Embedding(input_dim=self.emb_vocab_size, output_dim=self.emb_output_dim) + self.emb_layer = Embedding(input_dim=self.emb_vocab_size, + output_dim=self.emb_output_dim, + embeddings_initializer=self.embeddings_initializer, + embeddings_regularizer=self.embeddings_regularizer, + activity_regularizer=self.activity_regularizer) if self.dropout_rate > 0: self.dropout = SpatialDropout1D(self.dropout_rate, name='var_len_emb_dropout') else: @@ -945,20 +957,25 @@ def build(self, input_shape=None): self.built = True def call(self, inputs): - embedding_output = self.emb_layer.call(inputs) - embedding_output = embedding_output.reshape((embedding_output[0], 1, -1)) + embedding_output = self.emb_layer(inputs) + embedding_output_reshape = tf.reshape(embedding_output, [embedding_output.shape[0], 1, -1]) if self.dropout is not None: - dropout_output = self.dropout(embedding_output) + dropout_output = self.dropout(embedding_output_reshape) else: - dropout_output = embedding_output + dropout_output = embedding_output_reshape return dropout_output def compute_mask(self, inputs, mask=None): return None def get_config(self, ): - config = { 'dropout_rate': self.dropout_rate, - 'emb_layer': self.emb_layer.get_config()} + config = { 'dropout_rate': self.dropout_rate, + 'emb_layer': self.emb_layer.get_config(), + 'embeddings_initializer': self.embeddings_initializer, + 'embeddings_regularizer': self.embeddings_regularizer, + 'emb_vocab_size': self.emb_vocab_size, + 'emb_output_dim': self.emb_output_dim + } base_config = super(VarLenColumnEmbedding, self).get_config() return dict(list(base_config.items()) + list(config.items()))