Skip to content

Commit 5365c55

Browse files
authored
Arm backend: Fix bug in decompose linear vector norm (#11755)
The introduction of decomposition for linalg vector norm revealed a bug that when dim is None, then all dimensions should be reduced. Signed-off-by: Elena Zhelezina <[email protected]>
1 parent b22a2be commit 5365c55

File tree

5 files changed

+11
-6
lines changed

5 files changed

+11
-6
lines changed

backends/arm/_passes/decompose_linalg_vector_norm_pass.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,12 @@ def call_operator(self, op, args, kwargs, meta):
5151
f"is not supported for linalg_vector_norm operator"
5252
)
5353

54+
# Sum over all dimensions if dim is None
5455
if norm_dim is None:
55-
raise ValueError("The norm_dim for linalg_vector_norm is None.")
56-
57-
dims = [norm_dim] if isinstance(norm_dim, int) else list(norm_dim)
56+
rank = input_tensor.data.dim()
57+
dims = list(range(rank))
58+
else:
59+
dims = [norm_dim] if isinstance(norm_dim, int) else list(norm_dim)
5860

5961
# Decomposition based on norm order.
6062
if norm_order == 1:

backends/arm/_passes/decompose_sum_pass.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,11 @@ def call_operator(self, op, args, kwargs, meta):
6363
case _:
6464
raise ValueError(f"Invalid number of arguments ({len(args)}) provided.")
6565

66+
# If dims is None, sum over all dimensions
67+
if dims is None:
68+
shape = input_node.data.size()
69+
dims = list(range(len(shape)))
70+
6671
view_op, sum_op = _get_sum_decomp(op)
6772

6873
for dim in dims:

backends/arm/operators/op_sum.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,6 @@ def define_node(
106106
if inputs[0].dtype == ts.DType.INT8:
107107
return super().define_node(node, tosa_graph, inputs, output)
108108

109-
validate_num_inputs(self.target, inputs, 3)
110-
111109
tensor = inputs[0]
112110
input_shape = list(tensor.shape)
113111
dim = int(inputs[1].number % len(input_shape))

backends/arm/test/models/test_torch_functions.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,6 @@ def test_torch_fns_MI(test_data):
130130
"topk": "NotImplementedError: No registered serialization name for <class 'torch.return_types.topk'> found",
131131
"sort": "NotImplementedError: No registered serialization name for <class 'torch.return_types.sort'> found",
132132
"t": "MLETORCH-855: Issue with Quantization folding.",
133-
"norm": "An error occurred when running the 'KeepDimsFalseToSqueezePass' pass after the following passes:",
134133
},
135134
strict=False,
136135
)

backends/arm/test/ops/test_sum.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ class Sum(torch.nn.Module):
3333
"4d_dims_no_keep": lambda: (torch.rand(1, 1, 5, 8), 1, False),
3434
"4d_dim_3_keep": lambda: (torch.rand(1, 2, 3, 4), 3, True),
3535
"4d_dims_keep": lambda: (torch.rand(1, 2, 8, 8), [2, 3, 0], True),
36+
"dim_None": lambda: (torch.rand(10), None, True),
3637
}
3738

3839
def forward(self, x: torch.Tensor, dim: int, keepdim: bool):

0 commit comments

Comments
 (0)