diff --git a/edward2/tensorflow/layers/__init__.py b/edward2/tensorflow/layers/__init__.py index 777f64be..329d57a3 100644 --- a/edward2/tensorflow/layers/__init__.py +++ b/edward2/tensorflow/layers/__init__.py @@ -31,6 +31,7 @@ from edward2.tensorflow.layers.convolutional import Conv2DVariationalDropout from edward2.tensorflow.layers.convolutional import DepthwiseCondConv2D from edward2.tensorflow.layers.convolutional import DepthwiseConv2DBatchEnsemble +from edward2.tensorflow.layers.dense import CondDense from edward2.tensorflow.layers.dense import DenseBatchEnsemble from edward2.tensorflow.layers.dense import DenseDVI from edward2.tensorflow.layers.dense import DenseFlipout @@ -70,6 +71,7 @@ from edward2.tensorflow.layers.recurrent import LSTMCellFlipout from edward2.tensorflow.layers.recurrent import LSTMCellRank1 from edward2.tensorflow.layers.recurrent import LSTMCellReparameterization +from edward2.tensorflow.layers.routing import RoutingLayer from edward2.tensorflow.layers.stochastic_output import MixtureLogistic __all__ = [ @@ -77,6 +79,7 @@ "Attention", "BayesianLinearModel", "CondConv2D", + "CondDense", "Conv1DBatchEnsemble", "Conv1DFlipout", "Conv1DRank1", @@ -122,6 +125,7 @@ "NeuralProcess", "RandomFeatureGaussianProcess", "Reverse", + "RoutingLayer", "SinkhornAutoregressiveFlow", "SparseGaussianProcess", "SpectralNormalization", diff --git a/edward2/tensorflow/layers/routing.py b/edward2/tensorflow/layers/routing.py new file mode 100644 index 00000000..46519d96 --- /dev/null +++ b/edward2/tensorflow/layers/routing.py @@ -0,0 +1,133 @@ +# coding=utf-8 +# Copyright 2021 The Edward2 Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Routing layer for mixture of experts.""" + +import tensorflow as tf +from edward2.tensorflow.layers import routing_utils + + +class RoutingLayer(tf.keras.layers.Layer): + + def __init__(self, num_experts, routing_pooling, routing_fn, k, + normalize_routing, noise_epsilon, **kwargs): + super().__init__(**kwargs) + self.num_experts = num_experts + self.routing_pooling = routing_pooling + self.routing_fn = routing_fn + self.k = k + self.normalize_routing = normalize_routing + self.noise_epsilon = noise_epsilon + self.use_noisy_routing = 'noisy' in routing_fn + self.use_softmax_top_k = routing_fn in [ + 'softmax_top_k', 'noisy_softmax_top_k' + ] + self.use_onehot_top_k = routing_fn in ['onehot_top_k', 'noisy_onehot_top_k'] + self.use_sigmoid_activation = routing_fn == 'sigmoid' + self.use_softmax_routing = routing_fn in ['softmax', 'noisy_softmax'] + + def build(self, input_shape): + input_shape = tf.TensorShape(input_shape) + self.input_size = input_shape[1] + self.kernel_shape = [self.input_size, self.num_experts] + + self.w_gate = self.add_weight( + name='w_gate', + shape=self.kernel_shape, + initializer=tf.keras.initializers.Zeros(), + regularizer=None, + constraint=None, + trainable=True, + dtype=self.dtype) + + if self.use_noisy_routing: + self.w_noise = self.add_weight( + name='w_gate', + shape=self.kernel_shape, + initializer=tf.keras.initializers.Zeros(), + regularizer=None, + constraint=None, + trainable=True, + dtype=self.dtype) + + if self.routing_pooling == 'global_average': + self.pooling_layer = tf.keras.layers.GlobalAveragePooling2D() + elif self.routing_pooling == 'global_max': + self.pooling_layer = tf.keras.layers.GlobalMaxPool2D() + elif self.routing_pooling == 'average_8': + self.pooling_layer = tf.keras.Sequential([ + tf.keras.layers.AveragePooling2D(pool_size=8), + tf.keras.layers.Flatten(), + ]) + elif self.routing_pooling == 'max_8': + self.pooling_layer = tf.keras.Sequential([ + tf.keras.layers.MaxPool2D(pool_size=8), + tf.keras.layers.Flatten(), + ]) + else: + self.pooling_layer = tf.keras.layers.Flatten() + + self.built = True + + def call(self, inputs, training=None): + pooled_inputs = self.pooling_layer(inputs) + routing_weights = tf.linalg.matmul(pooled_inputs, self.w_gate) + + if self.use_noisy_routing and training: + raw_noise_stddev = tf.linalg.matmul(pooled_inputs, self.w_noise) + noise_stddev = tf.nn.softplus(raw_noise_stddev) + self.noise_epsilon + routing_weights += tf.random.normal( + tf.shape(routing_weights)) * noise_stddev + + if self.use_sigmoid_activation: + routing_weights = tf.nn.sigmoid(routing_weights) + elif self.use_softmax_routing: + routing_weights = tf.nn.softmax(routing_weights) + elif self.use_softmax_top_k: + top_values, top_indices = tf.math.top_k(routing_weights, + min(self.k + 1, self.num_experts)) + # top k logits has shape [batch, k] + top_k_values = tf.slice(top_values, [0, 0], [-1, self.k]) + top_k_indices = tf.slice(top_indices, [0, 0], [-1, self.k]) + top_k_gates = tf.nn.softmax(top_k_values) + # This returns a [batch, n] Tensor with 0's in the positions of non-top-k + # expert values. + routing_weights = routing_utils.rowwise_unsorted_segment_sum( + top_k_gates, top_k_indices, self.num_experts) + elif self.use_onehot_top_k: + top_values, top_indices = tf.math.top_k(routing_weights, k=self.k) + one_hot_tensor = tf.one_hot(top_indices, depth=self.num_experts) + mask = tf.reduce_sum(one_hot_tensor, axis=1) + routing_weights *= mask + + if self.normalize_routing: + normalization = tf.math.reduce_sum( + routing_weights, axis=-1, keepdims=True) + routing_weights /= normalization + + return routing_weights + + def get_config(self): + config = { + 'num_experts': self.num_experts, + 'routing_pooling': self.routing_pooling, + 'routing_fn': self.routing_fn, + 'k': self.k, + 'normalize_routing': self.normalize_routing, + 'noise_epsilon': self.noise_epsilon, + } + new_config = super().get_config() + new_config.update(config) + return new_config diff --git a/edward2/tensorflow/layers/routing_utils.py b/edward2/tensorflow/layers/routing_utils.py new file mode 100644 index 00000000..18bdef46 --- /dev/null +++ b/edward2/tensorflow/layers/routing_utils.py @@ -0,0 +1,97 @@ +# coding=utf-8 +# Copyright 2021 The Edward2 Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Routing utils.""" +import tensorflow as tf + + +def rowwise_unsorted_segment_sum(values, indices, n): + """UnsortedSegmentSum on each row. + + Args: + values: a `Tensor` with shape `[batch_size, k]`. + indices: an integer `Tensor` with shape `[batch_size, k]`. + n: an integer. + + Returns: + A `Tensor` with the same type as `values` and shape `[batch_size, n]`. + """ + batch, k = tf.unstack(tf.shape(indices), num=2) + indices_flat = tf.reshape(indices, [-1]) + tf.cast( + tf.math.divide(tf.range(batch * k), k) * n, tf.int32) + ret_flat = tf.math.unsorted_segment_sum( + tf.reshape(values, [-1]), indices_flat, batch * n) + return tf.reshape(ret_flat, [batch, n]) + + +def normal_distribution_cdf(x, stddev): + """Evaluates the CDF of the normal distribution. + + Normal distribution with mean 0 and standard deviation stddev, + evaluated at x=x. + input and output `Tensor`s have matching shapes. + Args: + x: a `Tensor` + stddev: a `Tensor` with the same shape as `x`. + + Returns: + a `Tensor` with the same shape as `x`. + """ + return 0.5 * (1.0 + tf.erf(x / (tf.math.sqrt(2) * stddev + 1e-20))) + + +def prob_in_top_k(clean_values, noisy_values, noise_stddev, noisy_top_values, + k): + """Helper function to NoisyTopKGating. + + Computes the probability that value is in top k, given different random noise. + This gives us a way of backpropagating from a loss that balances the number + of times each expert is in the top k experts per example. + In the case of no noise, pass in None for noise_stddev, and the result will + not be differentiable. + Args: + clean_values: a `Tensor` of shape [batch, n]. + noisy_values: a `Tensor` of shape [batch, n]. Equal to clean values plus + normally distributed noise with standard deviation noise_stddev. + noise_stddev: a `Tensor` of shape [batch, n], or None + noisy_top_values: a `Tensor` of shape [batch, m]. "values" Output of + tf.top_k(noisy_top_values, m). m >= k+1 + k: an integer. + + Returns: + a `Tensor` of shape [batch, n]. + """ + batch = tf.shape(clean_values)[0] + m = tf.shape(noisy_top_values)[1] + top_values_flat = tf.reshape(noisy_top_values, [-1]) + # we want to compute the threshold that a particular value would have to + # exceed in order to make the top k. This computation differs depending + # on whether the value is already in the top k. + threshold_positions_if_in = tf.range(batch) * m + k + threshold_if_in = tf.expand_dims( + tf.gather(top_values_flat, threshold_positions_if_in), 1) + is_in = tf.greater(noisy_values, threshold_if_in) + if noise_stddev is None: + return tf.to_float(is_in) + threshold_positions_if_out = threshold_positions_if_in - 1 + threshold_if_out = tf.expand_dims( + tf.gather(top_values_flat, threshold_positions_if_out), 1) + # is each value currently in the top k. + prob_if_in = normal_distribution_cdf(clean_values - threshold_if_in, + noise_stddev) + prob_if_out = normal_distribution_cdf(clean_values - threshold_if_out, + noise_stddev) + prob = tf.where(is_in, prob_if_in, prob_if_out) + return prob