Skip to content

Commit c606ee6

Browse files
committed
Fixing --save_w2v and --save_t2v options
1 parent 278b91f commit c606ee6

File tree

2 files changed

+24
-10
lines changed

2 files changed

+24
-10
lines changed

Diff for: config.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def arguments_parser(cls) -> ArgumentParser:
1313
parser.add_argument("-d", "--data", dest="data_path",
1414
help="path to preprocessed dataset", required=False)
1515
parser.add_argument("-te", "--test", dest="test_path",
16-
help="path to test file", metavar="FILE", required=False)
16+
help="path to test file", metavar="FILE", required=False, default='')
1717
parser.add_argument("-s", "--save", dest="save_path",
1818
help="path to save the model file", metavar="FILE", required=False)
1919
parser.add_argument("-w2v", "--save_word2v", dest="save_w2v",
@@ -117,7 +117,7 @@ def __init__(self, set_defaults: bool = False, load_from_args: bool = False, ver
117117
self.MODEL_SAVE_PATH: Optional[str] = None
118118
self.MODEL_LOAD_PATH: Optional[str] = None
119119
self.TRAIN_DATA_PATH_PREFIX: Optional[str] = None
120-
self.TEST_DATA_PATH: Optional[str] = None
120+
self.TEST_DATA_PATH: Optional[str] = ''
121121
self.RELEASE: bool = False
122122
self.EXPORT_CODE_VECTORS: bool = False
123123
self.SAVE_W2V: Optional[str] = None # TODO: update README;
@@ -171,8 +171,6 @@ def test_steps(self) -> int:
171171
return ceil(self.NUM_TEST_EXAMPLES / self.TEST_BATCH_SIZE) if self.TEST_BATCH_SIZE else 0
172172

173173
def data_path(self, is_evaluating: bool = False):
174-
if self.RELEASE:
175-
return ''
176174
return self.TEST_DATA_PATH if is_evaluating else self.train_data_path
177175

178176
def batch_size(self, is_evaluating: bool = False):

Diff for: tensorflow_model.py

+22-6
Original file line numberDiff line numberDiff line change
@@ -378,12 +378,28 @@ def _load_inner_model(self, sess=None):
378378
def _get_vocab_embedding_as_np_array(self, vocab_type: VocabType) -> np.ndarray:
379379
assert vocab_type in VocabType
380380
vocab_tf_variable_name = self.vocab_type_to_tf_variable_name_mapping[vocab_type]
381-
with tf.compat.v1.variable_scope('model', reuse=None):
382-
embeddings = tf.compat.v1.get_variable(vocab_tf_variable_name)
383-
self.saver = tf.compat.v1.train.Saver()
384-
self._load_inner_model(self.sess)
385-
vocab_embedding_matrix = self.sess.run(embeddings)
386-
return vocab_embedding_matrix
381+
382+
if self.eval_reader is None:
383+
self.eval_reader = PathContextReader(vocabs=self.vocabs,
384+
model_input_tensors_former=_TFEvaluateModelInputTensorsFormer(),
385+
config=self.config, estimator_action=EstimatorAction.Evaluate)
386+
input_iterator = tf.compat.v1.data.make_initializable_iterator(self.eval_reader.get_dataset())
387+
_, _, _, _, _, _, _, _ = self._build_tf_test_graph(input_iterator.get_next())
388+
389+
if vocab_type is VocabType.Token:
390+
shape = (self.vocabs.token_vocab.size, self.config.TOKEN_EMBEDDINGS_SIZE)
391+
elif vocab_type is VocabType.Target:
392+
shape = (self.vocabs.target_vocab.size, self.config.TARGET_EMBEDDINGS_SIZE)
393+
elif vocab_type is VocabType.Path:
394+
shape = (self.vocabs.path_vocab.size, self.config.PATH_EMBEDDINGS_SIZE)
395+
396+
with tf.compat.v1.variable_scope('model', reuse=True):
397+
embeddings = tf.compat.v1.get_variable(vocab_tf_variable_name, shape=shape)
398+
self.saver = tf.compat.v1.train.Saver()
399+
self._initialize_session_variables()
400+
self._load_inner_model(self.sess)
401+
vocab_embedding_matrix = self.sess.run(embeddings)
402+
return vocab_embedding_matrix
387403

388404
def get_should_reuse_variables(self):
389405
if self.config.TRAIN_DATA_PATH_PREFIX:

0 commit comments

Comments
 (0)