Skip to content

Added support for Pub/Sub mode in MultiDbClient #3722

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: vv-active-active-pipeline
Choose a base branch
from
Open
4 changes: 3 additions & 1 deletion redis/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1217,6 +1217,7 @@ def run_in_thread(
sleep_time: float = 0.0,
daemon: bool = False,
exception_handler: Optional[Callable] = None,
pubsub = None
) -> "PubSubWorkerThread":
for channel, handler in self.channels.items():
if handler is None:
Expand All @@ -1230,8 +1231,9 @@ def run_in_thread(
f"Shard Channel: '{s_channel}' has no handler registered"
)

pubsub = self if pubsub is None else pubsub
thread = PubSubWorkerThread(
self, sleep_time, daemon=daemon, exception_handler=exception_handler
pubsub, sleep_time, daemon=daemon, exception_handler=exception_handler
)
thread.start()
return thread
Expand Down
129 changes: 128 additions & 1 deletion redis/multidb/client.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import threading
import socket
from typing import List, Any, Callable
from typing import List, Any, Callable, Optional

from redis.background import BackgroundScheduler
from redis.client import PubSubWorkerThread
from redis.exceptions import ConnectionError, TimeoutError
from redis.commands import RedisModuleCommands, CoreCommands, SentinelCommands
from redis.multidb.command_executor import DefaultCommandExecutor
Expand Down Expand Up @@ -201,6 +202,17 @@ def transaction(self, func: Callable[["Pipeline"], None], *watches, **options):

return self.command_executor.execute_transaction(func, *watches, *options)

def pubsub(self, **kwargs):
"""
Return a Publish/Subscribe object. With this object, you can
subscribe to channels and listen for messages that get published to
them.
"""
if not self.initialized:
self.initialize()

return PubSub(self, **kwargs)

def _check_db_health(self, database: AbstractDatabase, on_error: Callable[[Exception], None] = None) -> None:
"""
Runs health checks on the given database until first failure.
Expand Down Expand Up @@ -311,3 +323,118 @@ def execute(self) -> List[Any]:
return self._client.command_executor.execute_pipeline(tuple(self._command_stack))
finally:
self.reset()

class PubSub:
"""
PubSub object for multi database client.
"""
def __init__(self, client: MultiDBClient, **kwargs):
self._client = client
self._client.command_executor.pubsub(**kwargs)

def __enter__(self) -> "PubSub":
return self

def __exit__(self, exc_type, exc_value, traceback) -> None:
self.reset()

def __del__(self) -> None:
try:
# if this object went out of scope prior to shutting down
# subscriptions, close the connection manually before
# returning it to the connection pool
self.reset()
except Exception:
pass

def reset(self) -> None:
pass

def close(self) -> None:
self.reset()

@property
def subscribed(self) -> bool:
return self._client.command_executor.active_pubsub.subscribed

def psubscribe(self, *args, **kwargs):
"""
Subscribe to channel patterns. Patterns supplied as keyword arguments
expect a pattern name as the key and a callable as the value. A
pattern's callable will be invoked automatically when a message is
received on that pattern rather than producing a message via
``listen()``.
"""
return self._client.command_executor.execute_pubsub_method('psubscribe', *args, **kwargs)

def punsubscribe(self, *args):
"""
Unsubscribe from the supplied patterns. If empty, unsubscribe from
all patterns.
"""
return self._client.command_executor.execute_pubsub_method('punsubscribe', *args)

def subscribe(self, *args, **kwargs):
"""
Subscribe to channels. Channels supplied as keyword arguments expect
a channel name as the key and a callable as the value. A channel's
callable will be invoked automatically when a message is received on
that channel rather than producing a message via ``listen()`` or
``get_message()``.
"""
return self._client.command_executor.execute_pubsub_method('subscribe', *args, **kwargs)

def unsubscribe(self, *args):
"""
Unsubscribe from the supplied channels. If empty, unsubscribe from
all channels
"""
return self._client.command_executor.execute_pubsub_method('unsubscribe', *args)

def ssubscribe(self, *args, **kwargs):
"""
Subscribes the client to the specified shard channels.
Channels supplied as keyword arguments expect a channel name as the key
and a callable as the value. A channel's callable will be invoked automatically
when a message is received on that channel rather than producing a message via
``listen()`` or ``get_sharded_message()``.
"""
return self._client.command_executor.execute_pubsub_method('ssubscribe', *args, **kwargs)

def sunsubscribe(self, *args):
"""
Unsubscribe from the supplied shard_channels. If empty, unsubscribe from
all shard_channels
"""
return self._client.command_executor.execute_pubsub_method('sunsubscribe', *args)

def get_message(
self, ignore_subscribe_messages: bool = False, timeout: float = 0.0
):
"""
Get the next message if one is available, otherwise None.

If timeout is specified, the system will wait for `timeout` seconds
before returning. Timeout should be specified as a floating point
number, or None, to wait indefinitely.
"""
return self._client.command_executor.execute_pubsub_method(
'get_message',
ignore_subscribe_messages=ignore_subscribe_messages, timeout=timeout
)

get_sharded_message = get_message

def run_in_thread(
self,
sleep_time: float = 0.0,
daemon: bool = False,
exception_handler: Optional[Callable] = None,
) -> "PubSubWorkerThread":
return self._client.command_executor.execute_pubsub_run_in_thread(
sleep_time=sleep_time,
daemon=daemon,
exception_handler=exception_handler,
pubsub=self
)

87 changes: 78 additions & 9 deletions redis/multidb/command_executor.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from abc import ABC, abstractmethod
from datetime import datetime, timedelta
from typing import List, Union, Optional, Callable
from typing import List, Optional, Callable

from redis.client import Pipeline
from redis.client import Pipeline, PubSub, PubSubWorkerThread
from redis.event import EventDispatcherInterface, OnCommandsFailEvent
from redis.multidb.config import DEFAULT_AUTO_FALLBACK_INTERVAL
from redis.multidb.database import Database, AbstractDatabase, Databases
from redis.multidb.circuit import State as CBState
from redis.multidb.event import RegisterCommandFailure
from redis.multidb.event import RegisterCommandFailure, ActiveDatabaseChanged, ResubscribeOnActiveDatabaseChanged
from redis.multidb.failover import FailoverStrategy
from redis.multidb.failure_detector import FailureDetector
from redis.retry import Retry
Expand All @@ -34,7 +34,7 @@ def databases(self) -> Databases:

@property
@abstractmethod
def active_database(self) -> Union[Database, None]:
def active_database(self) -> Optional[Database]:
"""Returns currently active database."""
pass

Expand All @@ -44,6 +44,23 @@ def active_database(self, database: AbstractDatabase) -> None:
"""Sets currently active database."""
pass

@abstractmethod
def pubsub(self, **kwargs):
"""Initializes a PubSub object on a currently active database"""
pass

@property
@abstractmethod
def active_pubsub(self) -> Optional[PubSub]:
"""Returns currently active pubsub."""
pass

@active_pubsub.setter
@abstractmethod
def active_pubsub(self, pubsub: PubSub) -> None:
"""Sets currently active pubsub."""
pass

@property
@abstractmethod
def failover_strategy(self) -> FailoverStrategy:
Expand Down Expand Up @@ -103,7 +120,9 @@ def __init__(
self._event_dispatcher = event_dispatcher
self._auto_fallback_interval = auto_fallback_interval
self._next_fallback_attempt: datetime
self._active_database: Union[Database, None] = None
self._active_database: Optional[Database] = None
self._active_pubsub: Optional[PubSub] = None
self._active_pubsub_kwargs = {}
self._setup_event_dispatcher()
self._schedule_next_fallback()

Expand All @@ -128,8 +147,22 @@ def active_database(self) -> Optional[AbstractDatabase]:

@active_database.setter
def active_database(self, database: AbstractDatabase) -> None:
old_active = self._active_database
self._active_database = database

if old_active is not None and old_active is not database:
self._event_dispatcher.dispatch(
ActiveDatabaseChanged(old_active, self._active_database, self, **self._active_pubsub_kwargs)
)

@property
def active_pubsub(self) -> Optional[PubSub]:
return self._active_pubsub

@active_pubsub.setter
def active_pubsub(self, pubsub: PubSub) -> None:
self._active_pubsub = pubsub

@property
def failover_strategy(self) -> FailoverStrategy:
return self._failover_strategy
Expand All @@ -143,6 +176,7 @@ def auto_fallback_interval(self, auto_fallback_interval: int) -> None:
self._auto_fallback_interval = auto_fallback_interval

def execute_command(self, *args, **options):
"""Executes a command and returns the result."""
def callback():
return self._active_database.client.execute_command(*args, **options)

Expand Down Expand Up @@ -170,6 +204,39 @@ def callback():

return self._execute_with_failure_detection(callback)

def pubsub(self, **kwargs):
def callback():
if self._active_pubsub is None:
self._active_pubsub = self._active_database.client.pubsub(**kwargs)
self._active_pubsub_kwargs = kwargs
return None

return self._execute_with_failure_detection(callback)

def execute_pubsub_method(self, method_name: str, *args, **kwargs):
"""
Executes given method on active pub/sub.
"""
def callback():
method = getattr(self.active_pubsub, method_name)
return method(*args, **kwargs)

return self._execute_with_failure_detection(callback, *args)

def execute_pubsub_run_in_thread(
self,
pubsub,
sleep_time: float = 0.0,
daemon: bool = False,
exception_handler: Optional[Callable] = None,
) -> "PubSubWorkerThread":
def callback():
return self._active_pubsub.run_in_thread(
sleep_time, daemon=daemon, exception_handler=exception_handler, pubsub=pubsub
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to provide the pubsub reference here? If a failover has occurred, shouldn’t _active_pubsub already contain all the necessary information? By passing in this reference, don’t we risk holding onto a pubsub object from a failed instance and potentially trying to work with it?

)

return self._execute_with_failure_detection(callback)

def _execute_with_failure_detection(self, callback: Callable, cmds: tuple = ()):
"""
Execute a commands execution callback with failure detection.
Expand Down Expand Up @@ -199,7 +266,7 @@ def _check_active_database(self):
and self._next_fallback_attempt <= datetime.now()
)
):
self._active_database = self._failover_strategy.database
self.active_database = self._failover_strategy.database
self._schedule_next_fallback()

def _schedule_next_fallback(self) -> None:
Expand All @@ -210,9 +277,11 @@ def _schedule_next_fallback(self) -> None:

def _setup_event_dispatcher(self):
"""
Registers command failure event listener.
Registers necessary listeners.
"""
event_listener = RegisterCommandFailure(self._failure_detectors)
failure_listener = RegisterCommandFailure(self._failure_detectors)
resubscribe_listener = ResubscribeOnActiveDatabaseChanged()
self._event_dispatcher.register_listeners({
OnCommandsFailEvent: [event_listener],
OnCommandsFailEvent: [failure_listener],
ActiveDatabaseChanged: [resubscribe_listener],
})
50 changes: 50 additions & 0 deletions redis/multidb/event.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,58 @@
from typing import List

from redis.event import EventListenerInterface, OnCommandsFailEvent
from redis.multidb.config import Databases
from redis.multidb.database import AbstractDatabase
from redis.multidb.failure_detector import FailureDetector

class ActiveDatabaseChanged:
"""
Event fired when an active database has been changed.
"""
def __init__(
self,
old_database: AbstractDatabase,
new_database: AbstractDatabase,
command_executor,
**kwargs
):
self._old_database = old_database
self._new_database = new_database
self._command_executor = command_executor
self._kwargs = kwargs

@property
def old_database(self) -> AbstractDatabase:
return self._old_database

@property
def new_database(self) -> AbstractDatabase:
return self._new_database

@property
def command_executor(self):
return self._command_executor

@property
def kwargs(self):
return self._kwargs

class ResubscribeOnActiveDatabaseChanged(EventListenerInterface):
"""
Re-subscribe currently active pub/sub to a new active database.
"""
def listen(self, event: ActiveDatabaseChanged):
old_pubsub = event.command_executor.active_pubsub

if old_pubsub is not None:
# Re-assign old channels and patterns so they will be automatically subscribed on connection.
new_pubsub = event.new_database.client.pubsub(**event.kwargs)
new_pubsub.channels = old_pubsub.channels
new_pubsub.patterns = old_pubsub.patterns
new_pubsub.shard_channels = old_pubsub.shard_channels
new_pubsub.on_connect(None)
event.command_executor.active_pubsub = new_pubsub
old_pubsub.close()

class RegisterCommandFailure(EventListenerInterface):
"""
Expand Down
1 change: 1 addition & 0 deletions redis/multidb/healthcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def __init__(
retry: Retry,
) -> None:
self._retry = retry
self._retry.update_supported_errors([ConnectionRefusedError])

@property
def retry(self) -> Retry:
Expand Down
Loading