Skip to content

Commit 34edcd5

Browse files
committed
sync writer
1 parent 920cb3e commit 34edcd5

File tree

3 files changed

+136
-2
lines changed

3 files changed

+136
-2
lines changed

tests/topics/test_topic_transactions.py

+62-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ async def callee(tx: ydb.aio.QueryTxContext):
3838
assert msg.data.decode() == "123"
3939

4040

41-
# @pytest.mark.skip("Not implemented yet.")
4241
class TestTopicTransactionalWriter:
4342
async def test_commit(self, driver: ydb.aio.Driver, topic_path, topic_reader: ydb.TopicReaderAsyncIO):
4443
async with ydb.aio.QuerySessionPool(driver) as pool:
@@ -65,3 +64,65 @@ async def callee(tx: ydb.aio.QueryTxContext):
6564

6665
with pytest.raises(asyncio.TimeoutError):
6766
await wait_for(topic_reader.receive_message(), 0.1)
67+
68+
async def test_no_msg_writter_in_error_case(
69+
self, driver: ydb.aio.Driver, topic_path, topic_reader: ydb.TopicReaderAsyncIO
70+
):
71+
async with ydb.aio.QuerySessionPool(driver) as pool:
72+
73+
async def callee(tx: ydb.aio.QueryTxContext):
74+
tx_writer = driver.topic_client.tx_writer(tx, topic_path)
75+
await tx_writer.write(ydb.TopicWriterMessage(data="123".encode()))
76+
77+
raise BaseException("error")
78+
79+
with pytest.raises(BaseException):
80+
await pool.retry_tx_async(callee)
81+
82+
with pytest.raises(asyncio.TimeoutError):
83+
await wait_for(topic_reader.receive_message(), 0.1)
84+
85+
86+
class TestTopicTransactionalWriterSync:
87+
def test_commit(self, driver_sync: ydb.Driver, topic_path, topic_reader_sync: ydb.TopicReader):
88+
with ydb.QuerySessionPool(driver_sync) as pool:
89+
90+
def callee(tx: ydb.QueryTxContext):
91+
tx_writer = driver_sync.topic_client.tx_writer(tx, topic_path)
92+
tx_writer.write(ydb.TopicWriterMessage(data="123".encode()))
93+
94+
pool.retry_tx_sync(callee)
95+
96+
msg = topic_reader_sync.receive_message(timeout=0.1)
97+
assert msg.data.decode() == "123"
98+
99+
def test_rollback(self, driver_sync: ydb.aio.Driver, topic_path, topic_reader_sync: ydb.TopicReader):
100+
with ydb.QuerySessionPool(driver_sync) as pool:
101+
102+
def callee(tx: ydb.QueryTxContext):
103+
tx_writer = driver_sync.topic_client.tx_writer(tx, topic_path)
104+
tx_writer.write(ydb.TopicWriterMessage(data="123".encode()))
105+
106+
tx.rollback()
107+
108+
pool.retry_tx_sync(callee)
109+
110+
with pytest.raises(TimeoutError):
111+
topic_reader_sync.receive_message(timeout=0.1)
112+
113+
def test_no_msg_writter_in_error_case(
114+
self, driver_sync: ydb.Driver, topic_path, topic_reader_sync: ydb.TopicReaderAsyncIO
115+
):
116+
with ydb.QuerySessionPool(driver_sync) as pool:
117+
118+
def callee(tx: ydb.QueryTxContext):
119+
tx_writer = driver_sync.topic_client.tx_writer(tx, topic_path)
120+
tx_writer.write(ydb.TopicWriterMessage(data="123".encode()))
121+
122+
raise BaseException("error")
123+
124+
with pytest.raises(BaseException):
125+
pool.retry_tx_sync(callee)
126+
127+
with pytest.raises(TimeoutError):
128+
topic_reader_sync.receive_message(timeout=0.1)

ydb/_topic_writer/topic_writer_sync.py

+43-1
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,20 @@
1414
TopicWriterClosedError,
1515
)
1616

17-
from .topic_writer_asyncio import WriterAsyncIO
17+
from ..query.base import TxListener
18+
19+
from .topic_writer_asyncio import (
20+
TxWriterAsyncIO,
21+
WriterAsyncIO,
22+
)
1823
from .._topic_common.common import (
1924
_get_shared_event_loop,
2025
TimeoutType,
2126
CallFromSyncToAsync,
2227
)
2328

29+
if typing.TYPE_CHECKING:
30+
from ..query.transaction import BaseQueryTxContext
2431

2532
class WriterSync:
2633
_caller: CallFromSyncToAsync
@@ -122,3 +129,38 @@ def write_with_ack(
122129
self._check_closed()
123130

124131
return self._caller.unsafe_call_with_result(self._async_writer.write_with_ack(messages), timeout=timeout)
132+
133+
134+
class TxWriterSync(WriterSync, TxListener):
135+
def __init__(
136+
self,
137+
tx: "BaseQueryTxContext",
138+
driver: SupportedDriverType,
139+
settings: PublicWriterSettings,
140+
*,
141+
eventloop: Optional[asyncio.AbstractEventLoop] = None,
142+
_parent=None,
143+
):
144+
145+
self._closed = False
146+
147+
if eventloop:
148+
loop = eventloop
149+
else:
150+
loop = _get_shared_event_loop()
151+
152+
self._caller = CallFromSyncToAsync(loop)
153+
154+
async def create_async_writer():
155+
return TxWriterAsyncIO(tx, driver, settings)
156+
157+
self._async_writer = self._caller.safe_call_with_result(create_async_writer(), None)
158+
self._parent = _parent
159+
160+
tx._add_listener(self)
161+
162+
def _on_before_commit(self):
163+
self.close()
164+
165+
def _on_before_rollback(self):
166+
self.close()

ydb/topic.py

+31
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@
6868
from ydb._topic_writer.topic_writer_asyncio import TxWriterAsyncIO as TopicTxWriterAsyncIO
6969
from ydb._topic_writer.topic_writer_asyncio import WriterAsyncIO as TopicWriterAsyncIO
7070
from ._topic_writer.topic_writer_sync import WriterSync as TopicWriter
71+
from ._topic_writer.topic_writer_sync import TxWriterSync as TxTopicWriter
7172

7273
from ._topic_common.common import (
7374
wrap_operation as _wrap_operation,
@@ -517,6 +518,36 @@ def writer(
517518

518519
return TopicWriter(self._driver, settings, _parent=self)
519520

521+
def tx_writer(
522+
self,
523+
tx,
524+
topic,
525+
*,
526+
producer_id: Optional[str] = None, # default - random
527+
session_metadata: Mapping[str, str] = None,
528+
partition_id: Union[int, None] = None,
529+
auto_seqno: bool = True,
530+
auto_created_at: bool = True,
531+
codec: Optional[TopicCodec] = None, # default mean auto-select
532+
# encoders: map[codec_code] func(encoded_bytes)->decoded_bytes
533+
# the func will be called from multiply threads in parallel.
534+
encoders: Optional[Mapping[_ydb_topic_public_types.PublicCodec, Callable[[bytes], bytes]]] = None,
535+
# custom encoder executor for call builtin and custom decoders. If None - use shared executor pool.
536+
# If max_worker in the executor is 1 - then encoders will be called from the thread without parallel.
537+
encoder_executor: Optional[concurrent.futures.Executor] = None, # default shared client executor pool
538+
) -> TopicWriter:
539+
args = locals().copy()
540+
del args["self"]
541+
del args["tx"]
542+
self._check_closed()
543+
544+
settings = TopicWriterSettings(**args)
545+
546+
if not settings.encoder_executor:
547+
settings.encoder_executor = self._executor
548+
549+
return TxTopicWriter(tx, self._driver, settings, _parent=self)
550+
520551
def close(self):
521552
if self._closed:
522553
return

0 commit comments

Comments
 (0)