diff --git a/CHANGELOG.rst b/CHANGELOG.rst index c6d9d1ced..82c0ae79a 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -4,6 +4,8 @@ CHANGELOG 7.0.7 (unreleased) ------------------ +- Add async context manager to acquire advisory write locks for objects + [nilbacardit26] - Docs: Update documentation and configuration settings - Chore: Update sphinx-guillotina-theme version to 1.0.9 [rboixaderg] diff --git a/guillotina/db/locks.py b/guillotina/db/locks.py new file mode 100644 index 000000000..ee6f67883 --- /dev/null +++ b/guillotina/db/locks.py @@ -0,0 +1,54 @@ +from contextlib import asynccontextmanager +from guillotina import task_vars +from guillotina.exceptions import ObjectLockedError +from guillotina.exceptions import ReadOnlyError +from guillotina.exceptions import TransactionNotFound + +import asyncio +import asyncpg +import hashlib + + +def _oid_lock_key(oid: str) -> int: + digest = hashlib.blake2b(oid.encode("utf-8"), digest_size=8).digest() + return int.from_bytes(digest, "big", signed=True) + + +@asynccontextmanager +async def lock_object_for_write(oid: str, *, retries: int = 3, delay: float = 0.05): + """ + Acquire a transaction-scoped advisory lock for a Guillotina object. + Must be used inside an active transaction. + """ + txn = task_vars.txn.get() + if txn is None: + raise TransactionNotFound() + if getattr(txn, "read_only", False): + raise ReadOnlyError() + + storage = txn.storage + if getattr(txn, "_db_txn", None) is None: + await storage.start_transaction(txn) + + if retries < 1: + retries = 1 + + key = _oid_lock_key(oid) + async with storage.acquire(txn, "object_lock") as conn: + for attempt in range(1, retries + 1): + try: + locked = await conn.fetchval("SELECT pg_try_advisory_xact_lock($1);", key) + except asyncpg.exceptions.UndefinedFunctionError as ex: + raise NotImplementedError("Object locks require PostgreSQL advisory locks") from ex + if locked: + break + if attempt < retries and delay > 0: + await asyncio.sleep(delay) + else: + raise ObjectLockedError(oid, retries) + + try: + yield + finally: + # xact lock is released on commit/rollback + pass diff --git a/guillotina/exceptions.py b/guillotina/exceptions.py index 26c5a4628..1afc14ee2 100644 --- a/guillotina/exceptions.py +++ b/guillotina/exceptions.py @@ -145,6 +145,13 @@ class TIDConflictError(ConflictError): pass +class ObjectLockedError(Exception): + def __init__(self, oid, retries): + super().__init__(f"Object {oid} is locked for modification after {retries} retries") + self.oid = oid + self.retries = retries + + class RestartCommit(Exception): """ Commits requires restart diff --git a/guillotina/tests/test_postgres.py b/guillotina/tests/test_postgres.py index 0bb9f0fe8..eff427b81 100644 --- a/guillotina/tests/test_postgres.py +++ b/guillotina/tests/test_postgres.py @@ -2,11 +2,13 @@ from guillotina.component import get_adapter from guillotina.content import Folder from guillotina.db.interfaces import IVacuumProvider +from guillotina.db.locks import lock_object_for_write from guillotina.db.storages.cockroach import CockroachStorage from guillotina.db.storages.pg import PostgresqlStorage from guillotina.db.transaction_manager import TransactionManager from guillotina.exceptions import ConflictError from guillotina.exceptions import ConflictIdOnContainer +from guillotina.exceptions import ObjectLockedError from guillotina.tests import mocks from guillotina.tests.utils import create_content from unittest.mock import Mock @@ -135,6 +137,37 @@ async def test_restart_connection_pg(db, dummy_guillotina): await cleanup(aps) +@pytest.mark.skipif(DATABASE != "postgres", reason="Requires postgres advisory locks") +async def test_object_lock_for_write(db, dummy_guillotina): + aps = await get_aps(db) + tm = TransactionManager(aps) + + async with tm: + txn1 = await tm.begin() + ob = create_content() + txn1.register(ob) + await tm.commit(txn=txn1) + + txn1 = await tm.begin() + async with lock_object_for_write(ob.__uuid__, retries=1, delay=0): + txn2 = await tm.begin() + try: + with pytest.raises(ObjectLockedError): + async with lock_object_for_write(ob.__uuid__, retries=2, delay=0): + pass + finally: + await tm.abort(txn=txn2) + await tm.abort(txn=txn1) + + txn3 = await tm.begin() + async with lock_object_for_write(ob.__uuid__, retries=1, delay=0): + pass + await tm.abort(txn=txn3) + + await aps.remove() + await cleanup(aps) + + @pytest.mark.skipif( DATABASE in ("cockroachdb", "DUMMY"), reason="Cockroach does not have cascade support",