Skip to content

Commit cb99a63

Browse files
committed
i have no idea what i'm doing
1 parent e9adad5 commit cb99a63

File tree

2 files changed

+126
-108
lines changed

2 files changed

+126
-108
lines changed

tests/topics/test_topic_transactions.py

+80-79
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
@pytest.mark.asyncio
99
class TestTopicTransactionalReader:
1010
async def test_commit(self, driver: ydb.aio.Driver, topic_with_messages, topic_consumer):
11-
async with driver.topic_client.reader(topic_with_messages, topic_consumer) as reader:
12-
async with ydb.aio.QuerySessionPool(driver) as pool:
11+
async with ydb.aio.QuerySessionPool(driver) as pool:
12+
async with driver.topic_client.reader(topic_with_messages, topic_consumer) as reader:
1313

1414
async def callee(tx: ydb.aio.QueryTxContext):
1515
batch = await wait_for(reader.receive_batch_with_tx(tx, max_messages=1), 1)
@@ -18,6 +18,7 @@ async def callee(tx: ydb.aio.QueryTxContext):
1818

1919
await pool.retry_tx_async(callee)
2020

21+
async with driver.topic_client.reader(topic_with_messages, topic_consumer) as reader:
2122
msg = await wait_for(reader.receive_message(), 1)
2223
assert msg.data.decode() == "456"
2324

@@ -38,114 +39,114 @@ async def callee(tx: ydb.aio.QueryTxContext):
3839
assert msg.data.decode() == "123"
3940

4041

41-
# class TestTopicTransactionalWriter:
42-
# async def test_commit(self, driver: ydb.aio.Driver, topic_path, topic_reader: ydb.TopicReaderAsyncIO):
43-
# async with ydb.aio.QuerySessionPool(driver) as pool:
42+
class TestTopicTransactionalWriter:
43+
async def test_commit(self, driver: ydb.aio.Driver, topic_path, topic_reader: ydb.TopicReaderAsyncIO):
44+
async with ydb.aio.QuerySessionPool(driver) as pool:
4445

45-
# async def callee(tx: ydb.aio.QueryTxContext):
46-
# tx_writer = driver.topic_client.tx_writer(tx, topic_path)
47-
# await tx_writer.write(ydb.TopicWriterMessage(data="123".encode()))
46+
async def callee(tx: ydb.aio.QueryTxContext):
47+
tx_writer = driver.topic_client.tx_writer(tx, topic_path)
48+
await tx_writer.write(ydb.TopicWriterMessage(data="123".encode()))
4849

49-
# await pool.retry_tx_async(callee)
50+
await pool.retry_tx_async(callee)
5051

51-
# msg = await wait_for(topic_reader.receive_message(), 0.1)
52-
# assert msg.data.decode() == "123"
52+
msg = await wait_for(topic_reader.receive_message(), 0.1)
53+
assert msg.data.decode() == "123"
5354

54-
# async def test_rollback(self, driver: ydb.aio.Driver, topic_path, topic_reader: ydb.TopicReaderAsyncIO):
55-
# async with ydb.aio.QuerySessionPool(driver) as pool:
55+
async def test_rollback(self, driver: ydb.aio.Driver, topic_path, topic_reader: ydb.TopicReaderAsyncIO):
56+
async with ydb.aio.QuerySessionPool(driver) as pool:
5657

57-
# async def callee(tx: ydb.aio.QueryTxContext):
58-
# tx_writer = driver.topic_client.tx_writer(tx, topic_path)
59-
# await tx_writer.write(ydb.TopicWriterMessage(data="123".encode()))
58+
async def callee(tx: ydb.aio.QueryTxContext):
59+
tx_writer = driver.topic_client.tx_writer(tx, topic_path)
60+
await tx_writer.write(ydb.TopicWriterMessage(data="123".encode()))
6061

61-
# await tx.rollback()
62+
await tx.rollback()
6263

63-
# await pool.retry_tx_async(callee)
64+
await pool.retry_tx_async(callee)
6465

65-
# with pytest.raises(asyncio.TimeoutError):
66-
# await wait_for(topic_reader.receive_message(), 0.1)
66+
with pytest.raises(asyncio.TimeoutError):
67+
await wait_for(topic_reader.receive_message(), 0.1)
6768

68-
# async def test_no_msg_written_in_error_case(
69-
# self, driver: ydb.aio.Driver, topic_path, topic_reader: ydb.TopicReaderAsyncIO
70-
# ):
71-
# async with ydb.aio.QuerySessionPool(driver) as pool:
69+
async def test_no_msg_written_in_error_case(
70+
self, driver: ydb.aio.Driver, topic_path, topic_reader: ydb.TopicReaderAsyncIO
71+
):
72+
async with ydb.aio.QuerySessionPool(driver) as pool:
7273

73-
# async def callee(tx: ydb.aio.QueryTxContext):
74-
# tx_writer = driver.topic_client.tx_writer(tx, topic_path)
75-
# await tx_writer.write(ydb.TopicWriterMessage(data="123".encode()))
74+
async def callee(tx: ydb.aio.QueryTxContext):
75+
tx_writer = driver.topic_client.tx_writer(tx, topic_path)
76+
await tx_writer.write(ydb.TopicWriterMessage(data="123".encode()))
7677

77-
# raise BaseException("error")
78+
raise BaseException("error")
7879

79-
# with pytest.raises(BaseException):
80-
# await pool.retry_tx_async(callee)
80+
with pytest.raises(BaseException):
81+
await pool.retry_tx_async(callee)
8182

82-
# with pytest.raises(asyncio.TimeoutError):
83-
# await wait_for(topic_reader.receive_message(), 0.1)
83+
with pytest.raises(asyncio.TimeoutError):
84+
await wait_for(topic_reader.receive_message(), 0.1)
8485

85-
# async def test_msg_written_exactly_once_with_retries(
86-
# self, driver: ydb.aio.Driver, topic_path, topic_reader: ydb.TopicReaderAsyncIO
87-
# ):
88-
# error_raised = False
89-
# async with ydb.aio.QuerySessionPool(driver) as pool:
86+
async def test_msg_written_exactly_once_with_retries(
87+
self, driver: ydb.aio.Driver, topic_path, topic_reader: ydb.TopicReaderAsyncIO
88+
):
89+
error_raised = False
90+
async with ydb.aio.QuerySessionPool(driver) as pool:
9091

91-
# async def callee(tx: ydb.aio.QueryTxContext):
92-
# nonlocal error_raised
93-
# tx_writer = driver.topic_client.tx_writer(tx, topic_path)
94-
# await tx_writer.write(ydb.TopicWriterMessage(data="123".encode()))
92+
async def callee(tx: ydb.aio.QueryTxContext):
93+
nonlocal error_raised
94+
tx_writer = driver.topic_client.tx_writer(tx, topic_path)
95+
await tx_writer.write(ydb.TopicWriterMessage(data="123".encode()))
9596

96-
# if not error_raised:
97-
# error_raised = True
98-
# raise ydb.issues.Unavailable("some retriable error")
97+
if not error_raised:
98+
error_raised = True
99+
raise ydb.issues.Unavailable("some retriable error")
99100

100-
# await pool.retry_tx_async(callee)
101+
await pool.retry_tx_async(callee)
101102

102-
# msg = await wait_for(topic_reader.receive_message(), 0.1)
103-
# assert msg.data.decode() == "123"
103+
msg = await wait_for(topic_reader.receive_message(), 0.1)
104+
assert msg.data.decode() == "123"
104105

105-
# with pytest.raises(asyncio.TimeoutError):
106-
# await wait_for(topic_reader.receive_message(), 0.1)
106+
with pytest.raises(asyncio.TimeoutError):
107+
await wait_for(topic_reader.receive_message(), 0.1)
107108

108109

109-
# class TestTopicTransactionalWriterSync:
110-
# def test_commit(self, driver_sync: ydb.Driver, topic_path, topic_reader_sync: ydb.TopicReader):
111-
# with ydb.QuerySessionPool(driver_sync) as pool:
110+
class TestTopicTransactionalWriterSync:
111+
def test_commit(self, driver_sync: ydb.Driver, topic_path, topic_reader_sync: ydb.TopicReader):
112+
with ydb.QuerySessionPool(driver_sync) as pool:
112113

113-
# def callee(tx: ydb.QueryTxContext):
114-
# tx_writer = driver_sync.topic_client.tx_writer(tx, topic_path)
115-
# tx_writer.write(ydb.TopicWriterMessage(data="123".encode()))
114+
def callee(tx: ydb.QueryTxContext):
115+
tx_writer = driver_sync.topic_client.tx_writer(tx, topic_path)
116+
tx_writer.write(ydb.TopicWriterMessage(data="123".encode()))
116117

117-
# pool.retry_tx_sync(callee)
118+
pool.retry_tx_sync(callee)
118119

119-
# msg = topic_reader_sync.receive_message(timeout=0.1)
120-
# assert msg.data.decode() == "123"
120+
msg = topic_reader_sync.receive_message(timeout=0.1)
121+
assert msg.data.decode() == "123"
121122

122-
# def test_rollback(self, driver_sync: ydb.aio.Driver, topic_path, topic_reader_sync: ydb.TopicReader):
123-
# with ydb.QuerySessionPool(driver_sync) as pool:
123+
def test_rollback(self, driver_sync: ydb.aio.Driver, topic_path, topic_reader_sync: ydb.TopicReader):
124+
with ydb.QuerySessionPool(driver_sync) as pool:
124125

125-
# def callee(tx: ydb.QueryTxContext):
126-
# tx_writer = driver_sync.topic_client.tx_writer(tx, topic_path)
127-
# tx_writer.write(ydb.TopicWriterMessage(data="123".encode()))
126+
def callee(tx: ydb.QueryTxContext):
127+
tx_writer = driver_sync.topic_client.tx_writer(tx, topic_path)
128+
tx_writer.write(ydb.TopicWriterMessage(data="123".encode()))
128129

129-
# tx.rollback()
130+
tx.rollback()
130131

131-
# pool.retry_tx_sync(callee)
132+
pool.retry_tx_sync(callee)
132133

133-
# with pytest.raises(TimeoutError):
134-
# topic_reader_sync.receive_message(timeout=0.1)
134+
with pytest.raises(TimeoutError):
135+
topic_reader_sync.receive_message(timeout=0.1)
135136

136-
# def test_no_msg_written_in_error_case(
137-
# self, driver_sync: ydb.Driver, topic_path, topic_reader_sync: ydb.TopicReaderAsyncIO
138-
# ):
139-
# with ydb.QuerySessionPool(driver_sync) as pool:
137+
def test_no_msg_written_in_error_case(
138+
self, driver_sync: ydb.Driver, topic_path, topic_reader_sync: ydb.TopicReaderAsyncIO
139+
):
140+
with ydb.QuerySessionPool(driver_sync) as pool:
140141

141-
# def callee(tx: ydb.QueryTxContext):
142-
# tx_writer = driver_sync.topic_client.tx_writer(tx, topic_path)
143-
# tx_writer.write(ydb.TopicWriterMessage(data="123".encode()))
142+
def callee(tx: ydb.QueryTxContext):
143+
tx_writer = driver_sync.topic_client.tx_writer(tx, topic_path)
144+
tx_writer.write(ydb.TopicWriterMessage(data="123".encode()))
144145

145-
# raise BaseException("error")
146+
raise BaseException("error")
146147

147-
# with pytest.raises(BaseException):
148-
# pool.retry_tx_sync(callee)
148+
with pytest.raises(BaseException):
149+
pool.retry_tx_sync(callee)
149150

150-
# with pytest.raises(TimeoutError):
151-
# topic_reader_sync.receive_message(timeout=0.1)
151+
with pytest.raises(TimeoutError):
152+
topic_reader_sync.receive_message(timeout=0.1)

ydb/_topic_reader/topic_reader_asyncio.py

+46-29
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
IGrpcWrapperAsyncIO,
2121
SupportedDriverType,
2222
GrpcWrapperAsyncIO,
23+
to_thread,
2324
)
2425
from .._grpc.grpcwrapper.ydb_topic import (
2526
StreamReadMessage,
@@ -68,7 +69,7 @@ def __init__(self):
6869
super().__init__("Topic reader is closed already")
6970

7071

71-
class PublicAsyncIOReader:
72+
class PublicAsyncIOReader(TxListenerAsyncIO):
7273
_loop: asyncio.AbstractEventLoop
7374
_closed: bool
7475
_reconnector: ReaderReconnector
@@ -176,6 +177,12 @@ async def close(self, flush: bool = True):
176177
self._closed = True
177178
await self._reconnector.close(flush)
178179

180+
def _on_after_commit(self, exc):
181+
return super()._on_after_commit(exc)
182+
183+
def _on_after_rollback(self, exc):
184+
return super()._on_after_rollback(exc)
185+
179186

180187
class ReaderReconnector:
181188
_static_reader_reconnector_counter = AtomicCounter()
@@ -189,6 +196,7 @@ class ReaderReconnector:
189196
_stream_reader: Optional["ReaderStream"]
190197
_first_error: asyncio.Future[YdbError]
191198
_batches_to_commit: asyncio.Queue
199+
_wait_executor: Optional[concurrent.futures.ThreadPoolExecutor]
192200

193201
def __init__(self, driver: Driver, settings: topic_reader.PublicReaderSettings):
194202
self._id = self._static_reader_reconnector_counter.inc_and_get()
@@ -200,6 +208,7 @@ def __init__(self, driver: Driver, settings: topic_reader.PublicReaderSettings):
200208
self._stream_reader = None
201209
self._background_tasks.add(asyncio.create_task(self._connection_loop()))
202210
self._first_error = asyncio.get_running_loop().create_future()
211+
self._wait_executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
203212

204213
self._batches_to_commit = asyncio.Queue()
205214

@@ -270,34 +279,39 @@ async def _update_offsets_in_tx_loop(self):
270279
await self._update_offsets_in_tx_call(self._driver, tx, batch)
271280

272281
async def _update_offsets_in_tx_call(self, driver: SupportedDriverType, tx: "BaseQueryTxContext", batch: datatypes.ICommittable) -> None:
273-
partition_session = batch._commit_get_partition_session()
274-
request = UpdateOffsetsInTransactionRequest(
275-
tx=tx._tx_identity(),
276-
consumer=self._settings.consumer,
277-
topics=[
278-
UpdateOffsetsInTransactionRequest.TopicOffsets(
279-
path=partition_session.topic_path,
280-
partitions=[
281-
UpdateOffsetsInTransactionRequest.TopicOffsets.PartitionOffsets(
282-
partition_id=partition_session.partition_id,
283-
partition_offsets=[batch._commit_get_offsets_range()]
284-
)
285-
],
286-
)
287-
],
288-
).to_proto()
289-
290-
res = driver(
291-
request,
292-
_apis.TopicService.Stub,
293-
_apis.TopicService.UpdateOffsetsInTransaction,
294-
topic_common.wrap_operation,
295-
)
296-
297-
if asyncio.iscoroutinefunction(driver.__call__):
298-
res = await res
299-
300-
return res
282+
try:
283+
partition_session = batch._commit_get_partition_session()
284+
request = UpdateOffsetsInTransactionRequest(
285+
tx=tx._tx_identity(),
286+
consumer=self._settings.consumer,
287+
topics=[
288+
UpdateOffsetsInTransactionRequest.TopicOffsets(
289+
path=partition_session.topic_path,
290+
partitions=[
291+
UpdateOffsetsInTransactionRequest.TopicOffsets.PartitionOffsets(
292+
partition_id=partition_session.partition_id,
293+
partition_offsets=[batch._commit_get_offsets_range()]
294+
)
295+
],
296+
)
297+
],
298+
).to_proto()
299+
300+
args = [
301+
request,
302+
_apis.TopicService.Stub,
303+
_apis.TopicService.UpdateOffsetsInTransaction,
304+
topic_common.wrap_operation,
305+
]
306+
307+
if asyncio.iscoroutinefunction(driver.__call__):
308+
res = await driver(*args)
309+
else:
310+
res = await to_thread(driver, *args, executor=self._wait_executor)
311+
312+
return res
313+
except BaseException as e:
314+
self._set_first_error(e)
301315

302316
async def close(self, flush: bool):
303317
if self._stream_reader:
@@ -307,6 +321,9 @@ async def close(self, flush: bool):
307321

308322
await asyncio.wait(self._background_tasks)
309323

324+
if self._wait_executor is not None:
325+
self._wait_executor.shutdown(wait=flush)
326+
310327
async def flush(self):
311328
if self._stream_reader:
312329
await self._stream_reader.flush()

0 commit comments

Comments
 (0)