diff --git a/redis/client.py b/redis/client.py index 060fc29493..adb57d404e 100755 --- a/redis/client.py +++ b/redis/client.py @@ -1217,6 +1217,7 @@ def run_in_thread( sleep_time: float = 0.0, daemon: bool = False, exception_handler: Optional[Callable] = None, + pubsub = None ) -> "PubSubWorkerThread": for channel, handler in self.channels.items(): if handler is None: @@ -1230,8 +1231,9 @@ def run_in_thread( f"Shard Channel: '{s_channel}' has no handler registered" ) + pubsub = self if pubsub is None else pubsub thread = PubSubWorkerThread( - self, sleep_time, daemon=daemon, exception_handler=exception_handler + pubsub, sleep_time, daemon=daemon, exception_handler=exception_handler ) thread.start() return thread diff --git a/redis/multidb/client.py b/redis/multidb/client.py index 85c719fc1a..b7da716d4c 100644 --- a/redis/multidb/client.py +++ b/redis/multidb/client.py @@ -1,8 +1,9 @@ import threading import socket -from typing import List, Any, Callable +from typing import List, Any, Callable, Optional from redis.background import BackgroundScheduler +from redis.client import PubSubWorkerThread from redis.exceptions import ConnectionError, TimeoutError from redis.commands import RedisModuleCommands, CoreCommands, SentinelCommands from redis.multidb.command_executor import DefaultCommandExecutor @@ -201,6 +202,17 @@ def transaction(self, func: Callable[["Pipeline"], None], *watches, **options): return self.command_executor.execute_transaction(func, *watches, *options) + def pubsub(self, **kwargs): + """ + Return a Publish/Subscribe object. With this object, you can + subscribe to channels and listen for messages that get published to + them. + """ + if not self.initialized: + self.initialize() + + return PubSub(self, **kwargs) + def _check_db_health(self, database: AbstractDatabase, on_error: Callable[[Exception], None] = None) -> None: """ Runs health checks on the given database until first failure. @@ -311,3 +323,118 @@ def execute(self) -> List[Any]: return self._client.command_executor.execute_pipeline(tuple(self._command_stack)) finally: self.reset() + +class PubSub: + """ + PubSub object for multi database client. + """ + def __init__(self, client: MultiDBClient, **kwargs): + self._client = client + self._client.command_executor.pubsub(**kwargs) + + def __enter__(self) -> "PubSub": + return self + + def __exit__(self, exc_type, exc_value, traceback) -> None: + self.reset() + + def __del__(self) -> None: + try: + # if this object went out of scope prior to shutting down + # subscriptions, close the connection manually before + # returning it to the connection pool + self.reset() + except Exception: + pass + + def reset(self) -> None: + pass + + def close(self) -> None: + self.reset() + + @property + def subscribed(self) -> bool: + return self._client.command_executor.active_pubsub.subscribed + + def psubscribe(self, *args, **kwargs): + """ + Subscribe to channel patterns. Patterns supplied as keyword arguments + expect a pattern name as the key and a callable as the value. A + pattern's callable will be invoked automatically when a message is + received on that pattern rather than producing a message via + ``listen()``. + """ + return self._client.command_executor.execute_pubsub_method('psubscribe', *args, **kwargs) + + def punsubscribe(self, *args): + """ + Unsubscribe from the supplied patterns. If empty, unsubscribe from + all patterns. + """ + return self._client.command_executor.execute_pubsub_method('punsubscribe', *args) + + def subscribe(self, *args, **kwargs): + """ + Subscribe to channels. Channels supplied as keyword arguments expect + a channel name as the key and a callable as the value. A channel's + callable will be invoked automatically when a message is received on + that channel rather than producing a message via ``listen()`` or + ``get_message()``. + """ + return self._client.command_executor.execute_pubsub_method('subscribe', *args, **kwargs) + + def unsubscribe(self, *args): + """ + Unsubscribe from the supplied channels. If empty, unsubscribe from + all channels + """ + return self._client.command_executor.execute_pubsub_method('unsubscribe', *args) + + def ssubscribe(self, *args, **kwargs): + """ + Subscribes the client to the specified shard channels. + Channels supplied as keyword arguments expect a channel name as the key + and a callable as the value. A channel's callable will be invoked automatically + when a message is received on that channel rather than producing a message via + ``listen()`` or ``get_sharded_message()``. + """ + return self._client.command_executor.execute_pubsub_method('ssubscribe', *args, **kwargs) + + def sunsubscribe(self, *args): + """ + Unsubscribe from the supplied shard_channels. If empty, unsubscribe from + all shard_channels + """ + return self._client.command_executor.execute_pubsub_method('sunsubscribe', *args) + + def get_message( + self, ignore_subscribe_messages: bool = False, timeout: float = 0.0 + ): + """ + Get the next message if one is available, otherwise None. + + If timeout is specified, the system will wait for `timeout` seconds + before returning. Timeout should be specified as a floating point + number, or None, to wait indefinitely. + """ + return self._client.command_executor.execute_pubsub_method( + 'get_message', + ignore_subscribe_messages=ignore_subscribe_messages, timeout=timeout + ) + + get_sharded_message = get_message + + def run_in_thread( + self, + sleep_time: float = 0.0, + daemon: bool = False, + exception_handler: Optional[Callable] = None, + ) -> "PubSubWorkerThread": + return self._client.command_executor.execute_pubsub_run_in_thread( + sleep_time=sleep_time, + daemon=daemon, + exception_handler=exception_handler, + pubsub=self + ) + diff --git a/redis/multidb/command_executor.py b/redis/multidb/command_executor.py index 690ea49a5c..795ef8f8b1 100644 --- a/redis/multidb/command_executor.py +++ b/redis/multidb/command_executor.py @@ -1,13 +1,13 @@ from abc import ABC, abstractmethod from datetime import datetime, timedelta -from typing import List, Union, Optional, Callable +from typing import List, Optional, Callable -from redis.client import Pipeline +from redis.client import Pipeline, PubSub, PubSubWorkerThread from redis.event import EventDispatcherInterface, OnCommandsFailEvent from redis.multidb.config import DEFAULT_AUTO_FALLBACK_INTERVAL from redis.multidb.database import Database, AbstractDatabase, Databases from redis.multidb.circuit import State as CBState -from redis.multidb.event import RegisterCommandFailure +from redis.multidb.event import RegisterCommandFailure, ActiveDatabaseChanged, ResubscribeOnActiveDatabaseChanged from redis.multidb.failover import FailoverStrategy from redis.multidb.failure_detector import FailureDetector from redis.retry import Retry @@ -34,7 +34,7 @@ def databases(self) -> Databases: @property @abstractmethod - def active_database(self) -> Union[Database, None]: + def active_database(self) -> Optional[Database]: """Returns currently active database.""" pass @@ -44,6 +44,23 @@ def active_database(self, database: AbstractDatabase) -> None: """Sets currently active database.""" pass + @abstractmethod + def pubsub(self, **kwargs): + """Initializes a PubSub object on a currently active database""" + pass + + @property + @abstractmethod + def active_pubsub(self) -> Optional[PubSub]: + """Returns currently active pubsub.""" + pass + + @active_pubsub.setter + @abstractmethod + def active_pubsub(self, pubsub: PubSub) -> None: + """Sets currently active pubsub.""" + pass + @property @abstractmethod def failover_strategy(self) -> FailoverStrategy: @@ -103,7 +120,9 @@ def __init__( self._event_dispatcher = event_dispatcher self._auto_fallback_interval = auto_fallback_interval self._next_fallback_attempt: datetime - self._active_database: Union[Database, None] = None + self._active_database: Optional[Database] = None + self._active_pubsub: Optional[PubSub] = None + self._active_pubsub_kwargs = {} self._setup_event_dispatcher() self._schedule_next_fallback() @@ -128,8 +147,22 @@ def active_database(self) -> Optional[AbstractDatabase]: @active_database.setter def active_database(self, database: AbstractDatabase) -> None: + old_active = self._active_database self._active_database = database + if old_active is not None and old_active is not database: + self._event_dispatcher.dispatch( + ActiveDatabaseChanged(old_active, self._active_database, self, **self._active_pubsub_kwargs) + ) + + @property + def active_pubsub(self) -> Optional[PubSub]: + return self._active_pubsub + + @active_pubsub.setter + def active_pubsub(self, pubsub: PubSub) -> None: + self._active_pubsub = pubsub + @property def failover_strategy(self) -> FailoverStrategy: return self._failover_strategy @@ -143,6 +176,7 @@ def auto_fallback_interval(self, auto_fallback_interval: int) -> None: self._auto_fallback_interval = auto_fallback_interval def execute_command(self, *args, **options): + """Executes a command and returns the result.""" def callback(): return self._active_database.client.execute_command(*args, **options) @@ -170,6 +204,39 @@ def callback(): return self._execute_with_failure_detection(callback) + def pubsub(self, **kwargs): + def callback(): + if self._active_pubsub is None: + self._active_pubsub = self._active_database.client.pubsub(**kwargs) + self._active_pubsub_kwargs = kwargs + return None + + return self._execute_with_failure_detection(callback) + + def execute_pubsub_method(self, method_name: str, *args, **kwargs): + """ + Executes given method on active pub/sub. + """ + def callback(): + method = getattr(self.active_pubsub, method_name) + return method(*args, **kwargs) + + return self._execute_with_failure_detection(callback, *args) + + def execute_pubsub_run_in_thread( + self, + pubsub, + sleep_time: float = 0.0, + daemon: bool = False, + exception_handler: Optional[Callable] = None, + ) -> "PubSubWorkerThread": + def callback(): + return self._active_pubsub.run_in_thread( + sleep_time, daemon=daemon, exception_handler=exception_handler, pubsub=pubsub + ) + + return self._execute_with_failure_detection(callback) + def _execute_with_failure_detection(self, callback: Callable, cmds: tuple = ()): """ Execute a commands execution callback with failure detection. @@ -199,7 +266,7 @@ def _check_active_database(self): and self._next_fallback_attempt <= datetime.now() ) ): - self._active_database = self._failover_strategy.database + self.active_database = self._failover_strategy.database self._schedule_next_fallback() def _schedule_next_fallback(self) -> None: @@ -210,9 +277,11 @@ def _schedule_next_fallback(self) -> None: def _setup_event_dispatcher(self): """ - Registers command failure event listener. + Registers necessary listeners. """ - event_listener = RegisterCommandFailure(self._failure_detectors) + failure_listener = RegisterCommandFailure(self._failure_detectors) + resubscribe_listener = ResubscribeOnActiveDatabaseChanged() self._event_dispatcher.register_listeners({ - OnCommandsFailEvent: [event_listener], + OnCommandsFailEvent: [failure_listener], + ActiveDatabaseChanged: [resubscribe_listener], }) \ No newline at end of file diff --git a/redis/multidb/event.py b/redis/multidb/event.py index 315802e812..2598bc4d06 100644 --- a/redis/multidb/event.py +++ b/redis/multidb/event.py @@ -1,8 +1,58 @@ from typing import List from redis.event import EventListenerInterface, OnCommandsFailEvent +from redis.multidb.config import Databases +from redis.multidb.database import AbstractDatabase from redis.multidb.failure_detector import FailureDetector +class ActiveDatabaseChanged: + """ + Event fired when an active database has been changed. + """ + def __init__( + self, + old_database: AbstractDatabase, + new_database: AbstractDatabase, + command_executor, + **kwargs + ): + self._old_database = old_database + self._new_database = new_database + self._command_executor = command_executor + self._kwargs = kwargs + + @property + def old_database(self) -> AbstractDatabase: + return self._old_database + + @property + def new_database(self) -> AbstractDatabase: + return self._new_database + + @property + def command_executor(self): + return self._command_executor + + @property + def kwargs(self): + return self._kwargs + +class ResubscribeOnActiveDatabaseChanged(EventListenerInterface): + """ + Re-subscribe currently active pub/sub to a new active database. + """ + def listen(self, event: ActiveDatabaseChanged): + old_pubsub = event.command_executor.active_pubsub + + if old_pubsub is not None: + # Re-assign old channels and patterns so they will be automatically subscribed on connection. + new_pubsub = event.new_database.client.pubsub(**event.kwargs) + new_pubsub.channels = old_pubsub.channels + new_pubsub.patterns = old_pubsub.patterns + new_pubsub.shard_channels = old_pubsub.shard_channels + new_pubsub.on_connect(None) + event.command_executor.active_pubsub = new_pubsub + old_pubsub.close() class RegisterCommandFailure(EventListenerInterface): """ diff --git a/redis/multidb/healthcheck.py b/redis/multidb/healthcheck.py index a96b9cf815..1396a1e997 100644 --- a/redis/multidb/healthcheck.py +++ b/redis/multidb/healthcheck.py @@ -21,6 +21,7 @@ def __init__( retry: Retry, ) -> None: self._retry = retry + self._retry.update_supported_errors([ConnectionRefusedError]) @property def retry(self) -> Retry: diff --git a/tests/test_scenario/conftest.py b/tests/test_scenario/conftest.py index 8ae7441e98..486dc948f1 100644 --- a/tests/test_scenario/conftest.py +++ b/tests/test_scenario/conftest.py @@ -3,13 +3,22 @@ import pytest -from redis.backoff import NoBackoff +from redis.backoff import NoBackoff, ExponentialBackoff +from redis.event import EventDispatcher, EventListenerInterface from redis.multidb.client import MultiDBClient from redis.multidb.config import DatabaseConfig, MultiDbConfig, DEFAULT_HEALTH_CHECK_INTERVAL, \ DEFAULT_FAILURES_THRESHOLD +from redis.multidb.event import ActiveDatabaseChanged +from redis.multidb.healthcheck import EchoHealthCheck from redis.retry import Retry from tests.test_scenario.fault_injector_client import FaultInjectorClient +class CheckActiveDatabaseChangedListener(EventListenerInterface): + def __init__(self): + self.is_changed_flag = False + + def listen(self, event: ActiveDatabaseChanged): + self.is_changed_flag = True def get_endpoint_config(endpoint_name: str): endpoints_config = os.getenv("REDIS_ENDPOINTS_CONFIG_PATH", None) @@ -33,13 +42,22 @@ def fault_injector_client(): return FaultInjectorClient(url) @pytest.fixture() -def r_multi_db(request) -> MultiDBClient: +def r_multi_db(request) -> tuple[MultiDBClient, CheckActiveDatabaseChangedListener]: endpoint_config = get_endpoint_config('re-active-active') username = endpoint_config.get('username', None) password = endpoint_config.get('password', None) failure_threshold = request.param.get('failure_threshold', DEFAULT_FAILURES_THRESHOLD) - command_retry = request.param.get('command_retry', Retry(NoBackoff(), retries=3)) + command_retry = request.param.get('command_retry', Retry(ExponentialBackoff(cap=0.5, base=0.05), retries=3)) + + # Retry configuration different for health checks as initial health check require more time in case + # if infrastructure wasn't restored from the previous test. + health_checks = [EchoHealthCheck(Retry(ExponentialBackoff(cap=5, base=0.5), retries=3))] health_check_interval = request.param.get('health_check_interval', DEFAULT_HEALTH_CHECK_INTERVAL) + event_dispatcher = EventDispatcher() + listener = CheckActiveDatabaseChangedListener() + event_dispatcher.register_listeners({ + ActiveDatabaseChanged: [listener], + }) db_configs = [] db_config = DatabaseConfig( @@ -64,12 +82,13 @@ def r_multi_db(request) -> MultiDBClient: ) db_configs.append(db_config1) - config = MultiDbConfig( databases_config=db_configs, + health_checks=health_checks, command_retry=command_retry, failure_threshold=failure_threshold, health_check_interval=health_check_interval, + event_dispatcher=event_dispatcher, ) - return MultiDBClient(config) \ No newline at end of file + return MultiDBClient(config), listener \ No newline at end of file diff --git a/tests/test_scenario/test_active_active.py b/tests/test_scenario/test_active_active.py index 2b9bfc7e74..09d156ce53 100644 --- a/tests/test_scenario/test_active_active.py +++ b/tests/test_scenario/test_active_active.py @@ -1,13 +1,11 @@ +import json import logging import threading from time import sleep import pytest -from redis.backoff import NoBackoff from redis.client import Pipeline -from redis.exceptions import ConnectionError -from redis.retry import Retry from tests.test_scenario.conftest import get_endpoint_config from tests.test_scenario.fault_injector_client import ActionRequest, ActionType @@ -17,7 +15,7 @@ def trigger_network_failure_action(fault_injector_client, event: threading.Event endpoint_config = get_endpoint_config('re-active-active') action_request = ActionRequest( action_type=ActionType.NETWORK_FAILURE, - parameters={"bdb_id": endpoint_config['bdb_id'], "delay": 1, "cluster_index": 0} + parameters={"bdb_id": endpoint_config['bdb_id'], "delay": 2, "cluster_index": 0} ) result = fault_injector_client.trigger_action(action_request) @@ -34,7 +32,6 @@ def trigger_network_failure_action(fault_injector_client, event: threading.Event logger.info(f"Action completed. Status: {status_result['status']}") class TestActiveActiveStandalone: - def teardown_method(self, method): # Timeout so the cluster could recover from network failure. sleep(3) @@ -54,6 +51,8 @@ def test_multi_db_client_failover_to_another_db(self, r_multi_db, fault_injector args=(fault_injector_client,event) ) + r_multi_db, listener = r_multi_db + # Client initialized on the first command. r_multi_db.set('key', 'value') thread.start() @@ -68,6 +67,8 @@ def test_multi_db_client_failover_to_another_db(self, r_multi_db, fault_injector assert r_multi_db.get('key') == 'value' sleep(0.1) + assert listener.is_changed_flag == True + @pytest.mark.parametrize( "r_multi_db", [ @@ -83,6 +84,8 @@ def test_context_manager_pipeline_failover_to_another_db(self, r_multi_db, fault args=(fault_injector_client,event) ) + r_multi_db, listener = r_multi_db + # Client initialized on first pipe execution. with r_multi_db.pipeline() as pipe: pipe.set('{hash}key1', 'value1') @@ -119,6 +122,8 @@ def test_context_manager_pipeline_failover_to_another_db(self, r_multi_db, fault assert pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] sleep(0.1) + assert listener.is_changed_flag == True + @pytest.mark.parametrize( "r_multi_db", [ @@ -134,6 +139,8 @@ def test_chaining_pipeline_failover_to_another_db(self, r_multi_db, fault_inject args=(fault_injector_client,event) ) + r_multi_db, listener = r_multi_db + # Client initialized on first pipe execution. pipe = r_multi_db.pipeline() pipe.set('{hash}key1', 'value1') @@ -168,6 +175,8 @@ def test_chaining_pipeline_failover_to_another_db(self, r_multi_db, fault_inject assert pipe.execute() == [True, True, True, 'value1', 'value2', 'value3'] sleep(0.1) + assert listener.is_changed_flag == True + @pytest.mark.parametrize( "r_multi_db", [ @@ -183,6 +192,8 @@ def test_transaction_failover_to_another_db(self, r_multi_db, fault_injector_cli args=(fault_injector_client,event) ) + r_multi_db, listener = r_multi_db + def callback(pipe: Pipeline): pipe.set('{hash}key1', 'value1') pipe.set('{hash}key2', 'value2') @@ -203,4 +214,90 @@ def callback(pipe: Pipeline): # Execute pipeline after network failure for _ in range(3): r_multi_db.transaction(callback) - sleep(0.1) \ No newline at end of file + sleep(0.1) + + assert listener.is_changed_flag == True + + @pytest.mark.parametrize( + "r_multi_db", + [ + {"failure_threshold": 2} + ], + indirect=True + ) + def test_pubsub_failover_to_another_db(self, r_multi_db, fault_injector_client): + event = threading.Event() + thread = threading.Thread( + target=trigger_network_failure_action, + daemon=True, + args=(fault_injector_client,event) + ) + + r_multi_db, listener = r_multi_db + data = json.dumps({'message': 'test'}) + + def handler(message): + assert message['data'] == data + + pubsub = r_multi_db.pubsub() + + # Assign a handler and run in a separate thread. + pubsub.subscribe(**{'test-channel': handler}) + pubsub_thread = pubsub.run_in_thread(sleep_time=0.1, daemon=True) + thread.start() + + # Execute pipeline before network failure + while not event.is_set(): + r_multi_db.publish('test-channel', data) + sleep(0.1) + + # Execute pipeline after network failure + for _ in range(3): + r_multi_db.publish('test-channel', data) + sleep(0.1) + + pubsub_thread.stop() + + assert listener.is_changed_flag == True + + @pytest.mark.parametrize( + "r_multi_db", + [ + {"failure_threshold": 2} + ], + indirect=True + ) + def test_sharded_pubsub_failover_to_another_db(self, r_multi_db, fault_injector_client): + event = threading.Event() + thread = threading.Thread( + target=trigger_network_failure_action, + daemon=True, + args=(fault_injector_client,event) + ) + + r_multi_db, listener = r_multi_db + data = json.dumps({'message': 'test'}) + + def handler(message): + assert message['data'] == data + + pubsub = r_multi_db.pubsub() + + # Assign a handler and run in a separate thread. + pubsub.ssubscribe(**{'test-channel': handler}) + pubsub_thread = pubsub.run_in_thread(sleep_time=0.1, daemon=True) + thread.start() + + # Execute pipeline before network failure + while not event.is_set(): + r_multi_db.spublish('test-channel', data) + sleep(0.1) + + # Execute pipeline after network failure + for _ in range(3): + r_multi_db.spublish('test-channel', data) + sleep(0.1) + + pubsub_thread.stop() + + assert listener.is_changed_flag == True \ No newline at end of file