Skip to content

Commit cf18596

Browse files
WIP
1 parent d451023 commit cf18596

File tree

10 files changed

+171
-80
lines changed

10 files changed

+171
-80
lines changed

python/samples/concepts/audio/04-chat_with_realtime_api_simple.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ async def main() -> None:
4343
# create the realtime client and optionally add the audio output function, this is optional
4444
# you can define the protocol to use, either "websocket" or "webrtc"
4545
# they will behave the same way, even though the underlying protocol is quite different
46-
realtime_client = OpenAIRealtime(protocol="webrtc")
46+
realtime_client = OpenAIRealtime("webrtc")
4747
# Create the settings for the session
4848
settings = OpenAIRealtimeExecutionSettings(
4949
instructions="""

python/samples/concepts/audio/05-chat_with_realtime_api_complex.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -52,35 +52,41 @@ def get_weather(location: str) -> str:
5252
"""Get the weather for a location."""
5353
weather_conditions = ("sunny", "hot", "cloudy", "raining", "freezing", "snowing")
5454
weather = weather_conditions[randint(0, len(weather_conditions) - 1)] # nosec
55-
logger.info(f"Getting weather for {location}: {weather}")
55+
logger.info(f"@ Getting weather for {location}: {weather}")
5656
return f"The weather in {location} is {weather}."
5757

5858

5959
@kernel_function
6060
def get_date_time() -> str:
6161
"""Get the current date and time."""
62-
logger.info("Getting current datetime")
62+
logger.info("@ Getting current datetime")
6363
return f"The current date and time is {datetime.now().isoformat()}."
6464

6565

66+
@kernel_function
67+
def goodbye():
68+
"""When the user is done, say goodbye and then call this function."""
69+
logger.info("@ Goodbye has been called!")
70+
raise KeyboardInterrupt
71+
72+
6673
async def main() -> None:
6774
print_transcript = True
6875
# create the Kernel and add a simple function for function calling.
6976
kernel = Kernel()
70-
kernel.add_function(plugin_name="weather", function_name="get_weather", function=get_weather)
71-
kernel.add_function(plugin_name="time", function_name="get_date_time", function=get_date_time)
77+
kernel.add_functions(plugin_name="helpers", functions=[goodbye, get_weather, get_date_time])
7278

7379
# create the audio player and audio track
7480
# both take a device_id parameter, which is the index of the device to use, if None the default device is used
75-
audio_player = SKAudioPlayer()
81+
audio_player = SKAudioPlayer(sample_rate=24000, frame_duration=100, channels=1)
7682
audio_track = SKAudioTrack()
7783
# create the realtime client and optionally add the audio output function, this is optional
7884
# you can define the protocol to use, either "websocket" or "webrtc"
7985
# they will behave the same way, even though the underlying protocol is quite different
8086
realtime_client = OpenAIRealtime(
81-
protocol="webrtc",
87+
protocol="websocket",
8288
audio_output_callback=audio_player.client_callback,
83-
audio_track=audio_track,
89+
# audio_track=audio_track,
8490
)
8591

8692
# Create the settings for the session
@@ -110,7 +116,7 @@ async def main() -> None:
110116
chat_history.add_assistant_message("I am Mosscap, a chat bot. I'm trying to figure out what people need.")
111117

112118
# the context manager calls the create_session method on the client and start listening to the audio stream
113-
async with realtime_client, audio_player:
119+
async with realtime_client, audio_player, audio_track.stream_to_realtime_client(realtime_client):
114120
await realtime_client.update_session(
115121
settings=settings, chat_history=chat_history, kernel=kernel, create_response=True
116122
)

python/semantic_kernel/connectors/ai/open_ai/services/open_ai_realtime.py

Lines changed: 72 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Copyright (c) Microsoft. All rights reserved.
22

3-
from collections.abc import Callable, Coroutine, Mapping
3+
from collections.abc import AsyncGenerator, Callable, Coroutine, Mapping
44
from typing import TYPE_CHECKING, Any, ClassVar, Literal, TypeVar
55

66
from numpy import ndarray
@@ -15,42 +15,82 @@
1515
OpenAIRealtimeWebsocketBase,
1616
)
1717
from semantic_kernel.connectors.ai.open_ai.settings.open_ai_settings import OpenAISettings
18+
from semantic_kernel.connectors.ai.prompt_execution_settings import PromptExecutionSettings
19+
from semantic_kernel.connectors.ai.realtime_client_base import RealtimeClientBase
20+
from semantic_kernel.contents.chat_history import ChatHistory
21+
from semantic_kernel.contents.events.realtime_event import RealtimeEvent
1822
from semantic_kernel.exceptions.service_exceptions import ServiceInitializationError
1923

2024
if TYPE_CHECKING:
2125
from aiortc.mediastreams import MediaStreamTrack
2226

27+
from semantic_kernel.connectors.ai import PromptExecutionSettings
28+
from semantic_kernel.contents import ChatHistory
29+
2330
_T = TypeVar("_T", bound="OpenAIRealtime")
2431

2532

26-
class OpenAIRealtime(OpenAIConfigBase, OpenAIRealtimeBase):
33+
__all__ = ["OpenAIRealtime"]
34+
35+
36+
class RealtimeClientStub(RealtimeClientBase):
37+
"""This class makes sure that IDE's don't complain about missing methods in the below superclass."""
38+
39+
async def send(self, event: Any) -> None:
40+
pass
41+
42+
async def create_session(
43+
self,
44+
settings: "PromptExecutionSettings | None" = None,
45+
chat_history: "ChatHistory | None" = None,
46+
**kwargs: Any,
47+
) -> None:
48+
pass
49+
50+
def receive(self, **kwargs: Any) -> AsyncGenerator[RealtimeEvent, None]:
51+
pass
52+
53+
async def update_session(
54+
self,
55+
settings: "PromptExecutionSettings | None" = None,
56+
chat_history: "ChatHistory | None" = None,
57+
**kwargs: Any,
58+
) -> None:
59+
pass
60+
61+
async def close_session(self) -> None:
62+
pass
63+
64+
65+
class OpenAIRealtime(OpenAIRealtimeBase, RealtimeClientStub):
2766
"""OpenAI Realtime service."""
2867

29-
def __new__(cls: type["_T"], *args: Any, **kwargs: Any) -> "_T":
68+
def __new__(cls: type["_T"], protocol: str, *args: Any, **kwargs: Any) -> "_T":
3069
"""Pick the right subclass, based on protocol."""
3170
subclass_map = {subcl.protocol: subcl for subcl in cls.__subclasses__()}
32-
subclass = subclass_map[kwargs.pop("protocol", "websocket")]
71+
subclass = subclass_map[protocol]
3372
return super(OpenAIRealtime, subclass).__new__(subclass)
3473

3574
def __init__(
3675
self,
37-
protocol: Literal["websocket", "webrtc"] = "websocket",
76+
protocol: Literal["websocket", "webrtc"],
77+
*,
3878
audio_output_callback: Callable[[ndarray], Coroutine[Any, Any, None]] | None = None,
3979
audio_track: "MediaStreamTrack | None" = None,
4080
ai_model_id: str | None = None,
4181
api_key: str | None = None,
4282
org_id: str | None = None,
4383
service_id: str | None = None,
4484
default_headers: Mapping[str, str] | None = None,
45-
async_client: AsyncOpenAI | None = None,
85+
client: AsyncOpenAI | None = None,
4686
env_file_path: str | None = None,
4787
env_file_encoding: str | None = None,
4888
**kwargs: Any,
4989
) -> None:
5090
"""Initialize an OpenAIRealtime service.
5191
5292
Args:
53-
protocol: The protocol to use, can be either "websocket" or "webrtc".
93+
protocol: The protocol to use, must be either "websocket" or "webrtc".
5494
audio_output_callback: The audio output callback, optional.
5595
This should be a coroutine, that takes a ndarray with audio as input.
5696
The goal of this function is to allow you to play the audio with the
@@ -70,7 +110,7 @@ def __init__(
70110
the env vars or .env file value.
71111
default_headers: The default headers mapping of string keys to
72112
string values for HTTP requests. (Optional)
73-
async_client (Optional[AsyncOpenAI]): An existing client to use. (Optional)
113+
client (Optional[AsyncOpenAI]): An existing client to use. (Optional)
74114
env_file_path (str | None): Use the environment settings file as a fallback to
75115
environment variables. (Optional)
76116
env_file_encoding (str | None): The encoding of the environment settings file. (Optional)
@@ -88,7 +128,6 @@ def __init__(
88128
raise ServiceInitializationError("Failed to create OpenAI settings.", ex) from ex
89129
if not openai_settings.realtime_model_id:
90130
raise ServiceInitializationError("The OpenAI text model ID is required.")
91-
kwargs = {"audio_track": audio_track} if protocol == "webrtc" and audio_track else {}
92131
super().__init__(
93132
protocol=protocol,
94133
audio_output_callback=audio_output_callback,
@@ -98,12 +137,12 @@ def __init__(
98137
org_id=openai_settings.org_id,
99138
ai_model_type=OpenAIModelTypes.REALTIME,
100139
default_headers=default_headers,
101-
client=async_client,
140+
client=client,
102141
**kwargs,
103142
)
104143

105144

106-
class OpenAIRealtimeWebRTC(OpenAIRealtime, OpenAIRealtimeWebRTCBase):
145+
class OpenAIRealtimeWebRTC(OpenAIRealtime, OpenAIRealtimeWebRTCBase, OpenAIConfigBase):
107146
"""OpenAI Realtime service using WebRTC protocol.
108147
109148
This should not be used directly, use OpenAIRealtime instead.
@@ -112,12 +151,33 @@ class OpenAIRealtimeWebRTC(OpenAIRealtime, OpenAIRealtimeWebRTCBase):
112151

113152
protocol: ClassVar[Literal["webrtc"]] = "webrtc"
114153

154+
def __init__(
155+
self,
156+
*args: Any,
157+
**kwargs: Any,
158+
) -> None:
159+
"""Initialize an OpenAIRealtime service using WebRTC protocol."""
160+
super().__init__(
161+
*args,
162+
**kwargs,
163+
)
164+
115165

116-
class OpenAIRealtimeWebSocket(OpenAIRealtime, OpenAIRealtimeWebsocketBase):
166+
class OpenAIRealtimeWebSocket(OpenAIRealtime, OpenAIRealtimeWebsocketBase, OpenAIConfigBase):
117167
"""OpenAI Realtime service using WebSocket protocol.
118168
119169
This should not be used directly, use OpenAIRealtime instead.
120170
Set protocol="websocket" to use this class.
121171
"""
122172

123173
protocol: ClassVar[Literal["websocket"]] = "websocket"
174+
175+
def __init__(
176+
self,
177+
*args: Any,
178+
**kwargs: Any,
179+
) -> None:
180+
super().__init__(
181+
*args,
182+
**kwargs,
183+
)

python/semantic_kernel/connectors/ai/open_ai/services/realtime/open_ai_realtime_base.py

Lines changed: 26 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
# Copyright (c) Microsoft. All rights reserved.
22

3-
import base64
43
import json
54
import logging
65
import sys
7-
from abc import abstractmethod
86
from collections.abc import AsyncGenerator, Callable, Coroutine
97
from typing import TYPE_CHECKING, Any, ClassVar, Literal
108

@@ -146,11 +144,24 @@ async def update_session(
146144
)
147145
if chat_history and len(chat_history) > 0:
148146
for msg in chat_history.messages:
149-
await self.send(
150-
ServiceEvent(event_type="service", service_type=SendEvents.CONVERSATION_ITEM_CREATE, event=msg)
151-
)
147+
for item in msg.items:
148+
match item:
149+
case TextContent():
150+
await self.send(TextEvent(service_type=SendEvents.CONVERSATION_ITEM_CREATE, text=item))
151+
case FunctionCallContent():
152+
await self.send(
153+
FunctionCallEvent(service_type=SendEvents.CONVERSATION_ITEM_CREATE, function_call=item)
154+
)
155+
case FunctionResultContent():
156+
await self.send(
157+
FunctionResultEvent(
158+
service_type=SendEvents.CONVERSATION_ITEM_CREATE, function_result=item
159+
)
160+
)
161+
case _:
162+
logger.error("Unsupported item type: %s", item)
152163
if create_response:
153-
await self.send(ServiceEvent(event_type="service", service_type=SendEvents.RESPONSE_CREATE))
164+
await self.send(ServiceEvent(service_type=SendEvents.RESPONSE_CREATE))
154165

155166
@override
156167
def get_prompt_execution_settings_class(self) -> type["PromptExecutionSettings"]:
@@ -191,24 +202,21 @@ async def _parse_function_call_arguments_done(
191202
index=event.output_index,
192203
metadata={"call_id": event.call_id},
193204
)
194-
yield FunctionCallEvent(
195-
event_type="function_call",
196-
service_type=ListenEvents.RESPONSE_FUNCTION_CALL_ARGUMENTS_DONE,
197-
function_call=item,
198-
)
205+
yield FunctionCallEvent(service_type=ListenEvents.RESPONSE_FUNCTION_CALL_ARGUMENTS_DONE, function_call=item)
199206
chat_history = ChatHistory()
200207
await self.kernel.invoke_function_call(item, chat_history)
201208
created_output: FunctionResultContent = chat_history.messages[-1].items[0] # type: ignore
202209
# This returns the output to the service
203-
await self.send(
204-
ServiceEvent(event_type="service", service_type=SendEvents.CONVERSATION_ITEM_CREATE, event=created_output)
210+
result = FunctionResultEvent(
211+
service_type=SendEvents.CONVERSATION_ITEM_CREATE,
212+
function_result=created_output,
205213
)
214+
await self.send(result)
206215
# The model doesn't start responding to the tool call automatically, so triggering it here.
207-
await self.send(ServiceEvent(event_type="service", service_type=SendEvents.RESPONSE_CREATE))
216+
await self.send(ServiceEvent(service_type=SendEvents.RESPONSE_CREATE))
208217
# This allows a user to have a full conversation in his code
209-
yield FunctionResultEvent(event_type="function_result", function_result=created_output)
218+
yield result
210219

211-
@abstractmethod
212220
async def _send(self, event: RealtimeClientEvent) -> None:
213221
"""Send an event to the service."""
214222
raise NotImplementedError
@@ -217,14 +225,9 @@ async def _send(self, event: RealtimeClientEvent) -> None:
217225
async def send(self, event: RealtimeEvent, **kwargs: Any) -> None:
218226
match event.event_type:
219227
case "audio":
220-
if isinstance(event.audio.data, ndarray):
221-
audio_data = base64.b64encode(event.audio.data.tobytes()).decode("utf-8")
222-
else:
223-
audio_data = event.audio.data.decode("utf-8")
224228
await self._send(
225229
_create_realtime_client_event(
226-
event_type=SendEvents.INPUT_AUDIO_BUFFER_APPEND,
227-
audio=audio_data,
230+
event_type=SendEvents.INPUT_AUDIO_BUFFER_APPEND, audio=event.audio.to_base64_bytestring()
228231
)
229232
)
230233
case "text":
@@ -286,7 +289,7 @@ async def send(self, event: RealtimeEvent, **kwargs: Any) -> None:
286289
await self._send(
287290
_create_realtime_client_event(
288291
event_type=event.service_type,
289-
**settings.prepare_settings_dict(),
292+
session=settings.prepare_settings_dict(),
290293
)
291294
)
292295
case SendEvents.INPUT_AUDIO_BUFFER_APPEND:

python/semantic_kernel/connectors/ai/open_ai/services/realtime/open_ai_realtime_webrtc.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,9 +161,8 @@ async def _on_track(self, track: "MediaStreamTrack") -> None:
161161
try:
162162
await self._receive_buffer.put(
163163
AudioEvent(
164-
event_type="audio",
165164
service_type=ListenEvents.RESPONSE_AUDIO_DELTA,
166-
audio=AudioContent(data=frame.to_ndarray(), data_format="np.int16", inner_content=frame), # type: ignore
165+
audio=AudioContent(data=frame.to_ndarray(), data_format="np.int16", inner_content=frame),
167166
),
168167
)
169168
except Exception as e:

python/semantic_kernel/connectors/ai/open_ai/services/realtime/open_ai_realtime_websocket.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,7 @@
2020
from semantic_kernel.connectors.ai.open_ai.services.realtime.const import ListenEvents
2121
from semantic_kernel.connectors.ai.open_ai.services.realtime.open_ai_realtime_base import OpenAIRealtimeBase
2222
from semantic_kernel.contents.audio_content import AudioContent
23-
from semantic_kernel.contents.events.realtime_event import RealtimeEvent
24-
from semantic_kernel.contents.streaming_chat_message_content import StreamingChatMessageContent
25-
from semantic_kernel.contents.utils.author_role import AuthorRole
23+
from semantic_kernel.contents.events.realtime_event import AudioEvent, RealtimeEvent
2624
from semantic_kernel.utils.experimental_decorator import experimental_class
2725

2826
if TYPE_CHECKING:
@@ -54,18 +52,11 @@ async def receive(
5452
if self.audio_output_callback:
5553
await self.audio_output_callback(np.frombuffer(base64.b64decode(event.delta), dtype=np.int16))
5654
try:
57-
yield (
58-
event.type,
59-
StreamingChatMessageContent(
60-
role=AuthorRole.ASSISTANT,
61-
items=[
62-
AudioContent(
63-
data=base64.b64decode(event.delta),
64-
data_format="base64",
65-
inner_content=event,
66-
)
67-
], # type: ignore
68-
choice_index=event.content_index,
55+
yield AudioEvent(
56+
audio=AudioContent(
57+
data=base64.b64decode(event.delta),
58+
data_format="base64",
59+
inner_content=event,
6960
),
7061
)
7162
except Exception as e:

python/semantic_kernel/connectors/ai/open_ai/services/realtime/utils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,11 +68,10 @@ def kernel_function_metadata_to_function_call_format(
6868
def _create_realtime_client_event(event_type: SendEvents, **kwargs: Any) -> RealtimeClientEvent:
6969
match event_type:
7070
case SendEvents.SESSION_UPDATE:
71-
event_kwargs = {"event_id": kwargs.pop("event_id")} if "event_id" in kwargs else {}
7271
return SessionUpdateEvent(
7372
type=event_type,
74-
session=Session.model_validate(kwargs),
75-
**event_kwargs,
73+
session=Session.model_validate(kwargs.pop("session")),
74+
**kwargs,
7675
)
7776
case SendEvents.INPUT_AUDIO_BUFFER_APPEND:
7877
return InputAudioBufferAppendEvent(

0 commit comments

Comments
 (0)