Skip to content

Commit ec456af

Browse files
authored
Topic transactions feature (#559)
1 parent f2bbcf2 commit ec456af

27 files changed

+1345
-57
lines changed

.github/workflows/tests.yaml

+3-9
Original file line numberDiff line numberDiff line change
@@ -18,21 +18,15 @@ jobs:
1818
fail-fast: false
1919
matrix:
2020
python-version: [3.8, 3.9]
21-
environment: [py-proto5, py-tls-proto5, py-proto4, py-tls-proto4, py-proto3, py-tls-proto3]
22-
folder: [ydb, tests --ignore=tests/topics, tests/topics]
21+
environment: [py, py-tls, py-proto4, py-tls-proto4, py-proto3, py-tls-proto3]
22+
folder: [ydb, tests]
2323
exclude:
24-
- environment: py-tls-proto5
24+
- environment: py-tls
2525
folder: ydb
2626
- environment: py-tls-proto4
2727
folder: ydb
2828
- environment: py-tls-proto3
2929
folder: ydb
30-
- environment: py-tls-proto5
31-
folder: tests/topics
32-
- environment: py-tls-proto4
33-
folder: tests/topics
34-
- environment: py-tls-proto3
35-
folder: tests/topics
3630

3731
steps:
3832
- uses: actions/checkout@v1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import asyncio
2+
import argparse
3+
import logging
4+
import ydb
5+
6+
7+
async def connect(endpoint: str, database: str) -> ydb.aio.Driver:
8+
config = ydb.DriverConfig(endpoint=endpoint, database=database)
9+
config.credentials = ydb.credentials_from_env_variables()
10+
driver = ydb.aio.Driver(config)
11+
await driver.wait(5, fail_fast=True)
12+
return driver
13+
14+
15+
async def create_topic(driver: ydb.aio.Driver, topic: str, consumer: str):
16+
try:
17+
await driver.topic_client.drop_topic(topic)
18+
except ydb.SchemeError:
19+
pass
20+
21+
await driver.topic_client.create_topic(topic, consumers=[consumer])
22+
23+
24+
async def write_with_tx_example(driver: ydb.aio.Driver, topic: str, message_count: int = 10):
25+
async with ydb.aio.QuerySessionPool(driver) as session_pool:
26+
27+
async def callee(tx: ydb.aio.QueryTxContext):
28+
tx_writer: ydb.TopicTxWriterAsyncIO = driver.topic_client.tx_writer(tx, topic)
29+
30+
for i in range(message_count):
31+
async with await tx.execute(query=f"select {i} as res;") as result_stream:
32+
async for result_set in result_stream:
33+
message = str(result_set.rows[0]["res"])
34+
await tx_writer.write(ydb.TopicWriterMessage(message))
35+
print(f"Message {result_set.rows[0]['res']} was written with tx.")
36+
37+
await session_pool.retry_tx_async(callee)
38+
39+
40+
async def read_with_tx_example(driver: ydb.aio.Driver, topic: str, consumer: str, message_count: int = 10):
41+
async with driver.topic_client.reader(topic, consumer) as reader:
42+
async with ydb.aio.QuerySessionPool(driver) as session_pool:
43+
for _ in range(message_count):
44+
45+
async def callee(tx: ydb.aio.QueryTxContext):
46+
batch = await reader.receive_batch_with_tx(tx, max_messages=1)
47+
print(f"Message {batch.messages[0].data.decode()} was read with tx.")
48+
49+
await session_pool.retry_tx_async(callee)
50+
51+
52+
async def main():
53+
parser = argparse.ArgumentParser(
54+
formatter_class=argparse.RawDescriptionHelpFormatter,
55+
description="""YDB topic basic example.\n""",
56+
)
57+
parser.add_argument("-d", "--database", default="/local", help="Name of the database to use")
58+
parser.add_argument("-e", "--endpoint", default="grpc://localhost:2136", help="Endpoint url to use")
59+
parser.add_argument("-p", "--path", default="test-topic", help="Topic name")
60+
parser.add_argument("-c", "--consumer", default="consumer", help="Consumer name")
61+
parser.add_argument("-v", "--verbose", default=False, action="store_true")
62+
parser.add_argument(
63+
"-s",
64+
"--skip-drop-and-create-topic",
65+
default=False,
66+
action="store_true",
67+
help="Use existed topic, skip remove it and re-create",
68+
)
69+
70+
args = parser.parse_args()
71+
72+
if args.verbose:
73+
logger = logging.getLogger("topicexample")
74+
logger.setLevel(logging.DEBUG)
75+
logger.addHandler(logging.StreamHandler())
76+
77+
async with await connect(args.endpoint, args.database) as driver:
78+
if not args.skip_drop_and_create_topic:
79+
await create_topic(driver, args.path, args.consumer)
80+
81+
await write_with_tx_example(driver, args.path)
82+
await read_with_tx_example(driver, args.path, args.consumer)
83+
84+
85+
if __name__ == "__main__":
86+
asyncio.run(main())
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
import argparse
2+
import logging
3+
import ydb
4+
5+
6+
def connect(endpoint: str, database: str) -> ydb.Driver:
7+
config = ydb.DriverConfig(endpoint=endpoint, database=database)
8+
config.credentials = ydb.credentials_from_env_variables()
9+
driver = ydb.Driver(config)
10+
driver.wait(5, fail_fast=True)
11+
return driver
12+
13+
14+
def create_topic(driver: ydb.Driver, topic: str, consumer: str):
15+
try:
16+
driver.topic_client.drop_topic(topic)
17+
except ydb.SchemeError:
18+
pass
19+
20+
driver.topic_client.create_topic(topic, consumers=[consumer])
21+
22+
23+
def write_with_tx_example(driver: ydb.Driver, topic: str, message_count: int = 10):
24+
with ydb.QuerySessionPool(driver) as session_pool:
25+
26+
def callee(tx: ydb.QueryTxContext):
27+
tx_writer: ydb.TopicTxWriter = driver.topic_client.tx_writer(tx, topic)
28+
29+
for i in range(message_count):
30+
result_stream = tx.execute(query=f"select {i} as res;")
31+
for result_set in result_stream:
32+
message = str(result_set.rows[0]["res"])
33+
tx_writer.write(ydb.TopicWriterMessage(message))
34+
print(f"Message {message} was written with tx.")
35+
36+
session_pool.retry_tx_sync(callee)
37+
38+
39+
def read_with_tx_example(driver: ydb.Driver, topic: str, consumer: str, message_count: int = 10):
40+
with driver.topic_client.reader(topic, consumer) as reader:
41+
with ydb.QuerySessionPool(driver) as session_pool:
42+
for _ in range(message_count):
43+
44+
def callee(tx: ydb.QueryTxContext):
45+
batch = reader.receive_batch_with_tx(tx, max_messages=1)
46+
print(f"Message {batch.messages[0].data.decode()} was read with tx.")
47+
48+
session_pool.retry_tx_sync(callee)
49+
50+
51+
def main():
52+
parser = argparse.ArgumentParser(
53+
formatter_class=argparse.RawDescriptionHelpFormatter,
54+
description="""YDB topic basic example.\n""",
55+
)
56+
parser.add_argument("-d", "--database", default="/local", help="Name of the database to use")
57+
parser.add_argument("-e", "--endpoint", default="grpc://localhost:2136", help="Endpoint url to use")
58+
parser.add_argument("-p", "--path", default="test-topic", help="Topic name")
59+
parser.add_argument("-c", "--consumer", default="consumer", help="Consumer name")
60+
parser.add_argument("-v", "--verbose", default=False, action="store_true")
61+
parser.add_argument(
62+
"-s",
63+
"--skip-drop-and-create-topic",
64+
default=False,
65+
action="store_true",
66+
help="Use existed topic, skip remove it and re-create",
67+
)
68+
69+
args = parser.parse_args()
70+
71+
if args.verbose:
72+
logger = logging.getLogger("topicexample")
73+
logger.setLevel(logging.DEBUG)
74+
logger.addHandler(logging.StreamHandler())
75+
76+
with connect(args.endpoint, args.database) as driver:
77+
if not args.skip_drop_and_create_topic:
78+
create_topic(driver, args.path, args.consumer)
79+
80+
write_with_tx_example(driver, args.path)
81+
read_with_tx_example(driver, args.path, args.consumer)
82+
83+
84+
if __name__ == "__main__":
85+
main()

tests/query/test_query_transaction.py

+12
Original file line numberDiff line numberDiff line change
@@ -92,3 +92,15 @@ def test_execute_two_results(self, tx: QueryTxContext):
9292

9393
assert res == [[1], [2]]
9494
assert counter == 2
95+
96+
def test_tx_identity_before_begin_raises(self, tx: QueryTxContext):
97+
with pytest.raises(RuntimeError):
98+
tx._tx_identity()
99+
100+
def test_tx_identity_after_begin_works(self, tx: QueryTxContext):
101+
tx.begin()
102+
103+
identity = tx._tx_identity()
104+
105+
assert identity.tx_id == tx.tx_id
106+
assert identity.session_id == tx.session_id

tests/topics/test_topic_reader.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -174,12 +174,13 @@ def test_read_and_commit_with_close_reader(self, driver_sync, topic_with_message
174174
assert message != message2
175175

176176
def test_read_and_commit_with_ack(self, driver_sync, topic_with_messages, topic_consumer):
177-
reader = driver_sync.topic_client.reader(topic_with_messages, topic_consumer)
178-
message = reader.receive_message()
179-
reader.commit_with_ack(message)
177+
with driver_sync.topic_client.reader(topic_with_messages, topic_consumer) as reader:
178+
message = reader.receive_message()
179+
reader.commit_with_ack(message)
180+
181+
with driver_sync.topic_client.reader(topic_with_messages, topic_consumer) as reader:
182+
batch = reader.receive_batch()
180183

181-
reader = driver_sync.topic_client.reader(topic_with_messages, topic_consumer)
182-
batch = reader.receive_batch()
183184
assert message != batch.messages[0]
184185

185186
def test_read_compressed_messages(self, driver_sync, topic_path, topic_consumer):
@@ -247,3 +248,6 @@ async def wait(fut):
247248
datas.sort()
248249

249250
assert datas == ["10", "11"]
251+
252+
await reader0.close()
253+
await reader1.close()

0 commit comments

Comments
 (0)