Skip to content

Commit f924d83

Browse files
committed
Fix review comments
1 parent c7c1f9f commit f924d83

File tree

4 files changed

+129
-42
lines changed

4 files changed

+129
-42
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import asyncio
2+
import argparse
3+
import logging
4+
import ydb
5+
6+
7+
async def connect(endpoint: str, database: str) -> ydb.aio.Driver:
8+
config = ydb.DriverConfig(endpoint=endpoint, database=database)
9+
config.credentials = ydb.credentials_from_env_variables()
10+
driver = ydb.aio.Driver(config)
11+
await driver.wait(5, fail_fast=True)
12+
return driver
13+
14+
15+
async def create_topic(driver: ydb.aio.Driver, topic: str, consumer: str):
16+
try:
17+
await driver.topic_client.drop_topic(topic)
18+
except ydb.SchemeError:
19+
pass
20+
21+
await driver.topic_client.create_topic(topic, consumers=[consumer])
22+
23+
24+
async def write_with_tx_example(driver: ydb.aio.Driver, topic: str, message_count: int = 10):
25+
async with ydb.aio.QuerySessionPool(driver) as session_pool:
26+
27+
async def callee(tx: ydb.aio.QueryTxContext):
28+
tx_writer: ydb.TopicTxWriterAsyncIO = driver.topic_client.tx_writer(tx, topic)
29+
for i in range(message_count):
30+
result_stream = await tx.execute(query=f"select {i} as res")
31+
messages = [result_set.rows[0]["res"] async for result_set in result_stream]
32+
33+
await tx_writer.write([ydb.TopicWriterMessage(data=str(message)) for message in messages])
34+
35+
print(f"Messages {messages} were written with tx.")
36+
37+
await session_pool.retry_tx_async(callee)
38+
39+
40+
async def read_with_tx_example(driver: ydb.aio.Driver, topic: str, consumer: str, message_count: int = 10):
41+
async with driver.topic_client.reader(topic, consumer) as reader:
42+
async with ydb.aio.QuerySessionPool(driver) as session_pool:
43+
for _ in range(message_count):
44+
45+
async def callee(tx: ydb.aio.QueryTxContext):
46+
batch = await reader.receive_batch_with_tx(tx, max_messages=1)
47+
print(f"Messages [{batch.messages[0].data.decode()}] were read with tx.")
48+
49+
await session_pool.retry_tx_async(callee)
50+
51+
52+
async def main():
53+
parser = argparse.ArgumentParser(
54+
formatter_class=argparse.RawDescriptionHelpFormatter,
55+
description="""YDB topic basic example.\n""",
56+
)
57+
parser.add_argument("-d", "--database", default="/local", help="Name of the database to use")
58+
parser.add_argument("-e", "--endpoint", default="grpc://localhost:2136", help="Endpoint url to use")
59+
parser.add_argument("-p", "--path", default="test-topic", help="Topic name")
60+
parser.add_argument("-c", "--consumer", default="consumer", help="Consumer name")
61+
parser.add_argument("-v", "--verbose", default=False, action="store_true")
62+
parser.add_argument(
63+
"-s",
64+
"--skip-drop-and-create-topic",
65+
default=False,
66+
action="store_true",
67+
help="Use existed topic, skip remove it and re-create",
68+
)
69+
70+
args = parser.parse_args()
71+
72+
if args.verbose:
73+
logger = logging.getLogger("topicexample")
74+
logger.setLevel(logging.DEBUG)
75+
logger.addHandler(logging.StreamHandler())
76+
77+
async with await connect(args.endpoint, args.database) as driver:
78+
if not args.skip_drop_and_create_topic:
79+
await create_topic(driver, args.path, args.consumer)
80+
81+
await write_with_tx_example(driver, args.path)
82+
await read_with_tx_example(driver, args.path, args.consumer)
83+
84+
85+
if __name__ == "__main__":
86+
asyncio.run(main())
+27-32
Original file line numberDiff line numberDiff line change
@@ -1,59 +1,54 @@
1-
import asyncio
21
import argparse
32
import logging
43
import ydb
54

65

7-
async def connect(endpoint: str, database: str) -> ydb.aio.Driver:
6+
def connect(endpoint: str, database: str) -> ydb.Driver:
87
config = ydb.DriverConfig(endpoint=endpoint, database=database)
98
config.credentials = ydb.credentials_from_env_variables()
10-
driver = ydb.aio.Driver(config)
11-
await driver.wait(5, fail_fast=True)
9+
driver = ydb.Driver(config)
10+
driver.wait(5, fail_fast=True)
1211
return driver
1312

1413

15-
async def create_topic(driver: ydb.aio.Driver, topic: str, consumer: str):
14+
def create_topic(driver: ydb.Driver, topic: str, consumer: str):
1615
try:
17-
await driver.topic_client.drop_topic(topic)
16+
driver.topic_client.drop_topic(topic)
1817
except ydb.SchemeError:
1918
pass
2019

21-
await driver.topic_client.create_topic(topic, consumers=[consumer])
20+
driver.topic_client.create_topic(topic, consumers=[consumer])
2221

2322

24-
async def write_with_tx_example(driver: ydb.aio.Driver, topic: str, message_count: int = 10):
25-
async with ydb.aio.QuerySessionPool(driver) as session_pool:
23+
def write_with_tx_example(driver: ydb.Driver, topic: str, message_count: int = 10):
24+
with ydb.QuerySessionPool(driver) as session_pool:
2625

27-
async def callee(tx: ydb.aio.QueryTxContext):
28-
print(f"TX ID: {tx.tx_id}")
29-
print(f"TX STATE: {tx._tx_state._state.value}")
26+
def callee(tx: ydb.aio.QueryTxContext):
3027
tx_writer: ydb.TopicTxWriterAsyncIO = driver.topic_client.tx_writer(tx, topic)
31-
print(f"TX ID: {tx.tx_id}")
32-
print(f"TX STATE: {tx._tx_state._state.value}")
3328
for i in range(message_count):
34-
result_stream = await tx.execute(query=f"select {i} as res")
35-
messages = [result_set.rows[0]["res"] async for result_set in result_stream]
29+
result_stream = tx.execute(query=f"select {i} as res")
30+
messages = [result_set.rows[0]["res"] for result_set in result_stream]
3631

37-
await tx_writer.write([ydb.TopicWriterMessage(data=str(message)) for message in messages])
32+
tx_writer.write([ydb.TopicWriterMessage(data=str(message)) for message in messages])
3833

3934
print(f"Messages {messages} were written with tx.")
4035

41-
await session_pool.retry_tx_async(callee)
36+
session_pool.retry_tx_sync(callee)
4237

4338

44-
async def read_with_tx_example(driver: ydb.aio.Driver, topic: str, consumer: str, message_count: int = 10):
45-
async with driver.topic_client.reader(topic, consumer) as reader:
46-
async with ydb.aio.QuerySessionPool(driver) as session_pool:
39+
def read_with_tx_example(driver: ydb.Driver, topic: str, consumer: str, message_count: int = 10):
40+
with driver.topic_client.reader(topic, consumer) as reader:
41+
with ydb.QuerySessionPool(driver) as session_pool:
4742
for _ in range(message_count):
4843

49-
async def callee(tx: ydb.aio.QueryTxContext):
50-
batch = await reader.receive_batch_with_tx(tx, max_messages=1)
51-
print(f"Messages {batch.messages[0].data} were read with tx.")
44+
def callee(tx: ydb.aio.QueryTxContext):
45+
batch = reader.receive_batch_with_tx(tx, max_messages=1)
46+
print(f"Messages [{batch.messages[0].data.decode()}] were read with tx.")
5247

53-
await session_pool.retry_tx_async(callee)
48+
session_pool.retry_tx_sync(callee)
5449

5550

56-
async def main():
51+
def main():
5752
parser = argparse.ArgumentParser(
5853
formatter_class=argparse.RawDescriptionHelpFormatter,
5954
description="""YDB topic basic example.\n""",
@@ -78,13 +73,13 @@ async def main():
7873
logger.setLevel(logging.DEBUG)
7974
logger.addHandler(logging.StreamHandler())
8075

81-
driver = await connect(args.endpoint, args.database)
82-
if not args.skip_drop_and_create_topic:
83-
await create_topic(driver, args.path, args.consumer)
76+
with connect(args.endpoint, args.database) as driver:
77+
if not args.skip_drop_and_create_topic:
78+
create_topic(driver, args.path, args.consumer)
8479

85-
await write_with_tx_example(driver, args.path)
86-
await read_with_tx_example(driver, args.path, args.consumer)
80+
write_with_tx_example(driver, args.path)
81+
read_with_tx_example(driver, args.path, args.consumer)
8782

8883

8984
if __name__ == "__main__":
90-
asyncio.run(main())
85+
main()

ydb/_topic_reader/topic_reader_asyncio.py

+16-9
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def __init__(
8484
):
8585
self._loop = asyncio.get_running_loop()
8686
self._closed = False
87-
self._reconnector = ReaderReconnector(driver, settings)
87+
self._reconnector = ReaderReconnector(driver, settings, self._loop)
8888
self._parent = _parent
8989

9090
async def __aenter__(self):
@@ -190,18 +190,24 @@ class ReaderReconnector:
190190
_first_error: asyncio.Future[YdbError]
191191
_tx_to_batches_map: Dict[str, typing.List[datatypes.PublicBatch]]
192192

193-
def __init__(self, driver: Driver, settings: topic_reader.PublicReaderSettings):
193+
def __init__(
194+
self,
195+
driver: Driver,
196+
settings: topic_reader.PublicReaderSettings,
197+
loop: Optional[asyncio.AbstractEventLoop] = None,
198+
):
194199
self._id = self._static_reader_reconnector_counter.inc_and_get()
195200
self._settings = settings
196201
self._driver = driver
202+
self._loop = loop if loop is not None else asyncio.get_running_loop()
197203
self._background_tasks = set()
198204

199205
self._state_changed = asyncio.Event()
200206
self._stream_reader = None
201207
self._background_tasks.add(asyncio.create_task(self._connection_loop()))
202208
self._first_error = asyncio.get_running_loop().create_future()
203209

204-
self._tx_to_batches_map = defaultdict(list)
210+
self._tx_to_batches_map = dict()
205211

206212
async def _connection_loop(self):
207213
attempt = 0
@@ -254,22 +260,23 @@ def receive_batch_with_tx_nowait(self, tx: "BaseQueryTxContext", max_messages: O
254260
max_messages=max_messages,
255261
)
256262

257-
self._init_tx_if_needed(tx)
263+
self._init_tx(tx)
258264

259265
self._tx_to_batches_map[tx.tx_id].append(batch)
260266

261-
tx._add_callback(TxEvent.AFTER_COMMIT, batch._update_partition_offsets, None) # probably should be current loop
267+
tx._add_callback(TxEvent.AFTER_COMMIT, batch._update_partition_offsets, self._loop)
262268

263269
return batch
264270

265271
def receive_message_nowait(self):
266272
return self._stream_reader.receive_message_nowait()
267273

268-
def _init_tx_if_needed(self, tx: "BaseQueryTxContext"):
274+
def _init_tx(self, tx: "BaseQueryTxContext"):
269275
if tx.tx_id not in self._tx_to_batches_map: # Init tx callbacks
270-
tx._add_callback(TxEvent.BEFORE_COMMIT, self._commit_batches_with_tx, None)
271-
tx._add_callback(TxEvent.AFTER_COMMIT, self._handle_after_tx_commit, None)
272-
tx._add_callback(TxEvent.AFTER_ROLLBACK, self._handle_after_tx_rollback, None)
276+
self._tx_to_batches_map[tx.tx_id] = []
277+
tx._add_callback(TxEvent.BEFORE_COMMIT, self._commit_batches_with_tx, self._loop)
278+
tx._add_callback(TxEvent.AFTER_COMMIT, self._handle_after_tx_commit, self._loop)
279+
tx._add_callback(TxEvent.AFTER_ROLLBACK, self._handle_after_tx_rollback, self._loop)
273280

274281
async def _commit_batches_with_tx(self, tx: "BaseQueryTxContext"):
275282
grouped_batches = defaultdict(lambda: defaultdict(list))

ydb/_topic_writer/topic_writer_asyncio.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import asyncio
22
import concurrent.futures
33
import datetime
4-
import functools
54
import gzip
65
import typing
76
from collections import deque

0 commit comments

Comments
 (0)