@@ -52,7 +52,8 @@ def load(self):
5252
5353 logging .info ("Loading TensorFlow model '{}'" .format (self .model ))
5454 outputs = self .kwargs .get ("outputs" , None )
55- self ._graph_def = self ._graph_def_from_model (outputs )
55+ tags = self .kwargs .get ("tags" , None )
56+ self ._graph_def = self ._graph_def_from_model (outputs , tags )
5657
5758 if self ._graph_def is not None and len (self ._graph_def .node ) == 0 :
5859 msg = "tf.Graph should have at least 1 node, Got empty graph."
@@ -78,7 +79,7 @@ def load(self):
7879 return program
7980
8081 # @abstractmethod
81- def _graph_def_from_model (self , outputs = None ):
82+ def _graph_def_from_model (self , outputs = None , tags = None ):
8283 """Load TensorFlow model into GraphDef. Overwrite for different TF versions."""
8384 pass
8485
@@ -129,7 +130,7 @@ def __init__(self, model, debug=False, **kwargs):
129130 """
130131 TFLoader .__init__ (self , model , debug , ** kwargs )
131132
132- def _graph_def_from_model (self , outputs = None ):
133+ def _graph_def_from_model (self , outputs = None , tags = None ):
133134 """Overwrites TFLoader._graph_def_from_model()"""
134135 msg = "Expected model format: [tf.Graph | .pb | SavedModel | tf.keras.Model | .h5], got {}"
135136 if isinstance (self .model , tf .Graph ) and hasattr (self .model , "as_graph_def" ):
@@ -160,7 +161,7 @@ def _graph_def_from_model(self, outputs=None):
160161 graph_def = self ._from_tf_keras_model (self .model )
161162 return self .extract_sub_graph (graph_def , outputs )
162163 elif os .path .isdir (str (self .model )):
163- graph_def = self ._from_saved_model (self .model )
164+ graph_def = self ._from_saved_model (self .model , tags = tags )
164165 return self .extract_sub_graph (graph_def , outputs )
165166 else :
166167 raise NotImplementedError (msg .format (self .model ))
0 commit comments