diff --git a/demo/dynamic_embedding/movielens-1m-keras-ps/movielens-1m-keras-ps.py b/demo/dynamic_embedding/movielens-1m-keras-ps/movielens-1m-keras-ps.py index 26e537a47..8e298440e 100644 --- a/demo/dynamic_embedding/movielens-1m-keras-ps/movielens-1m-keras-ps.py +++ b/demo/dynamic_embedding/movielens-1m-keras-ps/movielens-1m-keras-ps.py @@ -2,13 +2,16 @@ import tensorflow as tf import tensorflow_datasets as tfds -from absl import flags -from absl import app from tensorflow_recommenders_addons import dynamic_embedding as de + try: from tensorflow.keras.optimizers.legacy import Adam + from tensorflow.keras.optimizers.legacy import Adagrad except: from tensorflow.keras.optimizers import Adam + from tensorflow.keras.optimizers import Adagrad + +from tensorflow import distribute as tf_dist flags = tf.compat.v1.app.flags FLAGS = flags.FLAGS @@ -34,6 +37,18 @@ ], dtype=tf.int64, name='movie_id') } +gpus = tf.config.list_physical_devices('GPU') +if gpus: + try: + # Currently, memory growth needs to be the same across GPUs + for gpu in gpus: + tf.config.experimental.set_memory_growth(gpu, True) + logical_gpus = tf.config.list_logical_devices('GPU') + print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs") + except RuntimeError as e: + # Memory growth must be set before GPUs have been initialized + print(e) + class DualChannelsDeepModel(tf.keras.Model): @@ -59,11 +74,13 @@ def __init__(self, user_embedding_size, initializer=embedding_initializer, devices=self.devices, + # with_unique=False, name='user_embedding') self.movie_embedding = de.keras.layers.SquashedEmbedding( movie_embedding_size, initializer=embedding_initializer, devices=self.devices, + # with_unique=False, name='movie_embedding') self.dnn1 = tf.keras.layers.Dense( @@ -90,11 +107,11 @@ def __init__(self, @tf.function def call(self, features): user_id = tf.reshape(features['user_id'], (-1, 1)) - movie_id = tf.reshape(features['movie_id'], (-1, 1)) + # movie_id = tf.reshape(features['movie_id'], (-1, 1)) user_latent = self.user_embedding(user_id) - movie_latent = self.movie_embedding(movie_id) - latent = tf.concat([user_latent, movie_latent], axis=1) - + # movie_latent = self.movie_embedding(movie_id) + # latent = tf.concat([user_latent, movie_latent], axis=1) + latent = user_latent x = self.dnn1(latent) x = self.dnn2(x) x = self.dnn3(x) @@ -115,7 +132,7 @@ def __init__(self, strategy, train_bs, test_bs, epochs, steps_per_epoch, "/job:ps/replica:0/task:{}/device:CPU:0".format(idx) for idx in range(self.num_ps) ] - self.embedding_size = 32 + self.embedding_size = 1 self.train_bs = train_bs self.test_bs = test_bs self.epochs = epochs @@ -133,7 +150,7 @@ def get_dataset(self, batch_size=1): ratings = dataset.map( lambda x: tf.one_hot(tf.cast(x['user_rating'] - 1, dtype=tf.int64), 5)) dataset = dataset.zip((features, ratings)) - dataset = dataset.shuffle(4096, reshuffle_each_iteration=False) + dataset = dataset.shuffle(4096, reshuffle_each_iteration=False).repeat() if batch_size > 1: dataset = dataset.batch(batch_size) return dataset @@ -146,6 +163,8 @@ def train(self): self.ps_devices, self.embedding_size, self.embedding_size, tf.keras.initializers.RandomNormal(0.0, 0.5)) optimizer = Adam(1E-3) + + # optimizer = Adagrad(1E-3) optimizer = de.DynamicEmbeddingOptimizer(optimizer) auc = tf.keras.metrics.AUC(num_thresholds=1000) @@ -161,7 +180,7 @@ def train(self): model.load_weights(self.model_dir) model.fit(dataset, epochs=self.epochs, steps_per_epoch=self.steps_per_epoch) - + print(f"model: {model.trainable_variables}") if self.model_dir: save_options = tf.saved_model.SaveOptions(namespace_whitelist=['TFRA']) model.save(self.model_dir, options=save_options) @@ -208,6 +227,7 @@ def test(self): dataset = self.get_dataset(batch_size=self.test_bs) dataset = self.strategy.experimental_distribute_dataset(dataset) + with self.strategy.scope(): model = tf.keras.models.load_model(self.export_dir) signature = model.signatures['serving_default'] @@ -237,13 +257,12 @@ def start_chief(config): cluster_spec = tf.train.ClusterSpec(config["cluster"]) cluster_resolver = tf.distribute.cluster_resolver.SimpleClusterResolver( cluster_spec, task_type="chief", task_id=0) - strategy = tf.distribute.experimental.ParameterServerStrategy( - cluster_resolver) + strategy = tf_dist.experimental.ParameterServerStrategy(cluster_resolver) runner = Runner(strategy=strategy, - train_bs=64, + train_bs=2, test_bs=1, - epochs=2, - steps_per_epoch=10, + epochs=1, + steps_per_epoch=2, model_dir=None, export_dir=None) runner.train() diff --git a/demo/dynamic_embedding/movielens-1m-keras-ps/one.sh b/demo/dynamic_embedding/movielens-1m-keras-ps/one.sh new file mode 100644 index 000000000..9fa628811 --- /dev/null +++ b/demo/dynamic_embedding/movielens-1m-keras-ps/one.sh @@ -0,0 +1,10 @@ +#!/bin/bash +rm -rf ./ckpt +sh stop.sh +sleep 1 +python movielens-1m-keras-ps.py --ps_list="localhost:2220" --worker_list="localhost:2231" --chief="localhost:2230" --task_mode="ps" --task_id=0 & +sleep 1 +python movielens-1m-keras-ps.py --ps_list="localhost:2220" --worker_list="localhost:2231" --chief="localhost:2230" --task_mode="worker" --task_id=0 & +sleep 1 +python movielens-1m-keras-ps.py --ps_list="localhost:2220" --worker_list="localhost:2231" --chief="localhost:2230" --task_mode="chief" --task_id=0 +echo "ok" \ No newline at end of file diff --git a/demo/dynamic_embedding/movielens-1m-keras/movielens-1m-keras.py b/demo/dynamic_embedding/movielens-1m-keras/movielens-1m-keras.py index 5c32745c4..b9f44f7fe 100644 --- a/demo/dynamic_embedding/movielens-1m-keras/movielens-1m-keras.py +++ b/demo/dynamic_embedding/movielens-1m-keras/movielens-1m-keras.py @@ -6,9 +6,12 @@ from absl import app from tensorflow_recommenders_addons import dynamic_embedding as de try: - from tensorflow.keras.legacy.optimizers import Adam + from tensorflow.keras.optimizers.legacy import Adam + from tensorflow.keras.optimizers.legacy import Adagrad except: from tensorflow.keras.optimizers import Adam + from tensorflow.keras.optimizers import Adagrad + flags.DEFINE_string('mode', 'train', 'Select the running mode: train or test.') flags.DEFINE_string('model_dir', 'model_dir', @@ -119,7 +122,8 @@ def train(): dataset = get_dataset(batch_size=32) model = DualChannelsDeepModel(FLAGS.embedding_size, FLAGS.embedding_size, tf.keras.initializers.RandomNormal(0.0, 0.5)) - optimizer = Adam(1E-3) + # optimizer = Adam(1E-3) + optimizer = Adagrad(1E-3) optimizer = de.DynamicEmbeddingOptimizer(optimizer) auc = tf.keras.metrics.AUC(num_thresholds=1000) diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/keras/callbacks.py b/tensorflow_recommenders_addons/dynamic_embedding/python/keras/callbacks.py index a6a6d9fcb..21e0cdde1 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/keras/callbacks.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/keras/callbacks.py @@ -68,9 +68,9 @@ def on_batch_end(self, batch, logs=None): with ops.device(self.device): if hvd._executing_eagerly() and hasattr(self.model, 'variables'): # TensorFlow 2.0 or TensorFlow eager + from tensorflow_recommenders_addons.dynamic_embedding.python.ops.shadow_embedding_ops import is_de_resource_variable filter_lambda = lambda x: (x.ref() not in self._local_vars) and ( - not isinstance(x, de.TrainableWrapper)) and (not isinstance( - x, de.DEResourceVariable)) + not is_de_resource_variable(x)) broadcast_vars = [ var for var in self.model.variables if filter_lambda(var) ] diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/keras/layers/embedding.py b/tensorflow_recommenders_addons/dynamic_embedding/python/keras/layers/embedding.py index a3aaa44f7..0e8fbcc2a 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/keras/layers/embedding.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/keras/layers/embedding.py @@ -17,7 +17,6 @@ Dynamic Embedding is designed for Large-scale Sparse Weights Training. See [Sparse Domain Isolation](https://github.com/tensorflow/community/pull/237) """ - from packaging import version import tensorflow as tf @@ -28,7 +27,10 @@ from tensorflow.python.keras.utils import tf_utils +from tensorflow_recommenders_addons.dynamic_embedding.python.ops.parameter_server import create_ps_shadow_variable from tensorflow_recommenders_addons.dynamic_embedding.python.ops.shadow_embedding_ops import HvdVariable +from tensorflow_recommenders_addons.dynamic_embedding.python.train.utils import \ + is_parameter_server_strategy if version.parse(tf.__version__) >= version.parse("2.14"): from tensorflow.python.distribute import distribute_lib as distribute_ctx @@ -225,7 +227,8 @@ def __init__(self, shadow_name = name + '-shadow' if name else 'ShadowVariable' if distribute_ctx.has_strategy(): self.distribute_strategy = distribute_ctx.get_strategy() - if self.distribute_strategy: + if self.distribute_strategy and not is_parameter_server_strategy( + self.distribute_strategy): strategy_devices = self.distribute_strategy.extended.worker_devices self.shadow_impl = tf_utils.ListWrapper([]) for i, strategy_device in enumerate(strategy_devices): @@ -242,12 +245,23 @@ def __init__(self, trainable=trainable, distribute_strategy=self.distribute_strategy)) else: - self.shadow_impl = tf_utils.ListWrapper([ - de.shadow_ops.ShadowVariable(self.params, - name=shadow_name, - max_norm=self.max_norm, - trainable=trainable) - ]) + if is_parameter_server_strategy(self.distribute_strategy): + self.shadow_impl = tf_utils.ListWrapper([ + create_ps_shadow_variable( + params=self.params, + name=shadow_name, + max_norm=self.max_norm, + strategy=self.distribute_strategy, + trainable=trainable) + ]) + else: + self.shadow_impl = tf_utils.ListWrapper([ + de.shadow_ops.ShadowVariable(self.params, + name=shadow_name, + max_norm=self.max_norm, + trainable=trainable) + ]) + if len(self.shadow_impl.as_list()) > 1: self._current_ids = data_structures.NoDependency( [shadow_i.ids for shadow_i in self.shadow_impl.as_list()]) @@ -261,16 +275,17 @@ def __init__(self, self._current_exists = data_structures.NoDependency( self.shadow_impl.as_list()[0].exists) self.optimizer_vars = self.shadow_impl.as_list()[0]._optimizer_vars - if distribute_ctx.has_strategy( - ) and self.distribute_strategy and 'OneDeviceStrategy' not in str( - self.distribute_strategy) and not values_util.is_saving_non_distributed( - ) and values_util.get_current_replica_id_as_int() is not None: + if distribute_ctx.has_strategy() and self.distribute_strategy and \ + 'OneDeviceStrategy' not in str(self.distribute_strategy) and \ + not values_util.is_saving_non_distributed() and \ + values_util.get_current_replica_id_as_int() is not None: self.shadow = de.DistributedVariableWrapper( self.distribute_strategy, self.shadow_impl.as_list(), VariableAggregation.NONE, TrainableWrapperDistributedPolicy(VariableAggregation.NONE)) else: self.shadow = self.shadow_impl.as_list()[0] + self.params._created_in_class = self # To facilitate access to the primitive class through params super(Embedding, self).__init__(name=name, trainable=trainable, @@ -278,7 +293,7 @@ def __init__(self, def call(self, ids): """ - Compute embedding output for feature ids. The output shape will be (shape(ids), + Compute embedding output for feature ids. The output shape will be (shape(ids), embedding_size). Args: @@ -289,10 +304,16 @@ def call(self, ids): Returns: A embedding output with shape (shape(ids), embedding_size). """ - return de.shadow_ops.embedding_lookup_unique(self.shadow, ids, + + r = de.shadow_ops.embedding_lookup_unique(self.shadow, ids, self.embedding_size, self.with_unique, self.name) + tfprint = tf.print("ids_8a:", r, ids, self.shadow.ids, output_stream=tf.compat.v1.logging.error) + with tf.control_dependencies([tfprint]): + pass + return r + def get_config(self): _initializer = self.params.initializer if _initializer is None: diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/BUILD b/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/BUILD index ef00ee0b9..8a4c1d7e1 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/BUILD +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/BUILD @@ -72,6 +72,18 @@ py_test( ], ) +# This test is not for pytest, it requires +# bazel test //tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests:parameter_server_bzl +py_test( + name = "parameter_server_bzl", + srcs = ["parameter_server_bzl.py"], + python_version = "PY3", + srcs_version = "PY3", + deps = [ + "//tensorflow_recommenders_addons", + ], +) + # This test will be banned by GitHub and cause account violations, please run the test manually locally. # py_test( # name = "redis_table_variable_test", diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/parameter_server_bzl.py b/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/parameter_server_bzl.py new file mode 100644 index 000000000..0e217dffb --- /dev/null +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests/parameter_server_bzl.py @@ -0,0 +1,222 @@ +# pytest: skip +import os +import sys + +from tensorflow.python.distribute import multi_process_lib +import multiprocessing +import tensorflow as tf +from tensorflow.python.framework import constant_op + +from tensorflow.python.training import server_lib + +from tensorflow_recommenders_addons import dynamic_embedding as de + +import numpy as np +from tensorflow.python.compat import v2_compat +from tensorflow.python.distribute import multi_process_runner +from tensorflow.python.distribute import multi_worker_test_base +from tensorflow.python.distribute import parameter_server_strategy_v2 +from tensorflow.python.distribute.cluster_resolver import cluster_resolver as cluster_resolver_lib + +from tensorflow.python.eager import test +from packaging import version +from tensorflow.python.distribute.coordinator import cluster_coordinator as coordinator_lib +from tensorflow.python.eager import def_function +from tensorflow.python.ops import variables + +if version.parse(tf.__version__) >= version.parse("2.16"): + from tf_keras import layers + from tf_keras import Sequential + from tf_keras.optimizers import Adam +else: + from tensorflow.python.keras import layers + from tensorflow.python.keras import Sequential + try: + from tensorflow.keras.optimizers import Adam + except: + from tensorflow.keras.optimizers.legacy import Adam + + +def create_multi_process_cluster(cluster_spec, + rpc_layer='grpc', + stream_output=False, + collective_leader=None): + + cluster = multi_worker_test_base.MultiProcessCluster( + cluster_resolver_lib.SimpleClusterResolver( + server_lib.ClusterSpec(cluster_spec), rpc_layer=rpc_layer), + stream_output=stream_output, + collective_leader=collective_leader) + cluster.start() + return cluster + + +class ParameterServerStrategyV2Test(test.TestCase): + + @classmethod + def setUpClass(cls): + super(ParameterServerStrategyV2Test, cls).setUpClass() + cluster_spec = { + "worker": ["localhost:2223", "localhost:2224"], + "ps": ["localhost:2222"] + } + cls.cluster = create_multi_process_cluster(cluster_spec) + cls.cluster_resolver = cls.cluster.cluster_resolver + # cls.strategy = DEParameterServerStrategy(cls.cluster_resolver) + cls.strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( + cls.cluster_resolver) + cls.coordinator = coordinator_lib.ClusterCoordinator(cls.strategy) + + @classmethod + def tearDownClass(cls): + super(ParameterServerStrategyV2Test, cls).tearDownClass() + cls.cluster.stop() + + def testPerWorkerTraining(self): + var_dtype = tf.dtypes.float32 + var_name = 'var' + shape = [1] + with self.strategy.scope(): + var = variables.Variable(initial_value=[0.0], + shape=shape, + dtype=var_dtype, + name=var_name, + per_worker_variable=True) + var._trainable = True + with backprop.GradientTape(persistent=True) as tape: + + # 定义训练步骤 + @tf.function + def train_step(): + with tf.GradientTape() as tape: + # var._maybe_create_per_worker_vars() + value = var.read_value() + # if not var.trainable: + tape.watch(value) # still need this with var._trainable = True set. + y = value * 2.0 + grad = tape.gradient(y, value) + return grad + + @tf.function + def train_step2(): + with tf.GradientTape() as tape: + var._maybe_create_per_worker_vars() + value = var.value() + # if not var.trainable: + tape.watch(value) # still need this with var._trainable = True set. + y = value * 2.0 + grad = tape.gradient(y, value) + return grad + + # 运行并检查结果 + grads = self.strategy.run(train_step2) + print(f"grads :{grads}") + print(f"var.read_all() {var.read_all()}") + #@parameterized.parameters(True, False) + # def testPerWorkerVariableCreation(self): + # var_dtype = tf.dtypes.float32 + # var_name = 'var' + # shape = [1] #if define_shape else None + # + # with self.strategy.scope(): + # var = variables.Variable(initial_value=[0.0], + # shape=shape, + # dtype=var_dtype, + # name=var_name, + # per_worker_de_variable=True) + # + # # Use per-worker variable as a capture + # @def_function.function + # def worker_fn(): + # var.assign_add(constant_op.constant([1.0])) + # return var + # + # num_closures = 10 + # for ix in range(num_closures): + # self.coordinator.schedule(worker_fn) + # # Read the PWV many times to ensure result is up-to-date + # self.coordinator.join() + # result_sum = sum(var.read_all()).numpy() + # self.assertEqual(result_sum, ix + 1) + # + # for _ in range(num_closures): + # self.coordinator.schedule(worker_fn) + # self.coordinator.join() + # + # # Verify placement of variables + # devices = [wv._get_values().device for wv in var._per_worker_vars._values] + # expected_devices = [ + # f'/job:worker/replica:0/task:{ix}/device:CPU:0' + # for ix in range(self.strategy._num_workers) + # ] # pylint: disable=protected-access + # self.assertAllEqual(devices, expected_devices) + # + # result_sum = sum(var.read_all()).numpy() + # self.assertEqual(result_sum, num_closures * 2) + + # def testKerasFit(self): + # embed_dim = 8 + # with self.strategy.scope(): + # model = Sequential([ + # layers.Input(shape=(1,), dtype=tf.int32), + # de.keras.layers.Embedding(embed_dim, key_dtype=tf.int32), + # layers.Flatten(), + # layers.Dense(1, activation='sigmoid') + # ]) + # optimizer = Adam(1E-3) + # optimizer = de.DynamicEmbeddingOptimizer(optimizer) + # model.compile(loss='binary_crossentropy', + # optimizer=optimizer, + # metrics=['accuracy']) + # + # ids = np.random.randint(0, 100, size=(64 * 2, 1)) + # labels = np.random.randint(0, 2, size=(64 * 2, 1)) + # + # def dataset_fn(input_context): + # global_batch_size = 32 + # batch_size = input_context.get_per_replica_batch_size(global_batch_size) + # dataset = tf.data.Dataset.from_tensor_slices((ids, labels)) + # dataset = dataset.shard(input_context.num_input_pipelines, + # input_context.input_pipeline_id) + # dataset = dataset.batch(batch_size).repeat() + # return dataset + # + # dataset = self.strategy.distribute_datasets_from_function(dataset_fn) + # + # history = model.fit(dataset, epochs=1, steps_per_epoch=len(ids) // 64) + # self.assertIn('loss', history.history) + + +# borrow from multi_process_lib._set_spawn_exe_path and modify it for tf_recommenders_addons +def custom_set_spawn_exe_path(): + if sys.argv[0].endswith('.py'): + + def guess_path(package_root): + # If all we have is a python module path, we'll need to make a guess for + # the actual executable path. + if 'bazel-out' in sys.argv[0] and package_root in sys.argv[0]: + package_root_base = sys.argv[0][:sys.argv[0].rfind(package_root)] + binary = os.environ['TEST_TARGET'][2:].replace(':', '/', 1) + print(f"package_root_base {package_root_base} binary {binary}") + possible_path = os.path.join(package_root_base, package_root, binary) + print('Guessed test binary path: %s', possible_path) + if os.access(possible_path, os.X_OK): + return possible_path + return None + + path = guess_path('tf_recommenders_addons') + if path is None: + print('Cannot determine binary path. sys.argv[0]=%s os.environ=%s', + sys.argv[0], os.environ) + raise RuntimeError('Cannot determine binary path') + sys.argv[0] = path + # Note that this sets the executable for *all* contexts. + multiprocessing.get_context().set_executable(sys.argv[0]) + + +# This is not for pytest bazel clean --expunge +# bazel test --test_output=all //tensorflow_recommenders_addons/dynamic_embedding/python/kernel_tests:parameter_server_bzl +if __name__ == "__main__": + multi_process_lib._set_spawn_exe_path = custom_set_spawn_exe_path + v2_compat.enable_v2_behavior() + multi_process_runner.test_main() diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/distributed_embedding_variable.py b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/distributed_embedding_variable.py index 353cc3088..d5cb0b70b 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/distributed_embedding_variable.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/distributed_embedding_variable.py @@ -10,10 +10,10 @@ class DistributedVariableWrapper(EmbeddingWeights, def __init__(self, strategy, values, aggregation, var_policy=None): super(DistributedVariableWrapper, self).__init__(strategy, values, aggregation, var_policy) - self.shadow = self._get_on_device_or_primary() + self._shadow = self._get_on_device_or_primary() def verify_embedding_weights(self, sparse_ids, sparse_weights=None): - EmbeddingWeights.verify_embedding_param_weights(self.shadow.params, + EmbeddingWeights.verify_embedding_param_weights(self._shadow.params, sparse_ids, sparse_weights) def embedding_lookup(self, diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_ops.py b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_ops.py index 47c62599e..7e706d0f3 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_ops.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_ops.py @@ -21,7 +21,7 @@ from packaging import version from tensorflow_recommenders_addons import dynamic_embedding as de -from tensorflow_recommenders_addons.dynamic_embedding.python.ops.shadow_embedding_ops import DEResourceVariable +from tensorflow_recommenders_addons.dynamic_embedding.python.ops.shadow_embedding_ops import is_de_resource_variable from tensorflow_recommenders_addons.dynamic_embedding.python.ops.embedding_weights import EmbeddingWeights from tensorflow import version as tf_version @@ -464,8 +464,7 @@ def trainable_wrapper_filter(iterable_object_in, dense_grads_and_vars_aggregated_out = [] sparse_grads_and_vars_unaggregated_out = [] if test_unaggregated_function is None: - test_unaggregated_function = lambda x: isinstance( - x, de.TrainableWrapper) or isinstance(x, DEResourceVariable) + test_unaggregated_function = is_de_resource_variable for item in iterable_object_in: if test_unaggregated_function(item): sparse_grads_and_vars_unaggregated_out.append(item) diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_optimizer.py b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_optimizer.py index ff843e2a4..e8161a184 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_optimizer.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_optimizer.py @@ -16,13 +16,20 @@ """patch on optimizers""" import functools + +import tensorflow as tf from packaging import version import six from tensorflow_recommenders_addons import dynamic_embedding as de from tensorflow import version as tf_version -from tensorflow.python.distribute import central_storage_strategy +from tensorflow.python.distribute import central_storage_strategy, ps_values + +from tensorflow_recommenders_addons.dynamic_embedding.python.ops.parameter_server import create_ps_shadow_variable +from tensorflow_recommenders_addons.dynamic_embedding.python.train.utils import worker_devices, \ + is_parameter_server_strategy + if version.parse(tf_version.VERSION) >= version.parse("2.14"): from tensorflow.python.distribute import distribute_lib as distribute_ctx else: @@ -138,7 +145,9 @@ def apply_grad_to_update_var(var, grad): """Apply gradient to variable.""" if isinstance(var, Tensor): raise NotImplementedError("Trying to update a Tensor ", var) - + tfprint = tf.print("_distributed_apply: ", var, grad, type(var)) + with tf.control_dependencies([tfprint]): + pass apply_kwargs = {} if not isinstance(var, de.TrainableWrapper): if isinstance(grad, IndexedSlices): @@ -180,6 +189,9 @@ def apply_grad_to_update_var(var, grad): "Cannot use a constraint function on a sparse variable.") if "apply_state" in self._sparse_apply_args: apply_kwargs["apply_state"] = apply_state + printop = tf.print("g_and_v_1 var:", var, output_stream=tf.compat.v1.logging.error) + with tf.control_dependencies([printop]): + pass with ops.control_dependencies(_before): _apply_op = self._resource_apply_sparse_duplicate_indices( grad.values, var, grad.indices, **apply_kwargs) @@ -692,9 +704,13 @@ def _dist_tw_grads_and_vars_filter(grads_and_vars_in): dense_grads_and_vars_aggregated_out = [] sparse_grads_and_vars_unaggregated_out = [] test_unaggregated_lambda = lambda x: isinstance( - x[1], de.DistributedVariableWrapper) - for g_and_v in grads_and_vars_in: + x[1], de.DistributedVariableWrapper) or isinstance(x[1], ps_values.PerWorkerVariable) + for g_and_v in grads_and_vars_in: # type(g_and_v[1]) + printop = tf.print("g_and_v_1 x[1]:", g_and_v[1], output_stream=tf.compat.v1.logging.error) + with tf.control_dependencies([printop]): + pass if test_unaggregated_lambda(g_and_v): + sparse_grads_and_vars_unaggregated_out.append(g_and_v) else: dense_grads_and_vars_aggregated_out.append(g_and_v) @@ -707,7 +723,9 @@ def apply_gradients_strategy_v2_lagacy(grads_and_vars, grads_and_vars = optimizer_v2_legacy_utils.filter_empty_gradients( grads_and_vars) var_list = [v for (_, v) in grads_and_vars] - + # tfprint = tf.print("g_and_v_1 in grads_and_vars:", output_stream=tf.compat.v1.logging.error) + # with tf.control_dependencies([tfprint]): + # pass with ops.name_scope_v2(self._name): # Create iteration if necessary. with ops.init_scope(): @@ -851,6 +869,9 @@ def compute_gradients_horovod_wrapper_impl(*args, **kwargs): self._compute_gradients = compute_gradients_horovod_wrapper( compute_gradients_horovod_v2) else: + tfprint = tf.print("g_and_v_1 apply_gradients_strategy_v2_lagacy", output_stream=tf.compat.v1.logging.error) + with tf.control_dependencies([tfprint]): + pass self.apply_gradients = apply_gradients_strategy_v2_lagacy elif hasattr(self, '_distributed_apply_gradients_fn'): # Latest Keras optimizer @@ -912,16 +933,28 @@ def create_slots(variable, init, slot_name, op_name, bp_v2): # for forward compatibility. slot_tw_name = slot_name - def slot_trainable_create_(var_impl, scope_store_params, full_name_in, - slot_tw_name_in): + def slot_trainable_create_(var_impl, + scope_store_params, + full_name_in, + slot_tw_name_in, + distribute_strategy=None): if isinstance(var_impl, de.shadow_ops.ShadowVariable): - slot_trainable = de.shadow_ops.ShadowVariable( + if is_parameter_server_strategy(distribute_strategy): + slot_trainable = create_ps_shadow_variable( params=scope_store_params, ids=var_impl.ids, exists=var_impl.exists, name=full_name_in, trainable=False, - ) + strategy=distribute_strategy) + else: + slot_trainable = de.shadow_ops.ShadowVariable( + params=scope_store_params, + ids=var_impl.ids, + exists=var_impl.exists, + name=full_name_in, + trainable=False, + distribute_strategy=distribute_strategy) else: _, slot_trainable = de.embedding_lookup( params=scope_store_params, @@ -951,8 +984,21 @@ def slot_trainable_create_(var_impl, scope_store_params, full_name_in, VariableAggregation.NONE, TrainableWrapperDistributedPolicy(VariableAggregation.NONE)) else: - slot_trainable = slot_trainable_create_(variable, - scope_store._vars[full_name], - full_name, slot_tw_name) + if hasattr( + variable, 'distribute_strategy' + ) and variable.distribute_strategy and is_parameter_server_strategy( + variable.distribute_strategy): + printop = tf.print("g_and_v_1 variab:", variable, + output_stream=tf.compat.v1.logging.error) + with tf.control_dependencies([printop]): + pass + slot_trainable = slot_trainable_create_(variable, + scope_store._vars[full_name], + full_name, slot_tw_name, + variable.distribute_strategy) + else: + slot_trainable = slot_trainable_create_(variable, + scope_store._vars[full_name], + full_name, slot_tw_name) return slot_trainable diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_variable.py b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_variable.py index 146352357..db5a76fbd 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_variable.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/dynamic_embedding_variable.py @@ -26,6 +26,8 @@ from tensorflow_recommenders_addons import dynamic_embedding as de from tensorflow_recommenders_addons.dynamic_embedding.python.ops.embedding_weights import EmbeddingWeights +from tensorflow_recommenders_addons.dynamic_embedding.python.ops.parameter_server import create_ps_trainable_wrapper, \ + create_ps_shadow_variable from tensorflow_recommenders_addons.utils.check_platform import is_macos, is_arm64 if version.parse(tf.__version__) >= version.parse("2.14"): @@ -1454,7 +1456,10 @@ def _create_or_get_trainable(trainable_name): trainable_name = ops.get_default_graph().unique_name( _ANONYMOUS_TRAINABLE_STORE_KEY) if not context.executing_eagerly() and not ops.inside_function(): - wrapper = de.TrainableWrapper(params=params, + distribute_strategy = distribute_ctx.get_strategy( + ) if distribute_ctx.has_strategy else None + wrapper = create_ps_trainable_wrapper(strategy=distribute_strategy, + params=params, ids=ids, max_norm=max_norm, initial_value=initial_value, @@ -1469,17 +1474,30 @@ def _create_or_get_trainable(trainable_name): with ops.init_scope(): shadow = params._trainable_store.get(trainable_name, None) if shadow is None: - shadow = de.shadow_ops.ShadowVariable( - params, - name=trainable_name, - max_norm=max_norm, - trainable=params.trainable, - model_mode=de.ModelMode.CURRENT_SETTING) + distribute_strategy = distribute_ctx.get_strategy( + ) if distribute_ctx.has_strategy else None + if is_parameter_server_strategy(distribute_strategy): + shadow = create_ps_shadow_variable( + params, + name=trainable_name, + max_norm=max_norm, + trainable=params.trainable, + distribute=distribute_strategy, + model_mode=de.ModelMode.CURRENT_SETTING) + else: + shadow = de.shadow_ops.ShadowVariable( + params, + name=trainable_name, + max_norm=max_norm, + trainable=params.trainable, + model_mode=de.ModelMode.CURRENT_SETTING) params._trainable_store[trainable_name] = shadow return shadow with ops.colocate_with(ids, ignore_existing=True): - if distribute_ctx.has_strategy(): + from tensorflow_recommenders_addons.dynamic_embedding.python.train.utils import is_parameter_server_strategy + if distribute_ctx.has_strategy() and not is_parameter_server_strategy( + distribute_ctx.get_strategy()): trainable_ = params._distribute_trainable_store.get(name, None) if trainable_ is None: strategy_devices = distribute_ctx.get_strategy( diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/embedding_weights.py b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/embedding_weights.py index 2d600f28e..e32ded1ad 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/embedding_weights.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/embedding_weights.py @@ -1,6 +1,14 @@ import abc import tensorflow as tf from tensorflow.core.framework import attr_value_pb2 +from tensorflow.python.distribute.coordinator.values import PerWorkerValues +from tensorflow.python.ops.resource_variable_ops import VariableSpec, ResourceVariableGradient, \ + eager_safe_variable_handle, _maybe_set_handle_data, get_eager_safe_handle_data +import contextlib +import weakref +from tensorflow.core.framework import variable_pb2 +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import tensor_shape from tensorflow_recommenders_addons.utils.resource_loader import get_tf_version_triple try: @@ -13,8 +21,8 @@ # tf version >= 2.13.0 from tensorflow.python.eager import record as tape_record from tensorflow.python.keras.optimizer_v2 import optimizer_v2 -from tensorflow.python.ops import clip_ops -from tensorflow.python.framework import ops +from tensorflow.python.ops import clip_ops, handle_data_util +from tensorflow.python.framework import ops, composite_tensor try: # tf version >= 2.14.0 from tensorflow.python.framework.tensor import Tensor @@ -26,6 +34,8 @@ from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import variables +from tensorflow.python.framework import tensor as tensor_module + try: # tf version >= 2.14.0 from tensorflow.python.ops.array_ops_stack import stack except: @@ -41,7 +51,8 @@ from tensorflow.python.util import compat, dispatch try: # tf version >= 2.14.0 - from tensorflow.python.distribute import distribute_lib as distribute_ctx + from tensorflow.python.distribute import distribute_lib as distribute_ctx, ps_values + assert hasattr(distribute_ctx, 'has_strategy') except: from tensorflow.python.distribute import distribution_strategy_context as distribute_ctx @@ -119,8 +130,565 @@ class ModelMode(object): # The default setting is training mode. CURRENT_SETTING = TRAIN +class TrainablePerWorker(ps_values.PerWorkerVariable): + def __init__(self, strategy, next_creator, **kwargs): + super(TrainablePerWorker, + self).__init__(strategy, next_creator, **kwargs) + self._trainable = kwargs.get("trainable", True) + def read_value(self): + tfprint = tf.print("TrainablePerWorker read_value:", self._coordinator_instance.device) + with tf.control_dependencies([tfprint]): + pass + with tf.GradientTape() as tape: + self._maybe_create_per_worker_vars() + value = super(ps_values.PerWorkerVariable).read_value() + tape.watch(value) + return value + def value(self): + tfprint = tf.print("TrainablePerWorker value:", self._coordinator_instance.device) + with tf.control_dependencies([tfprint]): + pass + with tf.GradientTape() as tape: + self._maybe_create_per_worker_vars() + value = super(ps_values.PerWorkerVariable).value() + tape.watch(value) + return value + def assign(self, value, use_locking=False, name=None, read_value=True): + self._maybe_create_per_worker_vars() + tfprint = tf.print("PerWorkerVariable:", value, use_locking, name, read_value) + with tf.control_dependencies([tfprint]): + pass + with ops.device(self._coordinator_instance.device): + return self._coordinator_instance.assign( + value, use_locking=use_locking, name=name, read_value=read_value) + + def assign_add(self, delta, use_locking=False, name=None, read_value=True): + self._maybe_create_per_worker_vars() + with ops.device(self._coordinator_instance.device): + return self._coordinator_instance.assign_add( + delta, use_locking=use_locking, name=name, read_value=read_value) + +class PerWorkerResourceVariable(ps_values.PerWorkerVariable, PerWorkerValues): + def __init__( + self, # pylint: disable=super-init-not-called + initial_value=None, + trainable=None, + collections=None, + validate_shape=True, # pylint: disable=unused-argument + caching_device=None, + name=None, + dtype=None, + variable_def=None, + import_scope=None, + constraint=None, + distribute_strategy=None, + synchronization=None, + aggregation=None, + shape=None, + handle=None, + experimental_enable_variable_lifting=None, + **kwargs + ): + """Creates a variable. + + Args: + initial_value: A `Tensor`, or Python object convertible to a `Tensor`, + which is the initial value for the Variable. Can also be a callable with + no argument that returns the initial value when called. (Note that + initializer functions from init_ops.py must first be bound to a shape + before being used here.) + trainable: If `True`, the default, also adds the variable to the graph + collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as + the default list of variables to use by the `Optimizer` classes. + Defaults to `True`, unless `synchronization` is set to `ON_READ`, in + which case it defaults to `False`. + collections: List of graph collections keys. The new variable is added to + these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`. + validate_shape: If `False`, allows the variable to be initialized with a + value of unknown shape. If `True`, the default, the shape of + `initial_value` must be known. + caching_device: Optional device string or function describing where the + Variable should be cached for reading. Defaults to the Variable's + device. If not `None`, caches on another device. Typical use is to + cache on the device where the Ops using the Variable reside, to + deduplicate copying through `Switch` and other conditional statements. + name: Optional name for the variable. Defaults to `'Variable'` and gets + uniquified automatically. + dtype: If set, initial_value will be converted to the given type. If None, + either the datatype will be kept (if initial_value is a Tensor) or + float32 will be used (if it is a Python object convertible to a Tensor). + variable_def: `VariableDef` protocol buffer. If not None, recreates the + `ResourceVariable` object with its contents. `variable_def` and other + arguments (except for import_scope) are mutually exclusive. + import_scope: Optional `string`. Name scope to add to the + ResourceVariable. Only used when `variable_def` is provided. + constraint: An optional projection function to be applied to the variable + after being updated by an `Optimizer` (e.g. used to implement norm + constraints or value constraints for layer weights). The function must + take as input the unprojected Tensor representing the value of the + variable and return the Tensor for the projected value (which must have + the same shape). Constraints are not safe to use when doing asynchronous + distributed training. + distribute_strategy: The tf.distribute.Strategy this variable is being + created inside of. + synchronization: Indicates when a distributed a variable will be + aggregated. Accepted values are constants defined in the class + `tf.VariableSynchronization`. By default the synchronization is set to + `AUTO` and the current `DistributionStrategy` chooses when to + synchronize. + aggregation: Indicates how a distributed variable will be aggregated. + Accepted values are constants defined in the class + `tf.VariableAggregation`. + shape: (optional) The shape of this variable. If None, the shape of + `initial_value` will be used. When setting this argument to + `tf.TensorShape(None)` (representing an unspecified shape), the variable + can be assigned with values of different shapes. + handle: (optional) The handle of a `tf.Variable`. If provided, only + `trainable`, `shape`, `dtype`, and `handle` will be used to construct + this `tf.Variable`. + experimental_enable_variable_lifting: Whether to lift the variable out if + it's in a `tf.function`. Default is `True`. When this argument + is `True`, variable creation will follow the behavior and + restrictions described + [here](https://www.tensorflow.org/guide/function#creating_tfvariables). + If this argument is `False`, that description doesn't apply, + and you can freely create and use the variable in the + `tf.function`, as if it's a "mutable `tf.Tensor`". You can't + return the variable though. + + Raises: + ValueError: If the initial value is not specified, or does not have a + shape and `validate_shape` is `True`. + + @compatibility(eager) + When Eager Execution is enabled, the default for the `collections` argument + is `None`, which signifies that this `Variable` will not be added to any + collections. + @end_compatibility + """ + if variable_def: + if initial_value is not None: + raise ValueError(f"The variable_def and initial_value args to " + f"`tf.Variable` are mutually exclusive, but got both: " + f"variable_def={variable_def},\n" + f"initial_value={initial_value}") + if context.executing_eagerly(): + raise ValueError(f"Creating a `tf.Variable` with a `variable_def` arg " + f"is not supported when eager execution is enabled. " + f"Got: variable_def={variable_def}") + self._init_from_proto( + variable_def, + import_scope=import_scope, + validate_shape=validate_shape) + elif handle is not None: + self._init_from_handle(trainable=trainable, + shape=shape, + dtype=dtype, + handle=handle) + else: + self._init_from_args( + initial_value=initial_value, + trainable=trainable, + collections=collections, + caching_device=caching_device, + name=name, + dtype=dtype, + constraint=constraint, + synchronization=synchronization, + aggregation=aggregation, + shape=shape, + distribute_strategy=distribute_strategy, + validate_shape=validate_shape, + experimental_enable_variable_lifting=experimental_enable_variable_lifting, + ) + self._trainable = kwargs.get("trainable", True) + + +# CompositeTensor method + @property + def _type_spec(self): + return VariableSpec.from_value(self) + + # CompositeTensor method + def _shape_invariant_to_type_spec(self, shape): + return VariableSpec(shape, self.dtype, self.trainable) + + # CompositeTensorGradient protocol + __composite_gradient__ = ResourceVariableGradient() + + def _init_from_args( + self, + initial_value=None, + trainable=None, + collections=None, + caching_device=None, + name=None, + dtype=None, + constraint=None, + synchronization=None, + aggregation=None, + distribute_strategy=None, + shape=None, + validate_shape=True, + experimental_enable_variable_lifting=None, + ): + """Creates a variable. + + Args: + initial_value: A `Tensor`, or Python object convertible to a `Tensor`, + which is the initial value for the Variable. The initial value must have + a shape specified unless `validate_shape` is set to False. Can also be a + callable with no argument that returns the initial value when called. + (Note that initializer functions from init_ops.py must first be bound to + a shape before being used here.) + trainable: If `True`, the default, also adds the variable to the graph + collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as + the default list of variables to use by the `Optimizer` classes. + Defaults to `True`, unless `synchronization` is set to `ON_READ`, in + which case it defaults to `False`. + collections: List of graph collections keys. The new variable is added to + these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`. + caching_device: Optional device string or function describing where the + Variable should be cached for reading. Defaults to the Variable's + device. If not `None`, caches on another device. Typical use is to + cache on the device where the Ops using the Variable reside, to + deduplicate copying through `Switch` and other conditional statements. + name: Optional name for the variable. Defaults to `'Variable'` and gets + uniquified automatically. + dtype: If set, initial_value will be converted to the given type. If None, + either the datatype will be kept (if initial_value is a Tensor) or + float32 will be used (if it is a Python object convertible to a Tensor). + constraint: An optional projection function to be applied to the variable + after being updated by an `Optimizer` (e.g. used to implement norm + constraints or value constraints for layer weights). The function must + take as input the unprojected Tensor representing the value of the + variable and return the Tensor for the projected value (which must have + the same shape). Constraints are not safe to use when doing asynchronous + distributed training. + synchronization: Indicates when a distributed a variable will be + aggregated. Accepted values are constants defined in the class + `tf.VariableSynchronization`. By default the synchronization is set to + `AUTO` and the current `DistributionStrategy` chooses when to + synchronize. + aggregation: Indicates how a distributed variable will be aggregated. + Accepted values are constants defined in the class + `tf.VariableAggregation`. + distribute_strategy: DistributionStrategy under which this variable was + created. + shape: (optional) The shape of this variable. If None, the shape of + `initial_value` will be used. When setting this argument to + `tf.TensorShape(None)` (representing an unspecified shape), the variable + can be assigned with values of different shapes. + validate_shape: If `False`, allows the variable to be initialized with a + value of unknown shape. If `True`, the default, the shape of + `initial_value` must be known. + experimental_enable_variable_lifting: Whether to lift the variable out if + it's in a `tf.function`. Default is `True`. When this argument + is `True`, variable creation will follow the behavior and + restrictions described + [here](https://www.tensorflow.org/guide/function#creating_tfvariables). + If this argument is `False`, that description doesn't apply, + and you can freely create and use the variable in the + `tf.function`, as if it's a "mutable `tf.Tensor`". You can't + return the variable though. + + Raises: + ValueError: If the initial value is not specified, or does not have a + shape and `validate_shape` is `True`. + + @compatibility(eager) + When Eager Execution is enabled, variables are never added to collections. + It is not implicitly added to the `GLOBAL_VARIABLES` or + `TRAINABLE_VARIABLES` collections, and the `collections` argument is + ignored. + @end_compatibility + """ + synchronization, aggregation, trainable = ( + variables.validate_synchronization_aggregation_trainable( + synchronization, aggregation, trainable, name)) + if experimental_enable_variable_lifting is None: + experimental_enable_variable_lifting = True + if initial_value is None: + raise ValueError("The `initial_value` arg to `tf.Variable` must " + "be specified except when you are not providing a " + "`variable_def`. You provided neither.") + init_from_fn = callable(initial_value) + + if isinstance(initial_value, tensor_module.Tensor) and hasattr( + initial_value, "graph") and initial_value.graph.building_function: + raise ValueError(f"Argument `initial_value` ({initial_value}) could not " + "be lifted out of a `tf.function`. " + f"(Tried to create variable with name='{name}'). " + "To avoid this error, when constructing `tf.Variable`s " + "inside of `tf.function` you can create the " + "`initial_value` tensor in a " + "`tf.init_scope` or pass a callable `initial_value` " + "(e.g., `tf.Variable(lambda : " + "tf.truncated_normal([10, 40]))`). " + "Please file a feature request if this " + "restriction inconveniences you.") + + if collections is None: + collections = [ops.GraphKeys.GLOBAL_VARIABLES] + if not isinstance(collections, (list, tuple, set)): + raise ValueError( + f"collections argument to Variable constructor must be a list, " + f"tuple, or set. Got {collections} of type {type(collections)}") + if constraint is not None and not callable(constraint): + raise ValueError(f"Argument `constraint` must be None or a callable. " + f"a callable. Got a {type(constraint)}: {constraint}") + + if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections: + collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES] + with ops.init_scope(): + self._in_graph_mode = not context.executing_eagerly() + if experimental_enable_variable_lifting: + maybe_init_scope = ops.init_scope + else: + maybe_init_scope = contextlib.nullcontext + with maybe_init_scope(): + with ops.name_scope( + name, + "Variable", [] if init_from_fn else [initial_value], + skip_on_eager=False) as name: + # pylint: disable=protected-access + handle_name = ops.name_from_scope_name(name) + if self._in_graph_mode: + shared_name = handle_name + unique_id = shared_name + else: + # When in eager mode, use a uid for the shared_name, to prevent + # accidental sharing. + unique_id = "%s_%d" % (handle_name, ops.uid()) + shared_name = None # Never shared + # Use attr_scope and device(None) to simulate the behavior of + # colocate_with when the variable we want to colocate with doesn't + # yet exist. + device_context_manager = ( + ops.device if self._in_graph_mode else ops.NullContextmanager) + attr = attr_value_pb2.AttrValue( + list=attr_value_pb2.AttrValue.ListValue( + s=[compat.as_bytes("loc:@%s" % handle_name)])) + with ops.get_default_graph()._attr_scope({"_class": attr}): + with ops.name_scope("Initializer"), device_context_manager(None): + if init_from_fn: + initial_value = initial_value() + if isinstance(initial_value, trackable.CheckpointInitialValue): + self._maybe_initialize_trackable() + self._update_uid = initial_value.checkpoint_position.restore_uid + initial_value = initial_value.wrapped_value + initial_value = ops.convert_to_tensor( + initial_value, name="initial_value", dtype=dtype) + if shape is not None: + if not initial_value.shape.is_compatible_with(shape): + raise ValueError( + f"In this `tf.Variable` creation, the initial value's shape " + f"({initial_value.shape}) is not compatible with " + f"the explicitly supplied `shape` argument ({shape}).") + else: + shape = initial_value.shape + handle = eager_safe_variable_handle( + initial_value=initial_value, + shape=shape, + shared_name=shared_name, + name=name, + graph_mode=self._in_graph_mode) + handle._parent_trackable = weakref.ref(self) + handle._name = handle_name + ":0" + handle._unique_id = unique_id + # pylint: disable=protected-access + if (self._in_graph_mode and initial_value is not None and + initial_value.op._get_control_flow_context() is not None): + raise ValueError( + f"The `initial_value` passed to `tf.Variable` {name} is from " + f"inside a control-flow construct, such as a loop or " + f"conditional. When creating a " + f"`tf.Variable` inside a loop or conditional, use a lambda as " + f"the `initial_value`. Got: initial_value=({initial_value})") + # pylint: enable=protected-access + dtype = initial_value.dtype.base_dtype + + if self._in_graph_mode: + with ops.name_scope("IsInitialized"): + is_initialized_op = ( + gen_resource_variable_ops.var_is_initialized_op(handle)) + if initial_value is not None: + # pylint: disable=g-backslash-continuation + with ops.name_scope("Assign") as n, \ + ops.colocate_with(None, ignore_existing=True), \ + ops.device(handle.device): + # pylint: disable=protected-access + initializer_op = ( + gen_resource_variable_ops.assign_variable_op( + handle, + variables._try_guard_against_uninitialized_dependencies( + name, initial_value), + name=n)) + # pylint: enable=protected-access + # pylint: enable=g-backslash-continuation + with ops.name_scope("Read"): + # Manually assign reads to the handle's device to avoid log + # messages. + with ops.device(handle.device): + value = gen_resource_variable_ops.read_variable_op(handle, dtype) + _maybe_set_handle_data(dtype, handle, value) + graph_element = value + if caching_device is not None: + # Variables may be created in a tf.device() or ops.colocate_with() + # context. At the same time, users would expect caching device to + # be independent of this context, and/or would not expect the + # current device context to be merged with the caching device + # spec. Therefore we reset the colocation stack before creating + # the cached value. Note that resetting the colocation stack will + # also reset the device stack. + with ops.colocate_with(None, ignore_existing=True): + with ops.device(caching_device): + cached_value = array_ops.identity(value) + else: + cached_value = None + else: + gen_resource_variable_ops.assign_variable_op(handle, initial_value) + is_initialized_op = None + initializer_op = None + graph_element = None + if caching_device: + with ops.device(caching_device): + cached_value = gen_resource_variable_ops.read_variable_op( + handle, dtype) + _maybe_set_handle_data(dtype, handle, cached_value) + else: + cached_value = None + + if cached_value is not None: + # Store the variable object so that the original variable can be + # accessed to generate functions that are compatible with SavedModel. + cached_value._cached_variable = weakref.ref(self) # pylint: disable=protected-access + + if self._in_graph_mode: + # Eager variables are only added to collections if they are part of an + # eager variable store (otherwise in an interactive session they would + # hog memory and cause OOM). This is done in ops/variable_scope.py. + ops.add_to_collections(collections, self) + elif ops.GraphKeys.GLOBAL_STEP in collections: + ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, self) + initial_value = initial_value if self._in_graph_mode else None + super(PerWorkerResourceVariable, self).__init__( + trainable=trainable, + shape=shape, + dtype=dtype, + handle=handle, + synchronization=synchronization, + constraint=constraint, + aggregation=aggregation, + distribute_strategy=distribute_strategy, + name=name, + unique_id=unique_id, + handle_name=handle_name, + graph_element=graph_element, + initial_value=initial_value, + initializer_op=initializer_op, + is_initialized_op=is_initialized_op, + cached_value=cached_value, + caching_device=caching_device, + validate_shape=validate_shape, + ) + + def _init_from_proto(self, + variable_def, + import_scope=None, + validate_shape=True): + """Initializes from `VariableDef` proto.""" + # Note that init_from_proto is currently not supported in Eager mode. + assert not context.executing_eagerly() + self._in_graph_mode = True + assert isinstance(variable_def, variable_pb2.VariableDef) + if not variable_def.is_resource: + raise ValueError(f"The `variable_def` you passed to `tf.Variable` is " + f"Trying to restore a TF 1.x Reference Variable " + f"as a TF 2.x ResourceVariable. This is unsupported. " + f"Got variable_def={variable_def}") + + # Create from variable_def. + g = ops.get_default_graph() + self._handle = g.as_graph_element( + ops.prepend_name_scope( + variable_def.variable_name, import_scope=import_scope), + allow_operation=False) + self._shape = tensor_shape.TensorShape(self._handle.op.get_attr("shape")) + self._handle_name = self._handle.name + self._unique_id = self._handle_name + self._initializer_op = g.as_graph_element( + ops.prepend_name_scope( + variable_def.initializer_name, import_scope=import_scope)) + # Check whether initial_value_name exists for backwards compatibility. + if (hasattr(variable_def, "initial_value_name") and + variable_def.initial_value_name): + self._initial_value = g.as_graph_element( + ops.prepend_name_scope( + variable_def.initial_value_name, import_scope=import_scope)) + else: + self._initial_value = None + synchronization, aggregation, trainable = ( + variables.validate_synchronization_aggregation_trainable( + variable_def.synchronization, variable_def.aggregation, + variable_def.trainable, variable_def.variable_name)) + self._synchronization = synchronization + self._aggregation = aggregation + self._trainable = trainable + if variable_def.snapshot_name: + snapshot = g.as_graph_element( + ops.prepend_name_scope( + variable_def.snapshot_name, import_scope=import_scope)) + if snapshot.op.type != "ReadVariableOp": + self._cached_value = snapshot + else: + self._cached_value = None + while snapshot.op.type != "ReadVariableOp": + snapshot = snapshot.op.inputs[0] + self._graph_element = snapshot + else: + self._cached_value = None + # Legacy case for protos without the snapshot name; assume it's the + # following. + self._graph_element = g.get_tensor_by_name(self._handle.op.name + + "/Read/ReadVariableOp:0") + if variable_def.HasField("save_slice_info_def"): + self._save_slice_info = variables.Variable.SaveSliceInfo( + save_slice_info_def=variable_def.save_slice_info_def, + import_scope=import_scope) + else: + self._save_slice_info = None + self._caching_device = None + self._dtype = dtypes.as_dtype(self._handle.op.get_attr("dtype")) + self._constraint = None + self._validate_shape = validate_shape + + def _init_from_handle(self, + trainable=None, + shape=None, + dtype=None, + handle=None): + handle_data = get_eager_safe_handle_data(handle) + if not handle_data.is_set: + # The handle may not have the handle shape and dtype if it was created + # using tf.placeholder. + handle_data = handle_data_util.create_handle_data(shape, dtype) + handle_data_util.set_handle_data(handle, handle_data) + # pylint: disable=protected-access + if hasattr(handle, "_name") and isinstance(handle._name, str): + handle_name = handle._name.rstrip(":0") + else: + handle_name = None + # pylint: enable=protected-access + unique_id = getattr(handle, "_unique_id", None) + super().__init__( + trainable=trainable, shape=shape, dtype=dtype, handle=handle, + unique_id=unique_id, handle_name=handle_name) + + -class TrainableWrapper(resource_variable_ops.ResourceVariable): + +class TrainableWrapper(PerWorkerResourceVariable): """ This class is a trainable wrapper of Dynamic Embedding, and the key role is recording the map relation between params and ids. @@ -130,10 +698,11 @@ class TrainableWrapper(resource_variable_ops.ResourceVariable): def __getattribute__(self, name): if name in ["sparse_read", "gather_nd"]: raise AttributeError("no such method: {}".format(name)) - return super(resource_variable_ops.ResourceVariable, + + return super(PerWorkerResourceVariable, self).__getattribute__(name) - def __init__(self, params, ids, max_norm, *args, **kwargs): + def __init__(self, *args, **kwargs): """Creates an empty `TrainableWrapper` object.© Creates a group of tables placed on devices, @@ -149,16 +718,24 @@ def __init__(self, params, ids, max_norm, *args, **kwargs): Returns: A `TrainableWrapper` object which is a subclass of ResourceVariable. """ - self.params = params - self.ids = ids + self.params = kwargs.pop("params") + self.ids = kwargs.get("ids") self.exists = None - self.max_norm = max_norm + self.max_norm = kwargs.get("max_norm") self.prefetch_values_op = None self.model_mode = kwargs.get("model_mode") kwargs.pop("model_mode") self._tracked_slots = [] self._optimizer_vars = data_structures.NoDependency([]) - super(TrainableWrapper, self).__init__(*args, **kwargs) + printop = tf.print("st_b:", kwargs, args, + output_stream=tf.compat.v1.logging.error) + with tf.control_dependencies([printop]): + pass + # strategy = kwargs.pop("distribute_strategy") + # next_creator = kwargs.pop("next_creator", None) + super(TrainableWrapper, self).__init__(**kwargs) + self._trainable = kwargs.get("trainable", True) + # self._handle = self.handle def prefetch_values(self, update=False): if update or (self.prefetch_values_op is None): @@ -172,11 +749,11 @@ def prefetch_values(self, update=False): def __repr__(self): if context.executing_eagerly() and not self._in_graph_mode: return "" % ( - self.name, self.get_shape(), self.dtype.name, - ops.numpy_text(self.read_value(), is_repr=True)) + self.name, self.get_shape(), self.dtype.name, + ops.numpy_text(self.read_value(), is_repr=True)) else: return "" % ( - self.name, self.get_shape(), self.dtype.name) + self.name, self.get_shape(), self.dtype.name) def _init_from_args(self, initial_value=None, @@ -253,11 +830,11 @@ def _init_from_args(self, @end_compatibility """ ( - synchronization, - aggregation, - trainable, + synchronization, + aggregation, + trainable, ) = variables.validate_synchronization_aggregation_trainable( - synchronization, aggregation, trainable, name) + synchronization, aggregation, trainable, name) if initial_value is None: raise ValueError("initial_value must be specified.") init_from_fn = callable(initial_value) @@ -275,8 +852,8 @@ def _init_from_args(self, collections = [ops.GraphKeys.GLOBAL_VARIABLES] if not isinstance(collections, (list, tuple, set)): raise ValueError( - "collections argument to Variable constructor must be a list, tuple, " - "or set. Got %s of type %s" % (collections, type(collections))) + "collections argument to Variable constructor must be a list, tuple, " + "or set. Got %s of type %s" % (collections, type(collections))) if constraint is not None and not callable(constraint): raise ValueError("The `constraint` argument must be a callable.") @@ -314,48 +891,48 @@ def _init_from_args(self, device_context_manager = (ops.device if self._in_graph_mode else ops.NullContextmanager) attr = attr_value_pb2.AttrValue(list=attr_value_pb2.AttrValue.ListValue( - s=[compat.as_bytes("loc:@%s" % handle_name)])) + s=[compat.as_bytes("loc:@%s" % handle_name)])) with ops.get_default_graph()._attr_scope({"_class": attr}): with ops.name_scope("Initializer"), device_context_manager(None): initial_value = ops.convert_to_tensor( - initial_value() if init_from_fn else initial_value, - name="initial_value", - dtype=dtype, + initial_value() if init_from_fn else initial_value, + name="initial_value", + dtype=dtype, ) if shape is None: shape = initial_value.shape handle = resource_variable_ops.eager_safe_variable_handle( - initial_value=initial_value, - shape=None, # shape, - shared_name=shared_name, - name=name, - graph_mode=self._in_graph_mode, + initial_value=initial_value, + shape=None, # shape, + shared_name=shared_name, + name=name, + graph_mode=self._in_graph_mode, ) # pylint: disable=protected-access if (self._in_graph_mode and initial_value is not None and initial_value.op._get_control_flow_context() is not None): raise ValueError( - "Initializer for variable %s is from inside a control-flow " - "construct, such as a loop or conditional. When creating a " - "variable inside a loop or conditional, use a lambda as the " - "initializer." % name) + "Initializer for variable %s is from inside a control-flow " + "construct, such as a loop or conditional. When creating a " + "variable inside a loop or conditional, use a lambda as the " + "initializer." % name) # pylint: enable=protected-access dtype = initial_value.dtype.base_dtype if self._in_graph_mode: with ops.name_scope("IsInitialized"): is_initialized_op = ( - gen_resource_variable_ops.var_is_initialized_op(handle)) + gen_resource_variable_ops.var_is_initialized_op(handle)) if initial_value is not None: # pylint: disable=g-backslash-continuation with ops.name_scope("Assign") as n, ops.colocate_with( None, ignore_existing=True), ops.device(handle.device): # pylint: disable=protected-access initializer_op = gen_resource_variable_ops.assign_variable_op( - handle, - variables._try_guard_against_uninitialized_dependencies( - name, initial_value), - name=n, + handle, + variables._try_guard_against_uninitialized_dependencies( + name, initial_value), + name=n, ) # pylint: enable=protected-access # pylint: enable=g-backslash-continuation @@ -364,14 +941,14 @@ def _init_from_args(self, # messages. with ops.device(handle.device): with ops.control_dependencies([ - gen_resource_variable_ops.assign_variable_op( - handle, - self.prefetch_values(), - name="AssignBeforeInitRead", - ) + gen_resource_variable_ops.assign_variable_op( + handle, + self.prefetch_values(), + name="AssignBeforeInitRead", + ) ]): value = gen_resource_variable_ops.read_variable_op( - handle, dtype) + handle, dtype) graph_element = value if caching_device is not None: # Variables may be created in a tf.device() or ops.colocate_with() @@ -394,14 +971,14 @@ def _init_from_args(self, if caching_device: with ops.device(caching_device): with ops.control_dependencies([ - gen_resource_variable_ops.assign_variable_op( - handle, - self.prefetch_values(), - name="AssignBeforeInitRead", - ) + gen_resource_variable_ops.assign_variable_op( + handle, + self.prefetch_values(), + name="AssignBeforeInitRead", + ) ]): cached_value = (gen_resource_variable_ops.read_variable_op( - handle, dtype)) + handle, dtype)) else: cached_value = None if not context.executing_eagerly(): @@ -412,23 +989,23 @@ def _init_from_args(self, elif ops.GraphKeys.GLOBAL_STEP in collections: ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, self) initial_value = initial_value if self._in_graph_mode else None - super(resource_variable_ops.ResourceVariable, self).__init__( - trainable=trainable, - shape=shape, - dtype=dtype, - handle=handle, - synchronization=synchronization, - constraint=constraint, - aggregation=aggregation, - distribute_strategy=distribute_strategy, - name=name, - unique_id=unique_id, - handle_name=handle_name, - graph_element=graph_element, - initial_value=initial_value, - initializer_op=initializer_op, - is_initialized_op=is_initialized_op, - cached_value=cached_value, + super(ps_values.PerWorkerVariable, self).__init__( + trainable=trainable, + shape=shape, + dtype=dtype, + handle=handle, + synchronization=synchronization, + constraint=constraint, + aggregation=aggregation, + distribute_strategy=distribute_strategy, + name=name, + unique_id=unique_id, + handle_name=handle_name, + graph_element=graph_element, + initial_value=initial_value, + initializer_op=initializer_op, + is_initialized_op=is_initialized_op, + cached_value=cached_value, ) def update_op(self, v0=None): @@ -454,16 +1031,16 @@ def _read_variable_op(self, do_prefetch=True, no_copy=False): if self.model_mode == "train": if do_prefetch: with ops.control_dependencies([ - gen_resource_variable_ops.assign_variable_op( - self._handle, - self.prefetch_values(), - name="AssignBeforeReadVariable") + gen_resource_variable_ops.assign_variable_op( + self._handle, + self.prefetch_values(), + name="AssignBeforeReadVariable") ]): _result = gen_resource_variable_ops.read_variable_op( - self._handle, self._dtype) + self._handle, self._dtype) else: _result = gen_resource_variable_ops.read_variable_op( - self._handle, self._dtype) + self._handle, self._dtype) else: _result = self.prefetch_values() @@ -509,10 +1086,10 @@ def _rank(x): ids_rank, ids_static = _rank(ids) params_rank, params_static = _rank(params) return clip_ops.clip_by_norm( - params, - max_norm, - axes=(list(range(ids_rank, params_rank)) if ids_static and params_static - else math_ops.range(ids_rank, params_rank)), + params, + max_norm, + axes=(list(range(ids_rank, params_rank)) if ids_static and params_static + else math_ops.range(ids_rank, params_rank)), ) def transform(self, result): @@ -523,8 +1100,8 @@ def transform(self, result): def _track_optimizer_slots(self, slots): if not all(isinstance(s, TrainableWrapper) for s in slots): raise TypeError( - 'Can only track TrainableWrapper slots, but get {}'.format( - [type(s) for s in slots])) + 'Can only track TrainableWrapper slots, but get {}'.format( + [type(s) for s in slots])) identifiers = [optimizer_v2._var_key(s) for s in self._tracked_slots] for s in slots: if optimizer_v2._var_key(s) not in identifiers: diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/parameter_server.py b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/parameter_server.py new file mode 100644 index 000000000..000d90183 --- /dev/null +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/parameter_server.py @@ -0,0 +1,62 @@ +from tensorflow.python.distribute import ps_values, distribute_lib +from tensorflow.python.distribute.distribute_lib import _get_per_thread_mode +from tensorflow.python.distribute.parameter_server_strategy_v2 import ParameterServerStrategyV2, \ + ParameterServerStrategyV2Extended +from tensorflow.python.ops import variables +import tensorflow as tf + + + +class DEPerWorkerVariable(ps_values.PerWorkerVariable): + def __init__(self, *args, **kwargs): + super(DEPerWorkerVariable, self).__init__(*args, **kwargs) + +def create_per_worker_de_variable(strategy, name, dtype, shape): + # printop = tf.print("st_2:", strategy, + # tf.distribute.get_replica_context() , + # output_stream=tf.compat.v1.logging.error) + # with tf.control_dependencies([printop]): + with strategy.scope(): + return variables.Variable(initial_value=(), + shape=shape, dtype=dtype, name=name, + per_worker_de_variable=True) + +def create_ps_trainable_wrapper(strategy, params, ids, max_norm, initial_value, **kwargs): + with strategy.scope(): + return variables.Variable(initial_value=initial_value, + params=params, max_norm=max_norm, ids=ids, + ps_trainable_wrapper=True, **kwargs) + +def create_ps_shadow_variable(strategy, params, trainable, max_norm, **kwargs): + with strategy.scope(): + return variables.Variable(distribute_strategy=strategy, + params=params, max_norm=max_norm, trainable=trainable, + ps_shadow_variable=True, **kwargs) + +original_create_variable = ParameterServerStrategyV2Extended._create_variable + +def patched_create_variable(self, next_creator, **kwargs): + if kwargs.pop("per_worker_de_variable", False): + return _create_per_worker_de_variable(self, next_creator, **kwargs) + if kwargs.pop("ps_trainable_wrapper", False): + return _create_ps_trainable_wrapper(self, next_creator, **kwargs) + if kwargs.pop("ps_shadow_variable", False): + return _create_ps_shadow_variable(self, next_creator, **kwargs) + return original_create_variable(self, next_creator, **kwargs) + +def _create_ps_trainable_wrapper(strategy_extended, next_creator, **kwargs): + from tensorflow_recommenders_addons.dynamic_embedding import TrainableWrapper + return TrainableWrapper(strategy_extended._container_strategy(), next_creator, **kwargs) + +def _create_ps_shadow_variable(strategy_extended, next_creator, **kwargs): + from tensorflow_recommenders_addons.dynamic_embedding.python.ops.shadow_embedding_ops import ShadowVariable + return ShadowVariable(next_creator, **kwargs) + +def _create_per_worker_de_variable(strategy_extended, next_creator, **kwargs): + return DEPerWorkerVariable(strategy_extended._container_strategy(), next_creator, **kwargs) + +ParameterServerStrategyV2Extended._create_variable = patched_create_variable + +class DEParameterServerStrategy(ParameterServerStrategyV2): + def __init__(self, cluster_resolver, variable_partitioner=None): + super(DEParameterServerStrategy, self).__init__(cluster_resolver, variable_partitioner) \ No newline at end of file diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/shadow_embedding_ops.py b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/shadow_embedding_ops.py index 87a4bfe67..70a055e3f 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/shadow_embedding_ops.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/shadow_embedding_ops.py @@ -37,7 +37,7 @@ import tensorflow as tf -from tensorflow.python.distribute import distribute_lib +from tensorflow.python.distribute import distribute_lib, ps_values from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -49,6 +49,10 @@ from tensorflow_recommenders_addons import dynamic_embedding as de from tensorflow_recommenders_addons.dynamic_embedding.python.ops.embedding_weights import EmbeddingWeights, \ TrainableWrapper +from tensorflow_recommenders_addons.dynamic_embedding.python.ops.parameter_server import create_per_worker_de_variable, \ + DEPerWorkerVariable +from tensorflow_recommenders_addons.dynamic_embedding.python.train.utils import is_parameter_server_strategy +from tensorflow.python.ops import variables if version.parse(tf.__version__) >= version.parse("2.10"): from tensorflow.python.trackable import base as trackable @@ -70,12 +74,7 @@ class ShadowVariable(EmbeddingWeights, TrainableWrapper): and [tf.function](https://www.tensorflow.org/guide/function). """ - def __init__(self, - params, - name='ShadowVariable', - max_norm=None, - trainable=True, - distribute_strategy=None, + def __init__(self, next_creator, **kwargs): """ Create a ShadowVariable object. @@ -102,27 +101,35 @@ def __init__(self, if not context.executing_eagerly(): raise NotImplementedError('Currently ShadowVariable is only allowed' ' in eager mode.') - + name = kwargs.get("name", 'ShadowVariable') + params = kwargs.get("params") + distribute_strategy = kwargs.get("distribute_strategy", None) + max_norm = kwargs.get("max_norm", None) + trainable = kwargs.get("trainable", True) self._name = name - if not isinstance(params, de.Variable): - raise TypeError('params must be de.Variable, but get %s' % type(params)) self.params = params + if not isinstance(self.params, de.Variable): + raise TypeError('params must be de.Variable, but get %s' % type(params)) + collections = kwargs.get('collections', None) ids = kwargs.get('ids', None) if ids is not None: kwargs.pop('ids') ids_name = self._name + '-ids' if ids is None: - self.ids = DEResourceVariable((), - trainable=False, - collections=collections, - name=ids_name, - dtype=self.params.key_dtype, - distribute_strategy=distribute_strategy, - shape=tensor_shape.TensorShape(None)) + self.ids = get_de_resource_variable( + collections=collections, + name=ids_name, + dtype=self.params.key_dtype, + distribute_strategy=distribute_strategy, + shape=tensor_shape.TensorShape(None)) else: if not isinstance(ids, resource_variable_ops.ResourceVariable): - raise TypeError('If ids is set, it needs to be a ResourceVariable') + tfprint = tf.print("ids_8c:", ids, type(ids), ids.__class__.__name__, output_stream=tf.compat.v1.logging.error) + with tf.control_dependencies([tfprint]): + pass + # not isinstance(ids, variables.Variable)): + # raise TypeError('If ids is set, it needs to be a ResourceVariable or ps_values.PerWorkerVariable') self.ids = ids model_mode = kwargs.get('model_mode', None) @@ -136,31 +143,45 @@ def __init__(self, if (distribute_strategy is not None) and (not isinstance( distribute_strategy, distribute_lib.StrategyBase)): raise TypeError('distribute_strategy must inherit from StrategyBase.') - + printop = tf.print("st_a:", self.params, + output_stream=tf.compat.v1.logging.error) + with tf.control_dependencies([printop]): + pass super(ShadowVariable, - self).__init__(self.params, - self.ids, + self).__init__(params=self.params, + ids=self.ids, max_norm=max_norm, initial_value=initial_value, dtype=self.params.value_dtype, trainable=trainable, collections=collections, model_mode=model_mode, + strategy=distribute_strategy, + next_creator=next_creator, distribute_strategy=distribute_strategy, name=name) + printop = tf.print("st_c:", self.params, + output_stream=tf.compat.v1.logging.error) + with tf.control_dependencies([printop]): + pass exists = kwargs.get('exists', None) exists_name = self._name + '-exists' if exists is None: - self.exists = DEResourceVariable((), - trainable=False, - collections=collections, - name=exists_name, - dtype=dtypes.bool, - distribute_strategy=distribute_strategy, - shape=tensor_shape.TensorShape(None)) + self.exists = get_de_resource_variable( + collections=collections, + name=exists_name, + dtype=dtypes.bool, + distribute_strategy=distribute_strategy, + shape=tensor_shape.TensorShape(None)) self._track_trackable(self.exists, exists_name, overwrite=False) else: self.exists = exists + + printop = tf.print("st_d:", self.params, + output_stream=tf.compat.v1.logging.error) + with tf.control_dependencies([printop]): + pass + self.params._trainable_store[name] = self def verify_embedding_weights(self, sparse_ids, sparse_weights=None): @@ -261,7 +282,6 @@ def embedding_lookup( containing the values from the params tensor(s) for keys in ids. """ ids = ops.convert_to_tensor(ids) - if distribute_utils.is_distributed_variable(shadow): shadow_ = shadow._get_on_device_or_primary() else: @@ -273,11 +293,13 @@ def embedding_lookup( with ops.name_scope(name, "shadow_embedding_lookup"): with ops.colocate_with(None, ignore_existing=True): if de.ModelMode.CURRENT_SETTING == de.ModelMode.TRAIN: + # tfprint = tf.print("ids_8b:", shadow_.ids, ids, output_stream=tf.compat.v1.logging.error) + # with tf.control_dependencies([tfprint]): + # pass with ops.control_dependencies([shadow_._reset_ids(ids)]): result = shadow_.read_value(do_prefetch=True) else: result = shadow_.params.lookup(ids) - return result @@ -362,6 +384,29 @@ def __init__(self, *args, **kwargs): super(DEResourceVariable, self).__init__(*args, **kwargs) +def get_de_resource_variable( + collections, + name, + dtype, + distribute_strategy, + shape=tensor_shape.TensorShape(None)): + if is_parameter_server_strategy(distribute_strategy): + return create_per_worker_de_variable(distribute_strategy, name, dtype, shape) + else: + return DEResourceVariable((), + trainable=False, + collections=collections, + name=name, + dtype=dtype, + distribute_strategy=distribute_strategy, + shape=shape) + + +def is_de_resource_variable(var): + return isinstance(var, DEResourceVariable) or isinstance( + var, TrainableWrapper) or isinstance(var, DEPerWorkerVariable) + + class HvdVariable(EmbeddingWeights): def __init__( diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/tf_patch.py b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/tf_patch.py index fbab5e38e..fe93b2f51 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/tf_patch.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/tf_patch.py @@ -136,6 +136,10 @@ def update_op(self, optimizer, g): def _get_processor(v): """The processor of v.""" + import tensorflow as tf + tfprint = tf.print("_get_processor:", v, output_stream=tf.compat.v1.logging.error) + with tf.control_dependencies([tfprint]): + pass if isinstance(v, de.TrainableWrapper): return _DenseDynamicEmbeddingTrainableProcessor(v) if context.executing_eagerly(): diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/tf_save_restore_patch.py b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/tf_save_restore_patch.py index 34e30c115..2503da97a 100644 --- a/tensorflow_recommenders_addons/dynamic_embedding/python/ops/tf_save_restore_patch.py +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/ops/tf_save_restore_patch.py @@ -16,7 +16,6 @@ """patch on tensorflow""" import inspect -import functools import os.path from packaging import version import re @@ -32,7 +31,6 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops -from tensorflow.python.keras.saving.saved_model import save as tf_saved_model_save from tensorflow.python.keras.utils import tf_utils from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -42,6 +40,7 @@ from tensorflow.python.platform import tf_logging from tensorflow.python.training import saver from tensorflow.python.training import training_util + if version.parse(tf_version.VERSION) >= version.parse("2.10"): from tensorflow.python.checkpoint import checkpoint_management from tensorflow.python.checkpoint import checkpoint_options @@ -294,18 +293,17 @@ def _get_dynamic_embedding_restore_ops(self): return control_flow_ops.group(restore_ops.as_list()) def _build(self, checkpoint_path, build_save, build_restore): - # TrainableWrapper and DEResourceVariable should not be save or restore parameter. - from tensorflow_recommenders_addons.dynamic_embedding.python.ops.shadow_embedding_ops import DEResourceVariable - filter_lambda = lambda x: (isinstance(x, de.TrainableWrapper)) or ( - isinstance(x, DEResourceVariable)) + # TrainableWrapper DEResourceVariable should not be save or restore parameter. + from tensorflow_recommenders_addons.dynamic_embedding.python.ops.shadow_embedding_ops import is_de_resource_variable + if isinstance(self._var_list, dict): for key, value in self._var_list.items(): - if filter_lambda(value): + if is_de_resource_variable(value): self._var_list.pop(key) elif isinstance(self._var_list, list): _tmp_var_list = [] for value in self._var_list: - if not filter_lambda(value): + if not is_de_resource_variable(value): _tmp_var_list.append(value) self._var_list = _tmp_var_list diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/train/utils.py b/tensorflow_recommenders_addons/dynamic_embedding/python/train/utils.py new file mode 100644 index 000000000..e65af85bc --- /dev/null +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/train/utils.py @@ -0,0 +1,23 @@ +from typing import List + +from tensorflow import distribute as tf_dist +""" +worker_devices property of ParameterServerStrategyV2Extended returns only /device +while ParameterServerStrategy returns /job:worker/task:0/device:CPU:0 +override this property to return the same format as ParameterServerStrategy.worker_devices +""" + + +def worker_devices(devices: List[str], tasks: int, type: str) -> List[str]: + if tasks % len(devices) != 0: + raise ValueError( + "Number of tasks must be a multiple of the number of devices.") + + return [ + f"/job:{type}/task:{i}" + devices[i % len(devices)] for i in range(tasks) + ] + + +def is_parameter_server_strategy(strategy): + return strategy is not None and isinstance( + strategy, tf_dist.experimental.ParameterServerStrategy) diff --git a/tensorflow_recommenders_addons/dynamic_embedding/python/train/utils_test.py b/tensorflow_recommenders_addons/dynamic_embedding/python/train/utils_test.py new file mode 100644 index 000000000..decb5a140 --- /dev/null +++ b/tensorflow_recommenders_addons/dynamic_embedding/python/train/utils_test.py @@ -0,0 +1,36 @@ +import unittest + +from tensorflow.python.distribute import multi_worker_test_base + +from tensorflow_recommenders_addons.dynamic_embedding.python.train.utils import worker_devices, \ + is_parameter_server_strategy +from tensorflow import distribute as tf_dist +from tensorflow.python.distribute.cluster_resolver import cluster_resolver as cluster_resolver_lib +from tensorflow.python.training import server_lib +from tensorflow.python.distribute import multi_process_runner + + +class TestWorkerDevices(unittest.TestCase): + + def test_valid_cases(self): + self.assertEqual( + worker_devices(["/device:CPU:0", "/device:CPU:1"], 4, "worker"), [ + "/job:worker/task:0/device:CPU:0", + "/job:worker/task:1/device:CPU:1", + "/job:worker/task:2/device:CPU:0", + "/job:worker/task:3/device:CPU:1", + ]) + self.assertEqual(worker_devices(["/device:GPU:0"], 2, "worker"), [ + "/job:worker/task:0/device:GPU:0", + "/job:worker/task:1/device:GPU:0", + ]) + + def test_invalid_cases(self): + with self.assertRaises(ValueError): + worker_devices(["/device:CPU:0", "/device:CPU:1"], 3, "worker") + + def test_ps_strategy(self): + # create ParameterServerStrategy needs bzl test, so we just test the False Case + s2 = tf_dist.MirroredStrategy() + self.assertFalse(is_parameter_server_strategy(s2)) + self.assertFalse(is_parameter_server_strategy(None))