diff --git a/redis/client.py b/redis/client.py index d3ab3cfcfe..b0497c1e55 100755 --- a/redis/client.py +++ b/redis/client.py @@ -45,7 +45,7 @@ AfterPubSubConnectionInstantiationEvent, AfterSingleConnectionInstantiationEvent, ClientType, - EventDispatcher, + EventDispatcher, AfterCommandExecutionEvent, ) from redis.exceptions import ( ConnectionError, @@ -478,7 +478,8 @@ def pipeline(self, transaction=True, shard_hint=None) -> "Pipeline": between the client and server. """ return Pipeline( - self.connection_pool, self.response_callbacks, transaction, shard_hint + self.connection_pool, self.response_callbacks, transaction, shard_hint, + event_dispatcher=self._event_dispatcher ) def transaction( @@ -662,16 +663,42 @@ def _execute_command(self, *args, **options): command_name = args[0] conn = self.connection or pool.get_connection() + # Start timing for observability + start_time = time.monotonic() + if self._single_connection_client: self.single_connection_lock.acquire() try: - return conn.retry.call_with_retry( + result = conn.retry.call_with_retry( lambda: self._send_command_parse_response( conn, command_name, *args, **options ), lambda _: self._close_connection(conn), ) + self._event_dispatcher.dispatch( + AfterCommandExecutionEvent( + command_name=command_name, + duration_seconds=time.monotonic() - start_time, + server_address=conn.host, + server_port=conn.port, + db_namespace=str(conn.db), + ) + ) + return result + except Exception as e: + self._event_dispatcher.dispatch( + AfterCommandExecutionEvent( + command_name=command_name, + duration_seconds=time.monotonic() - start_time, + server_address=conn.host, + server_port=conn.port, + db_namespace=str(conn.db), + error=e, + ) + ) + raise + finally: if conn and conn.should_reconnect(): self._close_connection(conn) @@ -1385,6 +1412,7 @@ def __init__( response_callbacks, transaction, shard_hint, + event_dispatcher: EventDispatcher ): self.connection_pool = connection_pool self.connection: Optional[Connection] = None @@ -1395,6 +1423,7 @@ def __init__( self.command_stack = [] self.scripts: Set[Script] = set() self.explicit_transaction = False + self._event_dispatcher = event_dispatcher def __enter__(self) -> "Pipeline": return self @@ -1501,12 +1530,41 @@ def immediate_execute_command(self, *args, **options): conn = self.connection_pool.get_connection() self.connection = conn - return conn.retry.call_with_retry( - lambda: self._send_command_parse_response( - conn, command_name, *args, **options - ), - lambda error: self._disconnect_reset_raise_on_watching(conn, error), - ) + # Start timing for observability + start_time = time.monotonic() + + try: + response = conn.retry.call_with_retry( + lambda: self._send_command_parse_response( + conn, command_name, *args, **options + ), + lambda error: self._disconnect_reset_raise_on_watching(conn, error), + ) + + self._event_dispatcher.dispatch( + AfterCommandExecutionEvent( + command_name=command_name, + duration_seconds=time.monotonic() - start_time, + server_address=conn.host, + server_port=conn.port, + db_namespace=str(conn.db), + ) + ) + + return response + except Exception as e: + self._event_dispatcher.dispatch( + AfterCommandExecutionEvent( + command_name=command_name, + duration_seconds=time.monotonic() - start_time, + server_address=conn.host, + server_port=conn.port, + db_namespace=str(conn.db), + error=e, + ) + ) + raise + def pipeline_execute_command(self, *args, **options) -> "Pipeline": """ @@ -1679,8 +1737,10 @@ def execute(self, raise_on_error: bool = True) -> List[Any]: self.load_scripts() if self.transaction or self.explicit_transaction: execute = self._execute_transaction + operation_name = "MULTI" else: execute = self._execute_pipeline + operation_name = "PIPELINE" conn = self.connection if not conn: @@ -1689,11 +1749,40 @@ def execute(self, raise_on_error: bool = True) -> List[Any]: # back to the pool after we're done self.connection = conn + # Start timing for observability + start_time = time.monotonic() + try: - return conn.retry.call_with_retry( + response = conn.retry.call_with_retry( lambda: execute(conn, stack, raise_on_error), lambda error: self._disconnect_raise_on_watching(conn, error), ) + + self._event_dispatcher.dispatch( + AfterCommandExecutionEvent( + command_name=operation_name, + duration_seconds=time.monotonic() - start_time, + server_address=conn.host, + server_port=conn.port, + db_namespace=str(conn.db), + batch_size=len(stack), + ) + ) + return response + except Exception as e: + self._event_dispatcher.dispatch( + AfterCommandExecutionEvent( + command_name=operation_name, + duration_seconds=time.monotonic() - start_time, + server_address=conn.host, + server_port=conn.port, + db_namespace=str(conn.db), + error=e, + batch_size=len(stack), + ) + ) + raise + finally: # in reset() the connection is disconnected before returned to the pool if # it is marked for reconnect. diff --git a/redis/cluster.py b/redis/cluster.py index 8f42c1a235..a495e466b3 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -29,7 +29,7 @@ AfterPooledConnectionsInstantiationEvent, AfterPubSubConnectionInstantiationEvent, ClientType, - EventDispatcher, + EventDispatcher, AfterCommandExecutionEvent, ) from redis.exceptions import ( AskError, @@ -984,6 +984,7 @@ def pipeline(self, transaction=None, shard_hint=None): retry=self.retry, lock=self._lock, transaction=transaction, + event_dispatcher=self._event_dispatcher ) def lock( @@ -1367,6 +1368,9 @@ def _execute_command(self, target_node, *args, **kwargs): moved = False ttl = int(self.RedisClusterRequestTTL) + # Start timing for observability + start_time = time.monotonic() + while ttl > 0: ttl -= 1 try: @@ -1401,14 +1405,32 @@ def _execute_command(self, target_node, *args, **kwargs): response = self.cluster_response_callbacks[command]( response, **kwargs ) + + self._emit_after_command_execution_event( + command_name=command, + duration_seconds=time.monotonic() - start_time, + connection=connection, + ) return response - except AuthenticationError: + except AuthenticationError as e: + self._emit_after_command_execution_event( + command_name=command, + duration_seconds=time.monotonic() - start_time, + connection=connection, + error=e, + ) raise - except MaxConnectionsError: + except MaxConnectionsError as e: # MaxConnectionsError indicates client-side resource exhaustion # (too many connections in the pool), not a node failure. # Don't treat this as a node failure - just re-raise the error # without reinitializing the cluster. + self._emit_after_command_execution_event( + command_name=command, + duration_seconds=time.monotonic() - start_time, + connection=connection, + error=e, + ) raise except (ConnectionError, TimeoutError) as e: # ConnectionError can also be raised if we couldn't get a @@ -1423,6 +1445,12 @@ def _execute_command(self, target_node, *args, **kwargs): # Reset the cluster node's connection target_node.redis_connection = None self.nodes_manager.initialize() + self._emit_after_command_execution_event( + command_name=command, + duration_seconds=time.monotonic() - start_time, + connection=connection, + error=e, + ) raise e except MovedError as e: # First, we will try to patch the slots/nodes cache with the @@ -1441,13 +1469,33 @@ def _execute_command(self, target_node, *args, **kwargs): else: self.nodes_manager.update_moved_exception(e) moved = True - except TryAgainError: + self._emit_after_command_execution_event( + command_name=command, + duration_seconds=time.monotonic() - start_time, + connection=connection, + error=e, + ) + except TryAgainError as e: if ttl < self.RedisClusterRequestTTL / 2: time.sleep(0.05) + + self._emit_after_command_execution_event( + command_name=command, + duration_seconds=time.monotonic() - start_time, + connection=connection, + error=e, + ) except AskError as e: redirect_addr = get_node_name(host=e.host, port=e.port) asking = True - except (ClusterDownError, SlotNotCoveredError): + + self._emit_after_command_execution_event( + command_name=command, + duration_seconds=time.monotonic() - start_time, + connection=connection, + error=e, + ) + except (ClusterDownError, SlotNotCoveredError) as e: # ClusterDownError can occur during a failover and to get # self-healed, we will try to reinitialize the cluster layout # and retry executing the command @@ -1459,12 +1507,32 @@ def _execute_command(self, target_node, *args, **kwargs): time.sleep(0.25) self.nodes_manager.initialize() + + self._emit_after_command_execution_event( + command_name=command, + duration_seconds=time.monotonic() - start_time, + connection=connection, + error=e, + ) raise - except ResponseError: + except ResponseError as e: + self._emit_after_command_execution_event( + command_name=command, + duration_seconds=time.monotonic() - start_time, + connection=connection, + error=e, + ) raise except Exception as e: if connection: connection.disconnect() + + self._emit_after_command_execution_event( + command_name=command, + duration_seconds=time.monotonic() - start_time, + connection=connection, + error=e, + ) raise e finally: if connection is not None: @@ -1472,6 +1540,27 @@ def _execute_command(self, target_node, *args, **kwargs): raise ClusterError("TTL exhausted.") + def _emit_after_command_execution_event( + self, + command_name: str, + duration_seconds: float, + connection: Connection, + error=None + ): + """ + Triggers AfterCommandExecutionEvent emit. + """ + self._event_dispatcher.dispatch( + AfterCommandExecutionEvent( + command_name=command_name, + duration_seconds=duration_seconds, + server_address=connection.host, + server_port=connection.port, + db_namespace=str(connection.db), + error=error, + ) + ) + def close(self) -> None: try: with self._lock: @@ -2326,6 +2415,7 @@ def __init__( lock=None, transaction=False, policy_resolver: PolicyResolver = StaticPolicyResolver(), + event_dispatcher: Optional["EventDispatcher"] = None, **kwargs, ): """ """ @@ -2395,6 +2485,11 @@ def __init__( self._policy_resolver = policy_resolver + if event_dispatcher is None: + self._event_dispatcher = EventDispatcher() + else: + self._event_dispatcher = event_dispatcher + def __repr__(self): """ """ return f"{type(self).__name__}" @@ -2889,7 +2984,7 @@ def __init__(self, pipe: ClusterPipeline): def execute_command(self, *args, **kwargs): return self.pipeline_execute_command(*args, **kwargs) - def _raise_first_error(self, stack): + def _raise_first_error(self, stack, start_time): """ Raise the first exception on the stack """ @@ -2897,6 +2992,16 @@ def _raise_first_error(self, stack): r = c.result if isinstance(r, Exception): self.annotate_exception(r, c.position + 1, c.args) + + self._pipe._event_dispatcher.dispatch( + AfterCommandExecutionEvent( + command_name="PIPELINE", + duration_seconds=time.monotonic() - start_time, + batch_size=len(stack), + error=r, + ) + ) + raise r def execute(self, raise_on_error: bool = True) -> List[Any]: @@ -3076,6 +3181,10 @@ def _send_cluster_commands( # so that we can read them from different sockets as they come back. # we dont' multiplex on the sockets as they come available, # but that shouldn't make too much difference. + + # Start timing for observability + start_time = time.monotonic() + try: node_commands = nodes.values() for n in node_commands: @@ -3083,6 +3192,17 @@ def _send_cluster_commands( for n in node_commands: n.read() + + self._pipe._event_dispatcher.dispatch( + AfterCommandExecutionEvent( + command_name="PIPELINE", + duration_seconds=time.monotonic() - start_time, + server_address=n.connection.host, + server_port=n.connection.port, + db_namespace=str(n.connection.db), + batch_size=len(n.commands), + ) + ) finally: # release all of the redis connections we allocated earlier # back into the connection pool. @@ -3168,7 +3288,7 @@ def _send_cluster_commands( response.append(c.result) if raise_on_error: - self._raise_first_error(stack) + self._raise_first_error(stack, start_time) return response @@ -3373,9 +3493,38 @@ def _immediate_execute_command(self, *args, **options): def _get_connection_and_send_command(self, *args, **options): redis_node, connection = self._get_client_and_connection_for_transaction() - return self._send_command_parse_response( - connection, redis_node, args[0], *args, **options - ) + + # Start timing for observability + start_time = time.monotonic() + + try: + response = self._send_command_parse_response( + connection, redis_node, args[0], *args, **options + ) + + self._pipe._event_dispatcher.dispatch( + AfterCommandExecutionEvent( + command_name=args[0], + duration_seconds=time.monotonic() - start_time, + server_address=connection.host, + server_port=connection.port, + db_namespace=str(connection.db), + ) + ) + + return response + except Exception as e: + self._pipe._event_dispatcher.dispatch( + AfterCommandExecutionEvent( + command_name=args[0], + duration_seconds=time.monotonic() - start_time, + server_address=connection.host, + server_port=connection.port, + db_namespace=str(connection.db), + error=e, + ) + ) + raise def _send_command_parse_response( self, conn, redis_node: Redis, command_name, *args, **options @@ -3413,13 +3562,24 @@ def _reinitialize_on_error(self, error): self._executing = False - def _raise_first_error(self, responses, stack): + def _raise_first_error(self, responses, stack, start_time): """ Raise the first exception on the stack """ for r, cmd in zip(responses, stack): if isinstance(r, Exception): self.annotate_exception(r, cmd.position + 1, cmd.args) + + self._pipe._event_dispatcher.dispatch( + AfterCommandExecutionEvent( + command_name='TRANSACTION', + duration_seconds=time.monotonic() - start_time, + server_address=self._transaction_connection.host, + server_port=self._transaction_connection.port, + db_namespace=str(self._transaction_connection.db), + ) + ) + raise r def execute(self, raise_on_error: bool = True) -> List[Any]: @@ -3456,6 +3616,10 @@ def _execute_transaction( ) commands = [c.args for c in stack if EMPTY_RESPONSE not in c.options] packed_commands = connection.pack_commands(commands) + + # Start timing for observability + start_time = time.monotonic() + connection.send_packed_command(packed_commands) errors = [] @@ -3500,6 +3664,17 @@ def _execute_transaction( self._executing = False + self._pipe._event_dispatcher.dispatch( + AfterCommandExecutionEvent( + command_name='TRANSACTION', + duration_seconds=time.monotonic() - start_time, + server_address=connection.host, + server_port=connection.port, + db_namespace=str(connection.db), + batch_size=len(self._command_queue), + ) + ) + # EXEC clears any watched keys self._watching = False @@ -3523,6 +3698,7 @@ def _execute_transaction( self._raise_first_error( response, self._command_queue, + start_time, ) # We have to run response callbacks manually diff --git a/redis/event.py b/redis/event.py index 03c72c6370..0ea9e342ad 100644 --- a/redis/event.py +++ b/redis/event.py @@ -1,11 +1,13 @@ import asyncio import threading from abc import ABC, abstractmethod +from dataclasses import dataclass from enum import Enum from typing import Dict, List, Optional, Type, Union from redis.auth.token import TokenInterface from redis.credentials import CredentialProvider, StreamingCredentialProvider +from redis.observability.recorder import record_operation_duration class EventListenerInterface(ABC): @@ -90,6 +92,7 @@ def __init__( ], AfterPubSubConnectionInstantiationEvent: [RegisterReAuthForPubSub()], AfterAsyncClusterInstantiationEvent: [RegisterReAuthForAsyncClusterNodes()], + AfterCommandExecutionEvent: [ExportOperationDurationMetric()], AsyncAfterConnectionReleasedEvent: [ AsyncReAuthConnectionListener(), ], @@ -295,6 +298,20 @@ def commands(self) -> tuple: def exception(self) -> Exception: return self._exception +@dataclass +class AfterCommandExecutionEvent: + """ + Event fired after command execution. + """ + command_name: str + duration_seconds: float + server_address: Optional[str] = None + server_port: Optional[int] = None + db_namespace: Optional[str] = None + error: Optional[Exception] = None + is_blocking: Optional[bool] = None + batch_size: Optional[int] = None + retry_attempts: Optional[int] = None class AsyncOnCommandsFailEvent(OnCommandsFailEvent): pass @@ -466,3 +483,20 @@ def _raise_on_error(self, error: Exception): async def _raise_on_error_async(self, error: Exception): raise EventException(error, self._event) + +class ExportOperationDurationMetric(EventListenerInterface): + """ + Listener that exports operation duration metric after command execution. + """ + def listen(self, event: AfterCommandExecutionEvent): + record_operation_duration( + command_name=event.command_name, + duration_seconds=event.duration_seconds, + server_address=event.server_address, + server_port=event.server_port, + db_namespace=event.db_namespace, + error=event.error, + is_blocking=event.is_blocking, + batch_size=event.batch_size, + retry_attempts=event.retry_attempts, + ) diff --git a/redis/exceptions.py b/redis/exceptions.py index dab17c5c1f..c01c27720a 100644 --- a/redis/exceptions.py +++ b/redis/exceptions.py @@ -1,28 +1,52 @@ +from enum import Enum + "Core exceptions raised by the Redis client" +class ExceptionType(Enum): + NETWORK = 'network' + TLS = 'tls' + AUTH = 'auth' + SERVER = 'server' + + class RedisError(Exception): - pass + def __init__(self, *args): + super().__init__(*args) + self.error_type = ExceptionType.SERVER + + def __repr__(self): + return f"{self.error_type.value}:{self.__class__.__name__}" class ConnectionError(RedisError): - pass + def __init__(self, *args): + super().__init__(*args) + self.error_type = ExceptionType.NETWORK class TimeoutError(RedisError): - pass + def __init__(self, *args): + super().__init__(*args) + self.error_type = ExceptionType.NETWORK class AuthenticationError(ConnectionError): - pass + def __init__(self, *args): + super().__init__(*args) + self.error_type = ExceptionType.AUTH class AuthorizationError(ConnectionError): - pass + def __init__(self, *args): + super().__init__(*args) + self.error_type = ExceptionType.AUTH class BusyLoadingError(ConnectionError): - pass + def __init__(self, *args): + super().__init__(*args) + self.error_type = ExceptionType.NETWORK class InvalidResponse(RedisError): @@ -70,7 +94,9 @@ class ReadOnlyError(ResponseError): class NoPermissionError(ResponseError): - pass + def __init__(self, *args): + super().__init__(*args) + self.error_type = ExceptionType.AUTH class ModuleError(ResponseError): @@ -84,6 +110,7 @@ class LockError(RedisError, ValueError): # This was originally chosen to behave like threading.Lock. def __init__(self, message=None, lock_name=None): + super().__init__(message) self.message = message self.lock_name = lock_name @@ -106,7 +133,9 @@ class AuthenticationWrongNumberOfArgsError(ResponseError): were sent to the AUTH command """ - pass + def __init__(self, *args): + super().__init__(*args) + self.error_type = ExceptionType.AUTH class RedisClusterException(Exception): @@ -114,7 +143,12 @@ class RedisClusterException(Exception): Base exception for the RedisCluster client """ - pass + def __init__(self, *args): + super().__init__(*args) + self.error_type = ExceptionType.SERVER + + def __repr__(self): + return f"{self.error_type.value}:{self.__class__.__name__}" class ClusterError(RedisError): @@ -123,7 +157,9 @@ class ClusterError(RedisError): command execution TTL """ - pass + def __init__(self, *args): + super().__init__(*args) + self.error_type = ExceptionType.SERVER class ClusterDownError(ClusterError, ResponseError): @@ -140,6 +176,7 @@ class ClusterDownError(ClusterError, ResponseError): def __init__(self, resp): self.args = (resp,) self.message = resp + self.error_type = ExceptionType.SERVER class AskError(ResponseError): @@ -160,6 +197,7 @@ class AskError(ResponseError): def __init__(self, resp): """should only redirect to master node""" + super().__init__(resp) self.args = (resp,) self.message = resp slot_id, new_node = resp.split(" ") @@ -176,7 +214,7 @@ class TryAgainError(ResponseError): """ def __init__(self, *args, **kwargs): - pass + super().__init__(*args) class ClusterCrossSlotError(ResponseError): @@ -188,6 +226,10 @@ class ClusterCrossSlotError(ResponseError): message = "Keys in request don't hash to the same slot" + def __init__(self, *args): + super().__init__(*args) + self.error_type = ExceptionType.SERVER + class MovedError(AskError): """ diff --git a/redis/observability/attributes.py b/redis/observability/attributes.py index 59fd3310ac..79244c3002 100644 --- a/redis/observability/attributes.py +++ b/redis/observability/attributes.py @@ -185,7 +185,6 @@ def build_connection_attributes( @staticmethod def build_error_attributes( is_internal: bool = False, - error_type: Optional[Exception] = None, ) -> Dict[str, Any]: """ Build error attributes. @@ -198,10 +197,6 @@ def build_error_attributes( Dictionary of error attributes """ attrs: Dict[str, Any] = {REDIS_CLIENT_ERROR_INTERNAL: is_internal} - - if error_type is not None: - attrs[DB_RESPONSE_STATUS_CODE] = None - return attrs @staticmethod @@ -274,7 +269,11 @@ def extract_error_type(exception: Exception) -> str: Returns: Error type string (exception class name) """ - return type(exception).__name__ + + if hasattr(exception, "error_type"): + return repr(exception) + else: + return f"other:{type(exception).__name__}" @staticmethod def build_pool_name( diff --git a/redis/observability/config.py b/redis/observability/config.py index e57a0a3005..1b10cf769e 100644 --- a/redis/observability/config.py +++ b/redis/observability/config.py @@ -16,8 +16,6 @@ class MetricGroup(IntFlag): class TelemetryOption(IntFlag): """Telemetry options to export.""" METRICS = auto() - TRACES = auto() - LOGS = auto() """ @@ -165,5 +163,5 @@ def should_track_command(self, command_name: str) -> bool: def __repr__(self) -> str: return ( - f"OTelConfig(enabled_telemetry={self.enabled_telemetry}, " + f"OTelConfig(enabled_telemetry={self.enabled_telemetry}" ) \ No newline at end of file diff --git a/redis/observability/metrics.py b/redis/observability/metrics.py index 7ffc1e51d8..55f21121f8 100644 --- a/redis/observability/metrics.py +++ b/redis/observability/metrics.py @@ -44,7 +44,7 @@ class RedisMetricsCollector: METER_NAME = "redis-py" METER_VERSION = "1.0.0" - def __init__(self, meter: "Meter", config: OTelConfig): + def __init__(self, meter: Meter, config: OTelConfig): if not OTEL_AVAILABLE: raise ImportError( "OpenTelemetry API is not installed. " @@ -207,9 +207,7 @@ def record_error_count( ) attrs.update( - self.attr_builder.build_error_attributes( - error_type=error_type, - ) + self.attr_builder.build_error_attributes() ) self.client_errors.add(1, attributes=attrs) @@ -406,11 +404,8 @@ def record_operation_duration( ) attrs.update( - self.attr_builder.build_error_attributes( - error_type=error_type, - ) + self.attr_builder.build_error_attributes() ) - self.operation_duration.record(duration_seconds, attributes=attrs) def record_connection_closed( diff --git a/redis/observability/recorder.py b/redis/observability/recorder.py index fee2d196bd..9bdbdcdaac 100644 --- a/redis/observability/recorder.py +++ b/redis/observability/recorder.py @@ -24,6 +24,7 @@ from redis.observability.attributes import PubSubDirection, ConnectionState from redis.observability.metrics import RedisMetricsCollector +from redis.observability.providers import get_observability_instance # Global metrics collector instance (lazy-initialized) _metrics_collector: Optional[RedisMetricsCollector] = None @@ -36,6 +37,9 @@ def record_operation_duration( server_port: Optional[int] = None, db_namespace: Optional[str] = None, error: Optional[Exception] = None, + is_blocking: Optional[bool] = None, + batch_size: Optional[int] = None, + retry_attempts: Optional[int] = None, ) -> None: """ Record a Redis command execution duration. @@ -50,6 +54,9 @@ def record_operation_duration( server_port: Redis server port db_namespace: Redis database index error: Exception if command failed, None if successful + is_blocking: Whether the operation is a blocking command + batch_size: Number of commands in batch (for pipelines/transactions) + retry_attempts: Number of retry attempts made Example: >>> start = time.monotonic() @@ -82,6 +89,9 @@ def record_operation_duration( response_status_code=status_code, network_peer_address=server_address, network_peer_port=server_port, + is_blocking=is_blocking, + batch_size=batch_size, + retry_attempts=retry_attempts, ) # except Exception: # # Don't let metric recording errors break Redis operations @@ -455,11 +465,8 @@ def _get_or_create_collector() -> Optional[RedisMetricsCollector]: RedisMetricsCollector instance if observability is enabled, None otherwise """ try: - from redis.observability.providers import get_provider_manager - from redis.observability.metrics import RedisMetricsCollector - - manager = get_provider_manager() - if manager is None or not manager.config.enable_metrics: + manager = get_observability_instance().get_provider_manager() + if manager is None or not manager.config.enabled_telemetry: return None # Get meter from the global MeterProvider diff --git a/tests/test_client.py b/tests/test_client.py new file mode 100644 index 0000000000..6cd3de39f6 --- /dev/null +++ b/tests/test_client.py @@ -0,0 +1,347 @@ +from unittest import mock + +import pytest + +import redis +from redis.event import EventDispatcher +from redis.observability import recorder +from redis.observability.config import OTelConfig, MetricGroup +from redis.observability.metrics import RedisMetricsCollector + + +class TestRedisClientEventEmission: + """ + Unit tests that verify AfterCommandExecutionEvent is properly emitted from Redis client + and delivered to the Meter through the event dispatcher chain. + + These tests use fully mocked connection and connection pool - no real Redis + or OTel integration is used. + """ + + @pytest.fixture + def mock_connection(self): + """Create a mock connection with required attributes.""" + conn = mock.MagicMock() + conn.host = 'localhost' + conn.port = 6379 + conn.db = 0 + conn.should_reconnect.return_value = False + + # Mock retry to just execute the function directly + conn.retry.call_with_retry = lambda func, _: func() + + return conn + + @pytest.fixture + def mock_connection_pool(self, mock_connection): + """Create a mock connection pool.""" + pool = mock.MagicMock() + pool.get_connection.return_value = mock_connection + pool.get_encoder.return_value = mock.MagicMock() + return pool + + @pytest.fixture + def mock_meter(self): + """Create a mock Meter that tracks all instrument calls.""" + meter = mock.MagicMock() + + # Create mock histogram for operation duration + self.operation_duration = mock.MagicMock() + + def create_histogram_side_effect(name, **kwargs): + if name == 'db.client.operation.duration': + return self.operation_duration + return mock.MagicMock() + + meter.create_counter.return_value = mock.MagicMock() + meter.create_up_down_counter.return_value = mock.MagicMock() + meter.create_histogram.side_effect = create_histogram_side_effect + + return meter + + @pytest.fixture + def setup_redis_client_with_otel( + self, mock_connection_pool, mock_connection, mock_meter + ): + """ + Setup a Redis client with mocked connection and OTel collector. + Returns tuple of (redis_client, operation_duration_mock). + """ + + # Reset any existing collector state + recorder.reset_collector() + + # Create config with COMMAND group enabled + config = OTelConfig(metric_groups=[MetricGroup.COMMAND]) + + # Create collector with mocked meter + with mock.patch('redis.observability.metrics.OTEL_AVAILABLE', True): + collector = RedisMetricsCollector(mock_meter, config) + + # Patch the recorder to use our collector + with mock.patch.object( + recorder, + '_get_or_create_collector', + return_value=collector + ): + # Create event dispatcher (real one, to test the full chain) + event_dispatcher = EventDispatcher() + + # Create Redis client with mocked connection pool + client = redis.Redis( + connection_pool=mock_connection_pool, + event_dispatcher=event_dispatcher, + ) + + yield client, self.operation_duration + + # Cleanup + recorder.reset_collector() + + def test_execute_command_emits_event_to_meter(self, setup_redis_client_with_otel): + """ + Test that executing a command emits AfterCommandExecutionEvent + which is delivered to the Meter's histogram.record() method. + """ + client, operation_duration_mock = setup_redis_client_with_otel + + # Mock _send_command_parse_response to return a successful response + client._send_command_parse_response = mock.MagicMock(return_value=True) + + # Execute a command + client.execute_command('SET', 'key1', 'value1') + + # Verify the Meter's histogram.record() was called + operation_duration_mock.record.assert_called_once() + + # Get the call arguments + call_args = operation_duration_mock.record.call_args + + # Verify duration was recorded (first positional arg) + duration = call_args[0][0] + assert isinstance(duration, float) + assert duration >= 0 + + # Verify attributes + attrs = call_args[1]['attributes'] + assert attrs['db.operation.name'] == 'SET' + assert attrs['server.address'] == 'localhost' + assert attrs['server.port'] == 6379 + assert attrs['db.namespace'] == '0' + + def test_get_command_emits_event_to_meter( + self, mock_connection_pool, mock_connection, mock_meter + ): + """ + Test that GET command emits AfterCommandExecutionEvent with correct command name. + """ + + recorder.reset_collector() + config = OTelConfig(metric_groups=[MetricGroup.COMMAND]) + + with mock.patch('redis.observability.metrics.OTEL_AVAILABLE', True): + collector = RedisMetricsCollector(mock_meter, config) + + with mock.patch.object( + recorder, '_get_or_create_collector', return_value=collector + ): + event_dispatcher = EventDispatcher() + + client = redis.Redis( + connection_pool=mock_connection_pool, + event_dispatcher=event_dispatcher, + ) + + client._send_command_parse_response = mock.MagicMock(return_value=b'value1') + + # Execute GET command + client.execute_command('GET', 'key1') + + # Verify command name is GET + call_args = self.operation_duration.record.call_args + attrs = call_args[1]['attributes'] + assert attrs['db.operation.name'] == 'GET' + + recorder.reset_collector() + + def test_command_error_emits_event_with_error( + self, mock_connection_pool, mock_connection, mock_meter + ): + """ + Test that when a command execution raises an exception, + AfterCommandExecutionEvent is still emitted with error information. + """ + + recorder.reset_collector() + config = OTelConfig(metric_groups=[MetricGroup.COMMAND]) + + with mock.patch('redis.observability.metrics.OTEL_AVAILABLE', True): + collector = RedisMetricsCollector(mock_meter, config) + + with mock.patch.object( + recorder, '_get_or_create_collector', return_value=collector + ): + event_dispatcher = EventDispatcher() + + client = redis.Redis( + connection_pool=mock_connection_pool, + event_dispatcher=event_dispatcher, + ) + + # Make command raise an exception + test_error = redis.ResponseError("WRONGTYPE Operation error") + client._send_command_parse_response = mock.MagicMock(side_effect=test_error) + + # Execute should raise the error + with pytest.raises(redis.ResponseError): + client.execute_command('LPUSH', 'string_key', 'value') + + # Verify the Meter's histogram.record() was still called + self.operation_duration.record.assert_called_once() + + # Verify error type is recorded in attributes + call_args = self.operation_duration.record.call_args + attrs = call_args[1]['attributes'] + assert attrs['db.operation.name'] == 'LPUSH' + assert 'error.type' in attrs + + recorder.reset_collector() + + def test_server_attributes_recorded_correctly(self, setup_redis_client_with_otel): + """ + Test that server address, port, and db namespace are correctly recorded. + """ + client, operation_duration_mock = setup_redis_client_with_otel + + client._send_command_parse_response = mock.MagicMock(return_value=b'PONG') + + client.execute_command('PING') + + call_args = operation_duration_mock.record.call_args + attrs = call_args[1]['attributes'] + + # Verify server attributes match mock connection + assert attrs['server.address'] == 'localhost' + assert attrs['server.port'] == 6379 + assert attrs['db.namespace'] == '0' + + def test_multiple_commands_emit_multiple_events( + self, mock_connection_pool, mock_connection, mock_meter + ): + """ + Test that each command execution emits a separate event to the Meter. + """ + + recorder.reset_collector() + config = OTelConfig(metric_groups=[MetricGroup.COMMAND]) + + with mock.patch('redis.observability.metrics.OTEL_AVAILABLE', True): + collector = RedisMetricsCollector(mock_meter, config) + + with mock.patch.object( + recorder, '_get_or_create_collector', return_value=collector + ): + event_dispatcher = EventDispatcher() + + client = redis.Redis( + connection_pool=mock_connection_pool, + event_dispatcher=event_dispatcher, + ) + + client._send_command_parse_response = mock.MagicMock(return_value=True) + + # Execute multiple commands + client.execute_command('SET', 'key1', 'value1') + client.execute_command('SET', 'key2', 'value2') + client.execute_command('GET', 'key1') + + # Verify histogram.record() was called three times + assert self.operation_duration.record.call_count == 3 + + # Verify command names in order + calls = self.operation_duration.record.call_args_list + assert calls[0][1]['attributes']['db.operation.name'] == 'SET' + assert calls[1][1]['attributes']['db.operation.name'] == 'SET' + assert calls[2][1]['attributes']['db.operation.name'] == 'GET' + + recorder.reset_collector() + + def test_different_db_namespace_recorded( + self, mock_connection_pool, mock_meter + ): + """ + Test that different db namespace values are correctly recorded. + """ + + # Create connection with different db + mock_connection = mock.MagicMock() + mock_connection.host = 'redis.example.com' + mock_connection.port = 6380 + mock_connection.db = 5 + mock_connection.should_reconnect.return_value = False + mock_connection.retry.call_with_retry = lambda func, _: func() + + mock_connection_pool.get_connection.return_value = mock_connection + + recorder.reset_collector() + config = OTelConfig(metric_groups=[MetricGroup.COMMAND]) + + with mock.patch('redis.observability.metrics.OTEL_AVAILABLE', True): + collector = RedisMetricsCollector(mock_meter, config) + + with mock.patch.object( + recorder, '_get_or_create_collector', return_value=collector + ): + event_dispatcher = EventDispatcher() + + client = redis.Redis( + connection_pool=mock_connection_pool, + event_dispatcher=event_dispatcher, + ) + + client._send_command_parse_response = mock.MagicMock(return_value=True) + + client.execute_command('SET', 'key', 'value') + + call_args = self.operation_duration.record.call_args + attrs = call_args[1]['attributes'] + + # Verify different server attributes + assert attrs['server.address'] == 'redis.example.com' + assert attrs['server.port'] == 6380 + assert attrs['db.namespace'] == '5' + + recorder.reset_collector() + + def test_duration_is_positive(self, setup_redis_client_with_otel): + """ + Test that the recorded duration is a positive float value. + """ + client, operation_duration_mock = setup_redis_client_with_otel + + client._send_command_parse_response = mock.MagicMock(return_value=True) + + client.execute_command('SET', 'key', 'value') + + call_args = operation_duration_mock.record.call_args + duration = call_args[0][0] + + assert isinstance(duration, float) + assert duration >= 0 + + def test_no_batch_size_for_single_command(self, setup_redis_client_with_otel): + """ + Test that single commands do not include batch_size attribute + (batch_size is only for pipeline operations). + """ + client, operation_duration_mock = setup_redis_client_with_otel + + client._send_command_parse_response = mock.MagicMock(return_value=True) + + client.execute_command('SET', 'key', 'value') + + call_args = operation_duration_mock.record.call_args + attrs = call_args[1]['attributes'] + + # batch_size should not be present for single commands + assert 'db.operation.batch_size' not in attrs \ No newline at end of file diff --git a/tests/test_cluster.py b/tests/test_cluster.py index a6cfcd2d94..b605be5554 100644 --- a/tests/test_cluster.py +++ b/tests/test_cluster.py @@ -31,6 +31,7 @@ ) from redis.connection import BlockingConnectionPool, Connection, ConnectionPool from redis.crc import key_slot +from redis.event import EventDispatcher from redis.exceptions import ( AskError, ClusterDownError, @@ -43,6 +44,9 @@ ResponseError, TimeoutError, ) +from redis.observability import recorder +from redis.observability.config import OTelConfig, MetricGroup +from redis.observability.metrics import RedisMetricsCollector from redis.retry import Retry from redis.utils import str_if_bytes from tests.test_pubsub import wait_for_message @@ -3687,3 +3691,403 @@ def test_flush(self, r): r.flushall() assert r.get("x") is None assert r.get("y") is None + + +@pytest.mark.onlycluster +class TestClusterEventEmission: + """ + Integration tests that verify AfterCommandExecutionEvent is properly emitted + from RedisCluster and delivered to the Meter through the event dispatcher chain. + + These tests use a real Redis cluster connection but mock the OTel Meter + to verify events are correctly emitted. + """ + + @pytest.fixture + def mock_meter(self): + """Create a mock Meter that tracks all instrument calls.""" + meter = Mock() + + # Create mock histogram for operation duration + self.operation_duration = Mock() + + def create_histogram_side_effect(name, **kwargs): + if name == 'db.client.operation.duration': + return self.operation_duration + return Mock() + + meter.create_counter.return_value = Mock() + meter.create_up_down_counter.return_value = Mock() + meter.create_histogram.side_effect = create_histogram_side_effect + + return meter + + @pytest.fixture + def cluster_with_otel(self, r, mock_meter): + """ + Setup a RedisCluster with real connection and mocked OTel collector. + Returns tuple of (cluster, operation_duration_mock). + """ + + # Reset any existing collector state + recorder.reset_collector() + + # Create config with COMMAND group enabled + config = OTelConfig(metric_groups=[MetricGroup.COMMAND]) + + # Create collector with mocked meter + with patch('redis.observability.metrics.OTEL_AVAILABLE', True): + collector = RedisMetricsCollector(mock_meter, config) + + # Patch the recorder to use our collector + with patch.object( + recorder, + '_get_or_create_collector', + return_value=collector + ): + # Create a new event dispatcher and attach it to the cluster + event_dispatcher = EventDispatcher() + r._event_dispatcher = event_dispatcher + + yield r, self.operation_duration + + # Cleanup + recorder.reset_collector() + + def test_execute_command_emits_event_to_meter(self, cluster_with_otel): + """ + Test that execute_command emits AfterCommandExecutionEvent to Meter. + """ + cluster, operation_duration_mock = cluster_with_otel + + # Execute a command + cluster.set('test_key', 'test_value') + + # Verify the Meter's histogram.record() was called + operation_duration_mock.record.assert_called() + + # Get the last call arguments + call_args = operation_duration_mock.record.call_args + + # Verify duration was recorded + duration = call_args[0][0] + assert isinstance(duration, float) + assert duration >= 0 + + # Verify attributes + attrs = call_args[1]['attributes'] + assert attrs['db.operation.name'] == 'SET' + assert 'server.address' in attrs + assert 'server.port' in attrs + assert 'db.namespace' in attrs + + def test_get_command_emits_event_to_meter(self, cluster_with_otel): + """ + Test that GET command emits event with correct command name. + """ + cluster, operation_duration_mock = cluster_with_otel + + # Execute GET command + cluster.get('test_key') + + # Verify command name is GET + call_args = operation_duration_mock.record.call_args + attrs = call_args[1]['attributes'] + assert attrs['db.operation.name'] == 'GET' + + def test_multiple_commands_emit_multiple_events(self, cluster_with_otel): + """ + Test that multiple command executions emit multiple events. + """ + cluster, operation_duration_mock = cluster_with_otel + + # Execute multiple commands + cluster.set('key1', 'value1') + cluster.get('key1') + cluster.delete('key1') + + # Verify histogram.record() was called 3 times + assert operation_duration_mock.record.call_count == 3 + + def test_server_attributes_recorded(self, cluster_with_otel): + """ + Test that server address, port, and db namespace are recorded. + """ + cluster, operation_duration_mock = cluster_with_otel + + cluster.ping() + + call_args = operation_duration_mock.record.call_args + attrs = call_args[1]['attributes'] + + # Verify server attributes are present and have valid values + assert 'server.address' in attrs + assert isinstance(attrs['server.address'], str) + assert len(attrs['server.address']) > 0 + + assert 'server.port' in attrs + assert isinstance(attrs['server.port'], int) + assert attrs['server.port'] > 0 + + assert 'db.namespace' in attrs + + def test_duration_is_positive(self, cluster_with_otel): + """ + Test that the recorded duration is a positive float. + """ + cluster, operation_duration_mock = cluster_with_otel + + cluster.set('duration_test', 'value') + + call_args = operation_duration_mock.record.call_args + duration = call_args[0][0] + + assert isinstance(duration, float) + assert duration >= 0 + + def test_no_batch_size_for_single_command(self, cluster_with_otel): + """ + Test that single commands don't include batch_size attribute. + """ + cluster, operation_duration_mock = cluster_with_otel + + cluster.get('single_command_key') + + call_args = operation_duration_mock.record.call_args + attrs = call_args[1]['attributes'] + + # batch_size should not be present for single commands + assert 'db.operation.batch_size' not in attrs + + def test_different_commands_emit_correct_names(self, cluster_with_otel): + """ + Test that different commands emit events with correct command names. + """ + cluster, operation_duration_mock = cluster_with_otel + + commands_to_test = [ + ('SET', lambda: cluster.set('cmd_test', 'value')), + ('GET', lambda: cluster.get('cmd_test')), + ('DEL', lambda: cluster.delete('cmd_test')), + ('PING', lambda: cluster.ping()), + ] + + for expected_cmd, cmd_func in commands_to_test: + operation_duration_mock.reset_mock() + cmd_func() + + call_args = operation_duration_mock.record.call_args + attrs = call_args[1]['attributes'] + assert attrs['db.operation.name'] == expected_cmd + + +@pytest.mark.onlycluster +class TestClusterPipelineEventEmission: + """ + Integration tests that verify AfterCommandExecutionEvent is properly emitted + from ClusterPipeline and delivered to the Meter through the event dispatcher chain. + + These tests use a real Redis cluster connection but mock the OTel Meter + to verify events are correctly emitted. + """ + + @pytest.fixture + def mock_meter(self): + """Create a mock Meter that tracks all instrument calls.""" + meter = Mock() + + # Create mock histogram for operation duration + self.operation_duration = Mock() + + def create_histogram_side_effect(name, **kwargs): + if name == 'db.client.operation.duration': + return self.operation_duration + return Mock() + + meter.create_counter.return_value = Mock() + meter.create_up_down_counter.return_value = Mock() + meter.create_histogram.side_effect = create_histogram_side_effect + + return meter + + @pytest.fixture + def cluster_pipeline_with_otel(self, r, mock_meter): + """ + Setup a ClusterPipeline with real connection and mocked OTel collector. + Returns tuple of (cluster, operation_duration_mock). + """ + + # Reset any existing collector state + recorder.reset_collector() + + # Create config with COMMAND group enabled + config = OTelConfig(metric_groups=[MetricGroup.COMMAND]) + + # Create collector with mocked meter + with patch('redis.observability.metrics.OTEL_AVAILABLE', True): + collector = RedisMetricsCollector(mock_meter, config) + + # Patch the recorder to use our collector + with patch.object( + recorder, + '_get_or_create_collector', + return_value=collector + ): + # Create a new event dispatcher and attach it to the cluster + event_dispatcher = EventDispatcher() + r._event_dispatcher = event_dispatcher + + yield r, self.operation_duration + + # Cleanup + recorder.reset_collector() + + def test_pipeline_execute_emits_event_to_meter(self, cluster_pipeline_with_otel): + """ + Test that pipeline execute emits AfterCommandExecutionEvent to Meter. + """ + cluster, operation_duration_mock = cluster_pipeline_with_otel + + # Execute a pipeline + pipe = cluster.pipeline() + pipe.set('pipe_key1', 'value1') + pipe.get('pipe_key1') + pipe.execute() + + # Verify the Meter's histogram.record() was called + operation_duration_mock.record.assert_called() + + # Get the last call arguments (pipeline event) + call_args = operation_duration_mock.record.call_args + + # Verify duration was recorded + duration = call_args[0][0] + assert isinstance(duration, float) + assert duration >= 0 + + # Verify attributes + attrs = call_args[1]['attributes'] + assert attrs['db.operation.name'] == 'PIPELINE' + + def test_pipeline_batch_size_recorded(self, cluster_pipeline_with_otel): + """ + Test that pipeline batch_size is correctly recorded. + """ + cluster, operation_duration_mock = cluster_pipeline_with_otel + + # Execute a pipeline with 3 commands + pipe = cluster.pipeline() + pipe.set('batch_key', 'value1') + pipe.get('batch_key') + pipe.delete('batch_key') + pipe.execute() + + # Find the PIPELINE event call + pipeline_call = None + for call_obj in operation_duration_mock.record.call_args_list: + attrs = call_obj[1]['attributes'] + if attrs.get('db.operation.name') == 'PIPELINE': + pipeline_call = call_obj + break + + assert pipeline_call is not None + attrs = pipeline_call[1]['attributes'] + assert 'db.operation.batch.size' in attrs + assert attrs['db.operation.batch.size'] == 3 + + def test_pipeline_server_attributes_recorded(self, cluster_pipeline_with_otel): + """ + Test that server address, port, and db namespace are recorded for pipeline. + """ + cluster, operation_duration_mock = cluster_pipeline_with_otel + + pipe = cluster.pipeline() + pipe.set('server_attr_key', 'value') + pipe.execute() + + # Find the PIPELINE event call + pipeline_call = None + for call_obj in operation_duration_mock.record.call_args_list: + attrs = call_obj[1]['attributes'] + if attrs.get('db.operation.name') == 'PIPELINE': + pipeline_call = call_obj + break + + assert pipeline_call is not None + attrs = pipeline_call[1]['attributes'] + + # Verify server attributes are present + assert 'server.address' in attrs + assert isinstance(attrs['server.address'], str) + + assert 'server.port' in attrs + assert isinstance(attrs['server.port'], int) + + assert 'db.namespace' in attrs + + def test_pipeline_duration_is_positive(self, cluster_pipeline_with_otel): + """ + Test that the recorded duration for pipeline is a positive float. + """ + cluster, operation_duration_mock = cluster_pipeline_with_otel + + pipe = cluster.pipeline() + pipe.set('duration_key', 'value') + pipe.execute() + + # Find the PIPELINE event call + pipeline_call = None + for call_obj in operation_duration_mock.record.call_args_list: + attrs = call_obj[1]['attributes'] + if attrs.get('db.operation.name') == 'PIPELINE': + pipeline_call = call_obj + break + + assert pipeline_call is not None + duration = pipeline_call[0][0] + assert isinstance(duration, float) + assert duration >= 0 + + def test_multiple_pipeline_executions_emit_multiple_events( + self, cluster_pipeline_with_otel + ): + """ + Test that multiple pipeline executions emit multiple events. + """ + cluster, operation_duration_mock = cluster_pipeline_with_otel + + # Execute first pipeline + pipe1 = cluster.pipeline() + pipe1.set('multi_key1', 'value1') + pipe1.execute() + + # Execute second pipeline + pipe2 = cluster.pipeline() + pipe2.set('multi_key2', 'value2') + pipe2.execute() + + # Count PIPELINE events + pipeline_count = sum( + 1 for call_obj in operation_duration_mock.record.call_args_list + if call_obj[1]['attributes'].get('db.operation.name') == 'PIPELINE' + ) + + assert pipeline_count >= 2 + + def test_empty_pipeline_does_not_emit_event(self, cluster_pipeline_with_otel): + """ + Test that an empty pipeline does not emit events. + """ + cluster, operation_duration_mock = cluster_pipeline_with_otel + + # Execute an empty pipeline + pipe = cluster.pipeline() + pipe.execute() + + # Count PIPELINE events - should be 0 + pipeline_count = sum( + 1 for call_obj in operation_duration_mock.record.call_args_list + if call_obj[1]['attributes'].get('db.operation.name') == 'PIPELINE' + ) + + assert pipeline_count == 0 \ No newline at end of file diff --git a/tests/test_cluster_transaction.py b/tests/test_cluster_transaction.py index 6ebd6df566..e80836d89c 100644 --- a/tests/test_cluster_transaction.py +++ b/tests/test_cluster_transaction.py @@ -9,6 +9,10 @@ from redis.backoff import NoBackoff from redis.client import Redis from redis.cluster import PRIMARY, ClusterNode, NodesManager, RedisCluster +from redis.event import EventDispatcher +from redis.observability import recorder +from redis.observability.config import OTelConfig, MetricGroup +from redis.observability.metrics import RedisMetricsCollector from redis.retry import Retry from .conftest import skip_if_server_version_lt @@ -396,3 +400,258 @@ def test_transaction_discard(self, r): assert not pipe._execution_strategy._watching assert not pipe.command_stack + + +@pytest.mark.onlycluster +class TestClusterTransactionEventEmission: + """ + Integration tests that verify AfterCommandExecutionEvent is properly emitted + from ClusterPipeline (transaction mode) and delivered to the Meter through + the event dispatcher chain. + + These tests use a real Redis cluster connection but mock the OTel Meter + to verify events are correctly emitted. + """ + + @pytest.fixture + def mock_meter(self): + """Create a mock Meter that tracks all instrument calls.""" + meter = Mock() + + # Create mock histogram for operation duration + self.operation_duration = Mock() + + def create_histogram_side_effect(name, **kwargs): + if name == 'db.client.operation.duration': + return self.operation_duration + return Mock() + + meter.create_counter.return_value = Mock() + meter.create_up_down_counter.return_value = Mock() + meter.create_histogram.side_effect = create_histogram_side_effect + + return meter + + @pytest.fixture + def cluster_transaction_with_otel(self, r, mock_meter): + """ + Setup a ClusterPipeline (transaction mode) with real connection + and mocked OTel collector. + Returns tuple of (cluster, operation_duration_mock). + """ + + # Reset any existing collector state + recorder.reset_collector() + + # Create config with COMMAND group enabled + config = OTelConfig(metric_groups=[MetricGroup.COMMAND]) + + # Create collector with mocked meter + with patch('redis.observability.metrics.OTEL_AVAILABLE', True): + collector = RedisMetricsCollector(mock_meter, config) + + # Patch the recorder to use our collector + with patch.object( + recorder, + '_get_or_create_collector', + return_value=collector + ): + # Create a new event dispatcher and attach it to the cluster + event_dispatcher = EventDispatcher() + r._event_dispatcher = event_dispatcher + + yield r, self.operation_duration + + # Cleanup + recorder.reset_collector() + + def test_transaction_execute_emits_event_to_meter( + self, cluster_transaction_with_otel + ): + """ + Test that transaction execute emits AfterCommandExecutionEvent to Meter. + """ + cluster, operation_duration_mock = cluster_transaction_with_otel + + # Execute a transaction + with cluster.pipeline(transaction=True) as tx: + tx.set('{tx_key}1', 'value1') + tx.get('{tx_key}1') + tx.execute() + + # Verify the Meter's histogram.record() was called + operation_duration_mock.record.assert_called() + + # Find the TRANSACTION event call + transaction_call = None + for call_obj in operation_duration_mock.record.call_args_list: + attrs = call_obj[1]['attributes'] + if attrs.get('db.operation.name') == 'TRANSACTION': + transaction_call = call_obj + break + + assert transaction_call is not None + + # Verify duration was recorded + duration = transaction_call[0][0] + assert isinstance(duration, float) + assert duration >= 0 + + # Verify attributes + attrs = transaction_call[1]['attributes'] + assert attrs['db.operation.name'] == 'TRANSACTION' + + def test_transaction_server_attributes_recorded( + self, cluster_transaction_with_otel + ): + """ + Test that server address, port, and db namespace are recorded for transaction. + """ + cluster, operation_duration_mock = cluster_transaction_with_otel + + with cluster.pipeline(transaction=True) as tx: + tx.set('{server_attr}key', 'value') + tx.execute() + + # Find the TRANSACTION event call + transaction_call = None + for call_obj in operation_duration_mock.record.call_args_list: + attrs = call_obj[1]['attributes'] + if attrs.get('db.operation.name') == 'TRANSACTION': + transaction_call = call_obj + break + + assert transaction_call is not None + attrs = transaction_call[1]['attributes'] + + # Verify server attributes are present + assert 'server.address' in attrs + assert isinstance(attrs['server.address'], str) + + assert 'server.port' in attrs + assert isinstance(attrs['server.port'], int) + + assert 'db.namespace' in attrs + + def test_transaction_batch_size_recorded(self, cluster_transaction_with_otel): + """ + Test that transaction batch_size is correctly recorded. + """ + cluster, operation_duration_mock = cluster_transaction_with_otel + + # Execute a transaction with 3 commands + with cluster.pipeline(transaction=True) as tx: + tx.set('{batch}key1', 'value1') + tx.get('{batch}key1') + tx.delete('{batch}key1') + tx.execute() + + # Find the TRANSACTION event call + transaction_call = None + for call_obj in operation_duration_mock.record.call_args_list: + attrs = call_obj[1]['attributes'] + if attrs.get('db.operation.name') == 'TRANSACTION': + transaction_call = call_obj + break + + assert transaction_call is not None + attrs = transaction_call[1]['attributes'] + assert 'db.operation.batch.size' in attrs + assert attrs['db.operation.batch.size'] == 3 + + def test_transaction_duration_is_positive(self, cluster_transaction_with_otel): + """ + Test that the recorded duration for transaction is a positive float. + """ + cluster, operation_duration_mock = cluster_transaction_with_otel + + with cluster.pipeline(transaction=True) as tx: + tx.set('{duration}key', 'value') + tx.execute() + + # Find the TRANSACTION event call + transaction_call = None + for call_obj in operation_duration_mock.record.call_args_list: + attrs = call_obj[1]['attributes'] + if attrs.get('db.operation.name') == 'TRANSACTION': + transaction_call = call_obj + break + + assert transaction_call is not None + duration = transaction_call[0][0] + assert isinstance(duration, float) + assert duration >= 0 + + def test_multiple_transaction_executions_emit_multiple_events( + self, cluster_transaction_with_otel + ): + """ + Test that multiple transaction executions emit multiple events. + """ + cluster, operation_duration_mock = cluster_transaction_with_otel + + # Execute first transaction + with cluster.pipeline(transaction=True) as tx1: + tx1.set('{multi1}key', 'value1') + tx1.execute() + + # Execute second transaction + with cluster.pipeline(transaction=True) as tx2: + tx2.set('{multi2}key', 'value2') + tx2.execute() + + # Count TRANSACTION events + transaction_count = sum( + 1 for call_obj in operation_duration_mock.record.call_args_list + if call_obj[1]['attributes'].get('db.operation.name') == 'TRANSACTION' + ) + + assert transaction_count >= 2 + + def test_empty_transaction_does_not_emit_event( + self, cluster_transaction_with_otel + ): + """ + Test that an empty transaction does not emit TRANSACTION events. + """ + cluster, operation_duration_mock = cluster_transaction_with_otel + + # Execute an empty transaction + with cluster.pipeline(transaction=True) as tx: + tx.execute() + + # Count TRANSACTION events - should be 0 + transaction_count = sum( + 1 for call_obj in operation_duration_mock.record.call_args_list + if call_obj[1]['attributes'].get('db.operation.name') == 'TRANSACTION' + ) + + assert transaction_count == 0 + + def test_transaction_with_watch_emits_event(self, cluster_transaction_with_otel): + """ + Test that transaction with WATCH emits event correctly. + """ + cluster, operation_duration_mock = cluster_transaction_with_otel + + # Set initial value + cluster.set('{watch}key', '0') + + with cluster.pipeline(transaction=True) as tx: + tx.watch('{watch}key') + val = tx.get('{watch}key') + tx.multi() + tx.set('{watch}key', int(val or 0) + 1) + tx.execute() + + # Find the TRANSACTION event call + transaction_call = None + for call_obj in operation_duration_mock.record.call_args_list: + attrs = call_obj[1]['attributes'] + if attrs.get('db.operation.name') == 'TRANSACTION': + transaction_call = call_obj + break + + assert transaction_call is not None + attrs = transaction_call[1]['attributes'] + assert attrs['db.operation.name'] == 'TRANSACTION' diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index a395934920..19cd252943 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -4,6 +4,11 @@ import pytest from redis import RedisClusterException import redis +from redis.client import Pipeline +from redis.event import EventDispatcher +from redis.observability import recorder +from redis.observability.config import OTelConfig, MetricGroup +from redis.observability.metrics import RedisMetricsCollector from .conftest import skip_if_server_version_lt, wait_for_command @@ -454,3 +459,392 @@ def test_pipeline_with_msetex(self, r): p_transaction.msetex( {"key1_transaction": "value1", "key2_transaction": "value2"}, ex=10 ) + + +class TestPipelineEventEmission: + """ + Unit tests that verify AfterCommandExecutionEvent is properly emitted from Pipeline + and delivered to the Meter through the event dispatcher chain. + + These tests use fully mocked connection and connection pool - no real Redis + or OTel integration is used. + """ + + @pytest.fixture + def mock_connection(self): + """Create a mock connection with required attributes.""" + conn = mock.MagicMock() + conn.host = 'localhost' + conn.port = 6379 + conn.db = 0 + + # Mock retry to just execute the function directly + conn.retry.call_with_retry = lambda func, _: func() + + return conn + + @pytest.fixture + def mock_connection_pool(self, mock_connection): + """Create a mock connection pool.""" + pool = mock.MagicMock() + pool.get_connection.return_value = mock_connection + pool.get_encoder.return_value = mock.MagicMock() + return pool + + @pytest.fixture + def mock_meter(self): + """Create a mock Meter that tracks all instrument calls.""" + meter = mock.MagicMock() + + # Create mock histogram for operation duration + self.operation_duration = mock.MagicMock() + + def create_histogram_side_effect(name, **kwargs): + if name == 'db.client.operation.duration': + return self.operation_duration + return mock.MagicMock() + + meter.create_counter.return_value = mock.MagicMock() + meter.create_up_down_counter.return_value = mock.MagicMock() + meter.create_histogram.side_effect = create_histogram_side_effect + + return meter + + @pytest.fixture + def setup_pipeline_with_otel(self, mock_connection_pool, mock_connection, mock_meter): + """ + Setup a Pipeline with mocked connection and OTel collector. + Returns tuple of (pipeline, operation_duration_mock). + """ + + # Reset any existing collector state + recorder.reset_collector() + + # Create config with COMMAND group enabled + config = OTelConfig(metric_groups=[MetricGroup.COMMAND]) + + # Create collector with mocked meter + with mock.patch('redis.observability.metrics.OTEL_AVAILABLE', True): + collector = RedisMetricsCollector(mock_meter, config) + + # Patch the recorder to use our collector + with mock.patch.object( + recorder, + '_get_or_create_collector', + return_value=collector + ): + # Create event dispatcher (real one, to test the full chain) + event_dispatcher = EventDispatcher() + + # Create pipeline with mocked connection pool + pipeline = Pipeline( + connection_pool=mock_connection_pool, + response_callbacks={}, + transaction=True, + shard_hint=None, + event_dispatcher=event_dispatcher, + ) + + yield pipeline, self.operation_duration + + # Cleanup + recorder.reset_collector() + + def test_pipeline_execute_emits_event_to_meter(self, setup_pipeline_with_otel): + """ + Test that executing a pipeline emits AfterCommandExecutionEvent + which is delivered to the Meter's histogram.record() method. + """ + pipeline, operation_duration_mock = setup_pipeline_with_otel + + # Mock _execute_transaction to return successful responses + pipeline._execute_transaction = mock.MagicMock( + return_value=[True, True, b'value1'] + ) + + # Queue commands in the pipeline + pipeline.command_stack = [ + (('SET', 'key1', 'value1'), {}), + (('SET', 'key2', 'value2'), {}), + (('GET', 'key1'), {}), + ] + + # Execute the pipeline + pipeline.execute() + + # Verify the Meter's histogram.record() was called + operation_duration_mock.record.assert_called_once() + + # Get the call arguments + call_args = operation_duration_mock.record.call_args + + # Verify duration was recorded (first positional arg) + duration = call_args[0][0] + assert isinstance(duration, float) + assert duration >= 0 + + # Verify attributes + attrs = call_args[1]['attributes'] + assert attrs['db.operation.name'] == 'MULTI' + assert attrs['db.operation.batch.size'] == 3 + assert attrs['server.address'] == 'localhost' + assert attrs['server.port'] == 6379 + assert attrs['db.namespace'] == '0' + + def test_pipeline_transaction_emits_multi_command_name( + self, mock_connection_pool, mock_connection, mock_meter + ): + """ + Test that executing a pipeline in transaction mode (MULTI/EXEC) + emits AfterCommandExecutionEvent with command_name='MULTI'. + """ + + recorder.reset_collector() + config = OTelConfig(metric_groups=[MetricGroup.COMMAND]) + + with mock.patch('redis.observability.metrics.OTEL_AVAILABLE', True): + collector = RedisMetricsCollector(mock_meter, config) + + with mock.patch.object(recorder, '_get_or_create_collector', return_value=collector): + event_dispatcher = EventDispatcher() + + # Create pipeline with transaction=True + pipeline = Pipeline( + connection_pool=mock_connection_pool, + response_callbacks={}, + transaction=True, # Transaction mode + shard_hint=None, + event_dispatcher=event_dispatcher, + ) + + pipeline._execute_transaction = mock.MagicMock(return_value=[True]) + pipeline.command_stack = [(('SET', 'key', 'value'), {})] + + pipeline.execute() + + # Verify command name is MULTI + call_args = self.operation_duration.record.call_args + attrs = call_args[1]['attributes'] + assert attrs['db.operation.name'] == 'MULTI' + + recorder.reset_collector() + + def test_pipeline_no_transaction_emits_pipeline_command_name( + self, mock_connection_pool, mock_connection, mock_meter + ): + """ + Test that executing a pipeline without transaction + emits AfterCommandExecutionEvent with command_name='PIPELINE'. + """ + recorder.reset_collector() + config = OTelConfig(metric_groups=[MetricGroup.COMMAND]) + + with mock.patch('redis.observability.metrics.OTEL_AVAILABLE', True): + collector = RedisMetricsCollector(mock_meter, config) + + with mock.patch.object(recorder, '_get_or_create_collector', return_value=collector): + event_dispatcher = EventDispatcher() + + # Create pipeline with transaction=False + pipeline = Pipeline( + connection_pool=mock_connection_pool, + response_callbacks={}, + transaction=False, # Non-transaction mode + shard_hint=None, + event_dispatcher=event_dispatcher, + ) + + pipeline._execute_pipeline = mock.MagicMock(return_value=[True, True]) + pipeline.command_stack = [ + (('SET', 'key1', 'value1'), {}), + (('SET', 'key2', 'value2'), {}), + ] + + pipeline.execute() + + # Verify command name is PIPELINE + call_args = self.operation_duration.record.call_args + attrs = call_args[1]['attributes'] + assert attrs['db.operation.name'] == 'PIPELINE' + + recorder.reset_collector() + + def test_pipeline_error_emits_event_with_error( + self, mock_connection_pool, mock_connection, mock_meter + ): + """ + Test that when a pipeline execution raises an exception, + AfterCommandExecutionEvent is still emitted with error information. + """ + recorder.reset_collector() + config = OTelConfig(metric_groups=[MetricGroup.COMMAND]) + + with mock.patch('redis.observability.metrics.OTEL_AVAILABLE', True): + collector = RedisMetricsCollector(mock_meter, config) + + with mock.patch.object(recorder, '_get_or_create_collector', return_value=collector): + event_dispatcher = EventDispatcher() + + pipeline = Pipeline( + connection_pool=mock_connection_pool, + response_callbacks={}, + transaction=False, + shard_hint=None, + event_dispatcher=event_dispatcher, + ) + + # Make execute raise an exception + test_error = redis.ResponseError("WRONGTYPE Operation error") + pipeline._execute_pipeline = mock.MagicMock(side_effect=test_error) + pipeline.command_stack = [(('LPUSH', 'string_key', 'value'), {})] + + # Execute should raise the error + with pytest.raises(redis.ResponseError): + pipeline.execute() + + # Verify the Meter's histogram.record() was still called + self.operation_duration.record.assert_called_once() + + # Verify error type is recorded in attributes + call_args = self.operation_duration.record.call_args + attrs = call_args[1]['attributes'] + assert attrs['db.operation.name'] == 'PIPELINE' + assert 'error.type' in attrs + + recorder.reset_collector() + + def test_pipeline_batch_size_recorded_correctly(self, setup_pipeline_with_otel): + """ + Test that the batch_size attribute correctly reflects + the number of commands in the pipeline. + """ + pipeline, operation_duration_mock = setup_pipeline_with_otel + + pipeline._execute_transaction = mock.MagicMock( + return_value=[True, True, True, True, True] + ) + + # Queue exactly 5 commands + pipeline.command_stack = [ + (('SET', 'key1', 'v1'), {}), + (('SET', 'key2', 'v2'), {}), + (('SET', 'key3', 'v3'), {}), + (('SET', 'key4', 'v4'), {}), + (('SET', 'key5', 'v5'), {}), + ] + + pipeline.execute() + + # Verify batch_size is 5 + call_args = operation_duration_mock.record.call_args + attrs = call_args[1]['attributes'] + assert attrs['db.operation.batch.size'] == 5 + + def test_pipeline_server_attributes_recorded(self, setup_pipeline_with_otel): + """ + Test that server address, port, and db namespace are correctly recorded. + """ + pipeline, operation_duration_mock = setup_pipeline_with_otel + + pipeline._execute_transaction = mock.MagicMock(return_value=[True]) + pipeline.command_stack = [(('PING',), {})] + + pipeline.execute() + + call_args = operation_duration_mock.record.call_args + attrs = call_args[1]['attributes'] + + # Verify server attributes match mock connection + assert attrs['server.address'] == 'localhost' + assert attrs['server.port'] == 6379 + assert attrs['db.namespace'] == '0' + + def test_multiple_pipeline_executions_emit_multiple_events( + self, mock_connection_pool, mock_connection, mock_meter + ): + """ + Test that each pipeline execution emits a separate event to the Meter. + """ + recorder.reset_collector() + config = OTelConfig(metric_groups=[MetricGroup.COMMAND]) + + with mock.patch('redis.observability.metrics.OTEL_AVAILABLE', True): + collector = RedisMetricsCollector(mock_meter, config) + + with mock.patch.object(recorder, '_get_or_create_collector', return_value=collector): + event_dispatcher = EventDispatcher() + + # First pipeline execution + pipeline1 = Pipeline( + connection_pool=mock_connection_pool, + response_callbacks={}, + transaction=True, + shard_hint=None, + event_dispatcher=event_dispatcher, + ) + pipeline1._execute_transaction = mock.MagicMock(return_value=[True]) + pipeline1.command_stack = [(('SET', 'key1', 'value1'), {})] + pipeline1.execute() + + # Second pipeline execution + pipeline2 = Pipeline( + connection_pool=mock_connection_pool, + response_callbacks={}, + transaction=True, + shard_hint=None, + event_dispatcher=event_dispatcher, + ) + pipeline2._execute_transaction = mock.MagicMock(return_value=[True, True]) + pipeline2.command_stack = [ + (('SET', 'key2', 'value2'), {}), + (('SET', 'key3', 'value3'), {}), + ] + pipeline2.execute() + + # Verify histogram.record() was called twice + assert self.operation_duration.record.call_count == 2 + + recorder.reset_collector() + + def test_empty_pipeline_does_not_emit_event( + self, mock_connection_pool, mock_connection, mock_meter + ): + """ + Test that an empty pipeline (no commands) does not emit an event. + """ + from redis.client import Pipeline + from redis.event import EventDispatcher + from redis.observability import recorder + from redis.observability.config import OTelConfig, MetricGroup + from redis.observability.metrics import RedisMetricsCollector + + recorder.reset_collector() + config = OTelConfig(metric_groups=[MetricGroup.COMMAND]) + + with mock.patch('redis.observability.metrics.OTEL_AVAILABLE', True): + collector = RedisMetricsCollector(mock_meter, config) + + with mock.patch.object(recorder, '_get_or_create_collector', return_value=collector): + event_dispatcher = EventDispatcher() + + pipeline = Pipeline( + connection_pool=mock_connection_pool, + response_callbacks={}, + transaction=True, + shard_hint=None, + event_dispatcher=event_dispatcher, + ) + + # Empty command stack + pipeline.command_stack = [] + + # Execute empty pipeline + result = pipeline.execute() + + # Should return empty list + assert result == [] + + # No event should be emitted for empty pipeline + self.operation_duration.record.assert_not_called() + + recorder.reset_collector()