Skip to content
Open
130 changes: 99 additions & 31 deletions kazoo/client.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
"""Kazoo Zookeeper Client"""
from __future__ import annotations

from collections import defaultdict, deque
from functools import partial
import inspect
import logging
from os.path import split
import re
from typing import TYPE_CHECKING, overload
import warnings

from kazoo.exceptions import (
Expand Down Expand Up @@ -63,6 +66,20 @@
from kazoo.recipe.queue import Queue, LockingQueue
from kazoo.recipe.watchers import ChildrenWatch, DataWatch

if TYPE_CHECKING:
from typing import (
Any,
List,
Optional,
Sequence,
Tuple,
Union,
Callable,
Literal,
)
from kazoo.protocol.states import ZnodeStat

WatchListener = Callable[[WatchedEvent], None]

CLOSED_STATES = (
KeeperState.EXPIRED_SESSION,
Expand Down Expand Up @@ -268,17 +285,17 @@ def __init__(
self._stopped.set()
self._writer_stopped.set()

self.retry = self._conn_retry = None
_retry = self._conn_retry = None

if type(connection_retry) is dict:
self._conn_retry = KazooRetry(**connection_retry)
elif type(connection_retry) is KazooRetry:
self._conn_retry = connection_retry

if type(command_retry) is dict:
self.retry = KazooRetry(**command_retry)
_retry = KazooRetry(**command_retry)
elif type(command_retry) is KazooRetry:
self.retry = command_retry
_retry = command_retry

if type(self._conn_retry) is KazooRetry:
if self.handler.sleep_func != self._conn_retry.sleep_func:
Expand All @@ -287,14 +304,14 @@ def __init__(
" must use the same sleep func"
)

if type(self.retry) is KazooRetry:
if self.handler.sleep_func != self.retry.sleep_func:
if type(_retry) is KazooRetry:
if self.handler.sleep_func != _retry.sleep_func:
raise ConfigurationError(
"Command retry handler and event handler "
"must use the same sleep func"
)

if self.retry is None or self._conn_retry is None:
if _retry is None or self._conn_retry is None:
old_retry_keys = dict(_RETRY_COMPAT_DEFAULTS)
for key in old_retry_keys:
try:
Expand All @@ -310,16 +327,16 @@ def __init__(
except KeyError:
pass

retry_keys = {}
retry_keys: Any = {}
for oldname, value in old_retry_keys.items():
retry_keys[_RETRY_COMPAT_MAPPING[oldname]] = value

if self._conn_retry is None:
self._conn_retry = KazooRetry(
sleep_func=self.handler.sleep_func, **retry_keys
)
if self.retry is None:
self.retry = KazooRetry(
if _retry is None:
_retry = KazooRetry(
sleep_func=self.handler.sleep_func, **retry_keys
)

Expand Down Expand Up @@ -364,14 +381,7 @@ def __init__(
sasl_options=sasl_options,
)

# Every retry call should have its own copy of the retry helper
# to avoid shared retry counts
self._retry = self.retry

def _retry(*args, **kwargs):
return self._retry.copy()(*args, **kwargs)

self.retry = _retry
self._retry = _retry

self.Barrier = partial(Barrier, self)
self.Counter = partial(Counter, self)
Expand All @@ -398,6 +408,12 @@ def _retry(*args, **kwargs):
% (kwargs.keys(),)
)

@property
def retry(self) -> KazooRetry:
# Every retry call should have its own copy of the retry helper
# to avoid shared retry counts
return self._retry.copy()

def _reset(self):
"""Resets a variety of client states for a new connection."""
self._queue = deque()
Expand Down Expand Up @@ -910,14 +926,14 @@ def sync(self, path):

def create(
self,
path,
value=b"",
acl=None,
ephemeral=False,
sequence=False,
makepath=False,
include_data=False,
):
path: str,
value: bytes = b"",
acl: Optional[Sequence[ACL]] = None,
ephemeral: bool = False,
sequence: bool = False,
makepath: bool = False,
include_data: bool = False,
) -> Union[str, Tuple[str, ZnodeStat]]:
"""Create a node with the given value as its data. Optionally
set an ACL on the node.

Expand Down Expand Up @@ -1122,7 +1138,7 @@ def _create_async_inner(
raise async_result.exception
return async_result

def ensure_path(self, path, acl=None):
def ensure_path(self, path: str, acl: Optional[List[ACL]] = None) -> bool:
"""Recursively create a path if it doesn't exist.

:param path: Path of node.
Expand Down Expand Up @@ -1171,7 +1187,9 @@ def exists_completion(path, result):

return async_result

def exists(self, path, watch=None):
def exists(
self, path: str, watch: Optional[WatchListener] = None
) -> Optional[ZnodeStat]:
"""Check if a node exists.

If a watch is provided, it will be left on the node with the
Expand Down Expand Up @@ -1211,7 +1229,9 @@ def exists_async(self, path, watch=None):
)
return async_result

def get(self, path, watch=None):
def get(
self, path: str, watch: Optional[WatchListener] = None
) -> Tuple[bytes, ZnodeStat]:
"""Get the value of a node.

If a watch is provided, it will be left on the node with the
Expand Down Expand Up @@ -1254,7 +1274,53 @@ def get_async(self, path, watch=None):
)
return async_result

def get_children(self, path, watch=None, include_data=False):
@overload
def get_children( # noqa: F811
self,
path: str,
) -> List[str]:
...

@overload
def get_children( # noqa: F811
self,
path: str,
watch: WatchListener,
) -> List[str]:
...

@overload
def get_children( # noqa: F811
self,
path: str,
watch: Optional[WatchListener],
) -> List[str]:
...

@overload
def get_children( # noqa: F811
self,
path: str,
watch: Optional[WatchListener],
include_data: Literal[True],
) -> List[Tuple[str, ZnodeStat]]:
...

@overload
def get_children( # noqa: F811
self,
path: str,
watch: Optional[WatchListener] = None,
include_data: Literal[False] = False,
) -> List[str]:
...

def get_children( # noqa: F811
self,
path: str,
watch: Optional[WatchListener] = None,
include_data: bool = False,
) -> Union[List[Tuple[str, ZnodeStat]], List[str]]:
"""Get a list of child nodes of a path.

If a watch is provided it will be left on the node with the
Expand Down Expand Up @@ -1400,7 +1466,7 @@ def set_acls_async(self, path, acls, version=-1):
)
return async_result

def set(self, path, value, version=-1):
def set(self, path: str, value: bytes, version: int = -1) -> ZnodeStat:
"""Set the value of a node.

If the version of the node being updated is newer than the
Expand Down Expand Up @@ -1473,7 +1539,9 @@ def transaction(self):
"""
return TransactionRequest(self)

def delete(self, path, version=-1, recursive=False):
def delete(
self, path: str, version: int = -1, recursive: bool = False
) -> Optional[bool]:
"""Delete a node.

The call will succeed if such a node exists, and the given
Expand Down
45 changes: 31 additions & 14 deletions kazoo/recipe/barrier.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,24 @@
:Status: Unknown

"""
from __future__ import annotations

import os
import socket
from threading import Event
from typing import TYPE_CHECKING, cast
import uuid

from kazoo.exceptions import KazooException, NoNodeError, NodeExistsError
from kazoo.protocol.states import EventType

if TYPE_CHECKING:
from typing import Optional
from typing_extensions import Literal

from kazoo.client import KazooClient
from kazoo.protocol.states import WatchedEvent


class Barrier(object):
"""Kazoo Barrier
Expand All @@ -27,7 +38,7 @@ class Barrier(object):

"""

def __init__(self, client, path):
def __init__(self, client: KazooClient, path: str):
"""Create a Kazoo Barrier

:param client: A :class:`~kazoo.client.KazooClient` instance.
Expand All @@ -37,11 +48,11 @@ def __init__(self, client, path):
self.client = client
self.path = path

def create(self):
def create(self) -> None:
"""Establish the barrier if it doesn't exist already"""
self.client.retry(self.client.ensure_path, self.path)

def remove(self):
def remove(self) -> bool:
"""Remove the barrier

:returns: Whether the barrier actually needed to be removed.
Expand All @@ -54,17 +65,17 @@ def remove(self):
except NoNodeError:
return False

def wait(self, timeout=None):
def wait(self, timeout: Optional[float] = None) -> bool:
"""Wait on the barrier to be cleared

:returns: True if the barrier has been cleared, otherwise
False.
:rtype: bool

"""
cleared = self.client.handler.event_object()
cleared = cast(Event, self.client.handler.event_object())

def wait_for_clear(event):
def wait_for_clear(event: WatchedEvent) -> None:
if event.type == EventType.DELETED:
cleared.set()

Expand Down Expand Up @@ -93,7 +104,13 @@ class DoubleBarrier(object):

"""

def __init__(self, client, path, num_clients, identifier=None):
def __init__(
self,
client: KazooClient,
path: str,
num_clients: int,
identifier: Optional[str] = None,
):
"""Create a Double Barrier

:param client: A :class:`~kazoo.client.KazooClient` instance.
Expand All @@ -118,7 +135,7 @@ def __init__(self, client, path, num_clients, identifier=None):
self.node_name = uuid.uuid4().hex
self.create_path = self.path + "/" + self.node_name

def enter(self):
def enter(self) -> None:
"""Enter the barrier, blocks until all nodes have entered"""
try:
self.client.retry(self._inner_enter)
Expand All @@ -128,7 +145,7 @@ def enter(self):
self._best_effort_cleanup()
self.participating = False

def _inner_enter(self):
def _inner_enter(self) -> Literal[True]:
# make sure our barrier parent node exists
if not self.assured_path:
self.client.ensure_path(self.path)
Expand All @@ -145,7 +162,7 @@ def _inner_enter(self):
except NodeExistsError:
pass

def created(event):
def created(event: WatchedEvent) -> None:
if event.type == EventType.CREATED:
ready.set()

Expand All @@ -159,7 +176,7 @@ def created(event):
self.client.ensure_path(self.path + "/ready")
return True

def leave(self):
def leave(self) -> None:
"""Leave the barrier, blocks until all nodes have left"""
try:
self.client.retry(self._inner_leave)
Expand All @@ -168,7 +185,7 @@ def leave(self):
self._best_effort_cleanup()
self.participating = False

def _inner_leave(self):
def _inner_leave(self) -> Literal[True]:
# Delete the ready node if its around
try:
self.client.delete(self.path + "/ready")
Expand All @@ -188,7 +205,7 @@ def _inner_leave(self):

ready = self.client.handler.event_object()

def deleted(event):
def deleted(event: WatchedEvent) -> None:
if event.type == EventType.DELETED:
ready.set()

Expand All @@ -214,7 +231,7 @@ def deleted(event):
# Wait for the lowest to be deleted
ready.wait()

def _best_effort_cleanup(self):
def _best_effort_cleanup(self) -> None:
try:
self.client.retry(self.client.delete, self.create_path)
except NoNodeError:
Expand Down
Loading