|
40 | 40 | import torch.distributed as dist
|
41 | 41 | import torch.multiprocessing as mp
|
42 | 42 |
|
43 |
| -# pyre-fixme[21]: no attribute ProcessGroupNCCL |
44 | 43 | # pyre-fixme[21]: no attribute ProcessGroupGloo
|
45 | 44 | from torch.distributed import (
|
46 | 45 | DeviceMesh,
|
47 | 46 | PrefixStore,
|
48 | 47 | ProcessGroup as BaseProcessGroup,
|
49 | 48 | ProcessGroupGloo as BaseProcessGroupGloo,
|
50 |
| - ProcessGroupNCCL as BaseProcessGroupNCCL, |
51 | 49 | Store,
|
52 | 50 | TCPStore,
|
53 | 51 | )
|
@@ -687,6 +685,9 @@ def _wrap_work(self, work: Work, opts: object) -> Work:
|
687 | 685 | return _WorkCUDATimeout(self, work, timeout)
|
688 | 686 |
|
689 | 687 | def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGroup:
|
| 688 | + # pyre-fixme[21]: no attribute ProcessGroupNCCL |
| 689 | + from torch.distributed import ProcessGroupNCCL as BaseProcessGroupNCCL |
| 690 | + |
690 | 691 | self._errored = None
|
691 | 692 |
|
692 | 693 | pg = BaseProcessGroup(store, rank, world_size)
|
@@ -1717,6 +1718,8 @@ class ProcessGroupBabyNCCL(ProcessGroupBaby):
|
1717 | 1718 |
|
1718 | 1719 | @classmethod
|
1719 | 1720 | def _create_pg(cls, store: Store, rank: int, world_size: int) -> BaseProcessGroup:
|
| 1721 | + from torch.distributed import ProcessGroupNCCL as BaseProcessGroupNCCL |
| 1722 | + |
1720 | 1723 | pg = BaseProcessGroup(store, rank, world_size)
|
1721 | 1724 | pg._set_default_backend(ProcessGroup.BackendType.NCCL)
|
1722 | 1725 | # pyre-fixme[16]: no attribute ProcessGroupNCCL
|
|
0 commit comments