Skip to content

Commit 0d84260

Browse files
committed
use finer-grained locks
1 parent 222a55f commit 0d84260

File tree

2 files changed

+202
-160
lines changed

2 files changed

+202
-160
lines changed

pymongo/asynchronous/pool.py

Lines changed: 101 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -733,8 +733,11 @@ def __init__(
733733
# from the right side.
734734
self.conns: collections.deque[AsyncConnection] = collections.deque()
735735
self.active_contexts: set[_CancellationContext] = set()
736+
# The main lock for the pool. The lock should only be used to protect
737+
# updating attributes.
738+
# If possible, avoid any additional work while holding the lock.
739+
# If looping over an attribute, copy the container and do not take the lock.
736740
self.lock = _async_create_lock()
737-
self._max_connecting_cond = _async_create_condition(self.lock)
738741
self.active_sockets = 0
739742
# Monotonically increasing connection ID required for CMAP Events.
740743
self.next_connection_id = 1
@@ -760,15 +763,19 @@ def __init__(
760763
# The first portion of the wait queue.
761764
# Enforces: maxPoolSize
762765
# Also used for: clearing the wait queue
763-
self.size_cond = _async_create_condition(self.lock)
766+
# Use a different lock to prevent lock contention. This lock protects
767+
# "requests".
768+
self.size_cond = _async_create_condition(_async_create_lock())
764769
self.requests = 0
765770
self.max_pool_size = self.opts.max_pool_size
766771
if not self.max_pool_size:
767772
self.max_pool_size = float("inf")
768773
# The second portion of the wait queue.
769774
# Enforces: maxConnecting
770775
# Also used for: clearing the wait queue
771-
self._max_connecting_cond = _async_create_condition(self.lock)
776+
# Use a different lock to prevent lock contention. This lock protects
777+
# "_pending".
778+
self._max_connecting_cond = _async_create_condition(_async_create_lock())
772779
self._max_connecting = self.opts.max_connecting
773780
self._pending = 0
774781
self._client_id = client_id
@@ -797,20 +804,24 @@ def __init__(
797804

798805
async def ready(self) -> None:
799806
# Take the lock to avoid the race condition described in PYTHON-2699.
800-
async with self.lock:
801-
if self.state != PoolState.READY:
807+
state_changed = False
808+
if self.state != PoolState.READY:
809+
async with self.lock:
802810
self.state = PoolState.READY
803-
if self.enabled_for_cmap:
804-
assert self.opts._event_listeners is not None
805-
self.opts._event_listeners.publish_pool_ready(self.address)
806-
if self.enabled_for_logging and _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG):
807-
_debug_log(
808-
_CONNECTION_LOGGER,
809-
message=_ConnectionStatusMessage.POOL_READY,
810-
clientId=self._client_id,
811-
serverHost=self.address[0],
812-
serverPort=self.address[1],
813-
)
811+
state_changed = True
812+
if not state_changed:
813+
return
814+
if self.enabled_for_cmap:
815+
assert self.opts._event_listeners is not None
816+
self.opts._event_listeners.publish_pool_ready(self.address)
817+
if self.enabled_for_logging and _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG):
818+
_debug_log(
819+
_CONNECTION_LOGGER,
820+
message=_ConnectionStatusMessage.POOL_READY,
821+
clientId=self._client_id,
822+
serverHost=self.address[0],
823+
serverPort=self.address[1],
824+
)
814825

815826
@property
816827
def closed(self) -> bool:
@@ -824,38 +835,45 @@ async def _reset(
824835
interrupt_connections: bool = False,
825836
) -> None:
826837
old_state = self.state
827-
async with self.size_cond:
828-
if self.closed:
829-
return
830-
if self.opts.pause_enabled and pause and not self.opts.load_balanced:
838+
if self.closed:
839+
return
840+
if self.opts.pause_enabled and pause and not self.opts.load_balanced:
841+
async with self.lock:
831842
old_state, self.state = self.state, PoolState.PAUSED
843+
844+
with self.lock:
832845
self.gen.inc(service_id)
833-
newpid = os.getpid()
834-
if self.pid != newpid:
835-
self.pid = newpid
836-
self.active_sockets = 0
837-
self.operation_count = 0
838-
if service_id is None:
839-
sockets, self.conns = self.conns, collections.deque()
840-
else:
841-
discard: collections.deque = collections.deque() # type: ignore[type-arg]
842-
keep: collections.deque = collections.deque() # type: ignore[type-arg]
843-
for conn in self.conns:
844-
if conn.service_id == service_id:
845-
discard.append(conn)
846-
else:
847-
keep.append(conn)
848-
sockets = discard
846+
newpid = os.getpid()
847+
if self.pid != newpid:
848+
self.pid = newpid
849+
with self.lock:
850+
self.active_sockets = 0
851+
self.operation_count = 0
852+
if service_id is None:
853+
new_conns = collections.deque()
854+
with self.lock:
855+
sockets, self.conns = self.conns, new_conns
856+
else:
857+
discard: collections.deque = collections.deque() # type: ignore[type-arg]
858+
keep: collections.deque = collections.deque() # type: ignore[type-arg]
859+
for conn in self.conns.copy():
860+
if conn.service_id == service_id:
861+
discard.append(conn)
862+
else:
863+
keep.append(conn)
864+
sockets = discard
865+
with self.lock:
849866
self.conns = keep
850867

851868
if close:
852-
self.state = PoolState.CLOSED
869+
with self.lock:
870+
self.state = PoolState.CLOSED
853871
# Clear the wait queue
854872
self._max_connecting_cond.notify_all()
855873
self.size_cond.notify_all()
856874

857875
if interrupt_connections:
858-
for context in self.active_contexts:
876+
for context in self.active_contexts.copy():
859877
context.cancel()
860878

861879
listeners = self.opts._event_listeners
@@ -914,9 +932,8 @@ async def update_is_writable(self, is_writable: Optional[bool]) -> None:
914932
Pool.
915933
"""
916934
self.is_writable = is_writable
917-
async with self.lock:
918-
for _socket in self.conns:
919-
_socket.update_is_writable(self.is_writable) # type: ignore[arg-type]
935+
for _socket in self.conns.copy():
936+
_socket.update_is_writable(self.is_writable) # type: ignore[arg-type]
920937

921938
async def reset(
922939
self, service_id: Optional[ObjectId] = None, interrupt_connections: bool = False
@@ -947,12 +964,9 @@ async def remove_stale_sockets(self, reference_generation: int) -> None:
947964

948965
if self.opts.max_idle_time_seconds is not None:
949966
close_conns = []
950-
async with self.lock:
951-
while (
952-
self.conns
953-
and self.conns[-1].idle_time_seconds() > self.opts.max_idle_time_seconds
954-
):
955-
close_conns.append(self.conns.pop())
967+
conns = self.conns.copy()
968+
while conns and conns[-1].idle_time_seconds() > self.opts.max_idle_time_seconds:
969+
close_conns.append(self.conns.pop())
956970
if not _IS_SYNC:
957971
await asyncio.gather(
958972
*[conn.close_conn(ConnectionClosedReason.IDLE) for conn in close_conns], # type: ignore[func-returns-value]
@@ -963,12 +977,12 @@ async def remove_stale_sockets(self, reference_generation: int) -> None:
963977
await conn.close_conn(ConnectionClosedReason.IDLE)
964978

965979
while True:
980+
# There are enough sockets in the pool.
981+
if len(self.conns) + self.active_sockets >= self.opts.min_pool_size:
982+
return
983+
if self.requests >= self.opts.min_pool_size:
984+
return
966985
async with self.size_cond:
967-
# There are enough sockets in the pool.
968-
if len(self.conns) + self.active_sockets >= self.opts.min_pool_size:
969-
return
970-
if self.requests >= self.opts.min_pool_size:
971-
return
972986
self.requests += 1
973987
incremented = False
974988
try:
@@ -978,15 +992,15 @@ async def remove_stale_sockets(self, reference_generation: int) -> None:
978992
if self._pending >= self._max_connecting:
979993
return
980994
self._pending += 1
981-
incremented = True
995+
incremented = True
982996
conn = await self.connect()
983997
close_conn = False
984-
async with self.lock:
985-
# Close connection and return if the pool was reset during
986-
# socket creation or while acquiring the pool lock.
987-
if self.gen.get_overall() != reference_generation:
988-
close_conn = True
989-
if not close_conn:
998+
# Close connection and return if the pool was reset during
999+
# socket creation or while acquiring the pool lock.
1000+
if self.gen.get_overall() != reference_generation:
1001+
close_conn = True
1002+
if not close_conn:
1003+
async with self.lock:
9901004
self.conns.appendleft(conn)
9911005
self.active_contexts.discard(conn.cancel_context)
9921006
if close_conn:
@@ -1011,11 +1025,11 @@ async def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> A
10111025
Note that the pool does not keep a reference to the socket -- you
10121026
must call checkin() when you're done with it.
10131027
"""
1028+
# Use a temporary context so that interrupt_connections can cancel creating the socket.
1029+
tmp_context = _CancellationContext()
1030+
conn_id = self.next_connection_id
10141031
async with self.lock:
1015-
conn_id = self.next_connection_id
10161032
self.next_connection_id += 1
1017-
# Use a temporary context so that interrupt_connections can cancel creating the socket.
1018-
tmp_context = _CancellationContext()
10191033
self.active_contexts.add(tmp_context)
10201034

10211035
listeners = self.opts._event_listeners
@@ -1254,7 +1268,7 @@ async def _get_conn(
12541268
try:
12551269
async with self.lock:
12561270
self.active_sockets += 1
1257-
incremented = True
1271+
incremented = True
12581272
while conn is None:
12591273
# CMAP: we MUST wait for either maxConnecting OR for a socket
12601274
# to be checked back into the pool.
@@ -1272,7 +1286,8 @@ async def _get_conn(
12721286
self._raise_if_not_ready(checkout_started_time, emit_event=False)
12731287

12741288
try:
1275-
conn = self.conns.popleft()
1289+
async with self.lock:
1290+
conn = self.conns.popleft()
12761291
except IndexError:
12771292
self._pending += 1
12781293
if conn: # We got a socket from the pool
@@ -1291,10 +1306,11 @@ async def _get_conn(
12911306
if conn:
12921307
# We checked out a socket but authentication failed.
12931308
await conn.close_conn(ConnectionClosedReason.ERROR)
1309+
if incremented:
1310+
async with self.lock:
1311+
self.active_sockets -= 1
12941312
async with self.size_cond:
12951313
self.requests -= 1
1296-
if incremented:
1297-
self.active_sockets -= 1
12981314
self.size_cond.notify()
12991315

13001316
if not emitted_event:
@@ -1330,7 +1346,8 @@ async def checkin(self, conn: AsyncConnection) -> None:
13301346
conn.active = False
13311347
conn.pinned_txn = False
13321348
conn.pinned_cursor = False
1333-
self.__pinned_sockets.discard(conn)
1349+
async with self.lock:
1350+
self.__pinned_sockets.discard(conn)
13341351
listeners = self.opts._event_listeners
13351352
async with self.lock:
13361353
self.active_contexts.discard(conn.cancel_context)
@@ -1371,28 +1388,32 @@ async def checkin(self, conn: AsyncConnection) -> None:
13711388
)
13721389
else:
13731390
close_conn = False
1374-
async with self.lock:
1375-
# Hold the lock to ensure this section does not race with
1376-
# Pool.reset().
1377-
if self.stale_generation(conn.generation, conn.service_id):
1378-
close_conn = True
1379-
else:
1380-
conn.update_last_checkin_time()
1381-
conn.update_is_writable(bool(self.is_writable))
1391+
conn.update_last_checkin_time()
1392+
conn.update_is_writable(bool(self.is_writable))
1393+
if self.stale_generation(conn.generation, conn.service_id):
1394+
close_conn = True
1395+
else:
1396+
with self.lock:
13821397
self.conns.appendleft(conn)
1398+
with self._max_connecting_cond:
13831399
# Notify any threads waiting to create a connection.
13841400
self._max_connecting_cond.notify()
13851401
if close_conn:
13861402
await conn.close_conn(ConnectionClosedReason.STALE)
13871403

1388-
async with self.size_cond:
1389-
if txn:
1404+
async with self.lock:
1405+
self.active_sockets -= 1
1406+
self.operation_count -= 1
1407+
1408+
if txn:
1409+
async with self.lock:
13901410
self.ntxns -= 1
1391-
elif cursor:
1411+
elif cursor:
1412+
async with self.lock:
13921413
self.ncursors -= 1
1414+
1415+
async with self.size_cond:
13931416
self.requests -= 1
1394-
self.active_sockets -= 1
1395-
self.operation_count -= 1
13961417
self.size_cond.notify()
13971418

13981419
async def _perished(self, conn: AsyncConnection) -> bool:

0 commit comments

Comments
 (0)