Skip to content

Commit 9cc565c

Browse files
authored
basic mac support (#158)
1 parent bed99ed commit 9cc565c

File tree

5 files changed

+87
-6
lines changed

5 files changed

+87
-6
lines changed

.github/workflows/unittest-mac.yaml

+70
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
name: Unit Tests
2+
3+
on:
4+
push:
5+
branches:
6+
- main
7+
pull_request:
8+
9+
jobs:
10+
unittest-mac:
11+
runs-on: macos-m2-15
12+
steps:
13+
- name: Checkout
14+
uses: actions/checkout@v4
15+
16+
- name: Setup miniconda
17+
uses: pytorch/test-infra/.github/actions/setup-miniconda@main
18+
with:
19+
python-version: 3.12
20+
21+
- name: Install Rust
22+
run: |
23+
set -ex
24+
25+
curl https://sh.rustup.rs -sSf | sh -s -- --default-toolchain=stable --profile=default -y
26+
. "$HOME/.cargo/env"
27+
28+
- name: Install Dependencies
29+
run: |
30+
set -ex
31+
32+
if [[ -n "$CONDA_ENV" ]]; then
33+
# Use binaries under conda environment
34+
export PATH="$CONDA_ENV/bin":$PATH
35+
fi
36+
. "$HOME/.cargo/env"
37+
38+
conda install libprotobuf -y
39+
40+
python -m pip install --upgrade pip
41+
42+
pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu
43+
44+
pip install -e .[dev] -v
45+
46+
- name: Run Python Tests
47+
run: |
48+
set -ex
49+
50+
if [[ -n "$CONDA_ENV" ]]; then
51+
# Use binaries under conda environment
52+
export PATH="$CONDA_ENV/bin":$PATH
53+
fi
54+
55+
# Run tests
56+
pytest -v
57+
58+
- name: Run Rust Tests
59+
run: |
60+
set -ex
61+
62+
if [[ -n "$CONDA_ENV" ]]; then
63+
# Use binaries under conda environment
64+
export PATH="$CONDA_ENV/bin":$PATH
65+
fi
66+
. "$HOME/.cargo/env"
67+
68+
export RUSTFLAGS="-C link-arg=-undefined -C link-arg=dynamic_lookup"
69+
70+
cargo test -v

torchft/checkpointing/pg_transport_test.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1+
import sys
12
from datetime import timedelta
2-
from unittest import TestCase, skipUnless
3+
from unittest import TestCase, skipIf, skipUnless
34

45
import torch
56
from torch.distributed import TCPStore
@@ -14,6 +15,8 @@
1415

1516

1617
class PGTransportTest(TestCase):
18+
# pyre-fixme[56]: Pyre was not able to infer the type of argument
19+
@skipIf(sys.platform == "darwin", "not passing on mac")
1720
def test_pg_transport_gloo(self) -> None:
1821
store: TCPStore = TCPStore(
1922
host_name="localhost", port=0, is_master=True, wait_for_workers=False

torchft/multiprocessing_test.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,10 @@ def test_monitored_queue_put(self) -> None:
2525

2626
mq = _MonitoredPipe(local)
2727
mq.send(1)
28-
with self.assertRaisesRegex(ConnectionResetError, "Connection reset by peer"):
28+
with self.assertRaisesRegex(
29+
(ConnectionResetError, BrokenPipeError),
30+
"(Connection reset by peer|Broken pipe)",
31+
):
2932
while True:
3033
mq.send(1)
3134

torchft/process_group.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,12 @@
4040
import torch.distributed as dist
4141
import torch.multiprocessing as mp
4242

43-
# pyre-fixme[21]: no attribute ProcessGroupNCCL
4443
# pyre-fixme[21]: no attribute ProcessGroupGloo
4544
from torch.distributed import (
4645
DeviceMesh,
4746
PrefixStore,
4847
ProcessGroup as BaseProcessGroup,
4948
ProcessGroupGloo as BaseProcessGroupGloo,
50-
ProcessGroupNCCL as BaseProcessGroupNCCL,
5149
Store,
5250
TCPStore,
5351
)
@@ -687,6 +685,9 @@ def _wrap_work(self, work: Work, opts: object) -> Work:
687685
return _WorkCUDATimeout(self, work, timeout)
688686

689687
def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGroup:
688+
# pyre-fixme[21]: no attribute ProcessGroupNCCL
689+
from torch.distributed import ProcessGroupNCCL as BaseProcessGroupNCCL
690+
690691
self._errored = None
691692

692693
pg = BaseProcessGroup(store, rank, world_size)
@@ -1717,6 +1718,8 @@ class ProcessGroupBabyNCCL(ProcessGroupBaby):
17171718

17181719
@classmethod
17191720
def _create_pg(cls, store: Store, rank: int, world_size: int) -> BaseProcessGroup:
1721+
from torch.distributed import ProcessGroupNCCL as BaseProcessGroupNCCL
1722+
17201723
pg = BaseProcessGroup(store, rank, world_size)
17211724
pg._set_default_backend(ProcessGroup.BackendType.NCCL)
17221725
# pyre-fixme[16]: no attribute ProcessGroupNCCL

torchft/process_group_test.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@
66

77
import gc
88
import os
9+
import sys
910
from concurrent.futures import Future, ProcessPoolExecutor, ThreadPoolExecutor
1011
from datetime import timedelta
1112
from typing import Any, Callable, Dict, List, cast
12-
from unittest import TestCase, skipUnless
13+
from unittest import TestCase, skipIf, skipUnless
1314
from unittest.mock import Mock
1415

1516
import torch
@@ -949,7 +950,7 @@ def worker(pg: ProcessGroup, rank: int, dev: str) -> str:
949950
# nccl: Tensor-likes are not equal/not close (due to abort)
950951
with self.assertRaisesRegex(
951952
Exception,
952-
r"(Connection closed by peer|Timed out waiting|no error|Read error|not equal|not close)",
953+
r"(Connection closed by peer|timed out after|Timed out waiting|no error|Read error|not equal|not close)",
953954
):
954955
test(pg, rank, t1.clone())
955956
raise RuntimeError("no error")
@@ -992,6 +993,7 @@ def test_collective_with_resiliency(self, collective: str) -> None:
992993
self._run_with_resiliency(collective, device="cpu")
993994

994995

996+
@skipIf(sys.platform == "darwin", "not reliable on mac")
995997
class BabyGlooMultiPgTest(MultiPgBaseTest):
996998
BACKEND = "baby_gloo"
997999
WORLD_SIZE = 3

0 commit comments

Comments
 (0)