Skip to content

Commit

Permalink
update var len embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
oaksharks committed Nov 21, 2024
1 parent 4643439 commit b7ca9c6
Show file tree
Hide file tree
Showing 7 changed files with 32 additions and 32 deletions.
2 changes: 1 addition & 1 deletion deeptables/models/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def __new__(cls,
gpu_usage_strategy=consts.GPU_USAGE_STRATEGY_GROWTH,
distribute_strategy=None,
var_len_categorical_columns=None,
# a tuple3, format is (column_name, separator, pool_strategy), pool_strategy is one of max,avg; e.g. [('genres', '|', 'avg' )]
# a tuple2, format is (column_name, separator), pool_strategy is one of max,avg; e.g. [('genres', '|' )]
):

if var_len_categorical_columns is not None and len(var_len_categorical_columns) > 0:
Expand Down
7 changes: 3 additions & 4 deletions deeptables/models/deepmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def __build_model(self, task, num_classes, nets, categorical_columns, continuous
if len(embeddings) == 1:
flatten_emb_layer = Flatten(name='flatten_embeddings')(embeddings[0])
else:
flatten_emb_layer = Flatten(name='flatten_embeddings')(Concatenate(name='concat_embeddings_axis_0', axis=1)(embeddings))
flatten_emb_layer = Flatten(name='flatten_embeddings')(Concatenate(name='concat_embeddings_axis_0', axis=-1)(embeddings))

self.model_desc.nets = nets
self.model_desc.stacking = config.stacking_op
Expand Down Expand Up @@ -407,9 +407,8 @@ def __build_embeddings(self, categorical_columns, categorical_inputs,
for column in var_len_categorical_columns:
# todo add var len embedding description
input_layer = var_len_inputs[column.name]
var_len_embeddings = VarLenColumnEmbedding(pooling_strategy=column.pooling_strategy,
input_dim=column.vocabulary_size,
output_dim=column.embeddings_output_dim,
var_len_embeddings = VarLenColumnEmbedding(emb_vocab_size=column.vocabulary_size,
emb_output_dim=column.embeddings_output_dim,
dropout_rate=embedding_dropout,
name=consts.LAYER_PREFIX_EMBEDDING + column.name,
embeddings_initializer=self.config.embeddings_initializer,
Expand Down
35 changes: 20 additions & 15 deletions deeptables/models/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -922,38 +922,43 @@ def get_config(self):
return dict(list(base_config.items()) + list(config.items()))


class VarLenColumnEmbedding(Embedding):
def __init__(self, pooling_strategy='max', dropout_rate=0., **kwargs):
if pooling_strategy not in ['mean', 'max']:
raise ValueError("Param strategy should is one of mean, max")
self.pooling_strategy = pooling_strategy
self.dropout_rate = dropout_rate # 支持dropout
class VarLenColumnEmbedding(Layer):
def __init__(self, emb_vocab_size, emb_output_dim, dropout_rate=0. , **kwargs):
self.emb_vocab_size = emb_vocab_size
self.emb_output_dim = emb_output_dim
self.dropout_rate = dropout_rate
super(VarLenColumnEmbedding, self).__init__(**kwargs)
self._dropout = None
self.dropout = None
self.emb_layer = None

def compute_output_shape(self, input_shape):
n_dim = input_shape[1]
return input_shape[0] , self.emb_output_dim * n_dim

def build(self, input_shape=None):
super(VarLenColumnEmbedding, self).build(input_shape)
self.emb_layer = Embedding(input_dim=self.emb_vocab_size, output_dim=self.emb_output_dim)
if self.dropout_rate > 0:
self._dropout = SpatialDropout1D(self.dropout_rate, name='var_len_emb_dropout')
self.dropout = SpatialDropout1D(self.dropout_rate, name='var_len_emb_dropout')
else:
self._dropout = None
self.dropout = None
self.built = True

def call(self, inputs):
embedding_output = super(VarLenColumnEmbedding, self).call(inputs)

if self._dropout is not None:
dropout_output = self._dropout(embedding_output)
embedding_output = self.emb_layer.call(inputs)
embedding_output = embedding_output.reshape((embedding_output[0], 1, -1))
if self.dropout is not None:
dropout_output = self.dropout(embedding_output)
else:
dropout_output = embedding_output

return dropout_output

def compute_mask(self, inputs, mask=None):
return None

def get_config(self, ):
config = {'pooling_strategy': self.pooling_strategy}
config = { 'dropout_rate': self.dropout_rate,
'emb_layer': self.emb_layer.get_config()}
base_config = super(VarLenColumnEmbedding, self).get_config()
return dict(list(base_config.items()) + list(config.items()))

Expand Down
7 changes: 3 additions & 4 deletions deeptables/models/metainfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,21 +55,20 @@ class VarLenCategoricalColumn(collections.namedtuple('VarLenCategoricalColumn',
'embeddings_output_dim',
'dtype',
'input_name',
'sep',
'pooling_strategy',
'sep'
])):

def __hash__(self):
return self.name.__hash__()

def __new__(cls, name, vocabulary_size, embeddings_output_dim=10, dtype='int32', input_name=None, sep="|", pooling_strategy='max'):
def __new__(cls, name, vocabulary_size, embeddings_output_dim=10, dtype='int32', input_name=None, sep="|"):
if input_name is None:
input_name = consts.INPUT_PREFIX_CAT + name
if embeddings_output_dim == 0:
embeddings_output_dim = int(round(vocabulary_size ** 0.25))
# max_elements_length need a variable not const
return super(VarLenCategoricalColumn, cls).__new__(cls, name, vocabulary_size, embeddings_output_dim, dtype,
input_name, sep, pooling_strategy)
input_name, sep)


class ContinuousColumn(collections.namedtuple('ContinuousColumn',
Expand Down
10 changes: 3 additions & 7 deletions deeptables/models/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,10 +279,8 @@ def _prepare_features(self, X):
else:
var_len_column_names.append(v[0])
var_len_col_sep_dict = {v[0]: v[1] for v in var_len_categorical_columns}
var_len_col_pooling_strategy_dict = {v[0]: v[2] for v in var_len_categorical_columns}
else:
var_len_col_sep_dict = {}
var_len_col_pooling_strategy_dict = {}

X_shape = self._get_shape(X)
unique_upper_limit = round(X_shape[0] ** self.config.cat_exponent)
Expand All @@ -299,8 +297,7 @@ def _prepare_features(self, X):

# handle var len feature
if c in var_len_column_names:
self.__append_var_len_categorical_col(c, nunique, var_len_col_sep_dict[c],
var_len_col_pooling_strategy_dict[c])
self.__append_var_len_categorical_col(c, nunique, var_len_col_sep_dict[c])
continue

if self.config.categorical_columns is not None and isinstance(self.config.categorical_columns, list):
Expand Down Expand Up @@ -454,7 +451,7 @@ def _gbm_features_to_continuous_cols(self, X, gbmencoder):
# return [name for name in gbmencoder.new_columns]
return gbmencoder.new_columns

def __append_var_len_categorical_col(self, name, voc_size, sep, pooling_strategy):
def __append_var_len_categorical_col(self, name, voc_size, sep):
logger.debug(f'Var len categorical variables {name} appended.')

if self.config.fixed_embedding_dim:
Expand All @@ -470,8 +467,7 @@ def __append_var_len_categorical_col(self, name, voc_size, sep, pooling_strategy
voc_size,
embedding_output_dim if embedding_output_dim > 0 else min(
4 * int(pow(voc_size, 0.25)), 20),
sep=sep,
pooling_strategy=pooling_strategy)
sep=sep)

self.var_len_categorical_columns.append(vc)

Expand Down
1 change: 1 addition & 0 deletions deeptables/tests/models/config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,5 @@ def test_embeddings_output_dim(self):
dt = deeptable.DeepTable(config=conf)

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

model, history = dt.fit(X_train, y_train, epochs=1)
2 changes: 1 addition & 1 deletion deeptables/utils/dataset_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __call__(self, X, y=None, *, batch_size, shuffle, drop_remainder):
train_data.append(tf.constant(np.array(X[col.name].tolist()).astype(consts.DATATYPE_TENSOR_FLOAT).tolist()))

if y is None:
ds = tf.data.Dataset.from_tensor_slices(train_data, name='train_x')
ds = tf.data.Dataset.from_tensor_slices((tuple(train_data), ), name='train_x')
else:
y = tf.constant(np.array(y).tolist())
if self.task == consts.TASK_MULTICLASS:
Expand Down

0 comments on commit b7ca9c6

Please sign in to comment.