11from datetime import datetime
22import os
33import asyncio
4+ import copy
45from typing import Dict , Any
56
67from typing_extensions import override
2223from ten_runtime import (
2324 AsyncTenEnv ,
2425 AudioFrame ,
26+ Data ,
2527)
2628from ten_ai_base .const import (
2729 LOG_CATEGORY_VENDOR ,
@@ -73,6 +75,10 @@ async def on_init(self, ten_env: AsyncTenEnv) -> None:
7375 try :
7476 self .config = DeepgramASRConfig .model_validate_json (config_json )
7577 self .config .update (self .config .params )
78+
79+ # Apply default params based on model type (nova or flux)
80+ self .config .apply_defaults ()
81+
7682 ten_env .log_info (
7783 f"config: { self .config .to_json (sensitive_handling = True )} " ,
7884 category = LOG_CATEGORY_KEY_POINT ,
@@ -164,12 +170,73 @@ async def finalize(self, _session_id: str | None) -> None:
164170 )
165171
166172 finalize_mode = self .config .finalize_mode
167- if finalize_mode == "disconnect" :
168- await self ._handle_finalize_disconnect ()
169- elif finalize_mode == "mute_pkg" :
170- await self ._handle_finalize_mute_pkg ()
173+ if self .config .is_flux_model :
174+ if finalize_mode == "ignore" :
175+ await self ._finalize_end ()
176+ else :
177+ await self ._handle_finalize_mute_pkg ()
171178 else :
172- raise ValueError (f"invalid finalize mode: { finalize_mode } " )
179+ if finalize_mode == "flush_api" :
180+ await self ._handle_finalize_flush_api ()
181+ elif finalize_mode == "mute_pkg" :
182+ await self ._handle_finalize_mute_pkg ()
183+ else :
184+ raise ValueError (f"invalid finalize mode: { finalize_mode } " )
185+
186+ async def _handle_event_result (self , event : str ) -> None :
187+ """Handle ASR event result"""
188+ self .ten_env .log_info (
189+ f"_handle_event_result: { event } " ,
190+ category = LOG_CATEGORY_KEY_POINT ,
191+ )
192+ if event == "StartOfTurn" :
193+ data = Data .create ("sos" )
194+ await self .ten_env .send_data (data )
195+ elif event == "EndOfTurn" :
196+ data = Data .create ("eos" )
197+ await self .ten_env .send_data (data )
198+ elif event == "EagerEndOfTurn" :
199+ data = Data .create ("eager_eos" )
200+ await self .ten_env .send_data (data )
201+
202+ def _build_metadata_with_asr_info (
203+ self ,
204+ additional_fields : Dict [str , Any ] | None = None ,
205+ ) -> Dict [str , Any ]:
206+ """Build metadata according to protocol: session_id at root, others in asr_info.
207+
208+ Args:
209+ additional_fields: Additional fields to add to asr_info
210+
211+ Returns:
212+ Metadata dict with structure: {"session_id": "...", "asr_info": {...}}
213+ """
214+ # Start with a copy of base metadata if available
215+ base_metadata = (
216+ copy .deepcopy (self .metadata ) if self .metadata is not None else {}
217+ )
218+
219+ # Extract session_id from base metadata if present
220+ session_id = base_metadata .pop ("session_id" , None )
221+
222+ # Collect all other fields into asr_info
223+ asr_info = copy .deepcopy (base_metadata )
224+
225+ # Add vendor field to asr_info
226+ asr_info ["vendor" ] = self .vendor ()
227+ asr_info ["model" ] = self .config .params .get ("model" , "unknown" )
228+
229+ # Add additional fields to asr_info if provided
230+ if additional_fields :
231+ asr_info .update (additional_fields )
232+
233+ # Build final metadata structure
234+ metadata : Dict [str , Any ] = {}
235+ if session_id is not None :
236+ metadata ["session_id" ] = session_id
237+ metadata ["asr_info" ] = asr_info
238+
239+ return metadata
173240
174241 async def _handle_asr_result (
175242 self ,
@@ -178,6 +245,7 @@ async def _handle_asr_result(
178245 start_ms : int = 0 ,
179246 duration_ms : int = 0 ,
180247 language : str = "" ,
248+ metadata : Dict [str , Any ] | None = None ,
181249 ):
182250 """Process ASR recognition result"""
183251 assert self .config is not None
@@ -192,12 +260,22 @@ async def _handle_asr_result(
192260 duration_ms = duration_ms ,
193261 language = language ,
194262 words = [],
263+ metadata = metadata if metadata is not None else {},
195264 )
196265
197266 await self .send_asr_result (asr_result )
198267
199268 async def _handle_finalize_disconnect (self ):
200- """Handle disconnect mode finalization"""
269+ """Handle disconnect mode finalization.
270+
271+ Deprecated: This method uses flush_api for finalization.
272+ """
273+ if self .recognition :
274+ await self .recognition .stop ()
275+ self .ten_env .log_debug ("Deepgram finalize completed" )
276+
277+ async def _handle_finalize_flush_api (self ):
278+ """Handle flush API mode finalization"""
201279 if self .recognition :
202280 await self .recognition .stop ()
203281 self .ten_env .log_debug ("Deepgram finalize completed" )
@@ -324,41 +402,90 @@ async def on_open(self) -> None:
324402 @override
325403 async def on_result (self , message_data : Dict [str , Any ]) -> None :
326404 """Handle recognition result callback"""
405+ assert self .config is not None
327406
328407 try :
329408 # Extract basic fields
330- is_final = message_data .get ("is_final" , False )
409+ if self .config .is_flux_model :
410+ event = message_data .get ("event" , "" )
411+ result_to_send = message_data .get ("transcript" , "" )
331412
332- # Extract transcript and words from channel.alternatives[0]
333- channel = message_data .get ("channel" , {})
334- alternatives = channel .get ("alternatives" , [])
335- if not alternatives :
336- self .ten_env .log_debug ("No alternatives in Deepgram result" )
337- return
413+ is_final = event == "EndOfTurn"
338414
339- first_alt = alternatives [0 ]
340- result_to_send = first_alt .get ("transcript" , "" )
415+ start_ms = int (
416+ message_data .get ("audio_window_start" , 0 ) * 1000 or 0
417+ )
418+ end_ms = int (
419+ message_data .get ("audio_window_end" , 0 ) * 1000 or 0
420+ )
421+ duration_ms = end_ms - start_ms
341422
342- # Extract timing information (in seconds, convert to milliseconds)
343- start_seconds = message_data . get ( "start" , 0 )
344- duration_seconds = message_data . get ( "duration" , 0 )
345- start_ms = int ( start_seconds * 1000 )
346- duration_ms = int ( duration_seconds * 1000 )
423+ actual_start_ms = int (
424+ self . audio_timeline . get_audio_duration_before_time ( start_ms )
425+ + self . sent_user_audio_duration_ms_before_last_reset
426+ )
427+ await self . _handle_event_result ( event )
347428
348- # Calculate actual start time using audio timeline
349- actual_start_ms = int (
350- self . audio_timeline . get_audio_duration_before_time ( start_ms )
351- + self . sent_user_audio_duration_ms_before_last_reset
352- )
429+ # Build metadata with asr_info for flux model
430+ turn_index = message_data . get ( "turn_index" )
431+ end_of_turn_confidence = message_data . get (
432+ "end_of_turn_confidence"
433+ )
353434
354- # Process ASR result
355- await self ._handle_asr_result (
356- text = result_to_send ,
357- final = is_final ,
358- start_ms = actual_start_ms ,
359- duration_ms = duration_ms ,
360- language = self .config .normalized_language ,
361- )
435+ asr_info_fields : Dict [str , Any ] = {
436+ "turn_event" : event ,
437+ }
438+ if turn_index is not None :
439+ asr_info_fields ["turn_index" ] = turn_index
440+ if end_of_turn_confidence is not None :
441+ asr_info_fields ["end_of_turn_confidence" ] = (
442+ end_of_turn_confidence
443+ )
444+
445+ metadata = self ._build_metadata_with_asr_info (asr_info_fields )
446+
447+ # Process ASR result
448+ await self ._handle_asr_result (
449+ text = result_to_send ,
450+ final = is_final ,
451+ start_ms = actual_start_ms ,
452+ duration_ms = duration_ms ,
453+ language = self .config .normalized_language ,
454+ metadata = metadata ,
455+ )
456+
457+ else :
458+ is_final = message_data .get ("is_final" , False )
459+ # Extract transcript and words from channel.alternatives[0]
460+ channel = message_data .get ("channel" , {})
461+ alternatives = channel .get ("alternatives" , [])
462+ if not alternatives :
463+ self .ten_env .log_debug ("No alternatives in Deepgram result" )
464+ return
465+
466+ first_alt = alternatives [0 ]
467+ result_to_send = first_alt .get ("transcript" , "" )
468+
469+ # Extract timing information (in seconds, convert to milliseconds)
470+ start_seconds = message_data .get ("start" , 0 )
471+ duration_seconds = message_data .get ("duration" , 0 )
472+ start_ms = int (start_seconds * 1000 )
473+ duration_ms = int (duration_seconds * 1000 )
474+
475+ # Calculate actual start time using audio timeline
476+ actual_start_ms = int (
477+ self .audio_timeline .get_audio_duration_before_time (start_ms )
478+ + self .sent_user_audio_duration_ms_before_last_reset
479+ )
480+
481+ # Process ASR result
482+ await self ._handle_asr_result (
483+ text = result_to_send ,
484+ final = is_final ,
485+ start_ms = actual_start_ms ,
486+ duration_ms = duration_ms ,
487+ language = self .config .normalized_language ,
488+ )
362489
363490 except Exception as e :
364491 self .ten_env .log_error (f"Error processing Deepgram result: { e } " )
0 commit comments