Skip to content

Commit 4aa6e38

Browse files
committed
feat: update conversation when agent finishes speaking or is interrupted
1 parent 870616a commit 4aa6e38

File tree

4 files changed

+81
-20
lines changed

4 files changed

+81
-20
lines changed

components/screens/signup/index.tsx

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,16 @@
11
import { updateUserSettings } from "@/components/screens/auth";
2+
import { AppContext } from "@/context";
23
import { segmentTrackFinishedSignup } from "@/lib/analytics";
34
import { signupStyles } from "@/lib/style";
45
import { $Enums, Settings } from "@prisma/client";
56
import { Session } from "@supabase/supabase-js";
6-
import React, { ReactElement, useEffect, useRef, useState } from "react";
7+
import React, {
8+
ReactElement,
9+
useContext,
10+
useEffect,
11+
useRef,
12+
useState,
13+
} from "react";
714
import { Animated, Easing, View } from "react-native";
815
import { AINameSection } from "./aiName";
916
import { GenderSection } from "./gender";
@@ -44,6 +51,7 @@ export const SignupFlow: React.FC<Props> = ({
4451
setShowSignupFlow,
4552
setSettings,
4653
}) => {
54+
const { refetchToken } = useContext(AppContext);
4755
const [currentStepIndex, setCurrentStepIndex] = useState<number>(0);
4856
const [isTransitioning, setIsTransitioning] = useState<boolean>(false);
4957
const [topSection, setTopSection] = useState<"A" | "B">("A");
@@ -109,6 +117,7 @@ export const SignupFlow: React.FC<Props> = ({
109117
session.user.id
110118
);
111119
setSettings(settings);
120+
await refetchToken();
112121

113122
if (error === null) {
114123
Animated.parallel([

server/api/routes/chat.py

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
from server.api.constants import SUPABASE_AUDIO_MESSAGES_BUCKET_NAME, LLM
1717
from server.api.utils import add_memories, authorize_user, get_stream_content
1818
from prisma import Prisma, enums, types
19-
from datetime import datetime
19+
from pydantic import BaseModel
20+
from datetime import datetime, timedelta
2021
from server.api.analytics import track_sent_message
2122
from server.agent.index import generate_response
2223
from server.logger.index import fetch_logger
@@ -92,8 +93,12 @@ async def call_update_chat(
9293
audio_messages_enabled: bool,
9394
audio_id: Optional[str],
9495
):
95-
new_user_message = next(msg for msg in reversed(messages) if msg["role"] == "user")
96-
new_user_message = message_to_fixed_string_content(new_user_message)["content"]
96+
# We default the new_user_message to empty if the length of the messages array is 1
97+
# This handles the case where the agent is sending the first message in the conversation to greet the user
98+
new_user_message = ''
99+
if len(messages) > 1:
100+
new_user_message = next(msg for msg in reversed(messages) if msg["role"] == "user")
101+
new_user_message = message_to_fixed_string_content(new_user_message)["content"]
97102

98103
data = {
99104
"new_user_message": new_user_message,
@@ -129,10 +134,9 @@ def stream_and_update_chat(
129134
user_first_name: str,
130135
user_gender: str,
131136
audio_messages_enabled: bool,
132-
audio_id: Optional[str] = None,
133-
skip_final_processing: Optional[bool] = False,
137+
audio_id: str,
138+
skip_final_processing: bool,
134139
):
135-
user_message_timestamp = datetime.now()
136140
client = OpenAI(
137141
api_key=os.environ.get("OPENAI_API_KEY"),
138142
)
@@ -171,6 +175,8 @@ def stream_and_update_chat(
171175
content = choice.delta.content
172176
agent_response += content
173177

178+
user_message_timestamp = datetime.now()
179+
agent_message_timestamp = user_message_timestamp - timedelta(seconds=1)
174180
# Run asynchronous operations in a separate thread, which is necessary to prevent the main
175181
# thread from getting blocked during synchronous tasks with high latency, like network requests.
176182
# This is important when streaming voice responses because the voice will pause in the middle of
@@ -192,6 +198,7 @@ def stream_and_update_chat(
192198
user_message_timestamp=user_message_timestamp,
193199
audio_messages_enabled=audio_messages_enabled,
194200
audio_id=audio_id,
201+
agent_message_timestamp=agent_message_timestamp,
195202
)
196203
),
197204
daemon=True,
@@ -208,9 +215,8 @@ async def final_processing_coroutine(
208215
user_message_timestamp: datetime,
209216
audio_messages_enabled: bool,
210217
audio_id: Optional[str],
218+
agent_message_timestamp: datetime,
211219
) -> None:
212-
agent_message_timestamp = datetime.now()
213-
214220
await call_update_chat(
215221
messages=messages,
216222
agent_response=agent_response,
@@ -256,6 +262,7 @@ def stream_text(
256262
user_gender=user_gender,
257263
audio_messages_enabled=audio_messages_enabled,
258264
audio_id=audio_id,
265+
skip_final_processing=False,
259266
)
260267
for chunk in stream:
261268
for choice in chunk.choices:
@@ -349,6 +356,7 @@ def sync_function():
349356
user_id=user_id,
350357
chat_type="type",
351358
user_message_timestamp=user_message_timestamp,
359+
agent_message_timestamp=datetime.timestamp(),
352360
audio_messages_enabled=audio_messages_enabled,
353361
audio_id=audio_id,
354362
)
@@ -435,15 +443,16 @@ async def handle_update_chat(request: UpdateChatRequest):
435443
audio_messages_enabled = request.audio_messages_enabled
436444

437445
# Create new user chat message
438-
await prisma.chatmessages.create(
439-
data=types.ChatMessagesCreateInput(
440-
chatId=chat_id,
441-
role=enums.OpenAIRole.user,
442-
content=new_user_message,
443-
created=datetime.fromtimestamp(request.user_message_timestamp),
444-
displayType="text",
446+
if len(new_user_message) > 0:
447+
await prisma.chatmessages.create(
448+
data=types.ChatMessagesCreateInput(
449+
chatId=chat_id,
450+
role=enums.OpenAIRole.user,
451+
content=new_user_message,
452+
created=datetime.fromtimestamp(request.user_message_timestamp),
453+
displayType="text",
454+
)
445455
)
446-
)
447456

448457
display_type = "audio" if audio_messages_enabled else "text"
449458

server/livekit_worker/llm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from livekit.agents import llm
1515
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
1616
from server.agent.index import generate_response
17-
from server.api.routes.chat import final_processing_coroutine, stream_and_update_chat
17+
from server.api.routes.chat import stream_and_update_chat
1818
from typing import Any, Coroutine
1919
from dataclasses import dataclass
2020
from server.logger.index import fetch_logger
@@ -62,6 +62,7 @@ async def wrapper():
6262
user_gender=user_gender,
6363
audio_messages_enabled=audio_messages_enabled,
6464
audio_id=None,
65+
skip_final_processing=True,
6566
)
6667
it = iter(sync_gen)
6768

server/livekit_worker/main.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,17 @@
2121
import sys
2222
from pathlib import Path
2323
from fastapi import HTTPException, status
24-
from datetime import datetime
24+
from datetime import datetime, timedelta
2525
from sentry_sdk.integrations.asyncio import AsyncioIntegration
2626
from sentry_sdk.integrations.logging import LoggingIntegration
2727
from .voices import VoiceSettingMapping
2828
from server.logger.index import fetch_logger
2929
from dotenv import load_dotenv
30+
from livekit.plugins.openai.llm import _build_oai_context, build_oai_message
31+
from ..api.routes.chat import (
32+
final_processing_coroutine,
33+
message_to_fixed_string_content,
34+
)
3035

3136
load_dotenv()
3237

@@ -200,10 +205,47 @@ async def entrypoint(ctx: JobContext):
200205
),
201206
api_key=os.environ.get("ELEVEN_LABS_API_KEY"),
202207
),
203-
min_endpointing_delay=1,
208+
min_endpointing_delay=2,
204209
chat_ctx=initial_ctx,
205210
)
206211

212+
def handle_update_conversation(msg: llm.ChatMessage):
213+
messages = _build_oai_context(assistant.chat_ctx, id(assistant))
214+
new_agent_message = build_oai_message(msg, id(assistant))
215+
new_agent_message: str = message_to_fixed_string_content(new_agent_message)[
216+
"content"
217+
]
218+
219+
user_message_timestamp = datetime.now()
220+
agent_message_timestamp = datetime.now() + timedelta(seconds=1)
221+
222+
# Use asyncio.create_task to schedule the coroutine
223+
asyncio.create_task(
224+
final_processing_coroutine(
225+
messages=messages,
226+
agent_response=new_agent_message.strip(),
227+
chat_id=chat_id,
228+
user_id=user_id,
229+
chat_type="voice",
230+
user_message_timestamp=user_message_timestamp,
231+
agent_message_timestamp=agent_message_timestamp,
232+
audio_messages_enabled=False,
233+
audio_id=None,
234+
)
235+
)
236+
237+
# We update the database when the agent is interrupted and when the agent finishes talking
238+
# We include the interruption because this event reliably only fires when the agent has actually
239+
# Started talking, it does not fire if the agent has not started talking at all and the user simply
240+
# paused long enough for the response process to begin.
241+
@assistant.on("agent_speech_interrupted")
242+
def on_agent_speech_committed(msg: llm.ChatMessage):
243+
handle_update_conversation(msg)
244+
245+
@assistant.on("agent_speech_committed")
246+
def on_agent_speech_committed(msg: llm.ChatMessage):
247+
handle_update_conversation(msg)
248+
207249
assistant.start(ctx.room, participant)
208250

209251
if send_first_chat_message:

0 commit comments

Comments
 (0)