Skip to content

Commit c455d55

Browse files
committed
feat: update conversation when agent finishes speaking or is interrupted
1 parent 870616a commit c455d55

File tree

4 files changed

+65
-9
lines changed

4 files changed

+65
-9
lines changed

components/screens/signup/index.tsx

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,16 @@
11
import { updateUserSettings } from "@/components/screens/auth";
2+
import { AppContext } from "@/context";
23
import { segmentTrackFinishedSignup } from "@/lib/analytics";
34
import { signupStyles } from "@/lib/style";
45
import { $Enums, Settings } from "@prisma/client";
56
import { Session } from "@supabase/supabase-js";
6-
import React, { ReactElement, useEffect, useRef, useState } from "react";
7+
import React, {
8+
ReactElement,
9+
useContext,
10+
useEffect,
11+
useRef,
12+
useState,
13+
} from "react";
714
import { Animated, Easing, View } from "react-native";
815
import { AINameSection } from "./aiName";
916
import { GenderSection } from "./gender";
@@ -44,6 +51,7 @@ export const SignupFlow: React.FC<Props> = ({
4451
setShowSignupFlow,
4552
setSettings,
4653
}) => {
54+
const { refetchToken } = useContext(AppContext);
4755
const [currentStepIndex, setCurrentStepIndex] = useState<number>(0);
4856
const [isTransitioning, setIsTransitioning] = useState<boolean>(false);
4957
const [topSection, setTopSection] = useState<"A" | "B">("A");
@@ -109,6 +117,7 @@ export const SignupFlow: React.FC<Props> = ({
109117
session.user.id
110118
);
111119
setSettings(settings);
120+
await refetchToken();
112121

113122
if (error === null) {
114123
Animated.parallel([

server/api/routes/chat.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
from server.api.constants import SUPABASE_AUDIO_MESSAGES_BUCKET_NAME, LLM
1717
from server.api.utils import add_memories, authorize_user, get_stream_content
1818
from prisma import Prisma, enums, types
19-
from datetime import datetime
19+
from pydantic import BaseModel
20+
from datetime import datetime, timedelta
2021
from server.api.analytics import track_sent_message
2122
from server.agent.index import generate_response
2223
from server.logger.index import fetch_logger
@@ -129,10 +130,9 @@ def stream_and_update_chat(
129130
user_first_name: str,
130131
user_gender: str,
131132
audio_messages_enabled: bool,
132-
audio_id: Optional[str] = None,
133-
skip_final_processing: Optional[bool] = False,
133+
audio_id: str,
134+
skip_final_processing: bool,
134135
):
135-
user_message_timestamp = datetime.now()
136136
client = OpenAI(
137137
api_key=os.environ.get("OPENAI_API_KEY"),
138138
)
@@ -171,6 +171,8 @@ def stream_and_update_chat(
171171
content = choice.delta.content
172172
agent_response += content
173173

174+
user_message_timestamp = datetime.now()
175+
agent_message_timestamp = user_message_timestamp - timedelta(seconds=1)
174176
# Run asynchronous operations in a separate thread, which is necessary to prevent the main
175177
# thread from getting blocked during synchronous tasks with high latency, like network requests.
176178
# This is important when streaming voice responses because the voice will pause in the middle of
@@ -192,6 +194,7 @@ def stream_and_update_chat(
192194
user_message_timestamp=user_message_timestamp,
193195
audio_messages_enabled=audio_messages_enabled,
194196
audio_id=audio_id,
197+
agent_message_timestamp=agent_message_timestamp,
195198
)
196199
),
197200
daemon=True,
@@ -208,9 +211,8 @@ async def final_processing_coroutine(
208211
user_message_timestamp: datetime,
209212
audio_messages_enabled: bool,
210213
audio_id: Optional[str],
214+
agent_message_timestamp: datetime,
211215
) -> None:
212-
agent_message_timestamp = datetime.now()
213-
214216
await call_update_chat(
215217
messages=messages,
216218
agent_response=agent_response,
@@ -256,6 +258,7 @@ def stream_text(
256258
user_gender=user_gender,
257259
audio_messages_enabled=audio_messages_enabled,
258260
audio_id=audio_id,
261+
skip_final_processing=False,
259262
)
260263
for chunk in stream:
261264
for choice in chunk.choices:
@@ -349,6 +352,7 @@ def sync_function():
349352
user_id=user_id,
350353
chat_type="type",
351354
user_message_timestamp=user_message_timestamp,
355+
agent_message_timestamp=datetime.timestamp(),
352356
audio_messages_enabled=audio_messages_enabled,
353357
audio_id=audio_id,
354358
)

server/livekit_worker/llm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from livekit.agents import llm
1515
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
1616
from server.agent.index import generate_response
17-
from server.api.routes.chat import final_processing_coroutine, stream_and_update_chat
17+
from server.api.routes.chat import stream_and_update_chat
1818
from typing import Any, Coroutine
1919
from dataclasses import dataclass
2020
from server.logger.index import fetch_logger
@@ -62,6 +62,7 @@ async def wrapper():
6262
user_gender=user_gender,
6363
audio_messages_enabled=audio_messages_enabled,
6464
audio_id=None,
65+
skip_final_processing=True,
6566
)
6667
it = iter(sync_gen)
6768

server/livekit_worker/main.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,17 @@
2121
import sys
2222
from pathlib import Path
2323
from fastapi import HTTPException, status
24-
from datetime import datetime
24+
from datetime import datetime, timedelta
2525
from sentry_sdk.integrations.asyncio import AsyncioIntegration
2626
from sentry_sdk.integrations.logging import LoggingIntegration
2727
from .voices import VoiceSettingMapping
2828
from server.logger.index import fetch_logger
2929
from dotenv import load_dotenv
30+
from livekit.plugins.openai.llm import _build_oai_context, build_oai_message
31+
from ..api.routes.chat import (
32+
final_processing_coroutine,
33+
message_to_fixed_string_content,
34+
)
3035

3136
load_dotenv()
3237

@@ -204,6 +209,43 @@ async def entrypoint(ctx: JobContext):
204209
chat_ctx=initial_ctx,
205210
)
206211

212+
def handle_update_conversation():
213+
messages = _build_oai_context(assistant.chat_ctx, id(assistant))
214+
new_agent_message = build_oai_message(msg, id(assistant))
215+
new_agent_message: str = message_to_fixed_string_content(new_agent_message)[
216+
"content"
217+
]
218+
219+
user_message_timestamp = datetime.now()
220+
agent_message_timestamp = datetime.now() + timedelta(seconds=1)
221+
222+
# Use asyncio.create_task to schedule the coroutine
223+
asyncio.create_task(
224+
final_processing_coroutine(
225+
messages=messages,
226+
agent_response=new_agent_message.strip(),
227+
chat_id=chat_id,
228+
user_id=user_id,
229+
chat_type="voice",
230+
user_message_timestamp=user_message_timestamp,
231+
agent_message_timestamp=agent_message_timestamp,
232+
audio_messages_enabled=False,
233+
audio_id=None,
234+
)
235+
)
236+
237+
# We update the database when the agent is interrupted and when the agent finishes talking
238+
# We include the interruption because this event reliably only fires when the agent has actually
239+
# Started talking, it does not fire if the agent has not started talking at all and the user simply
240+
# paused long enough for the response process to begin.
241+
@assistant.on("agent_speech_interrupted")
242+
def on_agent_speech_committed(msg):
243+
handle_update_conversation()
244+
245+
@assistant.on("agent_speech_committed")
246+
def on_agent_speech_committed(msg: llm.ChatMessage):
247+
handle_update_conversation()
248+
207249
assistant.start(ctx.room, participant)
208250

209251
if send_first_chat_message:

0 commit comments

Comments
 (0)