Skip to content

Commit 45e0d03

Browse files
committed
Move host_array to comm.utils
1 parent 8c11a31 commit 45e0d03

File tree

4 files changed

+19
-31
lines changed

4 files changed

+19
-31
lines changed

distributed/comm/asyncio_tcp.py

+1-10
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import dask
1919

20-
from ..utils import ensure_ip, get_ip, get_ipv6
20+
from ..utils import ensure_ip, get_ip, get_ipv6, host_array
2121
from .addressing import parse_host_port, unparse_host_port
2222
from .core import Comm, CommClosedError, Connector, Listener
2323
from .registry import Backend
@@ -29,15 +29,6 @@
2929
_COMM_CLOSED = object()
3030

3131

32-
# Find the function, `host_array()`, to use when allocating new host arrays
33-
try:
34-
import numpy
35-
36-
host_array = lambda n: memoryview(numpy.empty((n,), dtype="u1")) # type: ignore
37-
except ImportError:
38-
host_array = lambda n: memoryview(bytearray(n)) # type: ignore
39-
40-
4132
def coalesce_buffers(
4233
buffers: list[bytes],
4334
target_buffer_size: int = 64 * 1024,

distributed/comm/tcp.py

+7-10
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,13 @@
3232
from .addressing import parse_host_port, unparse_host_port
3333
from .core import Comm, CommClosedError, Connector, FatalCommClosedError, Listener
3434
from .registry import Backend
35-
from .utils import ensure_concrete_host, from_frames, get_tcp_server_address, to_frames
35+
from .utils import (
36+
ensure_concrete_host,
37+
from_frames,
38+
get_tcp_server_address,
39+
host_array,
40+
to_frames,
41+
)
3642

3743
logger = logging.getLogger(__name__)
3844

@@ -41,15 +47,6 @@
4147
MAX_BUFFER_SIZE = MEMORY_LIMIT / 2
4248

4349

44-
# Find the function, `host_array()`, to use when allocating new host arrays
45-
try:
46-
import numpy
47-
48-
host_array = lambda n: memoryview(numpy.empty((n,), dtype="u1")) # type: ignore
49-
except ImportError:
50-
host_array = lambda n: memoryview(bytearray(n)) # type: ignore
51-
52-
5350
def set_tcp_timeout(comm):
5451
"""
5552
Set kernel-level TCP timeout on the stream.

distributed/comm/ucx.py

+2-11
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from dask.utils import parse_bytes
1818

1919
from ..diagnostics.nvml import has_cuda_context
20-
from ..utils import ensure_ip, get_ip, get_ipv6, log_errors, nbytes
20+
from ..utils import ensure_ip, get_ip, get_ipv6, host_array, log_errors, nbytes
2121
from .addressing import parse_host_port, unparse_host_port
2222
from .core import Comm, CommClosedError, Connector, Listener
2323
from .registry import Backend, backends
@@ -41,7 +41,6 @@
4141
ucx_create_endpoint = None # type: ignore
4242
ucx_create_listener = None # type: ignore
4343

44-
host_array = None
4544
device_array = None
4645
pre_existing_cuda_context = False
4746
cuda_context_created = False
@@ -57,7 +56,7 @@ def synchronize_stream(stream=0):
5756

5857

5958
def init_once():
60-
global ucp, host_array, device_array
59+
global ucp, device_array
6160
global ucx_create_endpoint, ucx_create_listener
6261
global pre_existing_cuda_context, cuda_context_created
6362

@@ -115,14 +114,6 @@ def init_once():
115114

116115
ucp.init(options=ucx_config, env_takes_precedence=True)
117116

118-
# Find the function, `host_array()`, to use when allocating new host arrays
119-
try:
120-
import numpy
121-
122-
host_array = lambda n: numpy.empty((n,), dtype="u1")
123-
except ImportError:
124-
host_array = lambda n: bytearray(n)
125-
126117
# Find the function, `cuda_array()`, to use when allocating new CUDA arrays
127118
try:
128119
import rmm

distributed/comm/utils.py

+9
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,15 @@
1818
OFFLOAD_THRESHOLD = parse_bytes(OFFLOAD_THRESHOLD)
1919

2020

21+
# Find the function, `host_array()`, to use when allocating new host arrays
22+
try:
23+
import numpy
24+
25+
host_array = lambda n: memoryview(numpy.empty((n,), dtype="u1")) # type: ignore
26+
except ImportError:
27+
host_array = lambda n: memoryview(bytearray(n)) # type: ignore
28+
29+
2130
async def to_frames(
2231
msg,
2332
allow_offload=True,

0 commit comments

Comments
 (0)