Skip to content

Commit 84af084

Browse files
authored
Fix: re-auth interfering with connectivity checks (#1144)
Connections might have pipelined messages when being picked up from the pool (e.g., re-auth). When unflushed those, they will be `RESET` when returned to the pool. This results in pipelineing `LOGON`, `LOGOFF`, and `RESET` where the `RESET` will skip the queue on the server side. Depending on the timing, this will leave the connection in an undesirable state (e.g., unauthenticated). This PR circumvents this situation by acquiring an unprepared connection from the pool when performing connectivity checks. Such connections will not have any work pipelined.
1 parent 8317bb9 commit 84af084

File tree

8 files changed

+226
-14
lines changed

8 files changed

+226
-14
lines changed

src/neo4j/_async/io/_pool.py

+20-4
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,9 @@ async def connection_creator():
274274
return connection_creator
275275
return None
276276

277-
async def _re_auth_connection(self, connection, auth, force):
277+
async def _re_auth_connection(self, connection, auth, force, unprepared):
278+
if unprepared and not force:
279+
return
278280
if auth:
279281
# Assert session auth is supported by the protocol.
280282
# The Bolt implementation will try as hard as it can to make the
@@ -312,7 +314,14 @@ async def _re_auth_connection(self, connection, auth, force):
312314
await connection.send_all()
313315
await connection.fetch_all()
314316

315-
async def _acquire(self, address, auth, deadline, liveness_check_timeout):
317+
async def _acquire(
318+
self,
319+
address,
320+
auth,
321+
deadline,
322+
liveness_check_timeout,
323+
unprepared=False,
324+
):
316325
"""
317326
Acquire a connection to a given address from the pool.
318327
@@ -362,7 +371,7 @@ async def health_check(connection_, deadline_):
362371
)
363372
try:
364373
await self._re_auth_connection(
365-
connection, auth, force_auth
374+
connection, auth, force_auth, unprepared
366375
)
367376
except ConfigurationError:
368377
if auth:
@@ -419,6 +428,7 @@ async def acquire(
419428
bookmarks,
420429
auth: AcquisitionAuth,
421430
liveness_check_timeout,
431+
unprepared=False,
422432
database_callback=None,
423433
):
424434
"""
@@ -431,6 +441,9 @@ async def acquire(
431441
:param bookmarks:
432442
:param auth:
433443
:param liveness_check_timeout:
444+
:param unprepared: If True, no messages will be pipelined on the
445+
connection. Meant to be used if no work is to be executed on the
446+
connection.
434447
:param database_callback:
435448
"""
436449
...
@@ -651,6 +664,7 @@ async def acquire(
651664
bookmarks,
652665
auth: AcquisitionAuth,
653666
liveness_check_timeout,
667+
unprepared=False,
654668
database_callback=None,
655669
):
656670
# The access_mode and database is not needed for a direct connection,
@@ -663,7 +677,7 @@ async def acquire(
663677
)
664678
deadline = Deadline.from_timeout_or_deadline(timeout)
665679
return await self._acquire(
666-
self.address, auth, deadline, liveness_check_timeout
680+
self.address, auth, deadline, liveness_check_timeout, unprepared
667681
)
668682

669683

@@ -1132,6 +1146,7 @@ async def acquire(
11321146
bookmarks,
11331147
auth: AcquisitionAuth | None,
11341148
liveness_check_timeout,
1149+
unprepared=False,
11351150
database_callback=None,
11361151
):
11371152
if access_mode not in {WRITE_ACCESS, READ_ACCESS}:
@@ -1201,6 +1216,7 @@ async def wrapped_database_callback(new_database):
12011216
auth,
12021217
deadline,
12031218
liveness_check_timeout,
1219+
unprepared,
12041220
)
12051221
except (ServiceUnavailable, SessionExpired):
12061222
await self.deactivate(address=address)

src/neo4j/_async/work/session.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -180,14 +180,16 @@ async def _result_error(self, error):
180180

181181
async def _get_server_info(self):
182182
assert not self._connection
183-
await self._connect(READ_ACCESS, liveness_check_timeout=0)
183+
await self._connect(
184+
READ_ACCESS, liveness_check_timeout=0, unprepared=True
185+
)
184186
server_info = self._connection.server_info
185187
await self._disconnect()
186188
return server_info
187189

188190
async def _verify_authentication(self):
189191
assert not self._connection
190-
await self._connect(READ_ACCESS, force_auth=True)
192+
await self._connect(READ_ACCESS, force_auth=True, unprepared=True)
191193
await self._disconnect()
192194

193195
@AsyncNonConcurrentMethodChecker._non_concurrent_method

src/neo4j/_sync/io/_pool.py

+20-4
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/neo4j/_sync/work/session.py

+4-2
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/unit/async_/io/test_direct.py

+27-1
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,11 @@ async def acquire(
7070
bookmarks,
7171
auth,
7272
liveness_check_timeout,
73+
unprepared=False,
7374
database_callback=None,
7475
):
7576
return await self._acquire(
76-
self.address, auth, timeout, liveness_check_timeout
77+
self.address, auth, timeout, liveness_check_timeout, unprepared
7778
)
7879

7980

@@ -277,3 +278,28 @@ async def test_liveness_check(
277278
cx1.reset.reset_mock()
278279
await pool.release(cx1)
279280
cx1.reset.assert_not_called()
281+
282+
283+
@pytest.mark.parametrize("unprepared", (True, False, None))
284+
@mark_async_test
285+
async def test_reauth(async_fake_connection_generator, unprepared):
286+
async with AsyncFakeBoltPool(
287+
async_fake_connection_generator,
288+
("127.0.0.1", 7687),
289+
) as pool:
290+
address = neo4j.Address(("127.0.0.1", 7687))
291+
# pre-populate pool
292+
cx = await pool._acquire(address, None, Deadline(3), None)
293+
await pool.release(cx)
294+
cx.reset_mock()
295+
296+
kwargs = {}
297+
if unprepared is not None:
298+
kwargs["unprepared"] = unprepared
299+
cx = await pool._acquire(address, None, Deadline(3), None, **kwargs)
300+
if unprepared:
301+
cx.re_auth.assert_not_called()
302+
else:
303+
cx.re_auth.assert_called_once()
304+
305+
await pool.release(cx)

tests/unit/async_/work/test_session.py

+62
Original file line numberDiff line numberDiff line change
@@ -877,3 +877,65 @@ async def resolve_db():
877877
cache_spy.set.assert_called_once_with(key, resolved_db)
878878
assert session._pinned_database
879879
assert config.database == resolved_db
880+
881+
882+
@pytest.mark.parametrize(
883+
"method", ("_get_server_info", "_verify_authentication")
884+
)
885+
@mark_async_test
886+
async def test_check_connections_are_unprepared_connection(
887+
async_fake_pool,
888+
method,
889+
):
890+
config = SessionConfig()
891+
async with AsyncSession(async_fake_pool, config) as session:
892+
await getattr(session, method)()
893+
assert len(async_fake_pool.acquired_connection_mocks) == 1
894+
async_fake_pool.acquire.assert_awaited_once()
895+
unprepared = async_fake_pool.acquire.call_args.kwargs.get("unprepared")
896+
assert unprepared is True
897+
898+
899+
async def _explicit_transaction(session: AsyncSession):
900+
async with await session.begin_transaction():
901+
pass
902+
903+
904+
async def _autocommit_transaction(session: AsyncSession):
905+
await session.run("RETURN 1")
906+
907+
908+
async def _tx_func_read(session: AsyncSession):
909+
async def work(tx: AsyncManagedTransaction):
910+
pass
911+
912+
await session.execute_read(work)
913+
914+
915+
async def _tx_func_write(session: AsyncSession):
916+
async def work(tx: AsyncManagedTransaction):
917+
pass
918+
919+
await session.execute_write(work)
920+
921+
922+
@pytest.mark.parametrize(
923+
"method",
924+
(
925+
_explicit_transaction,
926+
_autocommit_transaction,
927+
_tx_func_read,
928+
_tx_func_write,
929+
),
930+
)
931+
@mark_async_test
932+
async def test_work_connections_are_prepared_connection(
933+
async_fake_pool, method
934+
):
935+
config = SessionConfig()
936+
async with AsyncSession(async_fake_pool, config) as session:
937+
await method(session)
938+
assert len(async_fake_pool.acquired_connection_mocks) == 1
939+
async_fake_pool.acquire.assert_awaited_once()
940+
unprepared = async_fake_pool.acquire.call_args.kwargs.get("unprepared")
941+
assert unprepared is False or unprepared is None

tests/unit/sync/io/test_direct.py

+27-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)