diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index cbc0bc67..adbf779f 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -18,21 +18,15 @@ jobs: fail-fast: false matrix: python-version: [3.8, 3.9] - environment: [py-proto5, py-tls-proto5, py-proto4, py-tls-proto4, py-proto3, py-tls-proto3] - folder: [ydb, tests --ignore=tests/topics, tests/topics] + environment: [py, py-tls, py-proto4, py-tls-proto4, py-proto3, py-tls-proto3] + folder: [ydb, tests] exclude: - - environment: py-tls-proto5 + - environment: py-tls folder: ydb - environment: py-tls-proto4 folder: ydb - environment: py-tls-proto3 folder: ydb - - environment: py-tls-proto5 - folder: tests/topics - - environment: py-tls-proto4 - folder: tests/topics - - environment: py-tls-proto3 - folder: tests/topics steps: - uses: actions/checkout@v1 diff --git a/examples/topic/topic_transactions_async_example.py b/examples/topic/topic_transactions_async_example.py new file mode 100644 index 00000000..cae61063 --- /dev/null +++ b/examples/topic/topic_transactions_async_example.py @@ -0,0 +1,86 @@ +import asyncio +import argparse +import logging +import ydb + + +async def connect(endpoint: str, database: str) -> ydb.aio.Driver: + config = ydb.DriverConfig(endpoint=endpoint, database=database) + config.credentials = ydb.credentials_from_env_variables() + driver = ydb.aio.Driver(config) + await driver.wait(5, fail_fast=True) + return driver + + +async def create_topic(driver: ydb.aio.Driver, topic: str, consumer: str): + try: + await driver.topic_client.drop_topic(topic) + except ydb.SchemeError: + pass + + await driver.topic_client.create_topic(topic, consumers=[consumer]) + + +async def write_with_tx_example(driver: ydb.aio.Driver, topic: str, message_count: int = 10): + async with ydb.aio.QuerySessionPool(driver) as session_pool: + + async def callee(tx: ydb.aio.QueryTxContext): + tx_writer: ydb.TopicTxWriterAsyncIO = driver.topic_client.tx_writer(tx, topic) + + for i in range(message_count): + async with await tx.execute(query=f"select {i} as res;") as result_stream: + async for result_set in result_stream: + message = str(result_set.rows[0]["res"]) + await tx_writer.write(ydb.TopicWriterMessage(message)) + print(f"Message {result_set.rows[0]['res']} was written with tx.") + + await session_pool.retry_tx_async(callee) + + +async def read_with_tx_example(driver: ydb.aio.Driver, topic: str, consumer: str, message_count: int = 10): + async with driver.topic_client.reader(topic, consumer) as reader: + async with ydb.aio.QuerySessionPool(driver) as session_pool: + for _ in range(message_count): + + async def callee(tx: ydb.aio.QueryTxContext): + batch = await reader.receive_batch_with_tx(tx, max_messages=1) + print(f"Message {batch.messages[0].data.decode()} was read with tx.") + + await session_pool.retry_tx_async(callee) + + +async def main(): + parser = argparse.ArgumentParser( + formatter_class=argparse.RawDescriptionHelpFormatter, + description="""YDB topic basic example.\n""", + ) + parser.add_argument("-d", "--database", default="/local", help="Name of the database to use") + parser.add_argument("-e", "--endpoint", default="grpc://localhost:2136", help="Endpoint url to use") + parser.add_argument("-p", "--path", default="test-topic", help="Topic name") + parser.add_argument("-c", "--consumer", default="consumer", help="Consumer name") + parser.add_argument("-v", "--verbose", default=False, action="store_true") + parser.add_argument( + "-s", + "--skip-drop-and-create-topic", + default=False, + action="store_true", + help="Use existed topic, skip remove it and re-create", + ) + + args = parser.parse_args() + + if args.verbose: + logger = logging.getLogger("topicexample") + logger.setLevel(logging.DEBUG) + logger.addHandler(logging.StreamHandler()) + + async with await connect(args.endpoint, args.database) as driver: + if not args.skip_drop_and_create_topic: + await create_topic(driver, args.path, args.consumer) + + await write_with_tx_example(driver, args.path) + await read_with_tx_example(driver, args.path, args.consumer) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/topic/topic_transactions_example.py b/examples/topic/topic_transactions_example.py new file mode 100644 index 00000000..0f7432e7 --- /dev/null +++ b/examples/topic/topic_transactions_example.py @@ -0,0 +1,85 @@ +import argparse +import logging +import ydb + + +def connect(endpoint: str, database: str) -> ydb.Driver: + config = ydb.DriverConfig(endpoint=endpoint, database=database) + config.credentials = ydb.credentials_from_env_variables() + driver = ydb.Driver(config) + driver.wait(5, fail_fast=True) + return driver + + +def create_topic(driver: ydb.Driver, topic: str, consumer: str): + try: + driver.topic_client.drop_topic(topic) + except ydb.SchemeError: + pass + + driver.topic_client.create_topic(topic, consumers=[consumer]) + + +def write_with_tx_example(driver: ydb.Driver, topic: str, message_count: int = 10): + with ydb.QuerySessionPool(driver) as session_pool: + + def callee(tx: ydb.QueryTxContext): + tx_writer: ydb.TopicTxWriter = driver.topic_client.tx_writer(tx, topic) + + for i in range(message_count): + result_stream = tx.execute(query=f"select {i} as res;") + for result_set in result_stream: + message = str(result_set.rows[0]["res"]) + tx_writer.write(ydb.TopicWriterMessage(message)) + print(f"Message {message} was written with tx.") + + session_pool.retry_tx_sync(callee) + + +def read_with_tx_example(driver: ydb.Driver, topic: str, consumer: str, message_count: int = 10): + with driver.topic_client.reader(topic, consumer) as reader: + with ydb.QuerySessionPool(driver) as session_pool: + for _ in range(message_count): + + def callee(tx: ydb.QueryTxContext): + batch = reader.receive_batch_with_tx(tx, max_messages=1) + print(f"Message {batch.messages[0].data.decode()} was read with tx.") + + session_pool.retry_tx_sync(callee) + + +def main(): + parser = argparse.ArgumentParser( + formatter_class=argparse.RawDescriptionHelpFormatter, + description="""YDB topic basic example.\n""", + ) + parser.add_argument("-d", "--database", default="/local", help="Name of the database to use") + parser.add_argument("-e", "--endpoint", default="grpc://localhost:2136", help="Endpoint url to use") + parser.add_argument("-p", "--path", default="test-topic", help="Topic name") + parser.add_argument("-c", "--consumer", default="consumer", help="Consumer name") + parser.add_argument("-v", "--verbose", default=False, action="store_true") + parser.add_argument( + "-s", + "--skip-drop-and-create-topic", + default=False, + action="store_true", + help="Use existed topic, skip remove it and re-create", + ) + + args = parser.parse_args() + + if args.verbose: + logger = logging.getLogger("topicexample") + logger.setLevel(logging.DEBUG) + logger.addHandler(logging.StreamHandler()) + + with connect(args.endpoint, args.database) as driver: + if not args.skip_drop_and_create_topic: + create_topic(driver, args.path, args.consumer) + + write_with_tx_example(driver, args.path) + read_with_tx_example(driver, args.path, args.consumer) + + +if __name__ == "__main__": + main() diff --git a/tests/query/test_query_transaction.py b/tests/query/test_query_transaction.py index dfc88897..4533e528 100644 --- a/tests/query/test_query_transaction.py +++ b/tests/query/test_query_transaction.py @@ -92,3 +92,15 @@ def test_execute_two_results(self, tx: QueryTxContext): assert res == [[1], [2]] assert counter == 2 + + def test_tx_identity_before_begin_raises(self, tx: QueryTxContext): + with pytest.raises(RuntimeError): + tx._tx_identity() + + def test_tx_identity_after_begin_works(self, tx: QueryTxContext): + tx.begin() + + identity = tx._tx_identity() + + assert identity.tx_id == tx.tx_id + assert identity.session_id == tx.session_id diff --git a/tests/topics/test_topic_reader.py b/tests/topics/test_topic_reader.py index 23b5b4be..623dc8c0 100644 --- a/tests/topics/test_topic_reader.py +++ b/tests/topics/test_topic_reader.py @@ -174,12 +174,13 @@ def test_read_and_commit_with_close_reader(self, driver_sync, topic_with_message assert message != message2 def test_read_and_commit_with_ack(self, driver_sync, topic_with_messages, topic_consumer): - reader = driver_sync.topic_client.reader(topic_with_messages, topic_consumer) - message = reader.receive_message() - reader.commit_with_ack(message) + with driver_sync.topic_client.reader(topic_with_messages, topic_consumer) as reader: + message = reader.receive_message() + reader.commit_with_ack(message) + + with driver_sync.topic_client.reader(topic_with_messages, topic_consumer) as reader: + batch = reader.receive_batch() - reader = driver_sync.topic_client.reader(topic_with_messages, topic_consumer) - batch = reader.receive_batch() assert message != batch.messages[0] def test_read_compressed_messages(self, driver_sync, topic_path, topic_consumer): @@ -247,3 +248,6 @@ async def wait(fut): datas.sort() assert datas == ["10", "11"] + + await reader0.close() + await reader1.close() diff --git a/tests/topics/test_topic_transactions.py b/tests/topics/test_topic_transactions.py new file mode 100644 index 00000000..b79df740 --- /dev/null +++ b/tests/topics/test_topic_transactions.py @@ -0,0 +1,469 @@ +import asyncio +from asyncio import wait_for +import pytest +from unittest import mock +import ydb + +DEFAULT_TIMEOUT = 0.5 +DEFAULT_RETRY_SETTINGS = ydb.RetrySettings(max_retries=1) + + +@pytest.mark.asyncio +class TestTopicTransactionalReader: + async def test_commit(self, driver: ydb.aio.Driver, topic_with_messages, topic_consumer): + async with ydb.aio.QuerySessionPool(driver) as pool: + async with driver.topic_client.reader(topic_with_messages, topic_consumer) as reader: + + async def callee(tx: ydb.aio.QueryTxContext): + batch = await wait_for(reader.receive_batch_with_tx(tx, max_messages=1), DEFAULT_TIMEOUT) + assert len(batch.messages) == 1 + assert batch.messages[0].data.decode() == "123" + + batch = await wait_for(reader.receive_batch_with_tx(tx, max_messages=1), DEFAULT_TIMEOUT) + assert len(batch.messages) == 1 + assert batch.messages[0].data.decode() == "456" + + await pool.retry_tx_async(callee, retry_settings=DEFAULT_RETRY_SETTINGS) + assert len(reader._reconnector._tx_to_batches_map) == 0 + + async with driver.topic_client.reader(topic_with_messages, topic_consumer) as reader: + msg = await wait_for(reader.receive_message(), DEFAULT_TIMEOUT) + assert msg.data.decode() == "789" + + async def test_rollback(self, driver: ydb.aio.Driver, topic_with_messages, topic_consumer): + async with driver.topic_client.reader(topic_with_messages, topic_consumer) as reader: + async with ydb.aio.QuerySessionPool(driver) as pool: + + async def callee(tx: ydb.aio.QueryTxContext): + batch = await wait_for(reader.receive_batch_with_tx(tx, max_messages=1), DEFAULT_TIMEOUT) + assert len(batch.messages) == 1 + assert batch.messages[0].data.decode() == "123" + + await tx.rollback() + + await pool.retry_tx_async(callee, retry_settings=DEFAULT_RETRY_SETTINGS) + assert len(reader._reconnector._tx_to_batches_map) == 0 + + msg = await wait_for(reader.receive_message(), DEFAULT_TIMEOUT) + assert msg.data.decode() == "123" + + async def test_tx_failed_if_update_offsets_call_failed( + self, driver: ydb.aio.Driver, topic_with_messages, topic_consumer + ): + async with driver.topic_client.reader(topic_with_messages, topic_consumer) as reader: + async with ydb.aio.QuerySessionPool(driver) as pool: + with mock.patch.object( + reader._reconnector, + "_do_commit_batches_with_tx_call", + side_effect=ydb.Error("Update offsets in tx failed"), + ): + + async def callee(tx: ydb.aio.QueryTxContext): + batch = await wait_for(reader.receive_batch_with_tx(tx, max_messages=1), DEFAULT_TIMEOUT) + assert len(batch.messages) == 1 + assert batch.messages[0].data.decode() == "123" + + with pytest.raises(ydb.Error, match="Transaction was failed"): + await pool.retry_tx_async(callee, retry_settings=DEFAULT_RETRY_SETTINGS) + + assert len(reader._reconnector._tx_to_batches_map) == 0 + + msg = await wait_for(reader.receive_message(), DEFAULT_TIMEOUT) + assert msg.data.decode() == "123" + + async def test_error_in_lambda(self, driver: ydb.aio.Driver, topic_with_messages, topic_consumer): + async with driver.topic_client.reader(topic_with_messages, topic_consumer) as reader: + async with ydb.aio.QuerySessionPool(driver) as pool: + + async def callee(tx: ydb.aio.QueryTxContext): + batch = await wait_for(reader.receive_batch_with_tx(tx, max_messages=1), DEFAULT_TIMEOUT) + assert len(batch.messages) == 1 + assert batch.messages[0].data.decode() == "123" + + raise RuntimeError("Something went wrong") + + with pytest.raises(RuntimeError): + await pool.retry_tx_async(callee, retry_settings=DEFAULT_RETRY_SETTINGS) + + assert len(reader._reconnector._tx_to_batches_map) == 0 + + msg = await wait_for(reader.receive_message(), DEFAULT_TIMEOUT) + assert msg.data.decode() == "123" + + async def test_error_during_commit(self, driver: ydb.aio.Driver, topic_with_messages, topic_consumer): + async with driver.topic_client.reader(topic_with_messages, topic_consumer) as reader: + async with ydb.aio.QuerySessionPool(driver) as pool: + + async def callee(tx: ydb.aio.QueryTxContext): + with mock.patch.object( + tx, + "_commit_call", + side_effect=ydb.Unavailable("YDB Unavailable"), + ): + batch = await wait_for(reader.receive_batch_with_tx(tx, max_messages=1), DEFAULT_TIMEOUT) + assert len(batch.messages) == 1 + assert batch.messages[0].data.decode() == "123" + + await tx.commit() + + with pytest.raises(ydb.Unavailable): + await pool.retry_tx_async(callee, retry_settings=DEFAULT_RETRY_SETTINGS) + + assert len(reader._reconnector._tx_to_batches_map) == 0 + + msg = await wait_for(reader.receive_message(), DEFAULT_TIMEOUT) + assert msg.data.decode() == "123" + + async def test_error_during_rollback(self, driver: ydb.aio.Driver, topic_with_messages, topic_consumer): + async with driver.topic_client.reader(topic_with_messages, topic_consumer) as reader: + async with ydb.aio.QuerySessionPool(driver) as pool: + + async def callee(tx: ydb.aio.QueryTxContext): + with mock.patch.object( + tx, + "_rollback_call", + side_effect=ydb.Unavailable("YDB Unavailable"), + ): + batch = await wait_for(reader.receive_batch_with_tx(tx, max_messages=1), DEFAULT_TIMEOUT) + assert len(batch.messages) == 1 + assert batch.messages[0].data.decode() == "123" + + await tx.rollback() + + with pytest.raises(ydb.Unavailable): + await pool.retry_tx_async(callee, retry_settings=DEFAULT_RETRY_SETTINGS) + + assert len(reader._reconnector._tx_to_batches_map) == 0 + + msg = await wait_for(reader.receive_message(), DEFAULT_TIMEOUT) + assert msg.data.decode() == "123" + + +class TestTopicTransactionalReaderSync: + def test_commit(self, driver_sync: ydb.Driver, topic_with_messages, topic_consumer): + with ydb.QuerySessionPool(driver_sync) as pool: + with driver_sync.topic_client.reader(topic_with_messages, topic_consumer) as reader: + + def callee(tx: ydb.QueryTxContext): + batch = reader.receive_batch_with_tx(tx, max_messages=1, timeout=DEFAULT_TIMEOUT) + assert len(batch.messages) == 1 + assert batch.messages[0].data.decode() == "123" + + batch = reader.receive_batch_with_tx(tx, max_messages=1, timeout=DEFAULT_TIMEOUT) + assert len(batch.messages) == 1 + assert batch.messages[0].data.decode() == "456" + + pool.retry_tx_sync(callee, retry_settings=DEFAULT_RETRY_SETTINGS) + assert len(reader._async_reader._reconnector._tx_to_batches_map) == 0 + + with driver_sync.topic_client.reader(topic_with_messages, topic_consumer) as reader: + msg = reader.receive_message(timeout=DEFAULT_TIMEOUT) + assert msg.data.decode() == "789" + + def test_rollback(self, driver_sync: ydb.Driver, topic_with_messages, topic_consumer): + with driver_sync.topic_client.reader(topic_with_messages, topic_consumer) as reader: + with ydb.QuerySessionPool(driver_sync) as pool: + + def callee(tx: ydb.QueryTxContext): + batch = reader.receive_batch_with_tx(tx, max_messages=1, timeout=DEFAULT_TIMEOUT) + assert len(batch.messages) == 1 + assert batch.messages[0].data.decode() == "123" + + tx.rollback() + + pool.retry_tx_sync(callee, retry_settings=DEFAULT_RETRY_SETTINGS) + assert len(reader._async_reader._reconnector._tx_to_batches_map) == 0 + + msg = reader.receive_message(timeout=DEFAULT_TIMEOUT) + assert msg.data.decode() == "123" + + def test_tx_failed_if_update_offsets_call_failed( + self, driver_sync: ydb.Driver, topic_with_messages, topic_consumer + ): + with driver_sync.topic_client.reader(topic_with_messages, topic_consumer) as reader: + with ydb.QuerySessionPool(driver_sync) as pool: + with mock.patch.object( + reader._async_reader._reconnector, + "_do_commit_batches_with_tx_call", + side_effect=ydb.Error("Update offsets in tx failed"), + ): + + def callee(tx: ydb.QueryTxContext): + batch = reader.receive_batch_with_tx(tx, max_messages=1, timeout=DEFAULT_TIMEOUT) + assert len(batch.messages) == 1 + assert batch.messages[0].data.decode() == "123" + + with pytest.raises(ydb.Error, match="Transaction was failed"): + pool.retry_tx_sync(callee, retry_settings=DEFAULT_RETRY_SETTINGS) + + assert len(reader._async_reader._reconnector._tx_to_batches_map) == 0 + + msg = reader.receive_message(timeout=DEFAULT_TIMEOUT) + assert msg.data.decode() == "123" + + def test_error_in_lambda(self, driver_sync: ydb.Driver, topic_with_messages, topic_consumer): + with driver_sync.topic_client.reader(topic_with_messages, topic_consumer) as reader: + with ydb.QuerySessionPool(driver_sync) as pool: + + def callee(tx: ydb.QueryTxContext): + batch = reader.receive_batch_with_tx(tx, max_messages=1, timeout=DEFAULT_TIMEOUT) + assert len(batch.messages) == 1 + assert batch.messages[0].data.decode() == "123" + + raise RuntimeError("Something went wrong") + + with pytest.raises(RuntimeError): + pool.retry_tx_sync(callee, retry_settings=DEFAULT_RETRY_SETTINGS) + + assert len(reader._async_reader._reconnector._tx_to_batches_map) == 0 + + msg = reader.receive_message(timeout=DEFAULT_TIMEOUT) + assert msg.data.decode() == "123" + + def test_error_during_commit(self, driver_sync: ydb.Driver, topic_with_messages, topic_consumer): + with driver_sync.topic_client.reader(topic_with_messages, topic_consumer) as reader: + with ydb.QuerySessionPool(driver_sync) as pool: + + def callee(tx: ydb.QueryTxContext): + with mock.patch.object( + tx, + "_commit_call", + side_effect=ydb.Unavailable("YDB Unavailable"), + ): + batch = reader.receive_batch_with_tx(tx, max_messages=1, timeout=DEFAULT_TIMEOUT) + assert len(batch.messages) == 1 + assert batch.messages[0].data.decode() == "123" + + tx.commit() + + with pytest.raises(ydb.Unavailable): + pool.retry_tx_sync(callee, retry_settings=DEFAULT_RETRY_SETTINGS) + + assert len(reader._async_reader._reconnector._tx_to_batches_map) == 0 + + msg = reader.receive_message(timeout=DEFAULT_TIMEOUT) + assert msg.data.decode() == "123" + + def test_error_during_rollback(self, driver_sync: ydb.Driver, topic_with_messages, topic_consumer): + with driver_sync.topic_client.reader(topic_with_messages, topic_consumer) as reader: + with ydb.QuerySessionPool(driver_sync) as pool: + + def callee(tx: ydb.QueryTxContext): + with mock.patch.object( + tx, + "_rollback_call", + side_effect=ydb.Unavailable("YDB Unavailable"), + ): + batch = reader.receive_batch_with_tx(tx, max_messages=1, timeout=DEFAULT_TIMEOUT) + assert len(batch.messages) == 1 + assert batch.messages[0].data.decode() == "123" + + tx.rollback() + + with pytest.raises(ydb.Unavailable): + pool.retry_tx_sync(callee, retry_settings=DEFAULT_RETRY_SETTINGS) + + assert len(reader._async_reader._reconnector._tx_to_batches_map) == 0 + + msg = reader.receive_message(timeout=DEFAULT_TIMEOUT) + assert msg.data.decode() == "123" + + +class TestTopicTransactionalWriter: + async def test_commit(self, driver: ydb.aio.Driver, topic_path, topic_reader: ydb.TopicReaderAsyncIO): + async with ydb.aio.QuerySessionPool(driver) as pool: + + async def callee(tx: ydb.aio.QueryTxContext): + tx_writer = driver.topic_client.tx_writer(tx, topic_path) + await tx_writer.write(ydb.TopicWriterMessage(data="123".encode())) + + await pool.retry_tx_async(callee, retry_settings=DEFAULT_RETRY_SETTINGS) + + msg = await wait_for(topic_reader.receive_message(), 0.1) + assert msg.data.decode() == "123" + + async def test_rollback(self, driver: ydb.aio.Driver, topic_path, topic_reader: ydb.TopicReaderAsyncIO): + async with ydb.aio.QuerySessionPool(driver) as pool: + + async def callee(tx: ydb.aio.QueryTxContext): + tx_writer = driver.topic_client.tx_writer(tx, topic_path) + await tx_writer.write(ydb.TopicWriterMessage(data="123".encode())) + + await tx.rollback() + + await pool.retry_tx_async(callee, retry_settings=DEFAULT_RETRY_SETTINGS) + + with pytest.raises(asyncio.TimeoutError): + await wait_for(topic_reader.receive_message(), 0.1) + + async def test_no_msg_written_in_error_case( + self, driver: ydb.aio.Driver, topic_path, topic_reader: ydb.TopicReaderAsyncIO + ): + async with ydb.aio.QuerySessionPool(driver) as pool: + + async def callee(tx: ydb.aio.QueryTxContext): + tx_writer = driver.topic_client.tx_writer(tx, topic_path) + await tx_writer.write(ydb.TopicWriterMessage(data="123".encode())) + + raise BaseException("error") + + with pytest.raises(BaseException): + await pool.retry_tx_async(callee, retry_settings=DEFAULT_RETRY_SETTINGS) + + with pytest.raises(asyncio.TimeoutError): + await wait_for(topic_reader.receive_message(), 0.1) + + async def test_no_msg_written_in_tx_commit_error( + self, driver: ydb.aio.Driver, topic_path, topic_reader: ydb.TopicReaderAsyncIO + ): + async with ydb.aio.QuerySessionPool(driver) as pool: + + async def callee(tx: ydb.aio.QueryTxContext): + with mock.patch.object( + tx, + "_commit_call", + side_effect=ydb.Unavailable("YDB Unavailable"), + ): + tx_writer = driver.topic_client.tx_writer(tx, topic_path) + await tx_writer.write(ydb.TopicWriterMessage(data="123".encode())) + await tx.commit() + + with pytest.raises(ydb.Unavailable): + await pool.retry_tx_async(callee, retry_settings=DEFAULT_RETRY_SETTINGS) + + with pytest.raises(asyncio.TimeoutError): + await wait_for(topic_reader.receive_message(), 0.1) + + async def test_msg_written_exactly_once_with_retries( + self, driver: ydb.aio.Driver, topic_path, topic_reader: ydb.TopicReaderAsyncIO + ): + error_raised = False + async with ydb.aio.QuerySessionPool(driver) as pool: + + async def callee(tx: ydb.aio.QueryTxContext): + nonlocal error_raised + tx_writer = driver.topic_client.tx_writer(tx, topic_path) + await tx_writer.write(ydb.TopicWriterMessage(data="123".encode())) + + if not error_raised: + error_raised = True + raise ydb.issues.Unavailable("some retriable error") + + await pool.retry_tx_async(callee, retry_settings=DEFAULT_RETRY_SETTINGS) + + msg = await wait_for(topic_reader.receive_message(), 0.1) + assert msg.data.decode() == "123" + + with pytest.raises(asyncio.TimeoutError): + await wait_for(topic_reader.receive_message(), 0.1) + + async def test_writes_do_not_conflict_with_executes(self, driver: ydb.aio.Driver, topic_path): + async with ydb.aio.QuerySessionPool(driver) as pool: + + async def callee(tx: ydb.aio.QueryTxContext): + tx_writer = driver.topic_client.tx_writer(tx, topic_path) + for _ in range(3): + async with await tx.execute("select 1"): + await tx_writer.write(ydb.TopicWriterMessage(data="123".encode())) + + await pool.retry_tx_async(callee, retry_settings=DEFAULT_RETRY_SETTINGS) + + +class TestTopicTransactionalWriterSync: + def test_commit(self, driver_sync: ydb.Driver, topic_path, topic_reader_sync: ydb.TopicReader): + with ydb.QuerySessionPool(driver_sync) as pool: + + def callee(tx: ydb.QueryTxContext): + tx_writer = driver_sync.topic_client.tx_writer(tx, topic_path) + tx_writer.write(ydb.TopicWriterMessage(data="123".encode())) + + pool.retry_tx_sync(callee, retry_settings=DEFAULT_RETRY_SETTINGS) + + msg = topic_reader_sync.receive_message(timeout=DEFAULT_TIMEOUT) + assert msg.data.decode() == "123" + + def test_rollback(self, driver_sync: ydb.Driver, topic_path, topic_reader_sync: ydb.TopicReader): + with ydb.QuerySessionPool(driver_sync) as pool: + + def callee(tx: ydb.QueryTxContext): + tx_writer = driver_sync.topic_client.tx_writer(tx, topic_path) + tx_writer.write(ydb.TopicWriterMessage(data="123".encode())) + + tx.rollback() + + pool.retry_tx_sync(callee, retry_settings=DEFAULT_RETRY_SETTINGS) + + with pytest.raises(TimeoutError): + topic_reader_sync.receive_message(timeout=DEFAULT_TIMEOUT) + + def test_no_msg_written_in_error_case( + self, driver_sync: ydb.Driver, topic_path, topic_reader_sync: ydb.TopicReader + ): + with ydb.QuerySessionPool(driver_sync) as pool: + + def callee(tx: ydb.QueryTxContext): + tx_writer = driver_sync.topic_client.tx_writer(tx, topic_path) + tx_writer.write(ydb.TopicWriterMessage(data="123".encode())) + + raise BaseException("error") + + with pytest.raises(BaseException): + pool.retry_tx_sync(callee, retry_settings=DEFAULT_RETRY_SETTINGS) + + with pytest.raises(TimeoutError): + topic_reader_sync.receive_message(timeout=DEFAULT_TIMEOUT) + + def test_no_msg_written_in_tx_commit_error( + self, driver_sync: ydb.Driver, topic_path, topic_reader_sync: ydb.TopicReader + ): + with ydb.QuerySessionPool(driver_sync) as pool: + + def callee(tx: ydb.QueryTxContext): + with mock.patch.object( + tx, + "_commit_call", + side_effect=ydb.Unavailable("YDB Unavailable"), + ): + tx_writer = driver_sync.topic_client.tx_writer(tx, topic_path) + tx_writer.write(ydb.TopicWriterMessage(data="123".encode())) + tx.commit() + + with pytest.raises(ydb.Unavailable): + pool.retry_tx_sync(callee, retry_settings=DEFAULT_RETRY_SETTINGS) + + with pytest.raises(TimeoutError): + topic_reader_sync.receive_message(timeout=DEFAULT_TIMEOUT) + + def test_msg_written_exactly_once_with_retries( + self, driver_sync: ydb.Driver, topic_path, topic_reader_sync: ydb.TopicReader + ): + error_raised = False + with ydb.QuerySessionPool(driver_sync) as pool: + + def callee(tx: ydb.QueryTxContext): + nonlocal error_raised + tx_writer = driver_sync.topic_client.tx_writer(tx, topic_path) + tx_writer.write(ydb.TopicWriterMessage(data="123".encode())) + + if not error_raised: + error_raised = True + raise ydb.issues.Unavailable("some retriable error") + + pool.retry_tx_sync(callee, retry_settings=DEFAULT_RETRY_SETTINGS) + + msg = topic_reader_sync.receive_message(timeout=DEFAULT_TIMEOUT) + assert msg.data.decode() == "123" + + with pytest.raises(TimeoutError): + topic_reader_sync.receive_message(timeout=DEFAULT_TIMEOUT) + + def test_writes_do_not_conflict_with_executes(self, driver_sync: ydb.Driver, topic_path): + with ydb.QuerySessionPool(driver_sync) as pool: + + def callee(tx: ydb.QueryTxContext): + tx_writer = driver_sync.topic_client.tx_writer(tx, topic_path) + for _ in range(3): + with tx.execute("select 1"): + tx_writer.write(ydb.TopicWriterMessage(data="123".encode())) + + pool.retry_tx_sync(callee, retry_settings=DEFAULT_RETRY_SETTINGS) diff --git a/tox.ini b/tox.ini index df029d2a..f91e7d8a 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = py-proto5,py-proto4,py-proto3,py-tls-proto5,py-tls-proto4,py-tls-proto3,style,pylint,black,protoc,py-cov-proto4 +envlist = py,py-proto4,py-proto3,py-tls,py-tls-proto4,py-tls-proto3,style,pylint,black,protoc,py-cov-proto4 minversion = 4.2.6 skipsdist = True ignore_basepython_conflict = true @@ -30,7 +30,7 @@ deps = -r{toxinidir}/test-requirements.txt protobuf<4.0.0 -[testenv:py-proto5] +[testenv:py] commands = pytest -v -m "not tls" --docker-compose-remove-volumes --docker-compose=docker-compose.yml {posargs} deps = @@ -60,7 +60,7 @@ deps = -r{toxinidir}/test-requirements.txt protobuf<4.0.0 -[testenv:py-tls-proto5] +[testenv:py-tls] commands = pytest -v -m tls --docker-compose-remove-volumes --docker-compose=docker-compose-tls.yml {posargs} deps = diff --git a/ydb/_apis.py b/ydb/_apis.py index 2a9a14e8..e54f25d2 100644 --- a/ydb/_apis.py +++ b/ydb/_apis.py @@ -116,6 +116,7 @@ class TopicService(object): DropTopic = "DropTopic" StreamRead = "StreamRead" StreamWrite = "StreamWrite" + UpdateOffsetsInTransaction = "UpdateOffsetsInTransaction" class QueryService(object): diff --git a/ydb/_errors.py b/ydb/_errors.py index 17002d25..1e2308ef 100644 --- a/ydb/_errors.py +++ b/ydb/_errors.py @@ -5,6 +5,7 @@ _errors_retriable_fast_backoff_types = [ issues.Unavailable, + issues.ClientInternalError, ] _errors_retriable_slow_backoff_types = [ issues.Aborted, diff --git a/ydb/_grpc/grpcwrapper/common_utils.py b/ydb/_grpc/grpcwrapper/common_utils.py index 7fb5b684..6a7275b4 100644 --- a/ydb/_grpc/grpcwrapper/common_utils.py +++ b/ydb/_grpc/grpcwrapper/common_utils.py @@ -161,9 +161,6 @@ def __init__(self, convert_server_grpc_to_wrapper): self._stream_call = None self._wait_executor = None - def __del__(self): - self._clean_executor(wait=False) - async def start(self, driver: SupportedDriverType, stub, method): if asyncio.iscoroutinefunction(driver.__call__): await self._start_asyncio_driver(driver, stub, method) diff --git a/ydb/_grpc/grpcwrapper/ydb_topic.py b/ydb/_grpc/grpcwrapper/ydb_topic.py index 600dfb69..6db50a11 100644 --- a/ydb/_grpc/grpcwrapper/ydb_topic.py +++ b/ydb/_grpc/grpcwrapper/ydb_topic.py @@ -142,6 +142,18 @@ def from_proto(msg: ydb_topic_pb2.UpdateTokenResponse) -> typing.Any: ######################################################################################################################## +@dataclass +class TransactionIdentity(IToProto): + tx_id: str + session_id: str + + def to_proto(self) -> ydb_topic_pb2.TransactionIdentity: + return ydb_topic_pb2.TransactionIdentity( + id=self.tx_id, + session=self.session_id, + ) + + class StreamWriteMessage: @dataclass() class InitRequest(IToProto): @@ -200,6 +212,7 @@ def from_proto( class WriteRequest(IToProto): messages: typing.List["StreamWriteMessage.WriteRequest.MessageData"] codec: int + tx_identity: Optional[TransactionIdentity] @dataclass class MessageData(IToProto): @@ -238,6 +251,9 @@ def to_proto(self) -> ydb_topic_pb2.StreamWriteMessage.WriteRequest: proto = ydb_topic_pb2.StreamWriteMessage.WriteRequest() proto.codec = self.codec + if self.tx_identity is not None: + proto.tx.CopyFrom(self.tx_identity.to_proto()) + for message in self.messages: proto_mess = proto.messages.add() proto_mess.CopyFrom(message.to_proto()) @@ -298,6 +314,8 @@ def from_proto(cls, proto_ack: ydb_topic_pb2.StreamWriteMessage.WriteResponse.Wr ) except ValueError: message_write_status = reason + elif proto_ack.HasField("written_in_tx"): + message_write_status = StreamWriteMessage.WriteResponse.WriteAck.StatusWrittenInTx() else: raise NotImplementedError("unexpected ack status") @@ -310,6 +328,9 @@ def from_proto(cls, proto_ack: ydb_topic_pb2.StreamWriteMessage.WriteResponse.Wr class StatusWritten: offset: int + class StatusWrittenInTx: + pass + @dataclass class StatusSkipped: reason: "StreamWriteMessage.WriteResponse.WriteAck.StatusSkipped.Reason" @@ -1188,6 +1209,52 @@ def to_public(self) -> ydb_topic_public_types.PublicMeteringMode: return ydb_topic_public_types.PublicMeteringMode.UNSPECIFIED +@dataclass +class UpdateOffsetsInTransactionRequest(IToProto): + tx: TransactionIdentity + topics: List[UpdateOffsetsInTransactionRequest.TopicOffsets] + consumer: str + + def to_proto(self): + return ydb_topic_pb2.UpdateOffsetsInTransactionRequest( + tx=self.tx.to_proto(), + consumer=self.consumer, + topics=list( + map( + UpdateOffsetsInTransactionRequest.TopicOffsets.to_proto, + self.topics, + ) + ), + ) + + @dataclass + class TopicOffsets(IToProto): + path: str + partitions: List[UpdateOffsetsInTransactionRequest.TopicOffsets.PartitionOffsets] + + def to_proto(self): + return ydb_topic_pb2.UpdateOffsetsInTransactionRequest.TopicOffsets( + path=self.path, + partitions=list( + map( + UpdateOffsetsInTransactionRequest.TopicOffsets.PartitionOffsets.to_proto, + self.partitions, + ) + ), + ) + + @dataclass + class PartitionOffsets(IToProto): + partition_id: int + partition_offsets: List[OffsetsRange] + + def to_proto(self) -> ydb_topic_pb2.UpdateOffsetsInTransactionRequest.TopicOffsets.PartitionOffsets: + return ydb_topic_pb2.UpdateOffsetsInTransactionRequest.TopicOffsets.PartitionOffsets( + partition_id=self.partition_id, + partition_offsets=list(map(OffsetsRange.to_proto, self.partition_offsets)), + ) + + @dataclass class CreateTopicRequest(IToProto, IFromPublic): path: str diff --git a/ydb/_topic_reader/datatypes.py b/ydb/_topic_reader/datatypes.py index b48501af..74f06a08 100644 --- a/ydb/_topic_reader/datatypes.py +++ b/ydb/_topic_reader/datatypes.py @@ -108,6 +108,9 @@ def ack_notify(self, offset: int): waiter = self._ack_waiters.popleft() waiter._finish_ok() + def _update_last_commited_offset_if_needed(self, offset: int): + self.committed_offset = max(self.committed_offset, offset) + def close(self): if self.closed: return @@ -211,3 +214,9 @@ def _pop_batch(self, message_count: int) -> PublicBatch: self._bytes_size = self._bytes_size - new_batch._bytes_size return new_batch + + def _update_partition_offsets(self, tx, exc=None): + if exc is not None: + return + offsets = self._commit_get_offsets_range() + self._partition_session._update_last_commited_offset_if_needed(offsets.end) diff --git a/ydb/_topic_reader/topic_reader_asyncio.py b/ydb/_topic_reader/topic_reader_asyncio.py index 7061b4e4..c9704d55 100644 --- a/ydb/_topic_reader/topic_reader_asyncio.py +++ b/ydb/_topic_reader/topic_reader_asyncio.py @@ -5,7 +5,7 @@ import gzip import typing from asyncio import Task -from collections import OrderedDict +from collections import defaultdict, OrderedDict from typing import Optional, Set, Dict, Union, Callable import ydb @@ -19,17 +19,24 @@ from .._grpc.grpcwrapper.common_utils import ( IGrpcWrapperAsyncIO, SupportedDriverType, + to_thread, GrpcWrapperAsyncIO, ) from .._grpc.grpcwrapper.ydb_topic import ( StreamReadMessage, UpdateTokenRequest, UpdateTokenResponse, + UpdateOffsetsInTransactionRequest, Codec, ) from .._errors import check_retriable_error import logging +from ..query.base import TxEvent + +if typing.TYPE_CHECKING: + from ..query.transaction import BaseQueryTxContext + logger = logging.getLogger(__name__) @@ -77,7 +84,7 @@ def __init__( ): self._loop = asyncio.get_running_loop() self._closed = False - self._reconnector = ReaderReconnector(driver, settings) + self._reconnector = ReaderReconnector(driver, settings, self._loop) self._parent = _parent async def __aenter__(self): @@ -88,8 +95,7 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): def __del__(self): if not self._closed: - task = self._loop.create_task(self.close(flush=False)) - topic_common.wrap_set_name_for_asyncio_task(task, task_name="close reader") + logger.warning("Topic reader was not closed properly. Consider using method close().") async def wait_message(self): """ @@ -112,6 +118,23 @@ async def receive_batch( max_messages=max_messages, ) + async def receive_batch_with_tx( + self, + tx: "BaseQueryTxContext", + max_messages: typing.Union[int, None] = None, + ) -> typing.Union[datatypes.PublicBatch, None]: + """ + Get one messages batch with tx from reader. + All messages in a batch from same partition. + + use asyncio.wait_for for wait with timeout. + """ + await self._reconnector.wait_message() + return self._reconnector.receive_batch_with_tx_nowait( + tx=tx, + max_messages=max_messages, + ) + async def receive_message(self) -> typing.Optional[datatypes.PublicMessage]: """ Block until receive new message @@ -165,11 +188,18 @@ class ReaderReconnector: _state_changed: asyncio.Event _stream_reader: Optional["ReaderStream"] _first_error: asyncio.Future[YdbError] + _tx_to_batches_map: Dict[str, typing.List[datatypes.PublicBatch]] - def __init__(self, driver: Driver, settings: topic_reader.PublicReaderSettings): + def __init__( + self, + driver: Driver, + settings: topic_reader.PublicReaderSettings, + loop: Optional[asyncio.AbstractEventLoop] = None, + ): self._id = self._static_reader_reconnector_counter.inc_and_get() self._settings = settings self._driver = driver + self._loop = loop if loop is not None else asyncio.get_running_loop() self._background_tasks = set() self._state_changed = asyncio.Event() @@ -177,6 +207,8 @@ def __init__(self, driver: Driver, settings: topic_reader.PublicReaderSettings): self._background_tasks.add(asyncio.create_task(self._connection_loop())) self._first_error = asyncio.get_running_loop().create_future() + self._tx_to_batches_map = dict() + async def _connection_loop(self): attempt = 0 while True: @@ -190,6 +222,7 @@ async def _connection_loop(self): if not retry_info.is_retriable: self._set_first_error(err) return + await asyncio.sleep(retry_info.sleep_timeout_seconds) attempt += 1 @@ -222,9 +255,87 @@ def receive_batch_nowait(self, max_messages: Optional[int] = None): max_messages=max_messages, ) + def receive_batch_with_tx_nowait(self, tx: "BaseQueryTxContext", max_messages: Optional[int] = None): + batch = self._stream_reader.receive_batch_nowait( + max_messages=max_messages, + ) + + self._init_tx(tx) + + self._tx_to_batches_map[tx.tx_id].append(batch) + + tx._add_callback(TxEvent.AFTER_COMMIT, batch._update_partition_offsets, self._loop) + + return batch + def receive_message_nowait(self): return self._stream_reader.receive_message_nowait() + def _init_tx(self, tx: "BaseQueryTxContext"): + if tx.tx_id not in self._tx_to_batches_map: # Init tx callbacks + self._tx_to_batches_map[tx.tx_id] = [] + tx._add_callback(TxEvent.BEFORE_COMMIT, self._commit_batches_with_tx, self._loop) + tx._add_callback(TxEvent.AFTER_COMMIT, self._handle_after_tx_commit, self._loop) + tx._add_callback(TxEvent.AFTER_ROLLBACK, self._handle_after_tx_rollback, self._loop) + + async def _commit_batches_with_tx(self, tx: "BaseQueryTxContext"): + grouped_batches = defaultdict(lambda: defaultdict(list)) + for batch in self._tx_to_batches_map[tx.tx_id]: + grouped_batches[batch._partition_session.topic_path][batch._partition_session.partition_id].append(batch) + + request = UpdateOffsetsInTransactionRequest(tx=tx._tx_identity(), consumer=self._settings.consumer, topics=[]) + + for topic_path in grouped_batches: + topic_offsets = UpdateOffsetsInTransactionRequest.TopicOffsets(path=topic_path, partitions=[]) + for partition_id in grouped_batches[topic_path]: + partition_offsets = UpdateOffsetsInTransactionRequest.TopicOffsets.PartitionOffsets( + partition_id=partition_id, + partition_offsets=[ + batch._commit_get_offsets_range() for batch in grouped_batches[topic_path][partition_id] + ], + ) + topic_offsets.partitions.append(partition_offsets) + request.topics.append(topic_offsets) + + try: + return await self._do_commit_batches_with_tx_call(request) + except BaseException: + exc = issues.ClientInternalError("Failed to update offsets in tx.") + tx._set_external_error(exc) + self._stream_reader._set_first_error(exc) + finally: + del self._tx_to_batches_map[tx.tx_id] + + async def _do_commit_batches_with_tx_call(self, request: UpdateOffsetsInTransactionRequest): + args = [ + request.to_proto(), + _apis.TopicService.Stub, + _apis.TopicService.UpdateOffsetsInTransaction, + topic_common.wrap_operation, + ] + + if asyncio.iscoroutinefunction(self._driver.__call__): + res = await self._driver(*args) + else: + res = await to_thread(self._driver, *args, executor=None) + + return res + + async def _handle_after_tx_rollback(self, tx: "BaseQueryTxContext", exc: Optional[BaseException]) -> None: + if tx.tx_id in self._tx_to_batches_map: + del self._tx_to_batches_map[tx.tx_id] + exc = issues.ClientInternalError("Reconnect due to transaction rollback") + self._stream_reader._set_first_error(exc) + + async def _handle_after_tx_commit(self, tx: "BaseQueryTxContext", exc: Optional[BaseException]) -> None: + if tx.tx_id in self._tx_to_batches_map: + del self._tx_to_batches_map[tx.tx_id] + + if exc is not None: + self._stream_reader._set_first_error( + issues.ClientInternalError("Reconnect due to transaction commit failed") + ) + def commit(self, batch: datatypes.ICommittable) -> datatypes.PartitionSession.CommitAckWaiter: return self._stream_reader.commit(batch) diff --git a/ydb/_topic_reader/topic_reader_sync.py b/ydb/_topic_reader/topic_reader_sync.py index eda1d374..3e6806d0 100644 --- a/ydb/_topic_reader/topic_reader_sync.py +++ b/ydb/_topic_reader/topic_reader_sync.py @@ -1,5 +1,6 @@ import asyncio import concurrent.futures +import logging import typing from typing import List, Union, Optional @@ -20,6 +21,11 @@ TopicReaderClosedError, ) +if typing.TYPE_CHECKING: + from ..query.transaction import BaseQueryTxContext + +logger = logging.getLogger(__name__) + class TopicReaderSync: _caller: CallFromSyncToAsync @@ -52,7 +58,8 @@ async def create_reader(): self._parent = _parent def __del__(self): - self.close(flush=False) + if not self._closed: + logger.warning("Topic reader was not closed properly. Consider using method close().") def __enter__(self): return self @@ -109,6 +116,31 @@ def receive_batch( timeout, ) + def receive_batch_with_tx( + self, + tx: "BaseQueryTxContext", + *, + max_messages: typing.Union[int, None] = None, + max_bytes: typing.Union[int, None] = None, + timeout: Union[float, None] = None, + ) -> Union[PublicBatch, None]: + """ + Get one messages batch with tx from reader + It has no async_ version for prevent lost messages, use async_wait_message as signal for new batches available. + + if no new message in timeout seconds (default - infinite): raise TimeoutError() + if timeout <= 0 - it will fast wait only one event loop cycle - without wait any i/o operations or pauses, get messages from internal buffer only. + """ + self._check_closed() + + return self._caller.safe_call_with_result( + self._async_reader.receive_batch_with_tx( + tx=tx, + max_messages=max_messages, + ), + timeout, + ) + def commit(self, mess: typing.Union[datatypes.PublicMessage, datatypes.PublicBatch]): """ Put commit message to internal buffer. diff --git a/ydb/_topic_writer/topic_writer.py b/ydb/_topic_writer/topic_writer.py index aa5fe974..a3e407ed 100644 --- a/ydb/_topic_writer/topic_writer.py +++ b/ydb/_topic_writer/topic_writer.py @@ -11,6 +11,7 @@ import ydb.aio from .._grpc.grpcwrapper.ydb_topic import StreamWriteMessage +from .._grpc.grpcwrapper.ydb_topic import TransactionIdentity from .._grpc.grpcwrapper.common_utils import IToProto from .._grpc.grpcwrapper.ydb_topic_public_types import PublicCodec from .. import connection @@ -53,8 +54,12 @@ class Written: class Skipped: pass + @dataclass(eq=True) + class WrittenInTx: + pass + -PublicWriteResultTypes = Union[PublicWriteResult.Written, PublicWriteResult.Skipped] +PublicWriteResultTypes = Union[PublicWriteResult.Written, PublicWriteResult.Skipped, PublicWriteResult.WrittenInTx] class WriterSettings(PublicWriterSettings): @@ -205,6 +210,7 @@ def default_serializer_message_content(data: Any) -> bytes: def messages_to_proto_requests( messages: List[InternalMessage], + tx_identity: Optional[TransactionIdentity], ) -> List[StreamWriteMessage.FromClient]: gropus = _slit_messages_for_send(messages) @@ -215,6 +221,7 @@ def messages_to_proto_requests( StreamWriteMessage.WriteRequest( messages=list(map(InternalMessage.to_message_data, group)), codec=group[0].codec, + tx_identity=tx_identity, ) ) res.append(req) @@ -239,6 +246,7 @@ def messages_to_proto_requests( ), ], codec=20000, + tx_identity=None, ) ) .to_proto() diff --git a/ydb/_topic_writer/topic_writer_asyncio.py b/ydb/_topic_writer/topic_writer_asyncio.py index 32d8fefe..1ea6c250 100644 --- a/ydb/_topic_writer/topic_writer_asyncio.py +++ b/ydb/_topic_writer/topic_writer_asyncio.py @@ -1,7 +1,6 @@ import asyncio import concurrent.futures import datetime -import functools import gzip import typing from collections import deque @@ -35,6 +34,7 @@ UpdateTokenRequest, UpdateTokenResponse, StreamWriteMessage, + TransactionIdentity, WriterMessagesFromServerToClient, ) from .._grpc.grpcwrapper.common_utils import ( @@ -43,6 +43,11 @@ GrpcWrapperAsyncIO, ) +from ..query.base import TxEvent + +if typing.TYPE_CHECKING: + from ..query.transaction import BaseQueryTxContext + logger = logging.getLogger(__name__) @@ -74,10 +79,8 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): raise def __del__(self): - if self._closed or self._loop.is_closed(): - return - - self._loop.call_soon(functools.partial(self.close, flush=False)) + if not self._closed: + logger.warning("Topic writer was not closed properly. Consider using method close().") async def close(self, *, flush: bool = True): if self._closed: @@ -164,6 +167,57 @@ async def wait_init(self) -> PublicWriterInitInfo: return await self._reconnector.wait_init() +class TxWriterAsyncIO(WriterAsyncIO): + _tx: "BaseQueryTxContext" + + def __init__( + self, + tx: "BaseQueryTxContext", + driver: SupportedDriverType, + settings: PublicWriterSettings, + _client=None, + _is_implicit=False, + ): + self._tx = tx + self._loop = asyncio.get_running_loop() + self._closed = False + self._reconnector = WriterAsyncIOReconnector(driver=driver, settings=WriterSettings(settings), tx=self._tx) + self._parent = _client + self._is_implicit = _is_implicit + + # For some reason, creating partition could conflict with other session operations. + # Could be removed later. + self._first_write = True + + tx._add_callback(TxEvent.BEFORE_COMMIT, self._on_before_commit, self._loop) + tx._add_callback(TxEvent.BEFORE_ROLLBACK, self._on_before_rollback, self._loop) + + async def write( + self, + messages: Union[Message, List[Message]], + ): + """ + send one or number of messages to server. + it put message to internal buffer + + For wait with timeout use asyncio.wait_for. + """ + if self._first_write: + self._first_write = False + return await super().write_with_ack(messages) + return await super().write(messages) + + async def _on_before_commit(self, tx: "BaseQueryTxContext"): + if self._is_implicit: + return + await self.close() + + async def _on_before_rollback(self, tx: "BaseQueryTxContext"): + if self._is_implicit: + return + await self.close(flush=False) + + class WriterAsyncIOReconnector: _closed: bool _loop: asyncio.AbstractEventLoop @@ -178,6 +232,7 @@ class WriterAsyncIOReconnector: _codec_selector_batch_num: int _codec_selector_last_codec: Optional[PublicCodec] _codec_selector_check_batches_interval: int + _tx: Optional["BaseQueryTxContext"] if typing.TYPE_CHECKING: _messages_for_encode: asyncio.Queue[List[InternalMessage]] @@ -195,7 +250,9 @@ class WriterAsyncIOReconnector: _stop_reason: asyncio.Future _init_info: Optional[PublicWriterInitInfo] - def __init__(self, driver: SupportedDriverType, settings: WriterSettings): + def __init__( + self, driver: SupportedDriverType, settings: WriterSettings, tx: Optional["BaseQueryTxContext"] = None + ): self._closed = False self._loop = asyncio.get_running_loop() self._driver = driver @@ -205,6 +262,7 @@ def __init__(self, driver: SupportedDriverType, settings: WriterSettings): self._init_info = None self._stream_connected = asyncio.Event() self._settings = settings + self._tx = tx self._codec_functions = { PublicCodec.RAW: lambda data: data, @@ -354,10 +412,12 @@ async def _connection_loop(self): # noinspection PyBroadException stream_writer = None try: + tx_identity = None if self._tx is None else self._tx._tx_identity() stream_writer = await WriterAsyncIOStream.create( self._driver, self._init_message, self._settings.update_token_interval, + tx_identity=tx_identity, ) try: if self._init_info is None: @@ -387,7 +447,7 @@ async def _connection_loop(self): done.pop().result() # need for raise exception - reason of stop task except issues.Error as err: err_info = check_retriable_error(err, retry_settings, attempt) - if not err_info.is_retriable: + if not err_info.is_retriable or self._tx is not None: # no retries in tx writer self._stop(err) return @@ -533,6 +593,8 @@ def _handle_receive_ack(self, ack): result = PublicWriteResult.Skipped() elif isinstance(status, write_ack_msg.StatusWritten): result = PublicWriteResult.Written(offset=status.offset) + elif isinstance(status, write_ack_msg.StatusWrittenInTx): + result = PublicWriteResult.WrittenInTx() else: raise TopicWriterError("internal error - receive unexpected ack message.") message_future.set_result(result) @@ -597,10 +659,13 @@ class WriterAsyncIOStream: _update_token_event: asyncio.Event _get_token_function: Optional[Callable[[], str]] + _tx_identity: Optional[TransactionIdentity] + def __init__( self, update_token_interval: Optional[Union[int, float]] = None, get_token_function: Optional[Callable[[], str]] = None, + tx_identity: Optional[TransactionIdentity] = None, ): self._closed = False @@ -609,6 +674,8 @@ def __init__( self._update_token_event = asyncio.Event() self._update_token_task = None + self._tx_identity = tx_identity + async def close(self): if self._closed: return @@ -625,6 +692,7 @@ async def create( driver: SupportedDriverType, init_request: StreamWriteMessage.InitRequest, update_token_interval: Optional[Union[int, float]] = None, + tx_identity: Optional[TransactionIdentity] = None, ) -> "WriterAsyncIOStream": stream = GrpcWrapperAsyncIO(StreamWriteMessage.FromServer.from_proto) @@ -634,6 +702,7 @@ async def create( writer = WriterAsyncIOStream( update_token_interval=update_token_interval, get_token_function=creds.get_auth_token if creds else lambda: "", + tx_identity=tx_identity, ) await writer._start(stream, init_request) return writer @@ -680,7 +749,7 @@ def write(self, messages: List[InternalMessage]): if self._closed: raise RuntimeError("Can not write on closed stream.") - for request in messages_to_proto_requests(messages): + for request in messages_to_proto_requests(messages, self._tx_identity): self._stream.write(request) async def _update_token_loop(self): diff --git a/ydb/_topic_writer/topic_writer_asyncio_test.py b/ydb/_topic_writer/topic_writer_asyncio_test.py index b288d0f0..cf88f797 100644 --- a/ydb/_topic_writer/topic_writer_asyncio_test.py +++ b/ydb/_topic_writer/topic_writer_asyncio_test.py @@ -18,6 +18,7 @@ from .._grpc.grpcwrapper.ydb_topic import ( Codec, StreamWriteMessage, + TransactionIdentity, UpdateTokenRequest, UpdateTokenResponse, ) @@ -43,6 +44,12 @@ from ..credentials import AnonymousCredentials +FAKE_TRANSACTION_IDENTITY = TransactionIdentity( + tx_id="transaction_id", + session_id="session_id", +) + + @pytest.fixture def default_driver() -> aio.Driver: driver = mock.Mock(spec=aio.Driver) @@ -148,6 +155,44 @@ async def test_write_a_message(self, writer_and_stream: WriterWithMockedStream): expected_message = StreamWriteMessage.FromClient( StreamWriteMessage.WriteRequest( codec=Codec.CODEC_RAW, + tx_identity=None, + messages=[ + StreamWriteMessage.WriteRequest.MessageData( + seq_no=1, + created_at=now, + data=data, + metadata_items={}, + uncompressed_size=len(data), + partitioning=None, + ) + ], + ) + ) + + sent_message = await writer_and_stream.stream.from_client.get() + assert expected_message == sent_message + + async def test_write_a_message_with_tx(self, writer_and_stream: WriterWithMockedStream): + writer_and_stream.writer._tx_identity = FAKE_TRANSACTION_IDENTITY + + data = "123".encode() + now = datetime.datetime.now(datetime.timezone.utc) + writer_and_stream.writer.write( + [ + InternalMessage( + PublicMessage( + seqno=1, + created_at=now, + data=data, + ) + ) + ] + ) + + expected_message = StreamWriteMessage.FromClient( + StreamWriteMessage.WriteRequest( + codec=Codec.CODEC_RAW, + tx_identity=FAKE_TRANSACTION_IDENTITY, messages=[ StreamWriteMessage.WriteRequest.MessageData( seq_no=1, @@ -264,7 +309,7 @@ def _create(self): res = DoubleQueueWriters() - async def async_create(driver, init_message, token_getter): + async def async_create(driver, init_message, token_getter, tx_identity): return res.get_first() monkeypatch.setattr(WriterAsyncIOStream, "create", async_create) diff --git a/ydb/_topic_writer/topic_writer_sync.py b/ydb/_topic_writer/topic_writer_sync.py index a5193caf..4796d7ac 100644 --- a/ydb/_topic_writer/topic_writer_sync.py +++ b/ydb/_topic_writer/topic_writer_sync.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import logging import typing from concurrent.futures import Future from typing import Union, List, Optional @@ -14,13 +15,23 @@ TopicWriterClosedError, ) -from .topic_writer_asyncio import WriterAsyncIO +from ..query.base import TxEvent + +from .topic_writer_asyncio import ( + TxWriterAsyncIO, + WriterAsyncIO, +) from .._topic_common.common import ( _get_shared_event_loop, TimeoutType, CallFromSyncToAsync, ) +if typing.TYPE_CHECKING: + from ..query.transaction import BaseQueryTxContext + +logger = logging.getLogger(__name__) + class WriterSync: _caller: CallFromSyncToAsync @@ -63,7 +74,8 @@ def __exit__(self, exc_type, exc_val, exc_tb): raise def __del__(self): - self.close(flush=False) + if not self._closed: + logger.warning("Topic writer was not closed properly. Consider using method close().") def close(self, *, flush: bool = True, timeout: TimeoutType = None): if self._closed: @@ -122,3 +134,39 @@ def write_with_ack( self._check_closed() return self._caller.unsafe_call_with_result(self._async_writer.write_with_ack(messages), timeout=timeout) + + +class TxWriterSync(WriterSync): + def __init__( + self, + tx: "BaseQueryTxContext", + driver: SupportedDriverType, + settings: PublicWriterSettings, + *, + eventloop: Optional[asyncio.AbstractEventLoop] = None, + _parent=None, + ): + + self._closed = False + + if eventloop: + loop = eventloop + else: + loop = _get_shared_event_loop() + + self._caller = CallFromSyncToAsync(loop) + + async def create_async_writer(): + return TxWriterAsyncIO(tx, driver, settings, _is_implicit=True) + + self._async_writer = self._caller.safe_call_with_result(create_async_writer(), None) + self._parent = _parent + + tx._add_callback(TxEvent.BEFORE_COMMIT, self._on_before_commit, None) + tx._add_callback(TxEvent.BEFORE_ROLLBACK, self._on_before_rollback, None) + + def _on_before_commit(self, tx: "BaseQueryTxContext"): + self.close() + + def _on_before_rollback(self, tx: "BaseQueryTxContext"): + self.close(flush=False) diff --git a/ydb/aio/driver.py b/ydb/aio/driver.py index 9cd6fd2b..267997fb 100644 --- a/ydb/aio/driver.py +++ b/ydb/aio/driver.py @@ -62,4 +62,5 @@ def __init__( async def stop(self, timeout=10): await self.table_client._stop_pool_if_needed(timeout=timeout) + self.topic_client.close() await super().stop(timeout=timeout) diff --git a/ydb/aio/query/pool.py b/ydb/aio/query/pool.py index 947db658..fda22388 100644 --- a/ydb/aio/query/pool.py +++ b/ydb/aio/query/pool.py @@ -158,6 +158,8 @@ async def retry_tx_async( async def wrapped_callee(): async with self.checkout() as session: async with session.transaction(tx_mode=tx_mode) as tx: + if tx_mode.name in ["serializable_read_write", "snapshot_read_only"]: + await tx.begin() result = await callee(tx, *args, **kwargs) await tx.commit() return result diff --git a/ydb/aio/query/transaction.py b/ydb/aio/query/transaction.py index 5b63a32b..f0547e5f 100644 --- a/ydb/aio/query/transaction.py +++ b/ydb/aio/query/transaction.py @@ -16,6 +16,28 @@ class QueryTxContext(BaseQueryTxContext): + def __init__(self, driver, session_state, session, tx_mode): + """ + An object that provides a simple transaction context manager that allows statements execution + in a transaction. You don't have to open transaction explicitly, because context manager encapsulates + transaction control logic, and opens new transaction if: + + 1) By explicit .begin() method; + 2) On execution of a first statement, which is strictly recommended method, because that avoids useless round trip + + This context manager is not thread-safe, so you should not manipulate on it concurrently. + + :param driver: A driver instance + :param session_state: A state of session + :param tx_mode: Transaction mode, which is a one from the following choises: + 1) QuerySerializableReadWrite() which is default mode; + 2) QueryOnlineReadOnly(allow_inconsistent_reads=False); + 3) QuerySnapshotReadOnly(); + 4) QueryStaleReadOnly(). + """ + super().__init__(driver, session_state, session, tx_mode) + self._init_callback_handler(base.CallbackHandlerMode.ASYNC) + async def __aenter__(self) -> "QueryTxContext": """ Enters a context manager and returns a transaction @@ -30,7 +52,7 @@ async def __aexit__(self, *args, **kwargs): it is not finished explicitly """ await self._ensure_prev_stream_finished() - if self._tx_state._state == QueryTxStateEnum.BEGINED: + if self._tx_state._state == QueryTxStateEnum.BEGINED and self._external_error is None: # It's strictly recommended to close transactions directly # by using commit_tx=True flag while executing statement or by # .commit() or .rollback() methods, but here we trying to do best @@ -65,7 +87,9 @@ async def commit(self, settings: Optional[BaseRequestSettings] = None) -> None: :return: A committed transaction or exception if commit is failed """ - if self._tx_state._already_in(QueryTxStateEnum.COMMITTED): + self._check_external_error_set() + + if self._tx_state._should_skip(QueryTxStateEnum.COMMITTED): return if self._tx_state._state == QueryTxStateEnum.NOT_INITIALIZED: @@ -74,7 +98,13 @@ async def commit(self, settings: Optional[BaseRequestSettings] = None) -> None: await self._ensure_prev_stream_finished() - await self._commit_call(settings) + try: + await self._execute_callbacks_async(base.TxEvent.BEFORE_COMMIT) + await self._commit_call(settings) + await self._execute_callbacks_async(base.TxEvent.AFTER_COMMIT, exc=None) + except BaseException as e: + await self._execute_callbacks_async(base.TxEvent.AFTER_COMMIT, exc=e) + raise e async def rollback(self, settings: Optional[BaseRequestSettings] = None) -> None: """Calls rollback on a transaction if it is open otherwise is no-op. If transaction execution @@ -84,7 +114,9 @@ async def rollback(self, settings: Optional[BaseRequestSettings] = None) -> None :return: A committed transaction or exception if commit is failed """ - if self._tx_state._already_in(QueryTxStateEnum.ROLLBACKED): + self._check_external_error_set() + + if self._tx_state._should_skip(QueryTxStateEnum.ROLLBACKED): return if self._tx_state._state == QueryTxStateEnum.NOT_INITIALIZED: @@ -93,7 +125,13 @@ async def rollback(self, settings: Optional[BaseRequestSettings] = None) -> None await self._ensure_prev_stream_finished() - await self._rollback_call(settings) + try: + await self._execute_callbacks_async(base.TxEvent.BEFORE_ROLLBACK) + await self._rollback_call(settings) + await self._execute_callbacks_async(base.TxEvent.AFTER_ROLLBACK, exc=None) + except BaseException as e: + await self._execute_callbacks_async(base.TxEvent.AFTER_ROLLBACK, exc=e) + raise e async def execute( self, diff --git a/ydb/driver.py b/ydb/driver.py index 49bd223c..3998aeee 100644 --- a/ydb/driver.py +++ b/ydb/driver.py @@ -288,4 +288,5 @@ def __init__( def stop(self, timeout=10): self.table_client._stop_pool_if_needed(timeout=timeout) + self.topic_client.close() super().stop(timeout=timeout) diff --git a/ydb/issues.py b/ydb/issues.py index 065dcbc8..8b098667 100644 --- a/ydb/issues.py +++ b/ydb/issues.py @@ -179,6 +179,10 @@ class SessionPoolEmpty(Error, queue.Empty): status = StatusCode.SESSION_POOL_EMPTY +class ClientInternalError(Error): + status = StatusCode.CLIENT_INTERNAL_ERROR + + class UnexpectedGrpcMessage(Error): def __init__(self, message: str): super().__init__(message) diff --git a/ydb/query/base.py b/ydb/query/base.py index 57a769bb..a5ebedd9 100644 --- a/ydb/query/base.py +++ b/ydb/query/base.py @@ -1,6 +1,8 @@ import abc +import asyncio import enum import functools +from collections import defaultdict import typing from typing import ( @@ -17,6 +19,10 @@ from .. import _utilities from .. import _apis +from ydb._topic_common.common import CallFromSyncToAsync, _get_shared_event_loop +from ydb._grpc.grpcwrapper.common_utils import to_thread + + if typing.TYPE_CHECKING: from .transaction import BaseQueryTxContext @@ -196,3 +202,64 @@ def wrap_execute_query_response( return convert.ResultSet.from_message(response_pb.result_set, settings) return None + + +class TxEvent(enum.Enum): + BEFORE_COMMIT = "BEFORE_COMMIT" + AFTER_COMMIT = "AFTER_COMMIT" + BEFORE_ROLLBACK = "BEFORE_ROLLBACK" + AFTER_ROLLBACK = "AFTER_ROLLBACK" + + +class CallbackHandlerMode(enum.Enum): + SYNC = "SYNC" + ASYNC = "ASYNC" + + +def _get_sync_callback(method: typing.Callable, loop: Optional[asyncio.AbstractEventLoop]): + if asyncio.iscoroutinefunction(method): + if loop is None: + loop = _get_shared_event_loop() + + def async_to_sync_callback(*args, **kwargs): + caller = CallFromSyncToAsync(loop) + return caller.safe_call_with_result(method(*args, **kwargs), 10) + + return async_to_sync_callback + return method + + +def _get_async_callback(method: typing.Callable): + if asyncio.iscoroutinefunction(method): + return method + + async def sync_to_async_callback(*args, **kwargs): + return await to_thread(method, *args, **kwargs, executor=None) + + return sync_to_async_callback + + +class CallbackHandler: + def _init_callback_handler(self, mode: CallbackHandlerMode) -> None: + self._callbacks = defaultdict(list) + self._callback_mode = mode + + def _execute_callbacks_sync(self, event_name: str, *args, **kwargs) -> None: + for callback in self._callbacks[event_name]: + callback(self, *args, **kwargs) + + async def _execute_callbacks_async(self, event_name: str, *args, **kwargs) -> None: + tasks = [asyncio.create_task(callback(self, *args, **kwargs)) for callback in self._callbacks[event_name]] + if not tasks: + return + await asyncio.gather(*tasks) + + def _prepare_callback( + self, callback: typing.Callable, loop: Optional[asyncio.AbstractEventLoop] + ) -> typing.Callable: + if self._callback_mode == CallbackHandlerMode.SYNC: + return _get_sync_callback(callback, loop) + return _get_async_callback(callback) + + def _add_callback(self, event_name: str, callback: typing.Callable, loop: Optional[asyncio.AbstractEventLoop]): + self._callbacks[event_name].append(self._prepare_callback(callback, loop)) diff --git a/ydb/query/pool.py b/ydb/query/pool.py index e3775c4d..43cc2e8d 100644 --- a/ydb/query/pool.py +++ b/ydb/query/pool.py @@ -167,6 +167,8 @@ def retry_tx_sync( def wrapped_callee(): with self.checkout(timeout=retry_settings.max_session_acquire_timeout) as session: with session.transaction(tx_mode=tx_mode) as tx: + if tx_mode.name in ["serializable_read_write", "snapshot_read_only"]: + tx.begin() result = callee(tx, *args, **kwargs) tx.commit() return result diff --git a/ydb/query/transaction.py b/ydb/query/transaction.py index 414401da..ae7642db 100644 --- a/ydb/query/transaction.py +++ b/ydb/query/transaction.py @@ -11,6 +11,7 @@ _apis, issues, ) +from .._grpc.grpcwrapper import ydb_topic as _ydb_topic from .._grpc.grpcwrapper import ydb_query as _ydb_query from ..connection import _RpcState as RpcState @@ -42,10 +43,22 @@ class QueryTxStateHelper(abc.ABC): QueryTxStateEnum.DEAD: [], } + _SKIP_TRANSITIONS = { + QueryTxStateEnum.NOT_INITIALIZED: [], + QueryTxStateEnum.BEGINED: [], + QueryTxStateEnum.COMMITTED: [QueryTxStateEnum.COMMITTED, QueryTxStateEnum.ROLLBACKED], + QueryTxStateEnum.ROLLBACKED: [QueryTxStateEnum.COMMITTED, QueryTxStateEnum.ROLLBACKED], + QueryTxStateEnum.DEAD: [], + } + @classmethod def valid_transition(cls, before: QueryTxStateEnum, after: QueryTxStateEnum) -> bool: return after in cls._VALID_TRANSITIONS[before] + @classmethod + def should_skip(cls, before: QueryTxStateEnum, after: QueryTxStateEnum) -> bool: + return after in cls._SKIP_TRANSITIONS[before] + @classmethod def terminal(cls, state: QueryTxStateEnum) -> bool: return len(cls._VALID_TRANSITIONS[state]) == 0 @@ -88,8 +101,8 @@ def _check_tx_ready_to_use(self) -> None: if QueryTxStateHelper.terminal(self._state): raise RuntimeError(f"Transaction is in terminal state: {self._state.value}") - def _already_in(self, target: QueryTxStateEnum) -> bool: - return self._state == target + def _should_skip(self, target: QueryTxStateEnum) -> bool: + return QueryTxStateHelper.should_skip(self._state, target) def _construct_tx_settings(tx_state: QueryTxState) -> _ydb_query.TransactionSettings: @@ -170,7 +183,7 @@ def wrap_tx_rollback_response( return tx -class BaseQueryTxContext: +class BaseQueryTxContext(base.CallbackHandler): def __init__(self, driver, session_state, session, tx_mode): """ An object that provides a simple transaction context manager that allows statements execution @@ -196,6 +209,7 @@ def __init__(self, driver, session_state, session, tx_mode): self._session_state = session_state self.session = session self._prev_stream = None + self._external_error = None @property def session_id(self) -> str: @@ -215,6 +229,19 @@ def tx_id(self) -> Optional[str]: """ return self._tx_state.tx_id + def _tx_identity(self) -> _ydb_topic.TransactionIdentity: + if not self.tx_id: + raise RuntimeError("Unable to get tx identity without started tx.") + return _ydb_topic.TransactionIdentity(self.tx_id, self.session_id) + + def _set_external_error(self, exc: BaseException) -> None: + self._external_error = exc + + def _check_external_error_set(self): + if self._external_error is None: + return + raise issues.ClientInternalError("Transaction was failed by external error.") from self._external_error + def _begin_call(self, settings: Optional[BaseRequestSettings]) -> "BaseQueryTxContext": self._tx_state._check_invalid_transition(QueryTxStateEnum.BEGINED) @@ -228,6 +255,7 @@ def _begin_call(self, settings: Optional[BaseRequestSettings]) -> "BaseQueryTxCo ) def _commit_call(self, settings: Optional[BaseRequestSettings]) -> "BaseQueryTxContext": + self._check_external_error_set() self._tx_state._check_invalid_transition(QueryTxStateEnum.COMMITTED) return self._driver( @@ -240,6 +268,7 @@ def _commit_call(self, settings: Optional[BaseRequestSettings]) -> "BaseQueryTxC ) def _rollback_call(self, settings: Optional[BaseRequestSettings]) -> "BaseQueryTxContext": + self._check_external_error_set() self._tx_state._check_invalid_transition(QueryTxStateEnum.ROLLBACKED) return self._driver( @@ -262,6 +291,7 @@ def _execute_call( settings: Optional[BaseRequestSettings], ) -> Iterable[_apis.ydb_query.ExecuteQueryResponsePart]: self._tx_state._check_tx_ready_to_use() + self._check_external_error_set() request = base.create_execute_query_request( query=query, @@ -283,18 +313,41 @@ def _execute_call( ) def _move_to_beginned(self, tx_id: str) -> None: - if self._tx_state._already_in(QueryTxStateEnum.BEGINED) or not tx_id: + if self._tx_state._should_skip(QueryTxStateEnum.BEGINED) or not tx_id: return self._tx_state._change_state(QueryTxStateEnum.BEGINED) self._tx_state.tx_id = tx_id def _move_to_commited(self) -> None: - if self._tx_state._already_in(QueryTxStateEnum.COMMITTED): + if self._tx_state._should_skip(QueryTxStateEnum.COMMITTED): return self._tx_state._change_state(QueryTxStateEnum.COMMITTED) class QueryTxContext(BaseQueryTxContext): + def __init__(self, driver, session_state, session, tx_mode): + """ + An object that provides a simple transaction context manager that allows statements execution + in a transaction. You don't have to open transaction explicitly, because context manager encapsulates + transaction control logic, and opens new transaction if: + + 1) By explicit .begin() method; + 2) On execution of a first statement, which is strictly recommended method, because that avoids useless round trip + + This context manager is not thread-safe, so you should not manipulate on it concurrently. + + :param driver: A driver instance + :param session_state: A state of session + :param tx_mode: Transaction mode, which is a one from the following choises: + 1) QuerySerializableReadWrite() which is default mode; + 2) QueryOnlineReadOnly(allow_inconsistent_reads=False); + 3) QuerySnapshotReadOnly(); + 4) QueryStaleReadOnly(). + """ + + super().__init__(driver, session_state, session, tx_mode) + self._init_callback_handler(base.CallbackHandlerMode.SYNC) + def __enter__(self) -> "BaseQueryTxContext": """ Enters a context manager and returns a transaction @@ -309,7 +362,7 @@ def __exit__(self, *args, **kwargs): it is not finished explicitly """ self._ensure_prev_stream_finished() - if self._tx_state._state == QueryTxStateEnum.BEGINED: + if self._tx_state._state == QueryTxStateEnum.BEGINED and self._external_error is None: # It's strictly recommended to close transactions directly # by using commit_tx=True flag while executing statement or by # .commit() or .rollback() methods, but here we trying to do best @@ -345,7 +398,8 @@ def commit(self, settings: Optional[BaseRequestSettings] = None) -> None: :return: A committed transaction or exception if commit is failed """ - if self._tx_state._already_in(QueryTxStateEnum.COMMITTED): + self._check_external_error_set() + if self._tx_state._should_skip(QueryTxStateEnum.COMMITTED): return if self._tx_state._state == QueryTxStateEnum.NOT_INITIALIZED: @@ -354,7 +408,13 @@ def commit(self, settings: Optional[BaseRequestSettings] = None) -> None: self._ensure_prev_stream_finished() - self._commit_call(settings) + try: + self._execute_callbacks_sync(base.TxEvent.BEFORE_COMMIT) + self._commit_call(settings) + self._execute_callbacks_sync(base.TxEvent.AFTER_COMMIT, exc=None) + except BaseException as e: # TODO: probably should be less wide + self._execute_callbacks_sync(base.TxEvent.AFTER_COMMIT, exc=e) + raise e def rollback(self, settings: Optional[BaseRequestSettings] = None) -> None: """Calls rollback on a transaction if it is open otherwise is no-op. If transaction execution @@ -364,7 +424,8 @@ def rollback(self, settings: Optional[BaseRequestSettings] = None) -> None: :return: A committed transaction or exception if commit is failed """ - if self._tx_state._already_in(QueryTxStateEnum.ROLLBACKED): + self._check_external_error_set() + if self._tx_state._should_skip(QueryTxStateEnum.ROLLBACKED): return if self._tx_state._state == QueryTxStateEnum.NOT_INITIALIZED: @@ -373,7 +434,13 @@ def rollback(self, settings: Optional[BaseRequestSettings] = None) -> None: self._ensure_prev_stream_finished() - self._rollback_call(settings) + try: + self._execute_callbacks_sync(base.TxEvent.BEFORE_ROLLBACK) + self._rollback_call(settings) + self._execute_callbacks_sync(base.TxEvent.AFTER_ROLLBACK, exc=None) + except BaseException as e: # TODO: probably should be less wide + self._execute_callbacks_sync(base.TxEvent.AFTER_ROLLBACK, exc=e) + raise e def execute( self, diff --git a/ydb/topic.py b/ydb/topic.py index 55f4ea04..52f98e61 100644 --- a/ydb/topic.py +++ b/ydb/topic.py @@ -25,6 +25,8 @@ "TopicWriteResult", "TopicWriter", "TopicWriterAsyncIO", + "TopicTxWriter", + "TopicTxWriterAsyncIO", "TopicWriterInitInfo", "TopicWriterMessage", "TopicWriterSettings", @@ -33,6 +35,7 @@ import concurrent.futures import datetime from dataclasses import dataclass +import logging from typing import List, Union, Mapping, Optional, Dict, Callable from . import aio, Credentials, _apis, issues @@ -65,8 +68,10 @@ PublicWriteResult as TopicWriteResult, ) +from ydb._topic_writer.topic_writer_asyncio import TxWriterAsyncIO as TopicTxWriterAsyncIO from ydb._topic_writer.topic_writer_asyncio import WriterAsyncIO as TopicWriterAsyncIO from ._topic_writer.topic_writer_sync import WriterSync as TopicWriter +from ._topic_writer.topic_writer_sync import TxWriterSync as TopicTxWriter from ._topic_common.common import ( wrap_operation as _wrap_operation, @@ -88,6 +93,8 @@ PublicAlterAutoPartitioningSettings as TopicAlterAutoPartitioningSettings, ) +logger = logging.getLogger(__name__) + class TopicClientAsyncIO: _closed: bool @@ -108,7 +115,8 @@ def __init__(self, driver: aio.Driver, settings: Optional[TopicClientSettings] = ) def __del__(self): - self.close() + if not self._closed: + logger.warning("Topic client was not closed properly. Consider using method close().") async def create_topic( self, @@ -276,6 +284,35 @@ def writer( return TopicWriterAsyncIO(self._driver, settings, _client=self) + def tx_writer( + self, + tx, + topic, + *, + producer_id: Optional[str] = None, # default - random + session_metadata: Mapping[str, str] = None, + partition_id: Union[int, None] = None, + auto_seqno: bool = True, + auto_created_at: bool = True, + codec: Optional[TopicCodec] = None, # default mean auto-select + # encoders: map[codec_code] func(encoded_bytes)->decoded_bytes + # the func will be called from multiply threads in parallel. + encoders: Optional[Mapping[_ydb_topic_public_types.PublicCodec, Callable[[bytes], bytes]]] = None, + # custom encoder executor for call builtin and custom decoders. If None - use shared executor pool. + # If max_worker in the executor is 1 - then encoders will be called from the thread without parallel. + encoder_executor: Optional[concurrent.futures.Executor] = None, + ) -> TopicTxWriterAsyncIO: + args = locals().copy() + del args["self"] + del args["tx"] + + settings = TopicWriterSettings(**args) + + if not settings.encoder_executor: + settings.encoder_executor = self._executor + + return TopicTxWriterAsyncIO(tx=tx, driver=self._driver, settings=settings, _client=self) + def close(self): if self._closed: return @@ -287,7 +324,7 @@ def _check_closed(self): if not self._closed: return - raise RuntimeError("Topic client closed") + raise issues.Error("Topic client closed") class TopicClient: @@ -310,7 +347,8 @@ def __init__(self, driver: driver.Driver, settings: Optional[TopicClientSettings ) def __del__(self): - self.close() + if not self._closed: + logger.warning("Topic client was not closed properly. Consider using method close().") def create_topic( self, @@ -487,6 +525,36 @@ def writer( return TopicWriter(self._driver, settings, _parent=self) + def tx_writer( + self, + tx, + topic, + *, + producer_id: Optional[str] = None, # default - random + session_metadata: Mapping[str, str] = None, + partition_id: Union[int, None] = None, + auto_seqno: bool = True, + auto_created_at: bool = True, + codec: Optional[TopicCodec] = None, # default mean auto-select + # encoders: map[codec_code] func(encoded_bytes)->decoded_bytes + # the func will be called from multiply threads in parallel. + encoders: Optional[Mapping[_ydb_topic_public_types.PublicCodec, Callable[[bytes], bytes]]] = None, + # custom encoder executor for call builtin and custom decoders. If None - use shared executor pool. + # If max_worker in the executor is 1 - then encoders will be called from the thread without parallel. + encoder_executor: Optional[concurrent.futures.Executor] = None, # default shared client executor pool + ) -> TopicWriter: + args = locals().copy() + del args["self"] + del args["tx"] + self._check_closed() + + settings = TopicWriterSettings(**args) + + if not settings.encoder_executor: + settings.encoder_executor = self._executor + + return TopicTxWriter(tx, self._driver, settings, _parent=self) + def close(self): if self._closed: return @@ -498,7 +566,7 @@ def _check_closed(self): if not self._closed: return - raise RuntimeError("Topic client closed") + raise issues.Error("Topic client closed") @dataclass