|
1 | 1 | import pytest
|
| 2 | +import unittest.mock as mock |
2 | 3 |
|
| 4 | +from ydb.query.base import TxListener |
3 | 5 | from ydb.query.transaction import QueryTxContext
|
4 | 6 | from ydb.query.transaction import QueryTxStateEnum
|
5 | 7 |
|
@@ -104,3 +106,67 @@ def test_tx_identity_after_begin_works(self, tx: QueryTxContext):
|
104 | 106 |
|
105 | 107 | assert identity.tx_id == tx.tx_id
|
106 | 108 | assert identity.session_id == tx.session_id
|
| 109 | + |
| 110 | + |
| 111 | +class TestQueryTransactionListeners: |
| 112 | + class FakeListener(TxListener): |
| 113 | + def __init__(self): |
| 114 | + self._call_stack = [] |
| 115 | + |
| 116 | + def _on_before_commit(self): |
| 117 | + self._call_stack.append("before_commit") |
| 118 | + |
| 119 | + def _on_after_commit(self, exc): |
| 120 | + if exc is not None: |
| 121 | + self._call_stack.append("after_commit_exc") |
| 122 | + return |
| 123 | + self._call_stack.append("after_commit") |
| 124 | + |
| 125 | + def _on_before_rollback(self): |
| 126 | + self._call_stack.append("before_rollback") |
| 127 | + |
| 128 | + def _on_after_rollback(self, exc): |
| 129 | + if exc is not None: |
| 130 | + self._call_stack.append("after_rollback_exc") |
| 131 | + return |
| 132 | + self._call_stack.append("after_rollback") |
| 133 | + |
| 134 | + @property |
| 135 | + def call_stack(self): |
| 136 | + return self._call_stack |
| 137 | + |
| 138 | + def test_tx_commit_normal(self, tx: QueryTxContext): |
| 139 | + listener = TestQueryTransactionListeners.FakeListener() |
| 140 | + tx._add_listener(listener) |
| 141 | + tx.begin() |
| 142 | + tx.commit() |
| 143 | + |
| 144 | + assert listener.call_stack == ["before_commit", "after_commit"] |
| 145 | + |
| 146 | + def test_tx_commit_exc(self, tx: QueryTxContext): |
| 147 | + listener = TestQueryTransactionListeners.FakeListener() |
| 148 | + tx._add_listener(listener) |
| 149 | + tx.begin() |
| 150 | + with mock.patch.object(tx, "_commit_call", side_effect=BaseException("commit failed")): |
| 151 | + with pytest.raises(BaseException): |
| 152 | + tx.commit() |
| 153 | + |
| 154 | + assert listener.call_stack == ["before_commit", "after_commit_exc"] |
| 155 | + |
| 156 | + def test_tx_rollback_normal(self, tx: QueryTxContext): |
| 157 | + listener = TestQueryTransactionListeners.FakeListener() |
| 158 | + tx._add_listener(listener) |
| 159 | + tx.begin() |
| 160 | + tx.rollback() |
| 161 | + |
| 162 | + assert listener.call_stack == ["before_rollback", "after_rollback"] |
| 163 | + |
| 164 | + def test_tx_rollback_exc(self, tx: QueryTxContext): |
| 165 | + listener = TestQueryTransactionListeners.FakeListener() |
| 166 | + tx._add_listener(listener) |
| 167 | + tx.begin() |
| 168 | + with mock.patch.object(tx, "_rollback_call", side_effect=BaseException("commit failed")): |
| 169 | + with pytest.raises(BaseException): |
| 170 | + tx.rollback() |
| 171 | + |
| 172 | + assert listener.call_stack == ["before_rollback", "after_rollback_exc"] |
0 commit comments