forked from tensorflow/neural-structured-learning
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdynamic_embedding_manager.h
131 lines (108 loc) · 5.87 KB
/
dynamic_embedding_manager.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
/*Copyright 2020 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef NEURAL_STRUCTURED_LEARNING_RESEARCH_CARLS_DYNAMIC_EMBEDDING_MANAGER_H_
#define NEURAL_STRUCTURED_LEARNING_RESEARCH_CARLS_DYNAMIC_EMBEDDING_MANAGER_H_
#include <string>
#include "absl/status/status.h"
#include "research/carls/dynamic_embedding_config.pb.h" // proto to pb
#include "research/carls/knowledge_bank_grpc_service.h"
#include "tensorflow/core/framework/tensor.h"
namespace carls {
// Responsible for communicating with a KnowledgeBankService stub within
// Tensorflow C++ Operation code. Each instance of DynamicEmbeddingManager only
// works for one session.
class DynamicEmbeddingManager {
public:
// Connects to a KBS server and starts a session.
// Returns a nullptr if input parameters are invalid.
static std::unique_ptr<DynamicEmbeddingManager> Create(
const DynamicEmbeddingConfig& config, const std::string& name,
const std::string& kbs_address,
absl::Duration timeout = absl::InfiniteDuration());
DynamicEmbeddingManager(
std::unique_ptr</*grpc_gen::*/KnowledgeBankService::Stub> stub,
const DynamicEmbeddingConfig& config, const std::string& session_handle);
// Prepares KnowledgeBankService::LookupRequest from given input and
// calls DES server.
// If a given key is empty, the output tensor is filled with zero values.
absl::Status Lookup(const tensorflow::Tensor& keys, bool update,
tensorflow::Tensor* output);
// Updates the embedding values of given keys by calling
// KnowledgeBankService::UpdateRequest.
// If there are duplicated keys, it only updates the value of the last seen
// one.
absl::Status UpdateValues(const tensorflow::Tensor& keys,
const tensorflow::Tensor& values);
// Update the gradients of the embeddings for given keys.
absl::Status UpdateGradients(const tensorflow::Tensor& keys,
const tensorflow::Tensor& grads);
// Looks up the mean and variance for each input tensor.
// mode must be consistent with MemoryLookupRequest::LookupMode.
absl::Status LookupGaussianCluster(const tensorflow::Tensor& inputs, int mode,
tensorflow::Tensor* mean,
tensorflow::Tensor* variance,
tensorflow::Tensor* output_distance,
tensorflow::Tensor* output_cluster_id);
// Returns DynamicEmbeddingConfig.
const DynamicEmbeddingConfig& config() { return config_; }
// Samples negative keys from given positive keys.
//
// If update = true, new embeddings are dynamically allocated for new
// positive keys, which is often used in training.
//
// Note that for a logit layer with activation x in the last layer, one needs
// to append an extra 1 to the input activations to obtain wx + b, where [w,
// b] is the embedding of a particular output key.
//
// The `output_labels` indicates if the corresponding `output_keys` is a
// positive or negative sample, and the `output_expected_counts` represents
// the sampling probability. Please refer to
// carls.candidate_sampling.NegativeSamplingResult for details.
//
// `output_mask` indicates whether `positive_keys` of an entry in the input
// batch are all invalid (empty).
//
// `output_embedding` returns the embeddings of the sampled keys. It should
// be allocated as [batch_size, num_samples, embed_dim].
absl::Status NegativeSampling(const tensorflow::Tensor& positive_keys,
const tensorflow::Tensor& input_activations,
int num_samples, bool update,
tensorflow::Tensor* output_keys,
tensorflow::Tensor* output_labels,
tensorflow::Tensor* output_expected_counts,
tensorflow::Tensor* output_masks,
tensorflow::Tensor* output_embeddings);
// Return top k closest embeddings to each of the input activations.
// Note that for a logit layer with activation x, one need to append an extra
// 1 to the input activations to obtain wx + b, where [w, b] is the embedding
// of a particular output key.
absl::Status TopK(const tensorflow::Tensor& input_activations, int k,
tensorflow::Tensor* output_keys,
tensorflow::Tensor* output_logits);
// Calls the KnowledgeBankService::Export RPC.
absl::Status Export(const std::string& output_dir,
std::string* exported_path);
// Calls the KnowledgeBankService::Import RPC.
absl::Status Import(const std::string& saved_path);
private:
// Check validity of input for both UpdateValues() and UpdateGradients().
absl::Status CheckInputForUpdate(const tensorflow::Tensor& keys,
const tensorflow::Tensor& values);
// Internal implementation of the Lookup() method.
absl::Status LookupInternal(const tensorflow::Tensor& keys, bool update,
LookupResponse* response);
std::unique_ptr</*grpc_gen::*/KnowledgeBankService::Stub> stub_;
const DynamicEmbeddingConfig config_;
const std::string session_handle_;
};
} // namespace carls
#endif // NEURAL_STRUCTURED_LEARNING_RESEARCH_CARLS_DYNAMIC_EMBEDDING_MANAGER_H_