diff --git a/distributed/metrics.py b/distributed/metrics.py index d2ac55f15b5..983ec736859 100755 --- a/distributed/metrics.py +++ b/distributed/metrics.py @@ -5,6 +5,8 @@ from collections.abc import Callable from functools import wraps +import psutil + from distributed.compatibility import WINDOWS _empty_namedtuple = collections.namedtuple("_empty_namedtuple", ()) @@ -13,14 +15,8 @@ def _psutil_caller(method_name, default=_empty_namedtuple): """ Return a function calling the given psutil *method_name*, - or returning *default* if psutil is not present. + or returning *default* if psutil fails. """ - # Import only once to avoid the cost of a failing import at each wrapper() call - try: - import psutil - except ImportError: # pragma: no cover - return default - meth = getattr(psutil, method_name) @wraps(meth) diff --git a/distributed/pytest_resourceleaks.py b/distributed/pytest_resourceleaks.py index d4483ed9f9f..56d7622ed04 100644 --- a/distributed/pytest_resourceleaks.py +++ b/distributed/pytest_resourceleaks.py @@ -56,6 +56,7 @@ def test1(): import pytest from distributed.metrics import time +from distributed.system import process_memory def pytest_addoption(parser): @@ -172,7 +173,7 @@ class RSSMemoryChecker(ResourceChecker, name="memory"): LEAK_THRESHOLD = 10 * 2**20 def measure(self) -> int: - return psutil.Process().memory_info().rss + return process_memory() def has_leak(self, before: int, after: int) -> bool: return after > before + self.LEAK_THRESHOLD diff --git a/distributed/system.py b/distributed/system.py index 057b9df6728..0072c04d28b 100644 --- a/distributed/system.py +++ b/distributed/system.py @@ -5,6 +5,8 @@ import psutil +from distributed.compatibility import LINUX + __all__ = ("memory_limit", "MEMORY_LIMIT") @@ -63,4 +65,25 @@ def memory_limit() -> int: return limit +def process_memory(proc: psutil.Process | int | None = None) -> int: + """Return total memory used by a process + + Parameters + ---------- + proc: psutil.Process | int, optional + Process or PID to measure. Default: current process + """ + if proc is None: + proc = psutil.Process() + elif isinstance(proc, int): + proc = psutil.Process(proc) + + if LINUX: + minfo = proc.memory_full_info() + return minfo.rss + minfo.swap + else: + minfo = proc.memory_info() + return minfo.rss + + MEMORY_LIMIT = memory_limit() diff --git a/distributed/system_monitor.py b/distributed/system_monitor.py index d92dac6c119..5344348528c 100644 --- a/distributed/system_monitor.py +++ b/distributed/system_monitor.py @@ -11,6 +11,7 @@ from distributed.compatibility import WINDOWS from distributed.diagnostics import nvml from distributed.metrics import monotonic, time +from distributed.system import process_memory class SystemMonitor: @@ -112,7 +113,7 @@ def get_process_memory(self) -> int: as the OS allocating and releasing memory is highly volatile and a constant source of flakiness. """ - return self.proc.memory_info().rss + return process_memory(self.proc) def update(self) -> dict[str, Any]: now = time() diff --git a/distributed/tests/test_system.py b/distributed/tests/test_system.py index d297c70be87..94fbe1f90a8 100644 --- a/distributed/tests/test_system.py +++ b/distributed/tests/test_system.py @@ -7,7 +7,7 @@ import psutil import pytest -from distributed.system import memory_limit +from distributed.system import memory_limit, process_memory def test_memory_limit(): @@ -97,3 +97,7 @@ def test_rlimit(): assert memory_limit() == new_limit except OSError: pytest.skip("resource could not set the RSS limit") + + +def test_process_memory(): + assert 2**20 < process_memory() < 2**40 diff --git a/distributed/utils.py b/distributed/utils.py index d05961f361f..ba2f938340e 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -36,6 +36,7 @@ from typing import ClassVar, Iterator, TypeVar, overload import click +import psutil import tblib.pickling_support try: @@ -200,8 +201,6 @@ def get_ip_interface(ifname): ValueError is raised if the interface does no have an IPv4 address associated with it. """ - import psutil - net_if_addrs = psutil.net_if_addrs() if ifname not in net_if_addrs: diff --git a/distributed/utils_perf.py b/distributed/utils_perf.py index 3c33372d8aa..04abd23f592 100644 --- a/distributed/utils_perf.py +++ b/distributed/utils_perf.py @@ -5,9 +5,12 @@ import threading from collections import deque +import psutil + from dask.utils import format_bytes from distributed.metrics import thread_time +from distributed.system import process_memory logger = _logger = logging.getLogger(__name__) @@ -147,12 +150,7 @@ def __init__(self, warn_over_frac=0.1, info_over_rss_win=10 * 1e6): def enable(self): assert not self._enabled self._fractional_timer = FractionalTimer(n_samples=self.N_SAMPLES) - try: - import psutil - except ImportError: - self._proc = None - else: - self._proc = psutil.Process() + self._proc = psutil.Process() cb = self._gc_callback assert cb not in gc.callbacks @@ -181,10 +179,7 @@ def _gc_callback(self, phase, info): # don't waste time measuring them if info["generation"] != 2: return - if self._proc is not None: - rss = self._proc.memory_info().rss - else: - rss = 0 + rss = process_memory(self._proc) if phase == "start": self._fractional_timer.start_timing() self._gc_rss_before = rss diff --git a/distributed/worker_memory.py b/distributed/worker_memory.py index f45b23f56b8..23ca0a8fd1e 100644 --- a/distributed/worker_memory.py +++ b/distributed/worker_memory.py @@ -377,11 +377,13 @@ def memory_monitor(self, nanny: Nanny) -> None: process = nanny.process.process try: - memory = psutil.Process(process.pid).memory_info().rss + memory = system.process_memory(process.pid) except (ProcessLookupError, psutil.NoSuchProcess, psutil.AccessDenied): return # pragma: nocover - if memory / self.memory_limit <= self.memory_terminate_fraction: + assert self.memory_limit is not None + assert self.memory_terminate_fraction is not False + if memory <= self.memory_limit * self.memory_terminate_fraction: return if self._last_terminated_pid != process.pid: