@@ -97,6 +97,55 @@ class SentinelManagedSSLConnection(SentinelManagedConnection, SSLConnection):
9797 pass
9898
9999
100+ class SentinelConnectionPoolProxy :
101+ def __init__ (
102+ self ,
103+ connection_pool ,
104+ is_master ,
105+ check_connection ,
106+ service_name ,
107+ sentinel_manager ,
108+ ):
109+ self .connection_pool_ref = weakref .ref (connection_pool )
110+ self .is_master = is_master
111+ self .check_connection = check_connection
112+ self .service_name = service_name
113+ self .sentinel_manager = sentinel_manager
114+ self .reset ()
115+
116+ def reset (self ):
117+ self .master_address = None
118+ self .slave_rr_counter = None
119+
120+ async def get_master_address (self ):
121+ master_address = await self .sentinel_manager .discover_master (self .service_name )
122+ if self .is_master and self .master_address != master_address :
123+ self .master_address = master_address
124+ # disconnect any idle connections so that they reconnect
125+ # to the new master the next time that they are used.
126+ connection_pool = self .connection_pool_ref ()
127+ if connection_pool is not None :
128+ await connection_pool .disconnect (inuse_connections = False )
129+ return master_address
130+
131+ async def rotate_slaves (self ) -> AsyncIterator :
132+ """Round-robin slave balancer"""
133+ slaves = await self .sentinel_manager .discover_slaves (self .service_name )
134+ if slaves :
135+ if self .slave_rr_counter is None :
136+ self .slave_rr_counter = random .randint (0 , len (slaves ) - 1 )
137+ for _ in range (len (slaves )):
138+ self .slave_rr_counter = (self .slave_rr_counter + 1 ) % len (slaves )
139+ slave = slaves [self .slave_rr_counter ]
140+ yield slave
141+ # Fallback to the master connection
142+ try :
143+ yield await self .get_master_address ()
144+ except MasterNotFoundError :
145+ pass
146+ raise SlaveNotFoundError (f"No slave found for { self .service_name !r} " )
147+
148+
100149class SentinelConnectionPool (ConnectionPool ):
101150 """
102151 Sentinel backed connection pool.
@@ -116,6 +165,44 @@ def __init__(self, service_name, sentinel_manager, **kwargs):
116165 )
117166 self .is_master = kwargs .pop ("is_master" , True )
118167 self .check_connection = kwargs .pop ("check_connection" , False )
168+ self .proxy = SentinelConnectionPoolProxy (
169+ connection_pool = self ,
170+ is_master = self .is_master ,
171+ check_connection = self .check_connection ,
172+ service_name = service_name ,
173+ sentinel_manager = sentinel_manager ,
174+ )
175+ super ().__init__ (** kwargs )
176+ self .connection_kwargs ["connection_pool" ] = weakref .proxy (self )
177+ self .service_name = service_name
178+ self .sentinel_manager = sentinel_manager
179+
180+ def __repr__ (self ):
181+ return (
182+ f"<{ self .__class__ .__module__ } .{ self .__class__ .__name__ } "
183+ f"(service={ self .service_name } ({ self .is_master and 'master' or 'slave' } ))>"
184+ )
185+
186+ def reset (self ):
187+ super ().reset ()
188+ self .proxy .reset ()
189+
190+ @property
191+ def master_address (self ):
192+ return self .proxy .master_address
193+
194+ def owns_connection (self , connection : Connection ):
195+ check = not self .is_master or (
196+ self .is_master and self .master_address == (connection .host , connection .port )
197+ )
198+ return check and super ().owns_connection (connection )
199+
200+ async def get_master_address (self ):
201+ return await self .proxy .get_master_address ()
202+
203+ def rotate_slaves (self ) -> AsyncIterator :
204+ """Round-robin slave balancer"""
205+ return self .proxy .rotate_slaves ()
119206 super ().__init__ (** kwargs )
120207 self .connection_kwargs ["connection_pool" ] = weakref .proxy (self )
121208 self .service_name = service_name
0 commit comments