Skip to content

Commit 65bd8cf

Browse files
feat: add deepgram asr flux model (#2088)
* feat: add deepgram asr flux model * feat: add deepgram asr flux model
1 parent 3dfcdf7 commit 65bd8cf

File tree

9 files changed

+318
-87
lines changed

9 files changed

+318
-87
lines changed

ai_agents/agents/ten_packages/extension/deepgram_asr_python/README.md

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
All parameters are configured through the `params` object:
88

9+
### nova model
10+
911
```json
1012
{
1113
"params": {
@@ -21,3 +23,20 @@ All parameters are configured through the `params` object:
2123
}
2224
}
2325
```
26+
27+
### flux model
28+
29+
```json
30+
{
31+
"params": {
32+
"key": "${env:DEEPGRAM_API_KEY}",
33+
"url": "wss://api.deepgram.com/v2/listen",
34+
"model": "flux-general-en",
35+
"sample_rate": 16000,
36+
"encoding": "linear16",
37+
"eager_eot_threshold": 0.6,
38+
"eot_threshold": 0.8,
39+
"eot_timeout_ms": 700
40+
}
41+
}
42+
```

ai_agents/agents/ten_packages/extension/deepgram_asr_python/config.py

Lines changed: 83 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,87 @@
33
from ten_ai_base.utils import encrypt
44

55

6+
# Default params for nova models (v1 API)
7+
NOVA_DEFAULT_PARAMS = {
8+
"url": "wss://api.deepgram.com/v1/listen",
9+
"model": "nova-3",
10+
"language": "en-US",
11+
"sample_rate": 16000,
12+
"encoding": "linear16",
13+
"interim_results": True,
14+
"punctuate": True,
15+
"keep_alive": True,
16+
}
17+
18+
# Default params for flux models (v2 API)
19+
FLUX_DEFAULT_PARAMS = {
20+
"url": "wss://api.deepgram.com/v2/listen",
21+
"model": "flux-general-en",
22+
"sample_rate": 16000,
23+
"encoding": "linear16",
24+
"eager_eot_threshold": 0.6,
25+
"eot_threshold": 0.8,
26+
"eot_timeout_ms": 700,
27+
"keep_alive": False,
28+
}
29+
30+
631
class DeepgramASRConfig(BaseModel):
732
"""Deepgram ASR Configuration"""
833

934
# Debugging and dumping
1035
dump: bool = False
1136
dump_path: str = "/tmp"
12-
finalize_mode: str = "disconnect" # "disconnect" or "mute_pkg"
37+
finalize_mode: str = "mute_pkg" # "flush_api" or "mute_pkg" or "ignore"
1338
mute_pkg_duration_ms: int = 1000
1439
# Additional parameters
15-
params: Dict[str, Any] = Field(default_factory=dict)
40+
params: dict[str, Any] = Field(default_factory=dict)
41+
42+
def _get_default_params(self, model: str) -> Dict[str, Any]:
43+
"""Get default params based on model type."""
44+
model_lower = (model or "").lower()
45+
if "flux" in model_lower:
46+
return FLUX_DEFAULT_PARAMS.copy()
47+
else:
48+
# Default to nova params (includes nova-3, nova-2, etc.)
49+
return NOVA_DEFAULT_PARAMS.copy()
50+
51+
def _get_default_finalize_mode(self, model: str) -> str:
52+
"""Get default finalize mode based on model type."""
53+
model_lower = (model or "").lower()
54+
if "flux" in model_lower:
55+
return "ignore"
56+
else:
57+
return "flush_api"
58+
59+
def apply_defaults(self) -> None:
60+
"""Apply default params based on model type."""
61+
params_dict = self.params or {}
62+
# Get current model or use default
63+
current_model = params_dict.get("model", "") or "nova-3"
64+
65+
# Get defaults for this model type
66+
defaults = self._get_default_params(current_model)
67+
68+
# Ensure model is set
69+
if not params_dict.get("model"):
70+
params_dict["model"] = current_model
71+
72+
# Apply defaults for missing params
73+
for key, value in defaults.items():
74+
if key not in params_dict or params_dict[key] is None:
75+
params_dict[key] = value
76+
77+
self.params = params_dict
78+
79+
# Set finalize_mode from params or use default based on model type
80+
if (
81+
"finalize_mode" in params_dict
82+
and params_dict["finalize_mode"] is not None
83+
):
84+
self.finalize_mode = params_dict["finalize_mode"]
85+
else:
86+
self.finalize_mode = self._get_default_finalize_mode(current_model)
1687

1788
def update(self, params: Dict[str, Any]) -> None:
1889
"""Update configuration with additional parameters."""
@@ -31,6 +102,11 @@ def to_json(self, sensitive_handling: bool = False) -> str:
31102
config_dict["params"][key] = encrypt(value)
32103
return str(config_dict)
33104

105+
@property
106+
def is_flux_model(self) -> bool:
107+
params_dict = self.params or {}
108+
return "flux" in (params_dict.get("model", "") or "").lower()
109+
34110
@property
35111
def normalized_language(self) -> str:
36112
"""Convert language code to normalized format for Deepgram"""
@@ -49,5 +125,9 @@ def normalized_language(self) -> str:
49125
"ar": "ar-AE",
50126
}
51127
params_dict = self.params or {}
52-
language_code = params_dict.get("language", "") or ""
128+
if self.is_flux_model:
129+
# For flux models, use the 'language' param directly
130+
language_code = params_dict.get("language", "en-US")
131+
else:
132+
language_code = params_dict.get("language", "") or ""
53133
return language_map.get(language_code, language_code)

ai_agents/agents/ten_packages/extension/deepgram_asr_python/extension.py

Lines changed: 160 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from datetime import datetime
22
import os
33
import asyncio
4+
import copy
45
from typing import Dict, Any
56

67
from typing_extensions import override
@@ -22,6 +23,7 @@
2223
from ten_runtime import (
2324
AsyncTenEnv,
2425
AudioFrame,
26+
Data,
2527
)
2628
from ten_ai_base.const import (
2729
LOG_CATEGORY_VENDOR,
@@ -73,6 +75,10 @@ async def on_init(self, ten_env: AsyncTenEnv) -> None:
7375
try:
7476
self.config = DeepgramASRConfig.model_validate_json(config_json)
7577
self.config.update(self.config.params)
78+
79+
# Apply default params based on model type (nova or flux)
80+
self.config.apply_defaults()
81+
7682
ten_env.log_info(
7783
f"config: {self.config.to_json(sensitive_handling=True)}",
7884
category=LOG_CATEGORY_KEY_POINT,
@@ -164,12 +170,73 @@ async def finalize(self, _session_id: str | None) -> None:
164170
)
165171

166172
finalize_mode = self.config.finalize_mode
167-
if finalize_mode == "disconnect":
168-
await self._handle_finalize_disconnect()
169-
elif finalize_mode == "mute_pkg":
170-
await self._handle_finalize_mute_pkg()
173+
if self.config.is_flux_model:
174+
if finalize_mode == "ignore":
175+
await self._finalize_end()
176+
else:
177+
await self._handle_finalize_mute_pkg()
171178
else:
172-
raise ValueError(f"invalid finalize mode: {finalize_mode}")
179+
if finalize_mode == "flush_api":
180+
await self._handle_finalize_flush_api()
181+
elif finalize_mode == "mute_pkg":
182+
await self._handle_finalize_mute_pkg()
183+
else:
184+
raise ValueError(f"invalid finalize mode: {finalize_mode}")
185+
186+
async def _handle_event_result(self, event: str) -> None:
187+
"""Handle ASR event result"""
188+
self.ten_env.log_info(
189+
f"_handle_event_result: {event}",
190+
category=LOG_CATEGORY_KEY_POINT,
191+
)
192+
if event == "StartOfTurn":
193+
data = Data.create("sos")
194+
await self.ten_env.send_data(data)
195+
elif event == "EndOfTurn":
196+
data = Data.create("eos")
197+
await self.ten_env.send_data(data)
198+
elif event == "EagerEndOfTurn":
199+
data = Data.create("eager_eos")
200+
await self.ten_env.send_data(data)
201+
202+
def _build_metadata_with_asr_info(
203+
self,
204+
additional_fields: Dict[str, Any] | None = None,
205+
) -> Dict[str, Any]:
206+
"""Build metadata according to protocol: session_id at root, others in asr_info.
207+
208+
Args:
209+
additional_fields: Additional fields to add to asr_info
210+
211+
Returns:
212+
Metadata dict with structure: {"session_id": "...", "asr_info": {...}}
213+
"""
214+
# Start with a copy of base metadata if available
215+
base_metadata = (
216+
copy.deepcopy(self.metadata) if self.metadata is not None else {}
217+
)
218+
219+
# Extract session_id from base metadata if present
220+
session_id = base_metadata.pop("session_id", None)
221+
222+
# Collect all other fields into asr_info
223+
asr_info = copy.deepcopy(base_metadata)
224+
225+
# Add vendor field to asr_info
226+
asr_info["vendor"] = self.vendor()
227+
asr_info["model"] = self.config.params.get("model", "unknown")
228+
229+
# Add additional fields to asr_info if provided
230+
if additional_fields:
231+
asr_info.update(additional_fields)
232+
233+
# Build final metadata structure
234+
metadata: Dict[str, Any] = {}
235+
if session_id is not None:
236+
metadata["session_id"] = session_id
237+
metadata["asr_info"] = asr_info
238+
239+
return metadata
173240

174241
async def _handle_asr_result(
175242
self,
@@ -178,6 +245,7 @@ async def _handle_asr_result(
178245
start_ms: int = 0,
179246
duration_ms: int = 0,
180247
language: str = "",
248+
metadata: Dict[str, Any] | None = None,
181249
):
182250
"""Process ASR recognition result"""
183251
assert self.config is not None
@@ -192,12 +260,22 @@ async def _handle_asr_result(
192260
duration_ms=duration_ms,
193261
language=language,
194262
words=[],
263+
metadata=metadata if metadata is not None else {},
195264
)
196265

197266
await self.send_asr_result(asr_result)
198267

199268
async def _handle_finalize_disconnect(self):
200-
"""Handle disconnect mode finalization"""
269+
"""Handle disconnect mode finalization.
270+
271+
Deprecated: This method uses flush_api for finalization.
272+
"""
273+
if self.recognition:
274+
await self.recognition.stop()
275+
self.ten_env.log_debug("Deepgram finalize completed")
276+
277+
async def _handle_finalize_flush_api(self):
278+
"""Handle flush API mode finalization"""
201279
if self.recognition:
202280
await self.recognition.stop()
203281
self.ten_env.log_debug("Deepgram finalize completed")
@@ -324,41 +402,90 @@ async def on_open(self) -> None:
324402
@override
325403
async def on_result(self, message_data: Dict[str, Any]) -> None:
326404
"""Handle recognition result callback"""
405+
assert self.config is not None
327406

328407
try:
329408
# Extract basic fields
330-
is_final = message_data.get("is_final", False)
409+
if self.config.is_flux_model:
410+
event = message_data.get("event", "")
411+
result_to_send = message_data.get("transcript", "")
331412

332-
# Extract transcript and words from channel.alternatives[0]
333-
channel = message_data.get("channel", {})
334-
alternatives = channel.get("alternatives", [])
335-
if not alternatives:
336-
self.ten_env.log_debug("No alternatives in Deepgram result")
337-
return
413+
is_final = event == "EndOfTurn"
338414

339-
first_alt = alternatives[0]
340-
result_to_send = first_alt.get("transcript", "")
415+
start_ms = int(
416+
message_data.get("audio_window_start", 0) * 1000 or 0
417+
)
418+
end_ms = int(
419+
message_data.get("audio_window_end", 0) * 1000 or 0
420+
)
421+
duration_ms = end_ms - start_ms
341422

342-
# Extract timing information (in seconds, convert to milliseconds)
343-
start_seconds = message_data.get("start", 0)
344-
duration_seconds = message_data.get("duration", 0)
345-
start_ms = int(start_seconds * 1000)
346-
duration_ms = int(duration_seconds * 1000)
423+
actual_start_ms = int(
424+
self.audio_timeline.get_audio_duration_before_time(start_ms)
425+
+ self.sent_user_audio_duration_ms_before_last_reset
426+
)
427+
await self._handle_event_result(event)
347428

348-
# Calculate actual start time using audio timeline
349-
actual_start_ms = int(
350-
self.audio_timeline.get_audio_duration_before_time(start_ms)
351-
+ self.sent_user_audio_duration_ms_before_last_reset
352-
)
429+
# Build metadata with asr_info for flux model
430+
turn_index = message_data.get("turn_index")
431+
end_of_turn_confidence = message_data.get(
432+
"end_of_turn_confidence"
433+
)
353434

354-
# Process ASR result
355-
await self._handle_asr_result(
356-
text=result_to_send,
357-
final=is_final,
358-
start_ms=actual_start_ms,
359-
duration_ms=duration_ms,
360-
language=self.config.normalized_language,
361-
)
435+
asr_info_fields: Dict[str, Any] = {
436+
"turn_event": event,
437+
}
438+
if turn_index is not None:
439+
asr_info_fields["turn_index"] = turn_index
440+
if end_of_turn_confidence is not None:
441+
asr_info_fields["end_of_turn_confidence"] = (
442+
end_of_turn_confidence
443+
)
444+
445+
metadata = self._build_metadata_with_asr_info(asr_info_fields)
446+
447+
# Process ASR result
448+
await self._handle_asr_result(
449+
text=result_to_send,
450+
final=is_final,
451+
start_ms=actual_start_ms,
452+
duration_ms=duration_ms,
453+
language=self.config.normalized_language,
454+
metadata=metadata,
455+
)
456+
457+
else:
458+
is_final = message_data.get("is_final", False)
459+
# Extract transcript and words from channel.alternatives[0]
460+
channel = message_data.get("channel", {})
461+
alternatives = channel.get("alternatives", [])
462+
if not alternatives:
463+
self.ten_env.log_debug("No alternatives in Deepgram result")
464+
return
465+
466+
first_alt = alternatives[0]
467+
result_to_send = first_alt.get("transcript", "")
468+
469+
# Extract timing information (in seconds, convert to milliseconds)
470+
start_seconds = message_data.get("start", 0)
471+
duration_seconds = message_data.get("duration", 0)
472+
start_ms = int(start_seconds * 1000)
473+
duration_ms = int(duration_seconds * 1000)
474+
475+
# Calculate actual start time using audio timeline
476+
actual_start_ms = int(
477+
self.audio_timeline.get_audio_duration_before_time(start_ms)
478+
+ self.sent_user_audio_duration_ms_before_last_reset
479+
)
480+
481+
# Process ASR result
482+
await self._handle_asr_result(
483+
text=result_to_send,
484+
final=is_final,
485+
start_ms=actual_start_ms,
486+
duration_ms=duration_ms,
487+
language=self.config.normalized_language,
488+
)
362489

363490
except Exception as e:
364491
self.ten_env.log_error(f"Error processing Deepgram result: {e}")

0 commit comments

Comments
 (0)