diff --git a/tensorflow_ranking/python/model.py b/tensorflow_ranking/python/model.py index ea8a7f3..ce50162 100644 --- a/tensorflow_ranking/python/model.py +++ b/tensorflow_ranking/python/model.py @@ -50,6 +50,20 @@ def _get_params(mode, params): raise ValueError('Invalid mode: {}.'.format(mode)) return num_shuffles +def _get_params_shuffle_peritem(mode, params): + params = params or {} + # 'shuffle_peritem' should be bool + _SHUFFLE_PERITEM = 'shuffle_peritem' + if mode == tf.estimator.ModeKeys.TRAIN: + shuffle_peritem = bool(params.get(_SHUFFLE_PERITEM, None)) + elif mode == tf.estimator.ModeKeys.EVAL: + shuffle_peritem = False + elif mode == tf.estimator.ModeKeys.PREDICT: + shuffle_peritem = False + else: + raise ValueError('Invalid mode: {}.'.format(mode)) + return shuffle_peritem + class _RankingModel(object): """Interface for a ranking model.""" @@ -335,6 +349,32 @@ def _update_scatter_gather_indices(self, is_valid, mode, params): def _compute_logits_impl(self, context_features, example_features, labels, mode, params, config): + if _get_params_shuffle_peritem(mode, params): + with tf.compat.v1.name_scope("shuffle_peritem"): + # Shuffle labels and example features along list_size + # example_features are shape (batch, list_size, feature_space) + + first_example = next(iter(example_features.values())) + cur_list_size = tf.shape(input=first_example)[1] + + indicies = tf.range(start=0, limit=cur_list_size, dtype=tf.int32) + shuffled_indicies = tf.random.shuffle(indicies) + + for name, value in six.iteritems(example_features): + # Transpose to expose LIST_SIZE dimension on the 0th axis + transposed = tf.transpose(value, perm=[1,0,2]) + + # Shuffle along the new LIST_SIZE axis + shuffled_feature = tf.gather(transposed, shuffled_indicies) + + # Revert back to (Batch, LIST_SIZE, feature_space) + reverted = tf.transpose(shuffled_feature, perm=[1,0,2]) + example_features[name] = reverted + + transposed_label = tf.transpose(labels, perm=[1,0]) + shuffled_label = tf.gather(transposed_label, shuffled_indicies) + labels = tf.transpose(shuffled_label, perm=[1,0]) + # Scatter/Gather per-example scores through groupwise comparison. Each # instance in a mini-batch will form a number of groups. Each group of # examples are scored by `_score_fn` and scores for individual examples are