@@ -378,12 +378,28 @@ def _load_inner_model(self, sess=None):
378
378
def _get_vocab_embedding_as_np_array (self , vocab_type : VocabType ) -> np .ndarray :
379
379
assert vocab_type in VocabType
380
380
vocab_tf_variable_name = self .vocab_type_to_tf_variable_name_mapping [vocab_type ]
381
- with tf .compat .v1 .variable_scope ('model' , reuse = None ):
382
- embeddings = tf .compat .v1 .get_variable (vocab_tf_variable_name )
383
- self .saver = tf .compat .v1 .train .Saver ()
384
- self ._load_inner_model (self .sess )
385
- vocab_embedding_matrix = self .sess .run (embeddings )
386
- return vocab_embedding_matrix
381
+
382
+ if self .eval_reader is None :
383
+ self .eval_reader = PathContextReader (vocabs = self .vocabs ,
384
+ model_input_tensors_former = _TFEvaluateModelInputTensorsFormer (),
385
+ config = self .config , estimator_action = EstimatorAction .Evaluate )
386
+ input_iterator = tf .compat .v1 .data .make_initializable_iterator (self .eval_reader .get_dataset ())
387
+ _ , _ , _ , _ , _ , _ , _ , _ = self ._build_tf_test_graph (input_iterator .get_next ())
388
+
389
+ if vocab_type is VocabType .Token :
390
+ shape = (self .vocabs .token_vocab .size , self .config .TOKEN_EMBEDDINGS_SIZE )
391
+ elif vocab_type is VocabType .Target :
392
+ shape = (self .vocabs .target_vocab .size , self .config .TARGET_EMBEDDINGS_SIZE )
393
+ elif vocab_type is VocabType .Path :
394
+ shape = (self .vocabs .path_vocab .size , self .config .PATH_EMBEDDINGS_SIZE )
395
+
396
+ with tf .compat .v1 .variable_scope ('model' , reuse = True ):
397
+ embeddings = tf .compat .v1 .get_variable (vocab_tf_variable_name , shape = shape )
398
+ self .saver = tf .compat .v1 .train .Saver ()
399
+ self ._initialize_session_variables ()
400
+ self ._load_inner_model (self .sess )
401
+ vocab_embedding_matrix = self .sess .run (embeddings )
402
+ return vocab_embedding_matrix
387
403
388
404
def get_should_reuse_variables (self ):
389
405
if self .config .TRAIN_DATA_PATH_PREFIX :
0 commit comments