Skip to content

Commit 454e4e3

Browse files
committed
test tx listeners
1 parent c4895eb commit 454e4e3

File tree

2 files changed

+136
-0
lines changed

2 files changed

+136
-0
lines changed

tests/aio/query/test_query_transaction.py

+70
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import pytest
2+
from unittest import mock
23

34
from ydb.aio.query.transaction import QueryTxContext
45
from ydb.query.transaction import QueryTxStateEnum
6+
from ydb.query.base import TxListenerAsyncIO
57

68

79
class TestAsyncQueryTransaction:
@@ -107,3 +109,71 @@ async def test_execute_two_results(self, tx: QueryTxContext):
107109

108110
assert res == [[1], [2]]
109111
assert counter == 2
112+
113+
114+
class TestQueryTransactionListeners:
115+
class FakeListener(TxListenerAsyncIO):
116+
def __init__(self):
117+
self._call_stack = []
118+
119+
async def _on_before_commit(self):
120+
self._call_stack.append("before_commit")
121+
122+
async def _on_after_commit(self, exc):
123+
if exc is not None:
124+
self._call_stack.append("after_commit_exc")
125+
return
126+
self._call_stack.append("after_commit")
127+
128+
async def _on_before_rollback(self):
129+
self._call_stack.append("before_rollback")
130+
131+
async def _on_after_rollback(self, exc):
132+
if exc is not None:
133+
self._call_stack.append("after_rollback_exc")
134+
return
135+
self._call_stack.append("after_rollback")
136+
137+
@property
138+
def call_stack(self):
139+
return self._call_stack
140+
141+
@pytest.mark.asyncio
142+
async def test_tx_commit_normal(self, tx: QueryTxContext):
143+
listener = TestQueryTransactionListeners.FakeListener()
144+
tx._add_listener(listener)
145+
await tx.begin()
146+
await tx.commit()
147+
148+
assert listener.call_stack == ["before_commit", "after_commit"]
149+
150+
@pytest.mark.asyncio
151+
async def test_tx_commit_exc(self, tx: QueryTxContext):
152+
listener = TestQueryTransactionListeners.FakeListener()
153+
tx._add_listener(listener)
154+
await tx.begin()
155+
with mock.patch.object(tx, "_commit_call", side_effect=BaseException("commit failed")):
156+
with pytest.raises(BaseException):
157+
await tx.commit()
158+
159+
assert listener.call_stack == ["before_commit", "after_commit_exc"]
160+
161+
@pytest.mark.asyncio
162+
async def test_tx_rollback_normal(self, tx: QueryTxContext):
163+
listener = TestQueryTransactionListeners.FakeListener()
164+
tx._add_listener(listener)
165+
await tx.begin()
166+
await tx.rollback()
167+
168+
assert listener.call_stack == ["before_rollback", "after_rollback"]
169+
170+
@pytest.mark.asyncio
171+
async def test_tx_rollback_exc(self, tx: QueryTxContext):
172+
listener = TestQueryTransactionListeners.FakeListener()
173+
tx._add_listener(listener)
174+
await tx.begin()
175+
with mock.patch.object(tx, "_rollback_call", side_effect=BaseException("commit failed")):
176+
with pytest.raises(BaseException):
177+
await tx.rollback()
178+
179+
assert listener.call_stack == ["before_rollback", "after_rollback_exc"]

tests/query/test_query_transaction.py

+66
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import pytest
2+
import unittest.mock as mock
23

4+
from ydb.query.base import TxListener
35
from ydb.query.transaction import QueryTxContext
46
from ydb.query.transaction import QueryTxStateEnum
57

@@ -104,3 +106,67 @@ def test_tx_identity_after_begin_works(self, tx: QueryTxContext):
104106

105107
assert identity.tx_id == tx.tx_id
106108
assert identity.session_id == tx.session_id
109+
110+
111+
class TestQueryTransactionListeners:
112+
class FakeListener(TxListener):
113+
def __init__(self):
114+
self._call_stack = []
115+
116+
def _on_before_commit(self):
117+
self._call_stack.append("before_commit")
118+
119+
def _on_after_commit(self, exc):
120+
if exc is not None:
121+
self._call_stack.append("after_commit_exc")
122+
return
123+
self._call_stack.append("after_commit")
124+
125+
def _on_before_rollback(self):
126+
self._call_stack.append("before_rollback")
127+
128+
def _on_after_rollback(self, exc):
129+
if exc is not None:
130+
self._call_stack.append("after_rollback_exc")
131+
return
132+
self._call_stack.append("after_rollback")
133+
134+
@property
135+
def call_stack(self):
136+
return self._call_stack
137+
138+
def test_tx_commit_normal(self, tx: QueryTxContext):
139+
listener = TestQueryTransactionListeners.FakeListener()
140+
tx._add_listener(listener)
141+
tx.begin()
142+
tx.commit()
143+
144+
assert listener.call_stack == ["before_commit", "after_commit"]
145+
146+
def test_tx_commit_exc(self, tx: QueryTxContext):
147+
listener = TestQueryTransactionListeners.FakeListener()
148+
tx._add_listener(listener)
149+
tx.begin()
150+
with mock.patch.object(tx, "_commit_call", side_effect=BaseException("commit failed")):
151+
with pytest.raises(BaseException):
152+
tx.commit()
153+
154+
assert listener.call_stack == ["before_commit", "after_commit_exc"]
155+
156+
def test_tx_rollback_normal(self, tx: QueryTxContext):
157+
listener = TestQueryTransactionListeners.FakeListener()
158+
tx._add_listener(listener)
159+
tx.begin()
160+
tx.rollback()
161+
162+
assert listener.call_stack == ["before_rollback", "after_rollback"]
163+
164+
def test_tx_rollback_exc(self, tx: QueryTxContext):
165+
listener = TestQueryTransactionListeners.FakeListener()
166+
tx._add_listener(listener)
167+
tx.begin()
168+
with mock.patch.object(tx, "_rollback_call", side_effect=BaseException("commit failed")):
169+
with pytest.raises(BaseException):
170+
tx.rollback()
171+
172+
assert listener.call_stack == ["before_rollback", "after_rollback_exc"]

0 commit comments

Comments
 (0)