Skip to content

fix ps+keras without unique #488

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@
import tensorflow as tf
import tensorflow_datasets as tfds

from absl import flags
from absl import app
from tensorflow_recommenders_addons import dynamic_embedding as de

try:
from tensorflow.keras.optimizers.legacy import Adam
from tensorflow.keras.optimizers.legacy import Adagrad
except:
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.optimizers import Adagrad

from tensorflow import distribute as tf_dist

flags = tf.compat.v1.app.flags
FLAGS = flags.FLAGS
Expand All @@ -34,6 +37,18 @@
], dtype=tf.int64, name='movie_id')
}

gpus = tf.config.list_physical_devices('GPU')
if gpus:
try:
# Currently, memory growth needs to be the same across GPUs
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
logical_gpus = tf.config.list_logical_devices('GPU')
print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
except RuntimeError as e:
# Memory growth must be set before GPUs have been initialized
print(e)


class DualChannelsDeepModel(tf.keras.Model):

Expand All @@ -59,11 +74,13 @@ def __init__(self,
user_embedding_size,
initializer=embedding_initializer,
devices=self.devices,
# with_unique=False,
name='user_embedding')
self.movie_embedding = de.keras.layers.SquashedEmbedding(
movie_embedding_size,
initializer=embedding_initializer,
devices=self.devices,
# with_unique=False,
name='movie_embedding')

self.dnn1 = tf.keras.layers.Dense(
Expand All @@ -90,11 +107,11 @@ def __init__(self,
@tf.function
def call(self, features):
user_id = tf.reshape(features['user_id'], (-1, 1))
movie_id = tf.reshape(features['movie_id'], (-1, 1))
# movie_id = tf.reshape(features['movie_id'], (-1, 1))
user_latent = self.user_embedding(user_id)
movie_latent = self.movie_embedding(movie_id)
latent = tf.concat([user_latent, movie_latent], axis=1)

# movie_latent = self.movie_embedding(movie_id)
# latent = tf.concat([user_latent, movie_latent], axis=1)
latent = user_latent
x = self.dnn1(latent)
x = self.dnn2(x)
x = self.dnn3(x)
Expand All @@ -115,7 +132,7 @@ def __init__(self, strategy, train_bs, test_bs, epochs, steps_per_epoch,
"/job:ps/replica:0/task:{}/device:CPU:0".format(idx)
for idx in range(self.num_ps)
]
self.embedding_size = 32
self.embedding_size = 1
self.train_bs = train_bs
self.test_bs = test_bs
self.epochs = epochs
Expand All @@ -133,7 +150,7 @@ def get_dataset(self, batch_size=1):
ratings = dataset.map(
lambda x: tf.one_hot(tf.cast(x['user_rating'] - 1, dtype=tf.int64), 5))
dataset = dataset.zip((features, ratings))
dataset = dataset.shuffle(4096, reshuffle_each_iteration=False)
dataset = dataset.shuffle(4096, reshuffle_each_iteration=False).repeat()
if batch_size > 1:
dataset = dataset.batch(batch_size)
return dataset
Expand All @@ -146,6 +163,8 @@ def train(self):
self.ps_devices, self.embedding_size, self.embedding_size,
tf.keras.initializers.RandomNormal(0.0, 0.5))
optimizer = Adam(1E-3)

# optimizer = Adagrad(1E-3)
optimizer = de.DynamicEmbeddingOptimizer(optimizer)

auc = tf.keras.metrics.AUC(num_thresholds=1000)
Expand All @@ -161,7 +180,7 @@ def train(self):
model.load_weights(self.model_dir)

model.fit(dataset, epochs=self.epochs, steps_per_epoch=self.steps_per_epoch)

print(f"model: {model.trainable_variables}")
if self.model_dir:
save_options = tf.saved_model.SaveOptions(namespace_whitelist=['TFRA'])
model.save(self.model_dir, options=save_options)
Expand Down Expand Up @@ -208,6 +227,7 @@ def test(self):

dataset = self.get_dataset(batch_size=self.test_bs)
dataset = self.strategy.experimental_distribute_dataset(dataset)

with self.strategy.scope():
model = tf.keras.models.load_model(self.export_dir)
signature = model.signatures['serving_default']
Expand Down Expand Up @@ -237,13 +257,12 @@ def start_chief(config):
cluster_spec = tf.train.ClusterSpec(config["cluster"])
cluster_resolver = tf.distribute.cluster_resolver.SimpleClusterResolver(
cluster_spec, task_type="chief", task_id=0)
strategy = tf.distribute.experimental.ParameterServerStrategy(
cluster_resolver)
strategy = tf_dist.experimental.ParameterServerStrategy(cluster_resolver)
runner = Runner(strategy=strategy,
train_bs=64,
train_bs=2,
test_bs=1,
epochs=2,
steps_per_epoch=10,
epochs=1,
steps_per_epoch=2,
model_dir=None,
export_dir=None)
runner.train()
Expand Down
10 changes: 10 additions & 0 deletions demo/dynamic_embedding/movielens-1m-keras-ps/one.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#!/bin/bash
rm -rf ./ckpt
sh stop.sh
sleep 1
python movielens-1m-keras-ps.py --ps_list="localhost:2220" --worker_list="localhost:2231" --chief="localhost:2230" --task_mode="ps" --task_id=0 &
sleep 1
python movielens-1m-keras-ps.py --ps_list="localhost:2220" --worker_list="localhost:2231" --chief="localhost:2230" --task_mode="worker" --task_id=0 &
sleep 1
python movielens-1m-keras-ps.py --ps_list="localhost:2220" --worker_list="localhost:2231" --chief="localhost:2230" --task_mode="chief" --task_id=0
echo "ok"
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@
from absl import app
from tensorflow_recommenders_addons import dynamic_embedding as de
try:
from tensorflow.keras.legacy.optimizers import Adam
from tensorflow.keras.optimizers.legacy import Adam
from tensorflow.keras.optimizers.legacy import Adagrad
except:
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.optimizers import Adagrad


flags.DEFINE_string('mode', 'train', 'Select the running mode: train or test.')
flags.DEFINE_string('model_dir', 'model_dir',
Expand Down Expand Up @@ -119,7 +122,8 @@ def train():
dataset = get_dataset(batch_size=32)
model = DualChannelsDeepModel(FLAGS.embedding_size, FLAGS.embedding_size,
tf.keras.initializers.RandomNormal(0.0, 0.5))
optimizer = Adam(1E-3)
# optimizer = Adam(1E-3)
optimizer = Adagrad(1E-3)
optimizer = de.DynamicEmbeddingOptimizer(optimizer)

auc = tf.keras.metrics.AUC(num_thresholds=1000)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ def on_batch_end(self, batch, logs=None):
with ops.device(self.device):
if hvd._executing_eagerly() and hasattr(self.model, 'variables'):
# TensorFlow 2.0 or TensorFlow eager
from tensorflow_recommenders_addons.dynamic_embedding.python.ops.shadow_embedding_ops import is_de_resource_variable
filter_lambda = lambda x: (x.ref() not in self._local_vars) and (
not isinstance(x, de.TrainableWrapper)) and (not isinstance(
x, de.DEResourceVariable))
not is_de_resource_variable(x))
broadcast_vars = [
var for var in self.model.variables if filter_lambda(var)
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
Dynamic Embedding is designed for Large-scale Sparse Weights Training.
See [Sparse Domain Isolation](https://github.com/tensorflow/community/pull/237)
"""

from packaging import version

import tensorflow as tf
Expand All @@ -28,7 +27,10 @@

from tensorflow.python.keras.utils import tf_utils

from tensorflow_recommenders_addons.dynamic_embedding.python.ops.parameter_server import create_ps_shadow_variable
from tensorflow_recommenders_addons.dynamic_embedding.python.ops.shadow_embedding_ops import HvdVariable
from tensorflow_recommenders_addons.dynamic_embedding.python.train.utils import \
is_parameter_server_strategy

if version.parse(tf.__version__) >= version.parse("2.14"):
from tensorflow.python.distribute import distribute_lib as distribute_ctx
Expand Down Expand Up @@ -225,7 +227,8 @@ def __init__(self,
shadow_name = name + '-shadow' if name else 'ShadowVariable'
if distribute_ctx.has_strategy():
self.distribute_strategy = distribute_ctx.get_strategy()
if self.distribute_strategy:
if self.distribute_strategy and not is_parameter_server_strategy(
self.distribute_strategy):
strategy_devices = self.distribute_strategy.extended.worker_devices
self.shadow_impl = tf_utils.ListWrapper([])
for i, strategy_device in enumerate(strategy_devices):
Expand All @@ -242,12 +245,23 @@ def __init__(self,
trainable=trainable,
distribute_strategy=self.distribute_strategy))
else:
self.shadow_impl = tf_utils.ListWrapper([
de.shadow_ops.ShadowVariable(self.params,
name=shadow_name,
max_norm=self.max_norm,
trainable=trainable)
])
if is_parameter_server_strategy(self.distribute_strategy):
self.shadow_impl = tf_utils.ListWrapper([
create_ps_shadow_variable(
params=self.params,
name=shadow_name,
max_norm=self.max_norm,
strategy=self.distribute_strategy,
trainable=trainable)
])
else:
self.shadow_impl = tf_utils.ListWrapper([
de.shadow_ops.ShadowVariable(self.params,
name=shadow_name,
max_norm=self.max_norm,
trainable=trainable)
])

if len(self.shadow_impl.as_list()) > 1:
self._current_ids = data_structures.NoDependency(
[shadow_i.ids for shadow_i in self.shadow_impl.as_list()])
Expand All @@ -261,24 +275,25 @@ def __init__(self,
self._current_exists = data_structures.NoDependency(
self.shadow_impl.as_list()[0].exists)
self.optimizer_vars = self.shadow_impl.as_list()[0]._optimizer_vars
if distribute_ctx.has_strategy(
) and self.distribute_strategy and 'OneDeviceStrategy' not in str(
self.distribute_strategy) and not values_util.is_saving_non_distributed(
) and values_util.get_current_replica_id_as_int() is not None:
if distribute_ctx.has_strategy() and self.distribute_strategy and \
'OneDeviceStrategy' not in str(self.distribute_strategy) and \
not values_util.is_saving_non_distributed() and \
values_util.get_current_replica_id_as_int() is not None:
self.shadow = de.DistributedVariableWrapper(
self.distribute_strategy, self.shadow_impl.as_list(),
VariableAggregation.NONE,
TrainableWrapperDistributedPolicy(VariableAggregation.NONE))
else:
self.shadow = self.shadow_impl.as_list()[0]

self.params._created_in_class = self # To facilitate access to the primitive class through params
super(Embedding, self).__init__(name=name,
trainable=trainable,
dtype=value_dtype)

def call(self, ids):
"""
Compute embedding output for feature ids. The output shape will be (shape(ids),
Compute embedding output for feature ids. The output shape will be (shape(ids),
embedding_size).

Args:
Expand All @@ -289,10 +304,16 @@ def call(self, ids):
Returns:
A embedding output with shape (shape(ids), embedding_size).
"""
return de.shadow_ops.embedding_lookup_unique(self.shadow, ids,

r = de.shadow_ops.embedding_lookup_unique(self.shadow, ids,
self.embedding_size,
self.with_unique, self.name)

tfprint = tf.print("ids_8a:", r, ids, self.shadow.ids, output_stream=tf.compat.v1.logging.error)
with tf.control_dependencies([tfprint]):
pass
return r

def get_config(self):
_initializer = self.params.initializer
if _initializer is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,18 @@ py_test(
],
)

# This test is not for pytest, it requires
# bazel test //tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests:parameter_server_bzl
py_test(
name = "parameter_server_bzl",
srcs = ["parameter_server_bzl.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
"//tensorflow_recommenders_addons",
],
)

# This test will be banned by GitHub and cause account violations, please run the test manually locally.
# py_test(
# name = "redis_table_variable_test",
Expand Down
Loading
Loading