Skip to content

Commit 9dcf8ef

Browse files
committed
transaction identity
1 parent e64d447 commit 9dcf8ef

File tree

8 files changed

+142
-8
lines changed

8 files changed

+142
-8
lines changed

examples/topic/topic_transactions_example.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,13 @@ def writer_example(driver: ydb.Driver, topic: str):
55
session_pool = ydb.QuerySessionPool(driver)
66

77
def callee(tx: ydb.QueryTxContext):
8-
tx_writer: ydb.TopicTxWriter = driver.topic_client.tx_writer(tx, topic) # <=======
8+
tx_writer: ydb.TopicTxWriter = driver.topic_client.tx_writer(tx, topic) # <=======
99
# дефолт - без дедупликации, без ретраев и без producer_id.
1010

1111
with tx.execute(query="select 1") as result_sets:
1212
messages = [result_set.rows[0] for result_set in result_sets]
1313

14-
tx_writer.write(messages) # вне зависимости от состояния вышестоящего стрима поведение должно быть одинаковое
14+
tx_writer.write(messages) # вне зависимости от состояния вышестоящего стрима поведение должно быть одинаковое
1515

1616
session_pool.retry_tx_sync(callee)
1717

@@ -20,7 +20,7 @@ def reader_example(driver: ydb.Driver, reader: ydb.TopicReader):
2020
session_pool = ydb.QuerySessionPool(driver)
2121

2222
def callee(tx: ydb.QueryTxContext):
23-
batch = reader.receive_batch_with_tx(tx, max_messages=5) # <=======
23+
batch = reader.receive_batch_with_tx(tx, max_messages=5) # <=======
2424

2525
with tx.execute(query="INSERT INTO max_values(val) VALUES ($val)", parameters={"$val": max(batch)}) as _:
2626
pass

tests/query/test_query_transaction.py

+12
Original file line numberDiff line numberDiff line change
@@ -92,3 +92,15 @@ def test_execute_two_results(self, tx: QueryTxContext):
9292

9393
assert res == [[1], [2]]
9494
assert counter == 2
95+
96+
def test_tx_identity_before_begin_raises(self, tx: QueryTxContext):
97+
with pytest.raises(RuntimeError):
98+
tx._tx_identity()
99+
100+
def test_tx_identity_after_begin_works(self, tx: QueryTxContext):
101+
tx.begin()
102+
103+
identity = tx._tx_identity()
104+
105+
assert identity.tx_id == tx.tx_id
106+
assert identity.session_id == tx.session_id

ydb/_grpc/grpcwrapper/ydb_topic.py

+16
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,18 @@ def from_proto(msg: ydb_topic_pb2.UpdateTokenResponse) -> typing.Any:
142142
########################################################################################################################
143143

144144

145+
@dataclass
146+
class TransactionIdentity(IToProto):
147+
tx_id: str
148+
session_id: str
149+
150+
def to_proto(self) -> ydb_topic_pb2.TransactionIdentity:
151+
return ydb_topic_pb2.TransactionIdentity(
152+
id=self.tx_id,
153+
session=self.session_id,
154+
)
155+
156+
145157
class StreamWriteMessage:
146158
@dataclass()
147159
class InitRequest(IToProto):
@@ -200,6 +212,7 @@ def from_proto(
200212
class WriteRequest(IToProto):
201213
messages: typing.List["StreamWriteMessage.WriteRequest.MessageData"]
202214
codec: int
215+
tx_identity: Optional[TransactionIdentity]
203216

204217
@dataclass
205218
class MessageData(IToProto):
@@ -238,6 +251,9 @@ def to_proto(self) -> ydb_topic_pb2.StreamWriteMessage.WriteRequest:
238251
proto = ydb_topic_pb2.StreamWriteMessage.WriteRequest()
239252
proto.codec = self.codec
240253

254+
if self.tx_identity is not None:
255+
proto.tx = self.tx_identity.to_proto()
256+
241257
for message in self.messages:
242258
proto_mess = proto.messages.add()
243259
proto_mess.CopyFrom(message.to_proto())

ydb/_topic_writer/topic_writer.py

+4
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import ydb.aio
1313
from .._grpc.grpcwrapper.ydb_topic import StreamWriteMessage
14+
from .._grpc.grpcwrapper.ydb_topic import TransactionIdentity
1415
from .._grpc.grpcwrapper.common_utils import IToProto
1516
from .._grpc.grpcwrapper.ydb_topic_public_types import PublicCodec
1617
from .. import connection
@@ -205,6 +206,7 @@ def default_serializer_message_content(data: Any) -> bytes:
205206

206207
def messages_to_proto_requests(
207208
messages: List[InternalMessage],
209+
tx_identity: Optional[TransactionIdentity],
208210
) -> List[StreamWriteMessage.FromClient]:
209211

210212
gropus = _slit_messages_for_send(messages)
@@ -215,6 +217,7 @@ def messages_to_proto_requests(
215217
StreamWriteMessage.WriteRequest(
216218
messages=list(map(InternalMessage.to_message_data, group)),
217219
codec=group[0].codec,
220+
tx_identity=tx_identity,
218221
)
219222
)
220223
res.append(req)
@@ -239,6 +242,7 @@ def messages_to_proto_requests(
239242
),
240243
],
241244
codec=20000,
245+
tx_identity=None,
242246
)
243247
)
244248
.to_proto()

ydb/_topic_writer/topic_writer_asyncio.py

+33-3
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
UpdateTokenRequest,
3636
UpdateTokenResponse,
3737
StreamWriteMessage,
38+
TransactionIdentity,
3839
WriterMessagesFromServerToClient,
3940
)
4041
from .._grpc.grpcwrapper.common_utils import (
@@ -43,6 +44,9 @@
4344
GrpcWrapperAsyncIO,
4445
)
4546

47+
if typing.TYPE_CHECKING:
48+
from ..query.transaction import BaseQueryTxContext
49+
4650
logger = logging.getLogger(__name__)
4751

4852

@@ -165,7 +169,20 @@ async def wait_init(self) -> PublicWriterInitInfo:
165169

166170

167171
class TxWriterAsyncIO(WriterAsyncIO):
168-
...
172+
_tx: object
173+
174+
def __init__(
175+
self,
176+
tx,
177+
driver: SupportedDriverType,
178+
settings: PublicWriterSettings,
179+
_client=None,
180+
):
181+
self._tx = tx
182+
self._loop = asyncio.get_running_loop()
183+
self._closed = False
184+
self._reconnector = WriterAsyncIOReconnector(driver=driver, settings=WriterSettings(settings), tx=self._tx)
185+
self._parent = _client
169186

170187

171188
class WriterAsyncIOReconnector:
@@ -182,6 +199,7 @@ class WriterAsyncIOReconnector:
182199
_codec_selector_batch_num: int
183200
_codec_selector_last_codec: Optional[PublicCodec]
184201
_codec_selector_check_batches_interval: int
202+
_tx: Optional["BaseQueryTxContext"]
185203

186204
if typing.TYPE_CHECKING:
187205
_messages_for_encode: asyncio.Queue[List[InternalMessage]]
@@ -199,7 +217,9 @@ class WriterAsyncIOReconnector:
199217
_stop_reason: asyncio.Future
200218
_init_info: Optional[PublicWriterInitInfo]
201219

202-
def __init__(self, driver: SupportedDriverType, settings: WriterSettings):
220+
def __init__(
221+
self, driver: SupportedDriverType, settings: WriterSettings, tx: Optional["BaseQueryTxContext"] = None
222+
):
203223
self._closed = False
204224
self._loop = asyncio.get_running_loop()
205225
self._driver = driver
@@ -209,6 +229,7 @@ def __init__(self, driver: SupportedDriverType, settings: WriterSettings):
209229
self._init_info = None
210230
self._stream_connected = asyncio.Event()
211231
self._settings = settings
232+
self._tx = tx
212233

213234
self._codec_functions = {
214235
PublicCodec.RAW: lambda data: data,
@@ -358,10 +379,12 @@ async def _connection_loop(self):
358379
# noinspection PyBroadException
359380
stream_writer = None
360381
try:
382+
tx_identity = None if self._tx is None else self._tx._tx_identity()
361383
stream_writer = await WriterAsyncIOStream.create(
362384
self._driver,
363385
self._init_message,
364386
self._settings.update_token_interval,
387+
tx_identity=tx_identity,
365388
)
366389
try:
367390
if self._init_info is None:
@@ -601,10 +624,13 @@ class WriterAsyncIOStream:
601624
_update_token_event: asyncio.Event
602625
_get_token_function: Optional[Callable[[], str]]
603626

627+
_tx_identity: Optional[TransactionIdentity]
628+
604629
def __init__(
605630
self,
606631
update_token_interval: Optional[Union[int, float]] = None,
607632
get_token_function: Optional[Callable[[], str]] = None,
633+
tx_identity: Optional[TransactionIdentity] = None,
608634
):
609635
self._closed = False
610636

@@ -613,6 +639,8 @@ def __init__(
613639
self._update_token_event = asyncio.Event()
614640
self._update_token_task = None
615641

642+
self._tx_identity = tx_identity
643+
616644
async def close(self):
617645
if self._closed:
618646
return
@@ -629,6 +657,7 @@ async def create(
629657
driver: SupportedDriverType,
630658
init_request: StreamWriteMessage.InitRequest,
631659
update_token_interval: Optional[Union[int, float]] = None,
660+
tx_identity: Optional[TransactionIdentity] = None,
632661
) -> "WriterAsyncIOStream":
633662
stream = GrpcWrapperAsyncIO(StreamWriteMessage.FromServer.from_proto)
634663

@@ -638,6 +667,7 @@ async def create(
638667
writer = WriterAsyncIOStream(
639668
update_token_interval=update_token_interval,
640669
get_token_function=creds.get_auth_token if creds else lambda: "",
670+
tx_identity=tx_identity,
641671
)
642672
await writer._start(stream, init_request)
643673
return writer
@@ -684,7 +714,7 @@ def write(self, messages: List[InternalMessage]):
684714
if self._closed:
685715
raise RuntimeError("Can not write on closed stream.")
686716

687-
for request in messages_to_proto_requests(messages):
717+
for request in messages_to_proto_requests(messages, self._tx_identity):
688718
self._stream.write(request)
689719

690720
async def _update_token_loop(self):

ydb/_topic_writer/topic_writer_asyncio_test.py

+58-1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from .._grpc.grpcwrapper.ydb_topic import (
1919
Codec,
2020
StreamWriteMessage,
21+
TransactionIdentity,
2122
UpdateTokenRequest,
2223
UpdateTokenResponse,
2324
)
@@ -43,6 +44,12 @@
4344
from ..credentials import AnonymousCredentials
4445

4546

47+
FAKE_TRANSACTION_IDENTITY = TransactionIdentity(
48+
tx_id="transaction_id",
49+
session_id="session_id",
50+
)
51+
52+
4653
@pytest.fixture
4754
def default_driver() -> aio.Driver:
4855
driver = mock.Mock(spec=aio.Driver)
@@ -100,6 +107,20 @@ async def writer_and_stream(self, stream) -> WriterWithMockedStream:
100107

101108
await writer.close()
102109

110+
@pytest.fixture
111+
async def writer_and_stream_with_tx_identity(self, stream) -> WriterWithMockedStream:
112+
writer = await self.get_started_writer(
113+
stream,
114+
tx_identity=FAKE_TRANSACTION_IDENTITY,
115+
)
116+
117+
yield TestWriterAsyncIOStream.WriterWithMockedStream(
118+
stream=stream,
119+
writer=writer,
120+
)
121+
122+
await writer.close()
123+
103124
async def test_init_writer(self, stream):
104125
init_seqno = 4
105126
init_message = StreamWriteMessage.InitRequest(
@@ -148,6 +169,7 @@ async def test_write_a_message(self, writer_and_stream: WriterWithMockedStream):
148169
expected_message = StreamWriteMessage.FromClient(
149170
StreamWriteMessage.WriteRequest(
150171
codec=Codec.CODEC_RAW,
172+
tx_identity=None,
151173
messages=[
152174
StreamWriteMessage.WriteRequest.MessageData(
153175
seq_no=1,
@@ -164,6 +186,41 @@ async def test_write_a_message(self, writer_and_stream: WriterWithMockedStream):
164186
sent_message = await writer_and_stream.stream.from_client.get()
165187
assert expected_message == sent_message
166188

189+
async def test_write_a_message_with_tx(self, writer_and_stream_with_tx_identity: WriterWithMockedStream):
190+
data = "123".encode()
191+
now = datetime.datetime.now(datetime.timezone.utc)
192+
writer_and_stream_with_tx_identity.writer.write(
193+
[
194+
InternalMessage(
195+
PublicMessage(
196+
seqno=1,
197+
created_at=now,
198+
data=data,
199+
)
200+
)
201+
]
202+
)
203+
204+
expected_message = StreamWriteMessage.FromClient(
205+
StreamWriteMessage.WriteRequest(
206+
codec=Codec.CODEC_RAW,
207+
tx_identity=FAKE_TRANSACTION_IDENTITY,
208+
messages=[
209+
StreamWriteMessage.WriteRequest.MessageData(
210+
seq_no=1,
211+
created_at=now,
212+
data=data,
213+
metadata_items={},
214+
uncompressed_size=len(data),
215+
partitioning=None,
216+
)
217+
],
218+
)
219+
)
220+
221+
sent_message = await writer_and_stream_with_tx_identity.stream.from_client.get()
222+
assert expected_message == sent_message
223+
167224
async def test_update_token(self, stream: StreamMock):
168225
writer = await self.get_started_writer(stream, update_token_interval=0.1, get_token_function=lambda: "foo-bar")
169226
assert stream.from_client.empty()
@@ -264,7 +321,7 @@ def _create(self):
264321

265322
res = DoubleQueueWriters()
266323

267-
async def async_create(driver, init_message, token_getter):
324+
async def async_create(driver, init_message, token_getter, tx_identity):
268325
return res.get_first()
269326

270327
monkeypatch.setattr(WriterAsyncIOStream, "create", async_create)

ydb/query/transaction.py

+6
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
_apis,
1212
issues,
1313
)
14+
from .._grpc.grpcwrapper import ydb_topic as _ydb_topic
1415
from .._grpc.grpcwrapper import ydb_query as _ydb_query
1516
from ..connection import _RpcState as RpcState
1617

@@ -215,6 +216,11 @@ def tx_id(self) -> Optional[str]:
215216
"""
216217
return self._tx_state.tx_id
217218

219+
def _tx_identity(self) -> _ydb_topic.TransactionIdentity:
220+
if not self.tx_id:
221+
raise RuntimeError("Unable to get tx identity without started tx.")
222+
return _ydb_topic.TransactionIdentity(self.tx_id, self.session_id)
223+
218224
def _begin_call(self, settings: Optional[BaseRequestSettings]) -> "BaseQueryTxContext":
219225
self._tx_state._check_invalid_transition(QueryTxStateEnum.BEGINED)
220226

ydb/topic.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
PublicWriteResult as TopicWriteResult,
6666
)
6767

68+
from ydb._topic_writer.topic_writer_asyncio import TxWriterAsyncIO as TopicTxWriterAsyncIO
6869
from ydb._topic_writer.topic_writer_asyncio import WriterAsyncIO as TopicWriterAsyncIO
6970
from ._topic_writer.topic_writer_sync import WriterSync as TopicWriter
7071

@@ -294,7 +295,15 @@ def tx_writer(
294295
# If max_worker in the executor is 1 - then encoders will be called from the thread without parallel.
295296
encoder_executor: Optional[concurrent.futures.Executor] = None,
296297
) -> TopicTxWriterAsyncIO:
297-
298+
args = locals().copy()
299+
del args["self"]
300+
301+
settings = TopicWriterSettings(**args)
302+
303+
if not settings.encoder_executor:
304+
settings.encoder_executor = self._executor
305+
306+
return TopicTxWriterAsyncIO(tx=tx, driver=self._driver, settings=settings, _client=self)
298307

299308
def close(self):
300309
if self._closed:

0 commit comments

Comments
 (0)