Skip to content

Commit e9adad5

Browse files
committed
transactional reader draft
1 parent 0657186 commit e9adad5

File tree

4 files changed

+200
-82
lines changed

4 files changed

+200
-82
lines changed

tests/topics/test_topic_transactions.py

+82-82
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import ydb
55

66

7-
@pytest.mark.skip("Not implemented yet.")
7+
# @pytest.mark.skip("Not implemented yet.")
88
@pytest.mark.asyncio
99
class TestTopicTransactionalReader:
1010
async def test_commit(self, driver: ydb.aio.Driver, topic_with_messages, topic_consumer):
@@ -13,8 +13,8 @@ async def test_commit(self, driver: ydb.aio.Driver, topic_with_messages, topic_c
1313

1414
async def callee(tx: ydb.aio.QueryTxContext):
1515
batch = await wait_for(reader.receive_batch_with_tx(tx, max_messages=1), 1)
16-
assert len(batch) == 1
17-
assert batch[0].data.decode() == "123"
16+
assert len(batch.messages) == 1
17+
assert batch.messages[0].data.decode() == "123"
1818

1919
await pool.retry_tx_async(callee)
2020

@@ -27,8 +27,8 @@ async def test_rollback(self, driver: ydb.aio.Driver, topic_with_messages, topic
2727

2828
async def callee(tx: ydb.aio.QueryTxContext):
2929
batch = await wait_for(reader.receive_batch_with_tx(tx, max_messages=1), 1)
30-
assert len(batch) == 1
31-
assert batch[0].data.decode() == "123"
30+
assert len(batch.messages) == 1
31+
assert batch.messages[0].data.decode() == "123"
3232

3333
await tx.rollback()
3434

@@ -38,114 +38,114 @@ async def callee(tx: ydb.aio.QueryTxContext):
3838
assert msg.data.decode() == "123"
3939

4040

41-
class TestTopicTransactionalWriter:
42-
async def test_commit(self, driver: ydb.aio.Driver, topic_path, topic_reader: ydb.TopicReaderAsyncIO):
43-
async with ydb.aio.QuerySessionPool(driver) as pool:
41+
# class TestTopicTransactionalWriter:
42+
# async def test_commit(self, driver: ydb.aio.Driver, topic_path, topic_reader: ydb.TopicReaderAsyncIO):
43+
# async with ydb.aio.QuerySessionPool(driver) as pool:
4444

45-
async def callee(tx: ydb.aio.QueryTxContext):
46-
tx_writer = driver.topic_client.tx_writer(tx, topic_path)
47-
await tx_writer.write(ydb.TopicWriterMessage(data="123".encode()))
45+
# async def callee(tx: ydb.aio.QueryTxContext):
46+
# tx_writer = driver.topic_client.tx_writer(tx, topic_path)
47+
# await tx_writer.write(ydb.TopicWriterMessage(data="123".encode()))
4848

49-
await pool.retry_tx_async(callee)
49+
# await pool.retry_tx_async(callee)
5050

51-
msg = await wait_for(topic_reader.receive_message(), 0.1)
52-
assert msg.data.decode() == "123"
51+
# msg = await wait_for(topic_reader.receive_message(), 0.1)
52+
# assert msg.data.decode() == "123"
5353

54-
async def test_rollback(self, driver: ydb.aio.Driver, topic_path, topic_reader: ydb.TopicReaderAsyncIO):
55-
async with ydb.aio.QuerySessionPool(driver) as pool:
54+
# async def test_rollback(self, driver: ydb.aio.Driver, topic_path, topic_reader: ydb.TopicReaderAsyncIO):
55+
# async with ydb.aio.QuerySessionPool(driver) as pool:
5656

57-
async def callee(tx: ydb.aio.QueryTxContext):
58-
tx_writer = driver.topic_client.tx_writer(tx, topic_path)
59-
await tx_writer.write(ydb.TopicWriterMessage(data="123".encode()))
57+
# async def callee(tx: ydb.aio.QueryTxContext):
58+
# tx_writer = driver.topic_client.tx_writer(tx, topic_path)
59+
# await tx_writer.write(ydb.TopicWriterMessage(data="123".encode()))
6060

61-
await tx.rollback()
61+
# await tx.rollback()
6262

63-
await pool.retry_tx_async(callee)
63+
# await pool.retry_tx_async(callee)
6464

65-
with pytest.raises(asyncio.TimeoutError):
66-
await wait_for(topic_reader.receive_message(), 0.1)
65+
# with pytest.raises(asyncio.TimeoutError):
66+
# await wait_for(topic_reader.receive_message(), 0.1)
6767

68-
async def test_no_msg_written_in_error_case(
69-
self, driver: ydb.aio.Driver, topic_path, topic_reader: ydb.TopicReaderAsyncIO
70-
):
71-
async with ydb.aio.QuerySessionPool(driver) as pool:
68+
# async def test_no_msg_written_in_error_case(
69+
# self, driver: ydb.aio.Driver, topic_path, topic_reader: ydb.TopicReaderAsyncIO
70+
# ):
71+
# async with ydb.aio.QuerySessionPool(driver) as pool:
7272

73-
async def callee(tx: ydb.aio.QueryTxContext):
74-
tx_writer = driver.topic_client.tx_writer(tx, topic_path)
75-
await tx_writer.write(ydb.TopicWriterMessage(data="123".encode()))
73+
# async def callee(tx: ydb.aio.QueryTxContext):
74+
# tx_writer = driver.topic_client.tx_writer(tx, topic_path)
75+
# await tx_writer.write(ydb.TopicWriterMessage(data="123".encode()))
7676

77-
raise BaseException("error")
77+
# raise BaseException("error")
7878

79-
with pytest.raises(BaseException):
80-
await pool.retry_tx_async(callee)
79+
# with pytest.raises(BaseException):
80+
# await pool.retry_tx_async(callee)
8181

82-
with pytest.raises(asyncio.TimeoutError):
83-
await wait_for(topic_reader.receive_message(), 0.1)
82+
# with pytest.raises(asyncio.TimeoutError):
83+
# await wait_for(topic_reader.receive_message(), 0.1)
8484

85-
async def test_msg_written_exactly_once_with_retries(
86-
self, driver: ydb.aio.Driver, topic_path, topic_reader: ydb.TopicReaderAsyncIO
87-
):
88-
error_raised = False
89-
async with ydb.aio.QuerySessionPool(driver) as pool:
85+
# async def test_msg_written_exactly_once_with_retries(
86+
# self, driver: ydb.aio.Driver, topic_path, topic_reader: ydb.TopicReaderAsyncIO
87+
# ):
88+
# error_raised = False
89+
# async with ydb.aio.QuerySessionPool(driver) as pool:
9090

91-
async def callee(tx: ydb.aio.QueryTxContext):
92-
nonlocal error_raised
93-
tx_writer = driver.topic_client.tx_writer(tx, topic_path)
94-
await tx_writer.write(ydb.TopicWriterMessage(data="123".encode()))
91+
# async def callee(tx: ydb.aio.QueryTxContext):
92+
# nonlocal error_raised
93+
# tx_writer = driver.topic_client.tx_writer(tx, topic_path)
94+
# await tx_writer.write(ydb.TopicWriterMessage(data="123".encode()))
9595

96-
if not error_raised:
97-
error_raised = True
98-
raise ydb.issues.Unavailable("some retriable error")
96+
# if not error_raised:
97+
# error_raised = True
98+
# raise ydb.issues.Unavailable("some retriable error")
9999

100-
await pool.retry_tx_async(callee)
100+
# await pool.retry_tx_async(callee)
101101

102-
msg = await wait_for(topic_reader.receive_message(), 0.1)
103-
assert msg.data.decode() == "123"
102+
# msg = await wait_for(topic_reader.receive_message(), 0.1)
103+
# assert msg.data.decode() == "123"
104104

105-
with pytest.raises(asyncio.TimeoutError):
106-
await wait_for(topic_reader.receive_message(), 0.1)
105+
# with pytest.raises(asyncio.TimeoutError):
106+
# await wait_for(topic_reader.receive_message(), 0.1)
107107

108108

109-
class TestTopicTransactionalWriterSync:
110-
def test_commit(self, driver_sync: ydb.Driver, topic_path, topic_reader_sync: ydb.TopicReader):
111-
with ydb.QuerySessionPool(driver_sync) as pool:
109+
# class TestTopicTransactionalWriterSync:
110+
# def test_commit(self, driver_sync: ydb.Driver, topic_path, topic_reader_sync: ydb.TopicReader):
111+
# with ydb.QuerySessionPool(driver_sync) as pool:
112112

113-
def callee(tx: ydb.QueryTxContext):
114-
tx_writer = driver_sync.topic_client.tx_writer(tx, topic_path)
115-
tx_writer.write(ydb.TopicWriterMessage(data="123".encode()))
113+
# def callee(tx: ydb.QueryTxContext):
114+
# tx_writer = driver_sync.topic_client.tx_writer(tx, topic_path)
115+
# tx_writer.write(ydb.TopicWriterMessage(data="123".encode()))
116116

117-
pool.retry_tx_sync(callee)
117+
# pool.retry_tx_sync(callee)
118118

119-
msg = topic_reader_sync.receive_message(timeout=0.1)
120-
assert msg.data.decode() == "123"
119+
# msg = topic_reader_sync.receive_message(timeout=0.1)
120+
# assert msg.data.decode() == "123"
121121

122-
def test_rollback(self, driver_sync: ydb.aio.Driver, topic_path, topic_reader_sync: ydb.TopicReader):
123-
with ydb.QuerySessionPool(driver_sync) as pool:
122+
# def test_rollback(self, driver_sync: ydb.aio.Driver, topic_path, topic_reader_sync: ydb.TopicReader):
123+
# with ydb.QuerySessionPool(driver_sync) as pool:
124124

125-
def callee(tx: ydb.QueryTxContext):
126-
tx_writer = driver_sync.topic_client.tx_writer(tx, topic_path)
127-
tx_writer.write(ydb.TopicWriterMessage(data="123".encode()))
125+
# def callee(tx: ydb.QueryTxContext):
126+
# tx_writer = driver_sync.topic_client.tx_writer(tx, topic_path)
127+
# tx_writer.write(ydb.TopicWriterMessage(data="123".encode()))
128128

129-
tx.rollback()
129+
# tx.rollback()
130130

131-
pool.retry_tx_sync(callee)
131+
# pool.retry_tx_sync(callee)
132132

133-
with pytest.raises(TimeoutError):
134-
topic_reader_sync.receive_message(timeout=0.1)
133+
# with pytest.raises(TimeoutError):
134+
# topic_reader_sync.receive_message(timeout=0.1)
135135

136-
def test_no_msg_written_in_error_case(
137-
self, driver_sync: ydb.Driver, topic_path, topic_reader_sync: ydb.TopicReaderAsyncIO
138-
):
139-
with ydb.QuerySessionPool(driver_sync) as pool:
136+
# def test_no_msg_written_in_error_case(
137+
# self, driver_sync: ydb.Driver, topic_path, topic_reader_sync: ydb.TopicReaderAsyncIO
138+
# ):
139+
# with ydb.QuerySessionPool(driver_sync) as pool:
140140

141-
def callee(tx: ydb.QueryTxContext):
142-
tx_writer = driver_sync.topic_client.tx_writer(tx, topic_path)
143-
tx_writer.write(ydb.TopicWriterMessage(data="123".encode()))
141+
# def callee(tx: ydb.QueryTxContext):
142+
# tx_writer = driver_sync.topic_client.tx_writer(tx, topic_path)
143+
# tx_writer.write(ydb.TopicWriterMessage(data="123".encode()))
144144

145-
raise BaseException("error")
145+
# raise BaseException("error")
146146

147-
with pytest.raises(BaseException):
148-
pool.retry_tx_sync(callee)
147+
# with pytest.raises(BaseException):
148+
# pool.retry_tx_sync(callee)
149149

150-
with pytest.raises(TimeoutError):
151-
topic_reader_sync.receive_message(timeout=0.1)
150+
# with pytest.raises(TimeoutError):
151+
# topic_reader_sync.receive_message(timeout=0.1)

ydb/_apis.py

+1
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ class TopicService(object):
116116
DropTopic = "DropTopic"
117117
StreamRead = "StreamRead"
118118
StreamWrite = "StreamWrite"
119+
UpdateOffsetsInTransaction = "UpdateOffsetsInTransaction"
119120

120121

121122
class QueryService(object):

ydb/_grpc/grpcwrapper/ydb_topic.py

+46
Original file line numberDiff line numberDiff line change
@@ -1209,6 +1209,52 @@ def to_public(self) -> ydb_topic_public_types.PublicMeteringMode:
12091209
return ydb_topic_public_types.PublicMeteringMode.UNSPECIFIED
12101210

12111211

1212+
@dataclass
1213+
class UpdateOffsetsInTransactionRequest(IToProto):
1214+
tx: TransactionIdentity
1215+
topics: List[UpdateOffsetsInTransactionRequest.TopicOffsets]
1216+
consumer: str
1217+
1218+
def to_proto(self):
1219+
return ydb_topic_pb2.UpdateOffsetsInTransactionRequest(
1220+
tx=self.tx.to_proto(),
1221+
consumer=self.consumer,
1222+
topics=list(
1223+
map(
1224+
UpdateOffsetsInTransactionRequest.TopicOffsets.to_proto,
1225+
self.topics,
1226+
)
1227+
),
1228+
)
1229+
1230+
@dataclass
1231+
class TopicOffsets(IToProto):
1232+
path: str
1233+
partitions: List[UpdateOffsetsInTransactionRequest.TopicOffsets.PartitionOffsets]
1234+
1235+
def to_proto(self):
1236+
return ydb_topic_pb2.UpdateOffsetsInTransactionRequest.TopicOffsets(
1237+
path=self.path,
1238+
partitions=list(
1239+
map(
1240+
UpdateOffsetsInTransactionRequest.TopicOffsets.PartitionOffsets.to_proto,
1241+
self.partitions,
1242+
)
1243+
)
1244+
)
1245+
1246+
@dataclass
1247+
class PartitionOffsets(IToProto):
1248+
partition_id: int
1249+
partition_offsets: List[OffsetsRange]
1250+
1251+
def to_proto(self) -> ydb_topic_pb2.UpdateOffsetsInTransactionRequest.TopicOffsets.PartitionOffsets:
1252+
return ydb_topic_pb2.UpdateOffsetsInTransactionRequest.TopicOffsets.PartitionOffsets(
1253+
partition_id=self.partition_id,
1254+
partition_offsets=list(map(OffsetsRange.to_proto, self.partition_offsets)),
1255+
)
1256+
1257+
12121258
@dataclass
12131259
class CreateTopicRequest(IToProto, IFromPublic):
12141260
path: str

ydb/_topic_reader/topic_reader_asyncio.py

+71
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,17 @@
2525
StreamReadMessage,
2626
UpdateTokenRequest,
2727
UpdateTokenResponse,
28+
UpdateOffsetsInTransactionRequest,
2829
Codec,
2930
)
3031
from .._errors import check_retriable_error
3132
import logging
3233

34+
from ..query.base import TxListenerAsyncIO
35+
36+
if typing.TYPE_CHECKING:
37+
from ..query.transaction import BaseQueryTxContext
38+
3339
logger = logging.getLogger(__name__)
3440

3541

@@ -121,6 +127,23 @@ async def receive_message(self) -> typing.Optional[datatypes.PublicMessage]:
121127
await self._reconnector.wait_message()
122128
return self._reconnector.receive_message_nowait()
123129

130+
async def receive_batch_with_tx(
131+
self,
132+
tx: "BaseQueryTxContext",
133+
max_messages: typing.Union[int, None] = None,
134+
) -> typing.Union[datatypes.PublicBatch, None]:
135+
"""
136+
Get one messages batch from reader.
137+
All messages in a batch from same partition.
138+
139+
use asyncio.wait_for for wait with timeout.
140+
"""
141+
await self._reconnector.wait_message()
142+
return await self._reconnector.receive_batch_with_tx_nowait(
143+
tx,
144+
max_messages=max_messages,
145+
)
146+
124147
def commit(self, batch: typing.Union[datatypes.PublicMessage, datatypes.PublicBatch]):
125148
"""
126149
Write commit message to a buffer.
@@ -165,6 +188,7 @@ class ReaderReconnector:
165188
_state_changed: asyncio.Event
166189
_stream_reader: Optional["ReaderStream"]
167190
_first_error: asyncio.Future[YdbError]
191+
_batches_to_commit: asyncio.Queue
168192

169193
def __init__(self, driver: Driver, settings: topic_reader.PublicReaderSettings):
170194
self._id = self._static_reader_reconnector_counter.inc_and_get()
@@ -177,6 +201,8 @@ def __init__(self, driver: Driver, settings: topic_reader.PublicReaderSettings):
177201
self._background_tasks.add(asyncio.create_task(self._connection_loop()))
178202
self._first_error = asyncio.get_running_loop().create_future()
179203

204+
self._batches_to_commit = asyncio.Queue()
205+
180206
async def _connection_loop(self):
181207
attempt = 0
182208
while True:
@@ -228,6 +254,51 @@ def receive_message_nowait(self):
228254
def commit(self, batch: datatypes.ICommittable) -> datatypes.PartitionSession.CommitAckWaiter:
229255
return self._stream_reader.commit(batch)
230256

257+
async def _commit_with_tx(self, tx: "BaseQueryTxContext", batch: datatypes.ICommittable) -> None:
258+
pass
259+
260+
async def receive_batch_with_tx_nowait(self, tx: "BaseQueryTxContext", max_messages: Optional[int] = None):
261+
batch = self.receive_batch_nowait(max_messages=max_messages)
262+
tx._add_listener(batch)
263+
await self._update_offsets_in_tx_call(self._driver, tx, batch)
264+
return batch
265+
# self._batches_to_commit.put_nowait((tx, batch))
266+
267+
async def _update_offsets_in_tx_loop(self):
268+
while True:
269+
tx, batch = self._batches_to_commit.get()
270+
await self._update_offsets_in_tx_call(self._driver, tx, batch)
271+
272+
async def _update_offsets_in_tx_call(self, driver: SupportedDriverType, tx: "BaseQueryTxContext", batch: datatypes.ICommittable) -> None:
273+
partition_session = batch._commit_get_partition_session()
274+
request = UpdateOffsetsInTransactionRequest(
275+
tx=tx._tx_identity(),
276+
consumer=self._settings.consumer,
277+
topics=[
278+
UpdateOffsetsInTransactionRequest.TopicOffsets(
279+
path=partition_session.topic_path,
280+
partitions=[
281+
UpdateOffsetsInTransactionRequest.TopicOffsets.PartitionOffsets(
282+
partition_id=partition_session.partition_id,
283+
partition_offsets=[batch._commit_get_offsets_range()]
284+
)
285+
],
286+
)
287+
],
288+
).to_proto()
289+
290+
res = driver(
291+
request,
292+
_apis.TopicService.Stub,
293+
_apis.TopicService.UpdateOffsetsInTransaction,
294+
topic_common.wrap_operation,
295+
)
296+
297+
if asyncio.iscoroutinefunction(driver.__call__):
298+
res = await res
299+
300+
return res
301+
231302
async def close(self, flush: bool):
232303
if self._stream_reader:
233304
await self._stream_reader.close(flush)

0 commit comments

Comments
 (0)