Skip to content

Commit 036451a

Browse files
committed
Merge remote-tracking branch 'origin/add-async-compat-schema-registry' into add-async-compat-schema-registry
2 parents b8da4a8 + 920ddb6 commit 036451a

File tree

9 files changed

+241
-1
lines changed

9 files changed

+241
-1
lines changed

src/confluent_kafka/schema_registry/_sync/avro.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
SchemaRegistryClient)
3131
from confluent_kafka.serialization import (SerializationError,
3232
SerializationContext)
33-
from confluent_kafka.schema_registry.common import _ContextStringIO
33+
from confluent_kafka.schema_registry.common import _ContextStringIO, asyncinit
3434
from confluent_kafka.schema_registry.rule_registry import RuleRegistry
3535
from confluent_kafka.schema_registry.serde import BaseSerializer, BaseDeserializer, ParsedSchemaCache
3636

src/confluent_kafka/schema_registry/avro.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,5 @@
1616
# limitations under the License.
1717

1818
from .common.avro import *
19+
from ._async.avro import *
1920
from ._sync.avro import *

src/confluent_kafka/schema_registry/common/__init__.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,3 +89,22 @@ def __enter__(self):
8989
def __exit__(self, *args):
9090
self.close()
9191
return False
92+
93+
94+
def asyncinit(cls):
95+
"""
96+
Decorator to make a class async-initializable.
97+
"""
98+
__new__ = cls.__new__
99+
100+
async def init(obj, *arg, **kwarg):
101+
await obj.__init__(*arg, **kwarg)
102+
return obj
103+
104+
def new(klass, *arg, **kwarg):
105+
obj = __new__(klass)
106+
coro = init(obj, *arg, **kwarg)
107+
return coro
108+
109+
cls.__new__ = new
110+
return cls

src/confluent_kafka/schema_registry/json_schema.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,5 @@
1616
# limitations under the License.
1717

1818
from .common.json_schema import *
19+
from ._async.json_schema import *
1920
from ._sync.json_schema import *

src/confluent_kafka/schema_registry/protobuf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,5 @@
1616
# limitations under the License.
1717

1818
from .common.protobuf import *
19+
from ._async.protobuf import *
1920
from ._sync.protobuf import *

src/confluent_kafka/schema_registry/schema_registry_client.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
# limitations under the License.
1717

1818
from .common.schema_registry_client import *
19+
from ._async.schema_registry_client import *
1920
from ._sync.schema_registry_client import *
2021

2122
from .error import SchemaRegistryError

src/confluent_kafka/schema_registry/serde.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,5 @@
1717
#
1818

1919
from .common.serde import *
20+
from ._async.serde import *
2021
from ._sync.serde import *

tests/common/_async/consumer.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
#
4+
# Copyright 2025 Confluent Inc.
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
#
18+
19+
import asyncio
20+
21+
from confluent_kafka.cimpl import Consumer
22+
from confluent_kafka.error import ConsumeError, KeyDeserializationError, ValueDeserializationError
23+
from confluent_kafka.serialization import MessageField, SerializationContext
24+
25+
ASYNC_CONSUMER_POLL_INTERVAL_SECONDS: int = 0.2
26+
ASYNC_CONSUMER_POLL_INFINITE_TIMEOUT_SECONDS: int = -1
27+
28+
class AsyncConsumer(Consumer):
29+
def __init__(
30+
self,
31+
conf: dict,
32+
loop: asyncio.AbstractEventLoop = None,
33+
poll_interval_seconds: int = ASYNC_CONSUMER_POLL_INTERVAL_SECONDS
34+
):
35+
super().__init__(conf)
36+
37+
self._loop = loop or asyncio.get_event_loop()
38+
self._poll_interval = poll_interval_seconds
39+
40+
def __aiter__(self):
41+
return self
42+
43+
async def __anext__(self):
44+
return await self.poll(None)
45+
46+
async def poll(self, timeout: int = -1):
47+
timeout = None if timeout == -1 else timeout
48+
async with asyncio.timeout(timeout):
49+
while True:
50+
# Zero timeout here is what makes it non-blocking
51+
msg = super().poll(0)
52+
if msg is not None:
53+
return msg
54+
else:
55+
await asyncio.sleep(self._poll_interval)
56+
57+
58+
class TestAsyncDeserializingConsumer(AsyncConsumer):
59+
def __init__(self, conf):
60+
conf_copy = conf.copy()
61+
self._key_deserializer = conf_copy.pop('key.deserializer', None)
62+
self._value_deserializer = conf_copy.pop('value.deserializer', None)
63+
super().__init__(conf_copy)
64+
65+
async def poll(self, timeout=-1):
66+
msg = await super().poll(timeout)
67+
68+
if msg is None:
69+
return None
70+
71+
if msg.error() is not None:
72+
raise ConsumeError(msg.error(), kafka_message=msg)
73+
74+
ctx = SerializationContext(msg.topic(), MessageField.VALUE, msg.headers())
75+
value = msg.value()
76+
if self._value_deserializer is not None:
77+
try:
78+
value = await self._value_deserializer(value, ctx)
79+
except Exception as se:
80+
raise ValueDeserializationError(exception=se, kafka_message=msg)
81+
82+
key = msg.key()
83+
ctx.field = MessageField.KEY
84+
if self._key_deserializer is not None:
85+
try:
86+
key = await self._key_deserializer(key, ctx)
87+
except Exception as se:
88+
raise KeyDeserializationError(exception=se, kafka_message=msg)
89+
90+
msg.set_key(key)
91+
msg.set_value(value)
92+
return msg
93+
94+
def consume(self, num_messages=1, timeout=-1):
95+
"""
96+
:py:func:`Consumer.consume` not implemented, use
97+
:py:func:`DeserializingConsumer.poll` instead
98+
"""
99+
100+
raise NotImplementedError

tests/common/_async/producer.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
#
4+
# Copyright 2025 Confluent Inc.
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
#
18+
19+
from confluent_kafka.cimpl import Producer
20+
import inspect
21+
import asyncio
22+
23+
from confluent_kafka.error import KeySerializationError, ValueSerializationError
24+
from confluent_kafka.serialization import MessageField, SerializationContext
25+
26+
ASYNC_PRODUCER_POLL_INTERVAL: int = 0.2
27+
28+
class AsyncProducer(Producer):
29+
def __init__(
30+
self,
31+
conf: dict,
32+
loop: asyncio.AbstractEventLoop = None,
33+
poll_interval: int = ASYNC_PRODUCER_POLL_INTERVAL
34+
):
35+
super().__init__(conf)
36+
37+
self._loop = loop or asyncio.get_event_loop()
38+
self._poll_interval = poll_interval
39+
40+
self._poll_task = None
41+
self._waiters: int = 0
42+
43+
async def produce(
44+
self, topic, value=None, key=None, partition=-1,
45+
on_delivery=None, timestamp=0, headers=None
46+
):
47+
fut = self._loop.create_future()
48+
self._waiters += 1
49+
try:
50+
if self._poll_task is None or self._poll_task.done():
51+
self._poll_task = asyncio.create_task(self._poll_dr(self._poll_interval))
52+
53+
def wrapped_on_delivery(err, msg):
54+
if on_delivery is not None:
55+
if inspect.iscoroutinefunction(on_delivery):
56+
asyncio.run_coroutine_threadsafe(
57+
on_delivery(err, msg),
58+
self._loop
59+
)
60+
else:
61+
self._loop.call_soon_threadsafe(on_delivery, err, msg)
62+
63+
if err:
64+
self._loop.call_soon_threadsafe(fut.set_exception, err)
65+
else:
66+
self._loop.call_soon_threadsafe(fut.set_result, msg)
67+
68+
super().produce(
69+
topic,
70+
value,
71+
key,
72+
headers=headers,
73+
partition=partition,
74+
timestamp=timestamp,
75+
on_delivery=wrapped_on_delivery
76+
)
77+
return await fut
78+
finally:
79+
self._waiters -= 1
80+
81+
async def _poll_dr(self, interval: int):
82+
"""Poll delivery reports at interval seconds"""
83+
while self._waiters:
84+
super().poll(0)
85+
await asyncio.sleep(interval)
86+
87+
88+
class TestAsyncSerializingProducer(AsyncProducer):
89+
def __init__(self, conf):
90+
conf_copy = conf.copy()
91+
92+
self._key_serializer = conf_copy.pop('key.serializer', None)
93+
self._value_serializer = conf_copy.pop('value.serializer', None)
94+
95+
super(TestAsyncSerializingProducer, self).__init__(conf_copy)
96+
97+
async def produce(self, topic, key=None, value=None, partition=-1,
98+
on_delivery=None, timestamp=0, headers=None):
99+
ctx = SerializationContext(topic, MessageField.KEY, headers)
100+
if self._key_serializer is not None:
101+
try:
102+
key = await self._key_serializer(key, ctx)
103+
except Exception as se:
104+
raise KeySerializationError(se)
105+
ctx.field = MessageField.VALUE
106+
if self._value_serializer is not None:
107+
try:
108+
value = await self._value_serializer(value, ctx)
109+
except Exception as se:
110+
raise ValueSerializationError(se)
111+
112+
return await super().produce(topic, value, key,
113+
headers=headers,
114+
partition=partition,
115+
timestamp=timestamp,
116+
on_delivery=on_delivery)

0 commit comments

Comments
 (0)