|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
| 3 | +import collections.abc as cabc |
3 | 4 | import typing as t |
| 5 | +from contextlib import asynccontextmanager |
| 6 | +from contextlib import AsyncExitStack |
| 7 | +from contextlib import contextmanager |
| 8 | +from contextlib import ExitStack |
4 | 9 | from dataclasses import dataclass |
| 10 | +from unittest import mock |
5 | 11 | from weakref import WeakKeyDictionary |
6 | 12 |
|
7 | 13 | import sqlalchemy as sa |
@@ -406,6 +412,89 @@ async def async_one_or_abort( |
406 | 412 | except (sa_exc.NoResultFound, sa_exc.MultipleResultsFound): |
407 | 413 | abort(code, **(abort_kwargs or {})) |
408 | 414 |
|
| 415 | + @contextmanager |
| 416 | + def test_isolation(self) -> cabc.Iterator[None]: |
| 417 | + """Context manager to isolate the database during a test. Commits are |
| 418 | + rolled back when the ``with`` block exits, and will not be seen by other |
| 419 | + tests. |
| 420 | +
|
| 421 | + This patches the SQLAlchemy engine and session to use a single |
| 422 | + connection, in a transaction that is rolled back when the context exits. |
| 423 | +
|
| 424 | + If the code being tested uses async features, use |
| 425 | + :meth:`async_test_isolation` instead. It will isolate both the sync and |
| 426 | + async operations. |
| 427 | +
|
| 428 | + When using SQLite, follow the `SQLAlchemy docs`__ to fix the driver's |
| 429 | + transaction handling. |
| 430 | +
|
| 431 | + __ https://docs.sqlalchemy.org/en/20/dialects/sqlite.html#sqlite-transactions |
| 432 | +
|
| 433 | + .. versionadded:: 0.2 |
| 434 | + """ |
| 435 | + with ExitStack() as exit_stack: |
| 436 | + # Instruct the session to use nested transactions when it sees that |
| 437 | + # its connection is already in a transaction. |
| 438 | + exit_stack.enter_context( |
| 439 | + mock.patch.dict( |
| 440 | + self.sessionmaker.kw, {"join_transaction_mode": "create_savepoint"} |
| 441 | + ) |
| 442 | + ) |
| 443 | + |
| 444 | + for engine in self.engines.values(): |
| 445 | + # Create the connection, to be closed when the context exits. |
| 446 | + connection: sa.Connection = exit_stack.enter_context(engine.connect()) |
| 447 | + # The connection cannot be closed by code being tested. This |
| 448 | + # ensures the transaction remains active. |
| 449 | + connection.close = _nop # type: ignore[method-assign] |
| 450 | + # The engine will always return the same connection, with the |
| 451 | + # active transaction. |
| 452 | + exit_stack.enter_context( |
| 453 | + mock.patch.object(engine, "connect", lambda _c=connection: _c) |
| 454 | + ) |
| 455 | + # Start the transaction, to be rolled back when the context exits. |
| 456 | + transaction = connection.begin() |
| 457 | + exit_stack.callback(transaction.rollback) |
| 458 | + # If code being tested tries to start the transaction, start a |
| 459 | + # nested transaction instead. |
| 460 | + connection.begin = connection.begin_nested # type: ignore[assignment] |
| 461 | + |
| 462 | + yield None |
| 463 | + |
| 464 | + @asynccontextmanager |
| 465 | + async def async_test_isolation(self) -> cabc.AsyncIterator[None]: |
| 466 | + """Async version of :meth:`test_isolation` to be used as an |
| 467 | + ``async with`` block. It will isolate the sync code as well, so you do |
| 468 | + not need to use ``test_isolation`` as well. |
| 469 | +
|
| 470 | + .. versionadded:: 0.2 |
| 471 | + """ |
| 472 | + async with AsyncExitStack() as exit_stack: |
| 473 | + # Also isolate the sync operations. |
| 474 | + exit_stack.enter_context(self.test_isolation()) |
| 475 | + |
| 476 | + await exit_stack.enter_context( |
| 477 | + mock.patch.dict( |
| 478 | + self.async_sessionmaker.kw, |
| 479 | + {"join_transaction_mode": "create_savepoint"}, |
| 480 | + ) |
| 481 | + ) |
| 482 | + |
| 483 | + for engine in self.async_engines.values(): |
| 484 | + connection: sa_async.AsyncConnection = ( |
| 485 | + await exit_stack.enter_async_context(engine.connect()) |
| 486 | + ) |
| 487 | + connection.close = _async_nop # type: ignore[method-assign] |
| 488 | + connection.aclose = _async_nop # type: ignore[method-assign] |
| 489 | + exit_stack.enter_context( |
| 490 | + mock.patch.object(engine, "connect", lambda _c=connection: _c) |
| 491 | + ) |
| 492 | + transaction = connection.begin() |
| 493 | + exit_stack.push_async_callback(transaction.rollback) |
| 494 | + connection.begin = connection.begin_nested # type: ignore[method-assign] |
| 495 | + |
| 496 | + yield None |
| 497 | + |
409 | 498 |
|
410 | 499 | @dataclass |
411 | 500 | class _State: |
@@ -439,3 +528,11 @@ async def _close_async_sessions(e: BaseException | None) -> None: |
439 | 528 |
|
440 | 529 | for session in sessions.values(): |
441 | 530 | await session.close() |
| 531 | + |
| 532 | + |
| 533 | +def _nop() -> None: |
| 534 | + pass |
| 535 | + |
| 536 | + |
| 537 | +async def _async_nop() -> None: |
| 538 | + pass |
0 commit comments