From 9b5454e5f55f615b52173ce8d60e263844fcd9ac Mon Sep 17 00:00:00 2001 From: shraman Date: Mon, 16 Dec 2019 18:58:44 -0800 Subject: [PATCH] Blacklist custom ops while traversing graph. PiperOrigin-RevId: 285895559 --- morph_net/framework/op_handler_util.py | 2 ++ morph_net/tools/configurable_ops.py | 5 +++++ 2 files changed, 7 insertions(+) diff --git a/morph_net/framework/op_handler_util.py b/morph_net/framework/op_handler_util.py index f846dd2..45ccf92 100644 --- a/morph_net/framework/op_handler_util.py +++ b/morph_net/framework/op_handler_util.py @@ -31,6 +31,8 @@ def get_input_ops(op, op_reg_manager, whitelist_indices=None): Returns: List of tf.Operation that are the inputs to op. """ + if 'GumbelPrefix' in op.type: + return [] # Ignore scalar or 1-D constant inputs. def is_const(tensor): return tensor.op.type == 'Const' diff --git a/morph_net/tools/configurable_ops.py b/morph_net/tools/configurable_ops.py index 6534467..f4fe19a 100644 --- a/morph_net/tools/configurable_ops.py +++ b/morph_net/tools/configurable_ops.py @@ -159,6 +159,11 @@ def __init__(self, self._default_to_zero = fallback_rule == FallbackRule.zero self._strict = fallback_rule == FallbackRule.strict + @property + def parameterization(self): + """Returns the parameterization dict mapping op names to num_outputs.""" + return self._parameterization + @tf.contrib.framework.add_arg_scope def conv2d(self, *args, **kwargs): """Masks num_outputs from the function pointed to by 'conv2d'.