forked from tensorflow/neural-structured-learning
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdynamic_embedding_ops.py
182 lines (153 loc) · 7.04 KB
/
dynamic_embedding_ops.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""DynamicEmbedding related ops."""
import typing
from research.carls import context
from research.carls import dynamic_embedding_config_pb2 as de_config_pb2
from research.carls.kernels import gen_carls_ops
import tensorflow as tf
class DynamicEmbeddingLookup(tf.keras.layers.Layer):
"""A Keras Layer for Dynamic Embedding Lookup.
This is useful when the gradient descent update is required for embedding
lookup. The input of this layer is a `Tensor` of string keys and it outputs
the embedding output as a float `Tensor`.
"""
def __init__(self,
config: de_config_pb2.DynamicEmbeddingConfig,
var_name: typing.Text,
service_address: typing.Text = "",
timeout_ms: int = -1):
"""Constructor for DynamicEmbeddingLookup.
Args:
config: A DynamicEmbeddingConfig proto that configures the embedding.
var_name: A unique name for the given embedding.
service_address: The address of a knowledge bank service. If empty, the
value passed from --kbs_address (defined in
.../carls/dynamic_embedding_manager.cc) flag will be used instead.
timeout_ms: Timeout millseconds for the connection. If negative, never
timout.
Raises:
ValueError: if var_name is `None` or empty.
"""
super(DynamicEmbeddingLookup, self).__init__()
if not var_name:
raise ValueError("Must specify a non-empty var_name.")
self.embedding_dimension = config.embedding_dimension
context.add_to_collection(var_name, config)
self.resource = gen_carls_ops.dynamic_embedding_manager_resource(
config.SerializeToString(), var_name, service_address, timeout_ms)
def build(self, input_shape):
del input_shape # Not used.
# Creates a placeholder variable for the dynamic_embedding_lookup() such
# that the gradients can be passed into _dynamic_embedding_lookup_grad().
self.grad_placeholder = self.add_weight(
name="grad_placeholder",
shape=[1],
dtype=tf.float32,
trainable=True,
initializer=tf.keras.initializers.zeros)
def call(self, keys):
return gen_carls_ops.dynamic_embedding_lookup(keys, self.grad_placeholder,
self.resource,
self.embedding_dimension)
def dynamic_embedding_lookup(keys: tf.Tensor,
config: de_config_pb2.DynamicEmbeddingConfig,
var_name: typing.Text,
service_address: typing.Text = "",
skip_gradient_update: bool = False,
timeout_ms: int = -1) -> tf.Tensor:
"""Returns the embeddings of from given keys.
Args:
keys: A string `Tensor` of shape [batch_size] or [batch_size,
max_sequence_length] where an empty string would be mapped to an all zero
embedding.
config: A DynamicEmbeddingConfig proto that configures the embedding.
var_name: A unique name for the given embedding.
service_address: The address of a knowledge bank service. If empty, the
value passed from --kbs_address flag will be used instead.
skip_gradient_update: A boolean indicating if gradient update is needed.
timeout_ms: Timeout millseconds for the connection. If negative, never
timout.
Returns:
A `Tensor` of shape with one of below:
- [batch_size, config.embedding_dimension] if the input Tensor is 1D, or
- [batch_size, max_sequence_length, config.embedding_dimension] if the
input is 2D.
Raises:
ValueError: If name is not specified.
"""
if not var_name:
raise ValueError("Must specify a valid var_name.")
# If skip_gradient_update is true, reate a dummy variable so that the
# gradients can be passed in.
if skip_gradient_update:
grad_placeholder = tf.constant(0.0)
else:
grad_placeholder = tf.Variable(0.0)
context.add_to_collection(var_name, config)
resource = gen_carls_ops.dynamic_embedding_manager_resource(
config.SerializeToString(), var_name, service_address, timeout_ms)
return gen_carls_ops.dynamic_embedding_lookup(keys, grad_placeholder,
resource,
config.embedding_dimension)
def dynamic_embedding_update(keys: tf.Tensor,
values: tf.Tensor,
config: de_config_pb2.DynamicEmbeddingConfig,
var_name: typing.Text,
service_address: typing.Text = "",
timeout_ms: int = -1):
"""Updates the embeddings of given keys with given values.
Args:
keys: A string `Tensor` of shape [batch] or [batch_size,
max_sequence_length].
values: A `Tensor` of shape [batch_size, embedding_dimension] or
[batch_size, max_sequence_length, embedding_dimension].
config: A DynamicEmbeddingConfig proto that configures the embedding.
var_name: A unique name for the given embedding.
service_address: The address of a dynamic embedding service. If empty, the
value passed from --kbs_address flag will be used instead.
timeout_ms: Timeout millseconds for the connection. If negative, never
timout.
Returns:
A `Tensor` of shape with one of below:
- [batch_size, config.embedding_dimension] if the input Tensor is 1D, or
- [batch_size, max_sequence_length, config.embedding_dimension] if the
input is 2D.
Raises:
TypeError: If var_name is not specified.
"""
if not var_name:
raise TypeError("Must specify a valid var_name.")
context.add_to_collection(var_name, config)
resource = gen_carls_ops.dynamic_embedding_manager_resource(
config.SerializeToString(), var_name, service_address, timeout_ms)
return gen_carls_ops.dynamic_embedding_update(keys, values, resource,
config.embedding_dimension)
@tf.RegisterGradient("DynamicEmbeddingLookup")
def _dynamic_embedding_lookup_grad(op, grad):
"""The gradient for DynamicEmbeddingLookup.
Args:
op: The gen_de_op.dynamic_embedding_lookup() op.
grad: The tensor representing the gradient w.r.t. the output of the
gen_de_op.dynamic_embedding_lookup() op.
Returns:
The gradients w.r.t. the input of the gen_de_op.dynamic_embedding_lookup()
op.
"""
grad = tf.reshape(grad, [-1, grad.shape[-1]])
return gen_carls_ops.dynamic_embedding_lookup_grad(
op.inputs[0], # keys
grad,
op.inputs[2] # resource
)