@@ -378,12 +378,28 @@ def _load_inner_model(self, sess=None):
378378 def _get_vocab_embedding_as_np_array (self , vocab_type : VocabType ) -> np .ndarray :
379379 assert vocab_type in VocabType
380380 vocab_tf_variable_name = self .vocab_type_to_tf_variable_name_mapping [vocab_type ]
381- with tf .compat .v1 .variable_scope ('model' , reuse = None ):
382- embeddings = tf .compat .v1 .get_variable (vocab_tf_variable_name )
383- self .saver = tf .compat .v1 .train .Saver ()
384- self ._load_inner_model (self .sess )
385- vocab_embedding_matrix = self .sess .run (embeddings )
386- return vocab_embedding_matrix
381+
382+ if self .eval_reader is None :
383+ self .eval_reader = PathContextReader (vocabs = self .vocabs ,
384+ model_input_tensors_former = _TFEvaluateModelInputTensorsFormer (),
385+ config = self .config , estimator_action = EstimatorAction .Evaluate )
386+ input_iterator = tf .compat .v1 .data .make_initializable_iterator (self .eval_reader .get_dataset ())
387+ _ , _ , _ , _ , _ , _ , _ , _ = self ._build_tf_test_graph (input_iterator .get_next ())
388+
389+ if vocab_type is VocabType .Token :
390+ shape = (self .vocabs .token_vocab .size , self .config .TOKEN_EMBEDDINGS_SIZE )
391+ elif vocab_type is VocabType .Target :
392+ shape = (self .vocabs .target_vocab .size , self .config .TARGET_EMBEDDINGS_SIZE )
393+ elif vocab_type is VocabType .Path :
394+ shape = (self .vocabs .path_vocab .size , self .config .PATH_EMBEDDINGS_SIZE )
395+
396+ with tf .compat .v1 .variable_scope ('model' , reuse = True ):
397+ embeddings = tf .compat .v1 .get_variable (vocab_tf_variable_name , shape = shape )
398+ self .saver = tf .compat .v1 .train .Saver ()
399+ self ._initialize_session_variables ()
400+ self ._load_inner_model (self .sess )
401+ vocab_embedding_matrix = self .sess .run (embeddings )
402+ return vocab_embedding_matrix
387403
388404 def get_should_reuse_variables (self ):
389405 if self .config .TRAIN_DATA_PATH_PREFIX :
0 commit comments