forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path__init__.py
140 lines (120 loc) · 4.34 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
# mypy: allow-untyped-defs
import pdb
import sys
import torch
def is_available() -> bool:
"""
Return ``True`` if the distributed package is available.
Otherwise,
``torch.distributed`` does not expose any other APIs. Currently,
``torch.distributed`` is available on Linux, MacOS and Windows. Set
``USE_DISTRIBUTED=1`` to enable it when building PyTorch from source.
Currently, the default value is ``USE_DISTRIBUTED=1`` for Linux and Windows,
``USE_DISTRIBUTED=0`` for MacOS.
"""
return hasattr(torch._C, "_c10d_init")
if is_available() and not torch._C._c10d_init():
raise RuntimeError("Failed to initialize torch.distributed")
# Custom Runtime Errors thrown from the distributed package
DistError = torch._C._DistError
DistBackendError = torch._C._DistBackendError
DistNetworkError = torch._C._DistNetworkError
DistStoreError = torch._C._DistStoreError
if is_available():
from torch._C._distributed_c10d import (
_broadcast_coalesced,
_compute_bucket_assignment_by_size,
_ControlCollectives,
_DEFAULT_FIRST_BUCKET_BYTES,
_make_nccl_premul_sum,
_register_builtin_comm_hook,
_register_comm_hook,
_StoreCollectives,
_test_python_store,
_verify_params_across_processes,
Backend as _Backend,
BuiltinCommHookType,
DebugLevel,
FileStore,
get_debug_level,
GradBucket,
Logger,
PrefixStore,
ProcessGroup as ProcessGroup,
Reducer,
set_debug_level,
set_debug_level_from_env,
Store,
TCPStore,
Work as _Work,
)
class _DistributedPdb(pdb.Pdb):
"""
Supports using PDB from inside a multiprocessing child process.
Usage:
_DistributedPdb().set_trace()
"""
def interaction(self, *args, **kwargs):
_stdin = sys.stdin
try:
sys.stdin = open("/dev/stdin")
pdb.Pdb.interaction(self, *args, **kwargs)
finally:
sys.stdin = _stdin
def breakpoint(rank: int = 0):
"""
Set a breakpoint, but only on a single rank. All other ranks will wait for you to be
done with the breakpoint before continuing.
Args:
rank (int): Which rank to break on. Default: ``0``
"""
if get_rank() == rank:
pdb = _DistributedPdb()
pdb.message(
"\n!!! ATTENTION !!!\n\n"
f"Type 'up' to get to the frame that called dist.breakpoint(rank={rank})\n"
)
pdb.set_trace()
# If Meta/Python keys are in the TLS, we want to make sure that we ignore them
# and hit the (default) CPU/CUDA implementation of barrier.
meta_in_tls = torch._C._meta_in_tls_dispatch_include()
guard = torch._C._DisableTorchDispatch() # type: ignore[attr-defined]
torch._C._set_meta_in_tls_dispatch_include(False)
try:
barrier()
finally:
torch._C._set_meta_in_tls_dispatch_include(meta_in_tls)
del guard
if sys.platform != "win32":
from torch._C._distributed_c10d import _round_robin_process_groups, HashStore
from .device_mesh import DeviceMesh, init_device_mesh
# Variables prefixed with underscore are not auto imported
# See the comment in `distributed_c10d.py` above `_backend` on why we expose
# this.
from .distributed_c10d import * # noqa: F403
from .distributed_c10d import (
_all_gather_base,
_coalescing_manager,
_CoalescingManager,
_create_process_group_wrapper,
_get_process_group_name,
_rank_not_in_group,
_reduce_scatter_base,
get_node_local_rank,
)
from .remote_device import _remote_device
from .rendezvous import (
_create_store_from_options,
register_rendezvous_handler,
rendezvous,
)
set_debug_level_from_env()
else:
# This stub is sufficient to get
# python test/test_public_bindings.py -k test_correct_module_names
# working even when USE_DISTRIBUTED=0. Feel free to add more
# stubs as necessary.
# We cannot define stubs directly because they confuse pyre
class _ProcessGroupStub:
pass
sys.modules["torch.distributed"].ProcessGroup = _ProcessGroupStub # type: ignore[attr-defined]