Skip to content

Commit 37fec86

Browse files
committed
Update on "[rfc][dynamo] "skip_guard_eval" stance for power users"
# Motivation We have spent quite some time this year on improving guard performance and soundness. Nevertheless, guards STILL take time. We have seen multiple requests/evidences from power users where they want to have almost 0% guard overhead. First, we saw this in vLLM where even 1% overhead is bad. And recently we saw this in hqq (low precision LLM generation) - #138386. To put some numbers for perspective, low precision LLM inference reaches around 250 tokens/second, i.e, each token takes a mere 4 milliseconds. If guard overhead is even 200 us, its still 5% overhead in total. Here, users ask - "we can guarantee that there will no more recompilations in the steady state, give us the lowest guard overhead" # Design A must-have consideration is to support fast inference where the model has recompiled, i.e., has multiple cache entries for a code object (could be because of dynamism, or just tensor dtype change in the case of hqq). So, we still have to run the guards to figure out which compiled graph to run. What we need is the "minimal set of differentiating guards" - i.e. minimals set of guards that we can run to choose the compiled graph. Note that this works ONLY with the assumption that users really guarantee no more recompilation scenarios (no more mutations, no more dynamism after the model has been warmed up). It is possible that if user violates this assumption, and it is not covered by the diff guard set, we will choose a wrong compiled graph to run. When we designed C++ guards, Ed and Voz suggested to use Trie-structure to directly represent this "diff guard set". But due to complexity, we went for tree structure and relied on a GuardManager state - "fail_count" - to fail fast. I realized that we can rely on this "fail_count" to find the diff guard set. If we recompile, this means that all the cache line guard eval check_fns have failed. Whenever a guard check_fn fails, we increment the counter in the failing node (and propagate it to the root node) to do faster fail next time. If we want to run the "guard diff set", we just have to run only those nodes in the tree which have "fail_count > 0". This PR relies on this observation to introduce a new stance - "skip_guard_eval". The idea is that user will warm up their model with torch.compile, and the run the steady state with this stance. This stance go through the existing cache lines for the intercepted code object but only runs the diff guard set. This dramatically reduces the guard overhead. In case, all guards fail, we fall back to eager (however if this happens then user is violating the assumption, so we should perhaps hard error, I need to fix some silly issue from _dynamo.disable to hard error here). A bonus point here is that this "theoretically" works with graph breaks as well. But, I need more testing to convince myself about this. # Evaluation I tried the hqq model in #138386. With very small changes in the user code ([hqq PR](mobiusml/hqq#127)), I see the throughput increase from **160 tokens/sec to 174 tokens/sec**. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames rec [ghstack-poisoned]
2 parents 6128b69 + 05cddd3 commit 37fec86

File tree

520 files changed

+7029
-4723
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

520 files changed

+7029
-4723
lines changed
+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
ca4783992ed7602a39528ba304d61f00396b2a5a
1+
16b633b4daa7f3d3442be62a3589bd60b2f7fdc7

.ci/docker/libtorch/Dockerfile

+5
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,11 @@ RUN bash ./install_cuda.sh 12.4
6666
RUN bash ./install_magma.sh 12.4
6767
RUN ln -sf /usr/local/cuda-12.4 /usr/local/cuda
6868

69+
FROM cuda as cuda12.6
70+
RUN bash ./install_cuda.sh 12.6
71+
RUN bash ./install_magma.sh 12.6
72+
RUN ln -sf /usr/local/cuda-12.6 /usr/local/cuda
73+
6974
FROM cpu as rocm
7075
ARG PYTORCH_ROCM_ARCH
7176
ENV PYTORCH_ROCM_ARCH ${PYTORCH_ROCM_ARCH}

.ci/docker/requirements-ci.txt

+2-2
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ tb-nightly==2.13.0a20230426
257257
#test that import:
258258

259259
# needed by torchgen utils
260-
typing-extensions
260+
typing-extensions>=4.10.0
261261
#Description: type hints for python
262262
#Pinned versions:
263263
#test that import:
@@ -331,7 +331,7 @@ sympy==1.13.1 ; python_version >= "3.9"
331331
#Pinned versions:
332332
#test that import:
333333

334-
onnx==1.16.1
334+
onnx==1.17.0
335335
#Description: Required by mypy and test_public_bindings.py when checking torch.onnx._internal
336336
#Pinned versions:
337337
#test that import:

.github/ci_commit_pins/audio.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
79047bf6bdec9e32c4cffd0f9835b347781fefbf
1+
fa44bdab1fe49bab58389e7b6a33061ffced9bc7

.github/workflows/build-libtorch-images.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ jobs:
4444
runs-on: "${{ needs.get-label-type.outputs.label-type }}linux.9xlarge.ephemeral"
4545
strategy:
4646
matrix:
47-
cuda_version: ["12.4", "12.1", "11.8"]
47+
cuda_version: ["12.6", "12.4", "12.1", "11.8"]
4848
env:
4949
GPU_ARCH_TYPE: cuda
5050
GPU_ARCH_VERSION: ${{ matrix.cuda_version }}

.github/workflows/build-manywheel-images.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ jobs:
4848
runs-on: "${{ needs.get-label-type.outputs.label-type }}linux.9xlarge.ephemeral"
4949
strategy:
5050
matrix:
51-
cuda_version: ["12.4", "12.1", "11.8"]
51+
cuda_version: ["12.6", "12.4", "12.1", "11.8"]
5252
env:
5353
GPU_ARCH_TYPE: cuda
5454
GPU_ARCH_VERSION: ${{ matrix.cuda_version }}

.github/workflows/inductor-cu124.yml

+1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ jobs:
2121
get-label-type:
2222
name: get-label-type
2323
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
24+
if: github.event_name != 'schedule' || github.repository == 'pytorch/pytorch'
2425
with:
2526
triggering_actor: ${{ github.triggering_actor }}
2627
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}

.github/workflows/inductor-micro-benchmark-x86.yml

+1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ permissions: read-all
1717

1818
jobs:
1919
linux-jammy-cpu-py3_9-gcc11-inductor-build:
20+
if: github.event_name != 'schedule' || github.repository == 'pytorch/pytorch'
2021
name: linux-jammy-cpu-py3.9-gcc11-inductor
2122
uses: ./.github/workflows/_linux-build.yml
2223
with:

.github/workflows/inductor-micro-benchmark.yml

+1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ jobs:
1919
get-label-type:
2020
name: get-label-type
2121
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
22+
if: github.event_name != 'schedule' || github.repository == 'pytorch/pytorch'
2223
with:
2324
triggering_actor: ${{ github.triggering_actor }}
2425
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}

.github/workflows/inductor-perf-compare.yml

+1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ jobs:
2525
get-test-label-type:
2626
name: get-test-label-type
2727
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
28+
if: github.event_name != 'schedule' || github.repository == 'pytorch/pytorch'
2829
with:
2930
triggering_actor: ${{ github.triggering_actor }}
3031
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}

.github/workflows/inductor-perf-test-nightly-a10g.yml

+1
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ jobs:
7171
get-label-type:
7272
name: get-label-type
7373
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
74+
if: github.event_name != 'schedule' || github.repository == 'pytorch/pytorch'
7475
with:
7576
triggering_actor: ${{ github.triggering_actor }}
7677
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}

.github/workflows/inductor-perf-test-nightly-aarch64.yml

+1
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ jobs:
5151
get-label-type:
5252
name: get-label-type
5353
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
54+
if: github.event_name != 'schedule' || github.repository == 'pytorch/pytorch'
5455
with:
5556
triggering_actor: ${{ github.triggering_actor }}
5657
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}

.github/workflows/inductor-perf-test-nightly-x86.yml

+1
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ jobs:
5151
get-label-type:
5252
name: get-label-type
5353
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
54+
if: github.event_name != 'schedule' || github.repository == 'pytorch/pytorch'
5455
with:
5556
triggering_actor: ${{ github.triggering_actor }}
5657
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}

.github/workflows/inductor-perf-test-nightly.yml

+1
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ jobs:
6969
get-label-type:
7070
name: get-label-type
7171
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
72+
if: github.event_name != 'schedule' || github.repository == 'pytorch/pytorch'
7273
with:
7374
triggering_actor: ${{ github.triggering_actor }}
7475
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}

.github/workflows/inductor-periodic.yml

+1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ jobs:
2121
get-label-type:
2222
name: get-label-type
2323
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
24+
if: github.event_name != 'schedule' || github.repository == 'pytorch/pytorch'
2425
with:
2526
triggering_actor: ${{ github.triggering_actor }}
2627
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}

.github/workflows/inductor-rocm.yml

+1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ jobs:
2525
get-label-type:
2626
name: get-label-type
2727
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
28+
if: github.event_name != 'schedule' || github.repository == 'pytorch/pytorch'
2829
with:
2930
triggering_actor: ${{ github.triggering_actor }}
3031
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}

.github/workflows/inductor.yml

+1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ jobs:
2121
get-label-type:
2222
name: get-label-type
2323
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
24+
if: github.event_name != 'schedule' || github.repository == 'pytorch/pytorch'
2425
with:
2526
triggering_actor: ${{ github.triggering_actor }}
2627
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}

.github/workflows/nightly.yml

+1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ jobs:
2020
get-label-type:
2121
name: get-label-type
2222
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
23+
if: github.event_name != 'schedule' || github.repository == 'pytorch/pytorch'
2324
with:
2425
triggering_actor: ${{ github.triggering_actor }}
2526
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}

.github/workflows/periodic.yml

+1
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ jobs:
4141
get-label-type:
4242
name: get-label-type
4343
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
44+
if: github.event_name != 'schedule' || github.repository == 'pytorch/pytorch'
4445
with:
4546
triggering_actor: ${{ github.triggering_actor }}
4647
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}

.github/workflows/pull.yml

+21-16
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ jobs:
3838
get-label-type:
3939
name: get-label-type
4040
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
41+
if: github.event_name != 'schedule' || github.repository == 'pytorch/pytorch'
4142
with:
4243
triggering_actor: ${{ github.triggering_actor }}
4344
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
@@ -53,10 +54,11 @@ jobs:
5354
docker-image-name: pytorch-linux-jammy-py3.9-gcc11
5455
test-matrix: |
5556
{ include: [
56-
{ config: "default", shard: 1, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
57-
{ config: "default", shard: 2, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
58-
{ config: "default", shard: 3, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
59-
{ config: "default", shard: 4, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
57+
{ config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
58+
{ config: "default", shard: 2, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
59+
{ config: "default", shard: 3, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
60+
{ config: "default", shard: 4, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
61+
{ config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
6062
{ config: "docs_test", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
6163
{ config: "jit_legacy", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
6264
{ config: "backwards_compat", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
@@ -185,10 +187,11 @@ jobs:
185187
docker-image-name: pytorch-linux-focal-py3.9-clang10
186188
test-matrix: |
187189
{ include: [
188-
{ config: "default", shard: 1, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" },
189-
{ config: "default", shard: 2, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" },
190-
{ config: "default", shard: 3, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" },
191-
{ config: "default", shard: 4, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" },
190+
{ config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" },
191+
{ config: "default", shard: 2, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" },
192+
{ config: "default", shard: 3, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" },
193+
{ config: "default", shard: 4, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" },
194+
{ config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" },
192195
{ config: "crossref", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
193196
{ config: "crossref", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
194197
{ config: "dynamo", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
@@ -217,10 +220,11 @@ jobs:
217220
docker-image-name: pytorch-linux-focal-py3.11-clang10
218221
test-matrix: |
219222
{ include: [
220-
{ config: "default", shard: 1, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" },
221-
{ config: "default", shard: 2, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" },
222-
{ config: "default", shard: 3, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" },
223-
{ config: "default", shard: 4, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" },
223+
{ config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" },
224+
{ config: "default", shard: 2, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" },
225+
{ config: "default", shard: 3, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" },
226+
{ config: "default", shard: 4, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" },
227+
{ config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" },
224228
{ config: "crossref", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
225229
{ config: "crossref", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
226230
{ config: "dynamo", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
@@ -251,10 +255,11 @@ jobs:
251255
docker-image-name: pytorch-linux-focal-py3.12-clang10
252256
test-matrix: |
253257
{ include: [
254-
{ config: "default", shard: 1, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" },
255-
{ config: "default", shard: 2, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" },
256-
{ config: "default", shard: 3, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" },
257-
{ config: "default", shard: 4, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" },
258+
{ config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" },
259+
{ config: "default", shard: 2, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" },
260+
{ config: "default", shard: 3, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" },
261+
{ config: "default", shard: 4, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" },
262+
{ config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" },
258263
{ config: "dynamo", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
259264
{ config: "dynamo", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },
260265
{ config: "dynamo", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" },

.github/workflows/rocm.yml

+1
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ jobs:
2626
contents: read
2727

2828
linux-focal-rocm6_2-py3_10-build:
29+
if: github.event_name != 'schedule' || github.repository == 'pytorch/pytorch'
2930
name: linux-focal-rocm6.2-py3.10
3031
uses: ./.github/workflows/_linux-build.yml
3132
with:

.github/workflows/slow.yml

+1
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ jobs:
3939
get-label-type:
4040
name: get-label-type
4141
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
42+
if: github.event_name != 'schedule' || github.repository == 'pytorch/pytorch'
4243
with:
4344
triggering_actor: ${{ github.triggering_actor }}
4445
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}

.github/workflows/trunk.yml

+1
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ jobs:
3737
get-label-type:
3838
name: get-label-type
3939
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main
40+
if: github.event_name != 'schedule' || github.repository == 'pytorch/pytorch'
4041
with:
4142
triggering_actor: ${{ github.triggering_actor }}
4243
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}

CONTRIBUTING.md

+5
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,11 @@ The following packages should be installed with either `conda` or `pip`:
286286
- `expecttest` and `hypothesis` - required to run tests
287287
- `mypy` - recommended for linting
288288
- `pytest` - recommended to run tests more selectively
289+
Running
290+
```
291+
pip install -r requirements
292+
```
293+
will install these dependencies for you.
289294

290295
All PyTorch test suites are located in the `test` folder and start with
291296
`test_`. Run the entire test

aten/src/ATen/CMakeLists.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ if(NOT BUILD_LITE_INTERPRETER)
5454
endif()
5555
EXCLUDE(ATen_CORE_SRCS "${ATen_CORE_SRCS}" ${ATen_CORE_TEST_SRCS})
5656

57-
file(GLOB base_h "*.h" "detail/*.h" "cpu/*.h" "cpu/vec/vec512/*.h" "cpu/vec/vec256/*.h" "cpu/vec/vec256/vsx/*.h" "cpu/vec/vec256/zarch/*.h" "cpu/vec/sve/*.h" "cpu/vec/*.h" "quantized/*.h" "functorch/*.h")
57+
file(GLOB base_h "*.h" "detail/*.h" "cpu/*.h" "cpu/vec/vec512/*.h" "cpu/vec/vec128/*.h" "cpu/vec/vec256/*.h" "cpu/vec/vec256/vsx/*.h" "cpu/vec/vec256/zarch/*.h" "cpu/vec/sve/*.h" "cpu/vec/*.h" "quantized/*.h" "functorch/*.h")
5858
file(GLOB base_cpp "*.cpp" "detail/*.cpp" "cpu/*.cpp" "functorch/*.cpp")
5959
file(GLOB cuda_h "cuda/*.h" "cuda/detail/*.h" "cuda/*.cuh" "cuda/detail/*.cuh" "cuda/tunable/*.cuh" "cuda/tunable/*.h")
6060
file(GLOB cuda_cpp "cuda/*.cpp" "cuda/detail/*.cpp" "cuda/tunable/*.cpp")

aten/src/ATen/Parallel.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ TORCH_API std::string get_parallel_info();
133133
TORCH_API void set_num_interop_threads(int);
134134

135135
// Returns the number of threads used for inter-op parallelism
136-
TORCH_API int get_num_interop_threads();
136+
TORCH_API size_t get_num_interop_threads();
137137

138138
// Launches inter-op parallel task
139139
TORCH_API void launch(std::function<void()> func);
@@ -142,7 +142,7 @@ void launch_no_thread_state(std::function<void()> fn);
142142
} // namespace internal
143143

144144
// Launches intra-op parallel task
145-
TORCH_API void intraop_launch(std::function<void()> func);
145+
TORCH_API void intraop_launch(const std::function<void()>& func);
146146

147147
// Returns number of intra-op threads used by default
148148
TORCH_API int intraop_default_num_threads();

aten/src/ATen/ParallelFuture.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,6 @@ namespace at {
88

99
// Launches intra-op parallel task, returns a future
1010
TORCH_API c10::intrusive_ptr<c10::ivalue::Future> intraop_launch_future(
11-
std::function<void()> func);
11+
const std::function<void()>& func);
1212

1313
} // namespace at

aten/src/ATen/ParallelNative.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -273,10 +273,10 @@ bool in_parallel_region() {
273273
#endif // C10_MOBILE
274274
}
275275

276-
void intraop_launch(std::function<void()> func) {
276+
void intraop_launch(const std::function<void()>& func) {
277277
#ifndef C10_MOBILE
278278
if (!in_parallel_region() && get_num_threads() > 1) {
279-
_get_intraop_pool().run(std::move(func));
279+
_get_intraop_pool().run(func);
280280
} else {
281281
// execute inline if we're in parallel region
282282
func();
@@ -289,7 +289,7 @@ void intraop_launch(std::function<void()> func) {
289289
}
290290

291291
c10::intrusive_ptr<c10::ivalue::Future> intraop_launch_future(
292-
std::function<void()> func) {
292+
const std::function<void()>& func) {
293293
#ifndef C10_MOBILE
294294
auto future = c10::make_intrusive<c10::ivalue::Future>(c10::NoneType::get());
295295
if (!in_parallel_region() && get_num_threads() > 1) {

aten/src/ATen/ParallelOpenMP.cpp

+5-4
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,10 @@
1414

1515
namespace at {
1616
#if AT_MKLDNN_ENABLED()
17-
namespace native { namespace mkldnn {
17+
namespace native::mkldnn {
18+
// NOLINTNEXTLINE(misc-use-internal-linkage)
1819
void clear_computation_cache();
19-
}} // namespace native::mkldnn
20+
} // namespace native::mkldnn
2021
#endif
2122

2223
namespace {
@@ -100,13 +101,13 @@ bool in_parallel_region() {
100101
#endif
101102
}
102103

103-
void intraop_launch(std::function<void()> func) {
104+
void intraop_launch(const std::function<void()>& func) {
104105
// execute inline in openmp case
105106
func();
106107
}
107108

108109
c10::intrusive_ptr<c10::ivalue::Future> intraop_launch_future(
109-
std::function<void()> func) {
110+
const std::function<void()>& func) {
110111
func();
111112
auto future = c10::make_intrusive<c10::ivalue::Future>(NoneType::get());
112113
future->markCompleted();

aten/src/ATen/ParallelThreadPoolNative.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ void set_num_interop_threads(int nthreads) {
5656
"has started or set_num_interop_threads called");
5757
}
5858

59-
int get_num_interop_threads() {
59+
size_t get_num_interop_threads() {
6060
at::internal::lazy_init_num_threads();
6161
int nthreads = num_interop_threads.load();
6262
if (nthreads > 0) {
@@ -82,7 +82,7 @@ void launch_no_thread_state(std::function<void()> fn) {
8282
void launch(std::function<void()> func) {
8383
// NOLINTNEXTLINE(modernize-avoid-bind)
8484
internal::launch_no_thread_state(std::bind([](
85-
std::function<void()> f, ThreadLocalState thread_locals) {
85+
const std::function<void()>& f, const ThreadLocalState& thread_locals) {
8686
ThreadLocalStateGuard guard(thread_locals);
8787
f();
8888
},

aten/src/ATen/TensorIterator.cpp

-2
Original file line numberDiff line numberDiff line change
@@ -1483,8 +1483,6 @@ FastSetupType TensorIteratorBase::compute_fast_setup_type(const TensorIteratorCo
14831483
return FastSetupType::NONE;
14841484
}
14851485

1486-
TensorIteratorBase::TensorIteratorBase() = default;
1487-
14881486
void TensorIteratorBase::build(TensorIteratorConfig& config) {
14891487
// populate some persistent configuration fields
14901488
is_reduction_ = config.is_reduction_;

0 commit comments

Comments
 (0)