Skip to content

Commit 5456cd0

Browse files
authored
Merge branch 'main' into marlin-move-rescales-add-sub
2 parents 07d8738 + 7ce78c0 commit 5456cd0

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

65 files changed

+171
-184
lines changed

.github/workflows/cuda.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@ jobs:
8989
9090
export-voxtral-cuda-artifact:
9191
name: export-voxtral-cuda-${{ matrix.quant.name }}
92+
# Skip this job if the pull request is from a fork (HuggingFace secrets are not available)
93+
if: github.event.pull_request.head.repo.full_name == github.repository || github.event_name != 'pull_request'
9294
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
9395
permissions:
9496
id-token: write
@@ -166,6 +168,8 @@ jobs:
166168
167169
export-gemma3-cuda-artifact:
168170
name: export-gemma3-cuda-${{ matrix.quant.name }}
171+
# Skip this job if the pull request is from a fork (HuggingFace secrets are not available)
172+
if: github.event.pull_request.head.repo.full_name == github.repository || github.event_name != 'pull_request'
169173
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
170174
permissions:
171175
id-token: write

backends/arm/arm_vela.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
# pyre-unsafe
76

87
import os
98
import struct

backends/arm/operator_support/embedding_support.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,12 @@
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
5+
"""Declare operator support for ``aten.embedding`` in TOSA.
56
7+
Permit embeddings with int32 indices (TOSA lacks int64 support); other dtypes
8+
are rejected by this check.
9+
10+
"""
611

712
import torch
813

@@ -17,6 +22,8 @@
1722

1823
@register_tosa_support_check
1924
class EmbeddingSupported(SupportedTOSAOperatorCheck):
25+
"""Provide TOSA support check for ``aten.embedding``."""
26+
2027
targets = [exir_ops.edge.aten.embedding.default]
2128

2229
tosa_specs = [
@@ -27,16 +34,20 @@ class EmbeddingSupported(SupportedTOSAOperatorCheck):
2734
def is_node_tosa_supported(
2835
self, node: fx.Node, tosa_spec: TosaSpecification
2936
) -> bool: # type: ignore[override, misc]
30-
# Note aten.embedding.default requires int64 indices and TOSA does not
31-
# support it. Int32 indices here for aten.embedding.default is ok since
32-
# it will be decomposed into ops that can handle it.
37+
"""Return True if the node is supported by TOSA.
3338
39+
PyTorch's ``aten.embedding`` typically takes int64 indices, but for
40+
TOSA we only allow int32 indices. The export path decomposes the op so
41+
that int32 indices are ok.
42+
43+
"""
3444
if len(node.all_input_nodes) != 2:
3545
self.reporter.report_reject(
3646
node,
3747
(f"Expected exactly two input nodes, got {len(node.all_input_nodes)}"),
3848
)
3949
return False
50+
4051
indices_val = node.all_input_nodes[1].meta["val"]
4152
indices_dtype = indices_val.dtype
4253

backends/arm/operators/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
# pyre-unsafe
76

87
from . import ( # noqa
98
node_visitor,

backends/arm/operators/node_visitor.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
# pyre-unsafe
76

87
import json
98
from typing import Any, Dict, List, Optional

backends/arm/operators/op_abs.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
# pyre-unsafe
76
from typing import Any, List
87

98
import tosa_serializer as ts

backends/arm/operators/op_add.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
# pyre-unsafe
76

87
from typing import Any, List
98

backends/arm/operators/op_any.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
# pyre-unsafe
76
from typing import Any, cast, List
87

98
import tosa_serializer as ts

backends/arm/operators/op_avg_pool2d.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
# pyre-unsafe
76
from typing import Any, List
87

98
import torch

backends/arm/operators/op_cat.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
# pyre-unsafe
76

87
from typing import Any, List
98

0 commit comments

Comments
 (0)