From bd0510f9642e24ef033e9e73a1e1f2d96919f8da Mon Sep 17 00:00:00 2001 From: xuziqi Date: Thu, 9 Nov 2023 15:35:01 +0800 Subject: [PATCH 1/3] Fix torch listconstruct errors when dependent on inputs flexible shapes --- .../converters/mil/frontend/torch/ops.py | 77 ++++++++++++++++--- 1 file changed, 66 insertions(+), 11 deletions(-) diff --git a/coremltools/converters/mil/frontend/torch/ops.py b/coremltools/converters/mil/frontend/torch/ops.py index 79938ba4d..e4a5d0ca1 100644 --- a/coremltools/converters/mil/frontend/torch/ops.py +++ b/coremltools/converters/mil/frontend/torch/ops.py @@ -626,23 +626,72 @@ def unflatten(context, node): def _array_construct(context, node, array_type): assert len(node.outputs) == 1 inputs = _get_inputs(context, node) - scalar_inputs = [ - inp - for inp in inputs - if isinstance(inp, Var) and inp.can_be_folded_to_const() and len(inp.shape) == 0 - ] - if len(scalar_inputs) == len(inputs): + nodes = {n.name : n for n in context.torch_graph.nodes} + is_known_name = lambda name : name in nodes + def dfs_graph_input_dependent(inputs, non_const=None): + ''' + inputs would be [] if all constant + otherwise further depend on each of their inputs, all the way to the root + + if some name is not in context.torch_graph.nodes, then it should be a symbolic in graph input + ''' + if non_const is None: + # effectively only at dfs.layer[0], mutable obj throughout dfs + non_const = set() + + # len(inputs) == 0 is dfs base + for i in inputs: + if is_known_name(i): + dfs_graph_input_dependent(nodes[i].inputs, non_const) + else: + non_const.add(i) + return non_const + any_inheriting = dfs_graph_input_dependent(node.inputs) + + is_all_const = all(map(lambda inp : isinstance(inp, Var) and inp.can_be_folded_to_const() and len(inp.shape) == 0, inputs)) + dependent_on_graph_input = len(any_inheriting) > 0 + + if is_all_const: # All the list items are compile-time scalar constants, so let's create # a new const that concatenates them. val = array_type([inp.val for inp in inputs]) const = mb.const(val=val, name=node.name) context.add(const) - else: - # If at least one input to the construct op is non-const, collect - # the inputs and add them directly to the context. Ops that use this - # node's output will take the list directly as input. - context.add(array_type(inputs), node.name) + return + + elif dependent_on_graph_input: + to_concat = [] + for input in node.inputs: + inheriting = dfs_graph_input_dependent([input]) + if len(inheriting) == 0: + # is const + to_concat.append([context[input].val]) + + else: + # is non_const + iter_node = nodes[input] + while all([is_known_name(i) for i in iter_node.inputs]): + iter_node = nodes[iter_node.inputs[0]] + + if context[iter_node.name].op.op_type == 'gather': + non_const = iter_node.inputs[0] + non_const_name = iter_node.inputs[1] + non_const_idx = context[non_const_name].val + to_concat.append(mb.slice_by_size(x=mb.shape(x=context[non_const]), begin=[non_const_idx], size=[1])) + + else: + to_concat = [] + break + + if len(to_concat) > 0: + context.add(mb.concat(values=to_concat, axis=0), node.name) + return + + # If at least one input to the construct op is neither const nor symbolic, collect + # the inputs and add them directly to the context. Ops that use this + # node's output will take the list directly as input. + context.add(array_type(inputs), node.name) @register_torch_op @@ -1595,6 +1644,12 @@ def pad(context, node): pad = pad.val.reshape((-1, 2))[::-1].reshape(-1).tolist() missing_dims = x.rank - (len(pad) // 2) pad = [0, 0] * missing_dims + pad + else: + missing_dims = (x.rank * 2 - pad.shape[0]) // 2 + pad = mb.concat(values=[pad, [0, 0] * missing_dims], axis=0) + pad = mb.reshape(x=pad, shape=[-1,2]) + pad = mb.reverse(x=pad, axes=[0]) + pad = mb.reshape(x=pad, shape=[-1]) # mil.ops.defs.iOS15.pad asserts 1D tensor if len(inputs) == 4: mode = inputs[2].val From 11d61cc1c913707ca46cb50254ceaee6c0c21dc3 Mon Sep 17 00:00:00 2001 From: xuziqi Date: Mon, 13 Nov 2023 15:23:01 +0800 Subject: [PATCH 2/3] [improve] postpone dfs until it's necessary --- .../converters/mil/frontend/torch/ops.py | 21 +++++++++---------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/coremltools/converters/mil/frontend/torch/ops.py b/coremltools/converters/mil/frontend/torch/ops.py index e4a5d0ca1..ce97bf03a 100644 --- a/coremltools/converters/mil/frontend/torch/ops.py +++ b/coremltools/converters/mil/frontend/torch/ops.py @@ -627,6 +627,15 @@ def _array_construct(context, node, array_type): assert len(node.outputs) == 1 inputs = _get_inputs(context, node) + is_all_const = all(map(lambda inp : isinstance(inp, Var) and inp.can_be_folded_to_const() and len(inp.shape) == 0, inputs)) + if is_all_const: + # All the list items are compile-time scalar constants, so let's create + # a new const that concatenates them. + val = array_type([inp.val for inp in inputs]) + const = mb.const(val=val, name=node.name) + context.add(const) + return + nodes = {n.name : n for n in context.torch_graph.nodes} is_known_name = lambda name : name in nodes def dfs_graph_input_dependent(inputs, non_const=None): @@ -648,19 +657,9 @@ def dfs_graph_input_dependent(inputs, non_const=None): non_const.add(i) return non_const any_inheriting = dfs_graph_input_dependent(node.inputs) - - is_all_const = all(map(lambda inp : isinstance(inp, Var) and inp.can_be_folded_to_const() and len(inp.shape) == 0, inputs)) dependent_on_graph_input = len(any_inheriting) > 0 - if is_all_const: - # All the list items are compile-time scalar constants, so let's create - # a new const that concatenates them. - val = array_type([inp.val for inp in inputs]) - const = mb.const(val=val, name=node.name) - context.add(const) - return - - elif dependent_on_graph_input: + if dependent_on_graph_input: to_concat = [] for input in node.inputs: inheriting = dfs_graph_input_dependent([input]) From 9b60a091ac455f9ad2019e284a467a73c3cea24f Mon Sep 17 00:00:00 2001 From: xuziqi Date: Mon, 13 Nov 2023 15:52:55 +0800 Subject: [PATCH 3/3] [improve] add memoization for dfs to accelerate for relative larger models --- coremltools/converters/mil/frontend/torch/ops.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/coremltools/converters/mil/frontend/torch/ops.py b/coremltools/converters/mil/frontend/torch/ops.py index ce97bf03a..555f70fb2 100644 --- a/coremltools/converters/mil/frontend/torch/ops.py +++ b/coremltools/converters/mil/frontend/torch/ops.py @@ -638,6 +638,7 @@ def _array_construct(context, node, array_type): nodes = {n.name : n for n in context.torch_graph.nodes} is_known_name = lambda name : name in nodes + inheriting_bookkeeping = {name : -1 for name in nodes.keys()} def dfs_graph_input_dependent(inputs, non_const=None): ''' inputs would be [] if all constant @@ -646,13 +647,15 @@ def dfs_graph_input_dependent(inputs, non_const=None): if some name is not in context.torch_graph.nodes, then it should be a symbolic in graph input ''' if non_const is None: - # effectively only at dfs.layer[0], mutable obj throughout dfs + # init, effectively only at dfs.layer[0] non_const = set() # len(inputs) == 0 is dfs base for i in inputs: if is_known_name(i): - dfs_graph_input_dependent(nodes[i].inputs, non_const) + if inheriting_bookkeeping[i] == -1: + inheriting = dfs_graph_input_dependent(nodes[i].inputs, non_const) + inheriting_bookkeeping[i] = len(inheriting) else: non_const.add(i) return non_const @@ -662,8 +665,8 @@ def dfs_graph_input_dependent(inputs, non_const=None): if dependent_on_graph_input: to_concat = [] for input in node.inputs: - inheriting = dfs_graph_input_dependent([input]) - if len(inheriting) == 0: + inheriting = inheriting_bookkeeping[input] + if inheriting <= 0: # is const to_concat.append([context[input].val])