Skip to content

Commit 484b835

Browse files
committed
passing tags for use when loading TF saved_model
Some saved_models have more than one tag, e.g., MobileBERT SQuAD 1.1 checkpoints. Need a way to specify tag
1 parent 1e117da commit 484b835

File tree

2 files changed

+7
-6
lines changed
  • coremltools/converters/mil/frontend

2 files changed

+7
-6
lines changed

coremltools/converters/mil/frontend/tensorflow/load.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ def load(self):
5252

5353
logging.info("Loading TensorFlow model '{}'".format(self.model))
5454
outputs = self.kwargs.get("outputs", None)
55-
self._graph_def = self._graph_def_from_model(outputs)
55+
tags = self.kwargs.get("tags", None)
56+
self._graph_def = self._graph_def_from_model(outputs, tags)
5657

5758
if self._graph_def is not None and len(self._graph_def.node) == 0:
5859
msg = "tf.Graph should have at least 1 node, Got empty graph."
@@ -78,7 +79,7 @@ def load(self):
7879
return program
7980

8081
# @abstractmethod
81-
def _graph_def_from_model(self, outputs=None):
82+
def _graph_def_from_model(self, outputs=None, tags=None):
8283
"""Load TensorFlow model into GraphDef. Overwrite for different TF versions."""
8384
pass
8485

@@ -129,7 +130,7 @@ def __init__(self, model, debug=False, **kwargs):
129130
"""
130131
TFLoader.__init__(self, model, debug, **kwargs)
131132

132-
def _graph_def_from_model(self, outputs=None):
133+
def _graph_def_from_model(self, outputs=None, tags=None):
133134
"""Overwrites TFLoader._graph_def_from_model()"""
134135
msg = "Expected model format: [tf.Graph | .pb | SavedModel | tf.keras.Model | .h5], got {}"
135136
if isinstance(self.model, tf.Graph) and hasattr(self.model, "as_graph_def"):
@@ -160,7 +161,7 @@ def _graph_def_from_model(self, outputs=None):
160161
graph_def = self._from_tf_keras_model(self.model)
161162
return self.extract_sub_graph(graph_def, outputs)
162163
elif os.path.isdir(str(self.model)):
163-
graph_def = self._from_saved_model(self.model)
164+
graph_def = self._from_saved_model(self.model, tags=tags)
164165
return self.extract_sub_graph(graph_def, outputs)
165166
else:
166167
raise NotImplementedError(msg.format(self.model))

coremltools/converters/mil/frontend/tensorflow2/load.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def __init__(self, model, debug=False, **kwargs):
6868
"""
6969
TFLoader.__init__(self, model, debug, **kwargs)
7070

71-
def _graph_def_from_model(self, outputs=None):
71+
def _graph_def_from_model(self, outputs=None, tags=None):
7272
"""Overwrites TFLoader._graph_def_from_model()"""
7373
msg = (
7474
"Expected model format: [SavedModel | [concrete_function] | "
@@ -92,7 +92,7 @@ def _graph_def_from_model(self, outputs=None):
9292
elif _os_path.isfile(self.model) and self.model.endswith(".h5"):
9393
cfs = self._concrete_fn_from_tf_keras_or_h5(self.model)
9494
elif _os_path.isdir(self.model):
95-
saved_model = _tf.saved_model.load(self.model)
95+
saved_model = _tf.saved_model.load(self.model, tags=tags)
9696
sv = saved_model.signatures.values()
9797
cfs = sv if isinstance(sv, list) else list(sv)
9898
else:

0 commit comments

Comments
 (0)