Skip to content

Commit 8fe2c66

Browse files
committed
updated web agent
1 parent 0cbd210 commit 8fe2c66

File tree

11 files changed

+138
-70
lines changed

11 files changed

+138
-70
lines changed

packages/llm/local_llm/agents/voice_chat.py

+44-42
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from local_llm.plugins import (
66
UserPrompt, ChatQuery, PrintStream,
7-
RivaASR, RivaTTS, RateLimit,
7+
AutoASR, AutoTTS, RateLimit, ProcessProxy,
88
AudioOutputDevice, AudioOutputFile
99
)
1010

@@ -16,66 +16,68 @@ class VoiceChat(Agent):
1616
def __init__(self, **kwargs):
1717
super().__init__(**kwargs)
1818

19-
# ASR
20-
self.asr = RivaASR(**kwargs)
21-
22-
self.asr.add(PrintStream(partial=False, prefix='## ', color='blue'), RivaASR.OutputFinal)
23-
self.asr.add(PrintStream(partial=False, prefix='>> ', color='magenta'), RivaASR.OutputPartial)
24-
25-
self.asr.add(self.asr_partial, RivaASR.OutputPartial, threaded=False) # pause output when user is speaking
26-
self.asr.add(self.asr_final, RivaASR.OutputFinal, threaded=False) # clear queues on final ASR transcript
27-
28-
self.asr_history = None # store the partial ASR transcript
29-
3019
# LLM
31-
self.llm = ChatQuery(**kwargs)
20+
self.llm = ChatQuery(**kwargs) #ProcessProxy('ChatQuery', **kwargs) #
21+
self.llm.add(PrintStream(color='green'))
22+
23+
# ASR
24+
self.asr = AutoASR.from_pretrained(**kwargs)
3225

33-
self.llm.add(PrintStream(color='green', relay=True).add(self.on_eos))
34-
self.asr.add(self.llm, RivaASR.OutputFinal) # runs after asr_final() and any interruptions occur
26+
if self.asr:
27+
self.asr.add(PrintStream(partial=False, prefix='## ', color='blue'), AutoASR.OutputFinal)
28+
self.asr.add(PrintStream(partial=False, prefix='>> ', color='magenta'), AutoASR.OutputPartial)
29+
30+
self.asr.add(self.asr_partial, AutoASR.OutputPartial) # pause output when user is speaking
31+
self.asr.add(self.asr_final, AutoASR.OutputFinal) # clear queues on final ASR transcript
32+
self.asr.add(self.llm, AutoASR.OutputFinal) # runs after asr_final() and any interruptions occur
33+
34+
self.asr_history = None # store the partial ASR transcript
3535

3636
# TTS
37-
self.tts = RivaTTS(**kwargs)
38-
self.tts_output = RateLimit(kwargs['sample_rate_hz'], chunk=9600) # slow down TTS to realtime and be able to pause it
39-
40-
self.tts.add(self.tts_output)
41-
self.llm.add(self.tts, ChatQuery.OutputWords)
37+
self.tts = AutoTTS.from_pretrained(**kwargs)
4238

43-
# Audio Output
44-
self.audio_output_device = kwargs.get('audio_output_device')
45-
self.audio_output_file = kwargs.get('audio_output_file')
46-
47-
if self.audio_output_device is not None:
48-
self.audio_output_device = AudioOutputDevice(**kwargs)
49-
self.tts_output.add(self.audio_output_device)
50-
51-
if self.audio_output_file is not None:
52-
self.audio_output_file = AudioOutputFile(**kwargs)
53-
self.tts_output.add(self.audio_output_file)
39+
if self.tts:
40+
self.tts_output = RateLimit(kwargs['sample_rate_hz'], chunk=9600) # slow down TTS to realtime and be able to pause it
41+
self.tts.add(self.tts_output)
42+
self.llm.add(self.tts, ChatQuery.OutputWords)
43+
44+
self.audio_output_device = kwargs.get('audio_output_device')
45+
self.audio_output_file = kwargs.get('audio_output_file')
46+
47+
if self.audio_output_device is not None:
48+
self.audio_output_device = AudioOutputDevice(**kwargs)
49+
self.tts_output.add(self.audio_output_device)
50+
51+
if self.audio_output_file is not None:
52+
self.audio_output_file = AudioOutputFile(**kwargs)
53+
self.tts_output.add(self.audio_output_file)
5454

5555
# text prompts from web UI or CLI
5656
self.prompt = UserPrompt(interactive=True, **kwargs)
5757
self.prompt.add(self.llm)
5858

59-
self.pipeline = [self.prompt, self.asr]
59+
# setup pipeline with two entry nodes
60+
self.pipeline = [self.prompt]
61+
62+
if self.asr:
63+
self.pipeline.append(self.asr)
6064

6165
def asr_partial(self, text):
6266
self.asr_history = text
6367
if len(text.split(' ')) < 2:
6468
return
65-
self.tts_output.pause(1.0)
69+
if self.tts:
70+
self.tts_output.pause(1.0)
6671

6772
def asr_final(self, text):
6873
self.asr_history = None
74+
self.on_interrupt()
6975

70-
self.llm.interrupt()
71-
self.tts.interrupt()
72-
73-
self.tts_output.interrupt(block=False) # might be paused/asleep
74-
75-
def on_eos(self, text):
76-
if text.endswith('</s>'):
77-
print_table(self.llm.model.stats)
78-
76+
def on_interrupt(self):
77+
self.llm.interrupt(recursive=False)
78+
if self.tts:
79+
self.tts.interrupt(recursive=False)
80+
self.tts_output.interrupt(block=False, recursive=False) # might be paused/asleep
7981

8082
if __name__ == "__main__":
8183
parser = ArgParser(extras=ArgParser.Defaults+['asr', 'tts', 'audio_output'])

packages/llm/local_llm/agents/web_chat.py

+15-6
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,14 @@ def __init__(self, **kwargs):
2525
"""
2626
super().__init__(**kwargs)
2727

28-
self.asr.add(self.on_asr_partial, RivaASR.OutputPartial)
29-
#self.asr.add(self.on_asr_final, RivaASR.OutputFinal)
28+
if self.asr:
29+
self.asr.add(self.on_asr_partial, RivaASR.OutputPartial)
30+
#self.asr.add(self.on_asr_final, RivaASR.OutputFinal)
3031

3132
self.llm.add(self.on_llm_reply)
32-
self.tts_output.add(self.on_tts_samples)
33+
34+
if self.tts:
35+
self.tts_output.add(self.on_tts_samples)
3336

3437
self.server = WebServer(msg_callback=self.on_message, **kwargs)
3538

@@ -40,13 +43,19 @@ def on_message(self, msg, msg_type=0, metadata='', **kwargs):
4043
self.send_chat_history()
4144
if 'client_state' in msg:
4245
if msg['client_state'] == 'connected':
46+
if self.tts:
47+
self.server.send_message({'tts_voices': self.tts.voices, 'tts_voice': self.tts.voice, 'tts_rate': self.tts.rate})
4348
threading.Timer(1.0, lambda: self.send_chat_history()).start()
44-
if 'tts_voice' in msg:
49+
if 'tts_voice' in msg and self.tts:
4550
self.tts.voice = msg['tts_voice']
51+
if 'tts_rate' in msg and self.tts:
52+
self.tts.rate = float(msg['tts_rate'])
4653
elif msg_type == WebServer.MESSAGE_TEXT: # chat input
54+
self.on_interrupt()
4755
self.prompt(msg.strip('"'))
4856
elif msg_type == WebServer.MESSAGE_AUDIO: # web audio (mic)
49-
self.asr(msg)
57+
if self.asr:
58+
self.asr(msg)
5059
elif msg_type == WebServer.MESSAGE_IMAGE:
5160
logging.info(f"recieved {metadata} image message {msg.size} -> {msg.filename}")
5261
self.llm.chat_history.reset()
@@ -79,7 +88,7 @@ def send_chat_history(self, history=None):
7988

8089
history = history.to_list()
8190

82-
if self.asr_history:
91+
if self.asr and self.asr_history:
8392
history.append({'role': 'user', 'text': self.asr_history})
8493

8594
def web_text(text):

packages/llm/local_llm/local_llm.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def from_pretrained(model, api=None, **kwargs):
5151
model_name = os.path.basename(model)
5252

5353
if not api:
54-
api = default_model_api(model_path, quant)
54+
api = default_model_api(model_path, kwargs.get('quant'))
5555

5656
kwargs['name'] = model_name
5757
kwargs['api'] = api

packages/llm/local_llm/plugins/audio/auto_asr.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -39,4 +39,14 @@ def from_pretrained(asr=None, **kwargs):
3939
return RivaASR(**kwargs)
4040
else:
4141
raise ValueError(f"ASR model type should be 'riva'")
42-
42+
43+
def add_punctuation(self, text):
44+
"""
45+
Make sure that the transcript ends in some kind of punctuation
46+
"""
47+
x = text.strip()
48+
49+
if not any([x[-1] == y for y in ('.', ',', '?', '!', ':')]):
50+
return text + '.'
51+
52+
return text

packages/llm/local_llm/plugins/audio/auto_tts.py

+30-7
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,9 @@ def from_pretrained(tts=None, **kwargs):
4040
return None
4141

4242
if FastPitchTTS.is_model(tts):
43-
return FastPitchTTS(model=tts, **kwargs)
43+
return FastPitchTTS(**{**kwargs, 'model': tts})
4444
elif XTTS.is_model(tts):
45-
return XTTS(model=tts, **kwargs)
45+
return XTTS(**{**kwargs, 'model': tts})
4646
elif tts.lower() == 'riva':
4747
return RivaTTS(**kwargs)
4848
else:
@@ -92,7 +92,6 @@ def buffer_text(self, text):
9292
# see if input is needed to prevent a gap-out
9393
if 'time' in self.buffering:
9494
timeout = self.needs_text_by - time.perf_counter() - 0.05 # TODO make this RTFX factor adjustable
95-
9695
if timeout > 0:
9796
return None # we can keep accumulating text
9897

@@ -111,7 +110,7 @@ def buffer_text(self, text):
111110
return None
112111

113112
# for commas, make sure there are at least a handful of proceeding words
114-
if self.text_buffer[punc_pos] == ',' and len(self.text_buffer[:punc_pos].split(' ')) < 4:
113+
if len(self.text_buffer[:punc_pos].split(' ')) < 4: #and self.text_buffer[punc_pos] == ',':
115114
return None
116115

117116
# make sure that the character following the punctuation isn't alphanumeric
@@ -162,10 +161,14 @@ def filter_text(self, text, numbers_to_words=False):
162161
return None
163162

164163
# text = text.strip()
165-
text = text.replace('</s>', '')
164+
for stop_token in StopTokens:
165+
text = text.replace(stop_token, '')
166+
167+
#text = text.replace('</s>', '')
166168
text = text.replace('\n', ' ')
167-
#text = text.replace(' ', ' ')
168-
169+
text = text.replace('...', ' ')
170+
text = self.filter_chars(text)
171+
169172
if numbers_to_words:
170173
text = self.numbers_to_words(text)
171174

@@ -174,6 +177,26 @@ def filter_text(self, text, numbers_to_words=False):
174177

175178
return text
176179

180+
def filter_chars(self, text):
181+
"""
182+
Filter out non-alphanumeric and non-punctuation characters
183+
"""
184+
def filter_char(input):
185+
for idx, char in enumerate(input):
186+
if char.isalnum() or any([char == x for x in ('.', ',', '?', '!', ':', ';', '-', "'", '"', ' ', '/')]):
187+
continue
188+
else:
189+
return input.replace(char, ' ')
190+
return input
191+
192+
while True:
193+
filtered = filter_char(text)
194+
if filtered == text:
195+
return text
196+
else:
197+
text = filtered
198+
continue
199+
177200
def numbers_to_words(self, text):
178201
"""
179202
Convert instances of numbers to words in the text.

packages/llm/local_llm/plugins/audio/riva_asr.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def __init__(self, riva_server='localhost:50051',
5353
self.sample_rate = sample_rate_hz
5454
self.confidence_threshold = asr_confidence
5555
self.silence_threshold = asr_silence
56-
self.keep_alive_timeout = 99 # requests timeout after 1000 seconds
56+
self.keep_alive_timeout = 5 # requests timeout after 1000 seconds
5757

5858
self.asr_service = riva.client.ASRService(self.auth)
5959

@@ -104,7 +104,7 @@ def generate(self, audio_generator):
104104
score = result.alternatives[0].confidence
105105
if score >= self.confidence_threshold:
106106
logging.debug(f"submitting ASR transcript (confidence={score:.3f}) -> '{transcript}'")
107-
self.output(transcript, AutoASR.OutputFinal)
107+
self.output(self.add_punctuation(transcript), AutoASR.OutputFinal)
108108
else:
109109
logging.warning(f"dropping ASR transcript (confidence={score:.3f} < {self.confidence_threshold:.3f}) -> '{transcript}'")
110110
else:

packages/llm/local_llm/plugins/audio/riva_tts.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def __init__(self, riva_server='localhost:50051',
4343
self.pitch = voice_pitch
4444
self.volume = voice_volume
4545

46-
self.language = language
46+
self.language = language_code
4747
self.sample_rate = sample_rate_hz
4848

4949
# find out how to query these for non-English models

packages/llm/local_llm/plugins/process_proxy.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,9 @@ def input(self, input):
4444
self.data_parent.send_bytes(input)
4545

4646
def start(self):
47-
self.control_parent.send('start')
48-
self.assert_message(self.control_parent.recv(), 'started')
47+
if not self.is_alive():
48+
self.control_parent.send('start')
49+
self.assert_message(self.control_parent.recv(), 'started')
4950
return super().start()
5051

5152
def run(self):

packages/llm/local_llm/utils/args.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def __init__(self, extras=Defaults, **kwargs):
9898
self.add_argument("--language-code", default="en-US", help="Language code of the ASR/TTS to be used.")
9999

100100
if 'tts' in extras:
101-
self.add_argument("--tts", type=str, default="riva", help="name of path of the TTS model to use (e.g. 'riva', 'xtts', 'none', 'disabled')")
101+
self.add_argument("--tts", type=str, default=None, help="name of path of the TTS model to use (e.g. 'riva', 'xtts', 'none', 'disabled')")
102102
self.add_argument("--tts-buffering", type=str, default="punctuation", help="buffering method for TTS ('none', 'punctuation', 'time', 'punctuation,time')")
103103
self.add_argument("--voice", type=str, default="English-US.Female-1", help="Voice model name to use for TTS")
104104
self.add_argument("--voice-rate", type=float, default=1.0, help="TTS SSML voice speaker rate (between 25-250%%)")
@@ -107,7 +107,7 @@ def __init__(self, extras=Defaults, **kwargs):
107107
#self.add_argument("--voice-min-words", type=int, default=4, help="the minimum number of words the TTS should wait to speak")
108108

109109
if 'asr' in extras:
110-
self.add_argument("--asr", type=str, default="riva", help="name or path of the ASR model to use (e.g. 'riva', 'none', 'disabled')")
110+
self.add_argument("--asr", type=str, default=None, help="name or path of the ASR model to use (e.g. 'riva', 'none', 'disabled')")
111111
self.add_argument("--asr-confidence", type=float, default=-2.5, help="minimum ASR confidence (only applies to 'final' transcripts)")
112112
self.add_argument("--asr-silence", type=float, default=-1.0, help="audio with RMS equal to or below this amount will be considered silent (negative will disable silence detection)")
113113
self.add_argument("--asr-chunk", type=int, default=1600, help="the number of audio samples to buffer as input to ASR")

packages/llm/local_llm/web/server.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,9 @@ def __init__(self, web_host='0.0.0.0', web_port=8050, ws_port=49000,
111111
self.ssl_context.load_cert_chain(certfile=self.ssl_cert, keyfile=self.ssl_key)
112112

113113
# websocket
114-
self.ws_port = ws_port
115114
self.websocket = None
115+
self.ws_port = ws_port
116+
self.kwargs['ws_port'] = ws_port
116117

117118
self.ws_server = websocket_serve(self.on_websocket, host=self.host, port=self.ws_port, ssl_context=self.ssl_context, max_size=None)
118119
self.ws_thread = threading.Thread(target=lambda: self.ws_server.serve_forever(), daemon=True)

0 commit comments

Comments
 (0)