-
Notifications
You must be signed in to change notification settings - Fork 0
Update cascade model #2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Walkthrough이 변경사항은 ASR(음성 인식), LLM(언어 모델), TTS(음성 합성) 제공자를 플러그 가능하게 지원하는 새로운 CascadeAgent를 도입하며, 기존 에이전트 아키텍처를 확장합니다. 또한 여러 로컬 및 API 기반 제공자 구현과 CLI 통합을 추가합니다. Changes
Sequence Diagram(s)sequenceDiagram
participant Client
participant CascadeAgent
participant ASRProvider
participant LLMProvider
participant TTSProvider
Client->>CascadeAgent: connect()
CascadeAgent->>ASRProvider: initialize()
CascadeAgent->>LLMProvider: initialize()
CascadeAgent->>TTSProvider: initialize()
Client->>CascadeAgent: publish(audio.chunk)
CascadeAgent->>CascadeAgent: Buffer audio data
Client->>CascadeAgent: publish(audio.done)
CascadeAgent->>ASRProvider: transcribe(buffered_audio)
ASRProvider-->>CascadeAgent: transcript text
CascadeAgent->>LLMProvider: stream_completion(messages)
LLMProvider-->>CascadeAgent: LLMResponse chunks
loop For each sentence
CascadeAgent->>TTSProvider: synthesize_stream(sentence)
TTSProvider-->>CascadeAgent: audio bytes
CascadeAgent->>Client: publish(AudioChunkEvent)
end
CascadeAgent->>Client: publish(AudioDoneEvent)
Client->>CascadeAgent: disconnect()
CascadeAgent->>ASRProvider: shutdown()
CascadeAgent->>LLMProvider: shutdown()
CascadeAgent->>TTSProvider: shutdown()
Estimated code review effort🎯 4 (Complex) | ⏱️ ~75 minutes Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 inconclusive)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 10
🤖 Fix all issues with AI agents
In `@src/tau2_voice/agent/cascade.py`:
- Around line 175-178: The base64 decoding of audio chunks (base64.b64decode)
can raise exceptions and currently can break the agent loop; wrap the decode and
buffer append in a try/except that catches decoding errors (e.g.,
binascii.Error, ValueError, TypeError), log a warning (using
self._logger.warning or fallback logging.warning) including the problematic
event metadata, drop that chunk and continue without raising, and only append to
self._audio_buffer when decoding succeeds.
In `@src/tau2_voice/agent/gemini_live.py`:
- Around line 18-19: The project imports override from typing_extensions (seen
in gemini_live.py's "from typing_extensions import override") but
typing_extensions is missing from pyproject.toml dependencies; add
typing_extensions to the dependencies section of pyproject.toml (pick a
compatible minimum version, e.g. a modern release that supports override) so the
import of override resolves for supported Python versions.
In `@src/tau2_voice/providers/asr/whisper_local.py`:
- Around line 94-100: The ImportError handler in the try/except that imports
torch and transformers captures the exception as variable e but doesn't chain
it; update the except block to re-raise the ImportError with the same message
using exception chaining (raise ImportError("transformers or torch not
installed. Run: pip install transformers torch") from e) so the original
exception is preserved—modify the except clause surrounding the imports (torch,
AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline) to use "from e".
In `@src/tau2_voice/providers/base.py`:
- Around line 58-64: The empty hook methods (e.g., async def initialize and
async def shutdown in src/tau2_voice/providers/base.py and the other empty hooks
around the referenced ranges) should not use bare `pass` to avoid Ruff B027;
replace each `pass` with an explicit `return None` (or add a trailing `# noqa:
B027` comment if you intentionally want to keep an empty body) so the async
hooks remain no-op but lint-clean; update the bodies of initialize, shutdown and
the other empty async hook methods at the noted locations accordingly.
In `@src/tau2_voice/providers/llm/local.py`:
- Around line 280-305: The code currently logs potentially sensitive prompts and
LLM outputs via logger.info in the block around chat_messages handling and after
generation (see the system_content extraction from chat_messages[0]["content"],
the generate() function and the final logger.info of text); change these
logger.info calls to a lower-verbosity level (logger.debug or logger.trace) or
remove full content logging and instead log redacted/summary metadata (e.g.,
length, hash, or masked preview) to avoid exposing user data, and ensure any
debugging-only logs are gated by a config flag so production defaults do not
emit raw prompt or full LLM output.
- Around line 85-115: parse_tool_calls currently can append the same tool call
twice when both regex patterns match the same JSON block; to fix, in
parse_tool_calls maintain a seen set (e.g., of normalized JSON strings or a
tuple (name, json.dumps(arguments, sort_keys=True))) keyed by the parsed tc_data
before creating/appending a ToolCall and skip duplicates, and ensure the
generated id uses a stable source (like hash of the normalized string) so
duplicate inputs won't produce distinct ToolCall entries; update references to
match, tc_data, and ToolCall in that function accordingly.
- Around line 191-205: The load_model() function currently hardcodes
device_map="auto" which ignores self.device; change the logic in load_model to
honor self.device by conditionally setting model loading args: if self.device ==
"cpu" set model_kwargs["device_map"] = None and model_kwargs["device"] = "cpu"
(or pass device="cpu" into the pipeline call) else keep device_map="auto" for
automatic GPU placement; update the place where model_kwargs or pipe call is
constructed (referencing load_model, model_kwargs, and pipe_kwargs) so the
transformers pipeline receives the correct device configuration instead of
always using "auto".
In `@src/tau2_voice/providers/llm/openai.py`:
- Around line 216-229: When JSON parsing of tool arguments fails in the block
that checks choice.finish_reason == "tool_calls" (and the similar block around
lines 285-297), don’t replace with an empty dict; preserve the original raw
string so downstream code/debugging can see it. On json.JSONDecodeError, log the
warning and set arguments to a dict that contains the original raw argument
string (e.g. a single key like "_raw" or "__raw_arguments" mapped to
tc_data["arguments"]) before creating the ToolCall (the code that builds
response.tool_calls.append(ToolCall(...)) using tc_data["id"], tc_data["name"],
arguments, requestor="assistant"). Apply the same change to both occurrences
that parse tc_data["arguments"].
In `@src/tau2_voice/run.py`:
- Around line 247-271: run_task_with_index currently lets any exception from
run_task propagate and abort the whole batch; restore the try/except around the
run_task invocation in run_task_with_index to catch exceptions, log the error
(including task.id and the exception), and return a failure result dict (e.g.,
simulation_id None, reward 0.0, success False, duration 0, index task_idx) so
the batch continues. Ensure you re-raise critical control exceptions (like
asyncio.CancelledError and KeyboardInterrupt) instead of swallowing them.
Reference run_task_with_index and run_task when applying the change.
In `@src/tau2_voice/test_cascade.py`:
- Around line 162-209: CascadeAgent.publish currently ignores
TranscriptUpdateEvent so your test sending a TranscriptUpdateEvent won't trigger
the LLM/TTS pipeline; replace the test event with a speak.request-style event
(e.g., SpeakRequestEvent) or simulate the ASR audio sequence using audio.chunk
followed by audio.done to drive the pipeline instead, and ensure you call
agent.publish with the correct event class and field names matching the actual
event definitions; look for CascadeAgent.publish and the test_transcript
creation in test_cascade.py to update the event type and payload accordingly.
🧹 Nitpick comments (9)
src/tau2_voice/providers/llm/local.py (1)
393-399: 불필요한 반복 치환 루프 제거 권장
re.sub는 기본적으로 전체 치환이므로 루프가 중복입니다.♻️ 제안 수정
- for tc in tool_calls: - # Remove the JSON from content - content = re.sub(r'\{"tool_call":\s*\{[^}]+\}\}', '', content) - content = re.sub(r'\{"name":\s*"[^"]+",\s*"arguments":\s*\{[^}]*\}\}', '', content) + content = re.sub(r'\{"tool_call":\s*\{[^}]+\}\}', '', content) + content = re.sub(r'\{"name":\s*"[^"]+",\s*"arguments":\s*\{[^}]*\}\}', '', content)src/tau2_voice/providers/llm/__init__.py (1)
1-6: Ruff 경고(RUF022) 해소를 위해__all__정렬 권장♻️ 제안 수정
-__all__ = ["OpenAILLMProvider", "LocalLLMProvider"] +__all__ = ["LocalLLMProvider", "OpenAILLMProvider"]src/tau2_voice/config.py (1)
6-6: OpenAI 키가 선택값이 된 만큼 사용 시 명시적 검증 권장
OpenAI provider 사용 시 키가 없으면 늦게 실패하므로 초기화 단계에서 명확한 오류를 주는 편이 UX에 좋습니다. (SDK 동작은 버전별로 다를 수 있어 확인 필요)예시(Provider 초기화 시 추가)
if not self._api_key: raise ValueError("OPENAI_API_KEY가 설정되지 않았습니다.")src/tau2_voice/providers/asr/__init__.py (1)
6-6: all 정렬로 Ruff 경고 방지RUF022가 활성화되어 있다면 all 정렬 미준수로 린트 실패가 발생할 수 있습니다. 알파벳 순 정렬을 권장합니다.
🔧 제안 수정
-__all__ = ["WhisperLocalProvider", "OpenAIASRProvider"] +__all__ = ["OpenAIASRProvider", "WhisperLocalProvider"]src/tau2_voice/agent/__init__.py (1)
15-23: all 정렬로 린트 일관성 유지RUF022 규칙을 사용 중이면 all 정렬이 실패 원인이 될 수 있습니다. 알파벳 순 정렬을 권장합니다.
🔧 제안 수정
__all__ = [ "BaseAgent", - "RealtimeAgent", - "HumanAgent", - "UserAgent", - "Qwen3OmniAgent", - "GeminiLiveAgent", - "CascadeAgent", + "CascadeAgent", + "GeminiLiveAgent", + "HumanAgent", + "Qwen3OmniAgent", + "RealtimeAgent", + "UserAgent", ]src/tau2_voice/providers/tts/__init__.py (1)
6-6: all 정렬로 RUF022 해소Ruff 규칙이 적용되어 있다면 all 정렬을 맞춰 두는 편이 안전합니다.
🔧 제안 수정
-__all__ = ["OpenAITTSProvider", "ChatterboxTTSProvider", "ChatterboxMultilingualTTSProvider"] +__all__ = ["ChatterboxMultilingualTTSProvider", "ChatterboxTTSProvider", "OpenAITTSProvider"]src/tau2_voice/providers/tts/openai.py (1)
26-51: clean_text_for_tts 중복 정의 정리 권장동일 로직이
src/tau2_voice/agent/qwen3_omni.py에도 있어 유지보수 시 불일치 위험이 있습니다. 공용 유틸로 이동해 재사용하는 편이 안전합니다.♻️ 예시 방향 (공용 유틸로 이동)
-from tau2_voice.providers.base import BaseTTSProvider +from tau2_voice.providers.base import BaseTTSProvider +from tau2_voice.utils.text import clean_text_for_tts @@ -def clean_text_for_tts(text: str) -> str: - ...src/tau2_voice/orchestrator/orchestrator.py (1)
97-105: 미사용 변수/중복 import 정리
spanish_markers는 사용되지 않고, 함수 내부import re도 상단 import와 중복입니다. 간단히 정리하면 가독성이 좋아집니다.♻️ 제안 수정
- import re - spanish_markers = ["¿", "¡", "políticas", "podrías", "claro"] # Check special characters directly if any(tok in text for tok in ["¿", "¡"]): return True # Check Spanish words with word boundaries spanish_words = ["políticas", "reserva", "podrías", "claro"]src/tau2_voice/providers/tts/chatterbox.py (1)
128-134: 비동기 함수 내부에서는asyncio.get_running_loop()사용 권장비동기 함수 내부에서 이벤트 루프에 접근할 때
asyncio.get_event_loop()대신asyncio.get_running_loop()을 사용하는 것이 Python 공식 문서에서 명시적으로 권장됩니다.get_running_loop()은 현재 실행 중인 루프만 반환하므로 더 명확하고 안전하며,get_event_loop()의 복잡한 정책 동작을 피할 수 있습니다.현재 4개 위치 모두 비동기 메서드 내부에 있으므로 모두 교체하면 됩니다: 129, 171, 287, 326줄.
♻️ 제안 수정
- loop = asyncio.get_event_loop() + loop = asyncio.get_running_loop()
| if event.type == "audio.chunk": | ||
| # Decode and buffer audio | ||
| audio_bytes = base64.b64decode(event.audio_chunk) | ||
| self._audio_buffer.append(audio_bytes) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
audio.chunk 디코딩 실패 예외 처리 누락
base64.b64decode()가 실패하면 에이전트 루프가 끊길 수 있습니다. 경고 후 해당 청크만 버리도록 처리하는 편이 안전합니다.
🐛 제안 수정
-import base64
+import base64
+import binascii
@@
- audio_bytes = base64.b64decode(event.audio_chunk)
- self._audio_buffer.append(audio_bytes)
+ try:
+ audio_bytes = base64.b64decode(event.audio_chunk)
+ except binascii.Error as e:
+ logger.warning(f"[{self.role}] Invalid base64 audio chunk: {e}")
+ return
+ self._audio_buffer.append(audio_bytes)🤖 Prompt for AI Agents
In `@src/tau2_voice/agent/cascade.py` around lines 175 - 178, The base64 decoding
of audio chunks (base64.b64decode) can raise exceptions and currently can break
the agent loop; wrap the decode and buffer append in a try/except that catches
decoding errors (e.g., binascii.Error, ValueError, TypeError), log a warning
(using self._logger.warning or fallback logging.warning) including the
problematic event metadata, drop that chunk and continue without raising, and
only append to self._audio_buffer when decoding succeeds.
| from typing import Optional, Literal, AsyncGenerator | ||
| from typing_extensions import override |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Find dependency declaration files
echo "=== Searching for dependency files ==="
fd -a '(pyproject\.toml|requirements.*|setup\.py|setup\.cfg|Pipfile)' . --type f
echo -e "\n=== Checking for typing_extensions in dependency files ==="
rg "typing.extensions|typing_extensions" . -t toml -t txt -t py --type cfg -A 2 -B 2Repository: channel-io/ch-voice-tau
Length of output: 244
🏁 Script executed:
#!/bin/bash
# Read pyproject.toml to check for typing_extensions dependency
cat -n /home/jailuser/git/pyproject.toml | head -150Repository: channel-io/ch-voice-tau
Length of output: 1636
typing_extensions 의존성을 명시적으로 선언해야 합니다
override를 typing_extensions에서 가져오지만(line 19), pyproject.toml의 dependencies 목록에 typing_extensions가 명시되어 있지 않습니다. Python 3.10+를 지원하는 프로젝트에서 override는 Python 3.12+의 typing 모듈에만 포함되어 있으므로, typing_extensions는 필수 의존성입니다. pyproject.toml의 dependencies에 추가하세요.
🤖 Prompt for AI Agents
In `@src/tau2_voice/agent/gemini_live.py` around lines 18 - 19, The project
imports override from typing_extensions (seen in gemini_live.py's "from
typing_extensions import override") but typing_extensions is missing from
pyproject.toml dependencies; add typing_extensions to the dependencies section
of pyproject.toml (pick a compatible minimum version, e.g. a modern release that
supports override) so the import of override resolves for supported Python
versions.
| try: | ||
| import torch | ||
| from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline | ||
| except ImportError as e: | ||
| raise ImportError( | ||
| "transformers or torch not installed. Run: pip install transformers torch" | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
# 파일 존재 여부 및 라인 수 확인
wc -l src/tau2_voice/providers/asr/whisper_local.pyRepository: channel-io/ch-voice-tau
Length of output: 116
🏁 Script executed:
# 해당 파일의 94-100줄 확인
sed -n '90,105p' src/tau2_voice/providers/asr/whisper_local.pyRepository: channel-io/ch-voice-tau
Length of output: 662
🏁 Script executed:
# Ruff 설정 확인
if [ -f "pyproject.toml" ]; then cat pyproject.toml | grep -A 20 "\[tool.ruff"; fi
if [ -f ".flake8" ]; then cat .flake8; fi
if [ -f "ruff.toml" ]; then cat ruff.toml; fiRepository: channel-io/ch-voice-tau
Length of output: 1087
ImportError에서 예외 체이닝 추가
현재 코드는 변수 e를 캡처하지만 사용하지 않아 원인 예외가 손실되며, Ruff B904 규칙을 위반합니다. raise ... from e로 예외 체이닝을 추가하여 디버깅 추적성을 개선하세요.
🔧 제안 수정
except ImportError as e:
raise ImportError(
"transformers or torch not installed. Run: pip install transformers torch"
- )
+ ) from e🧰 Tools
🪛 Ruff (0.14.13)
97-97: Local variable e is assigned to but never used
Remove assignment to unused variable e
(F841)
98-100: Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling
(B904)
98-100: Avoid specifying long messages outside the exception class
(TRY003)
🤖 Prompt for AI Agents
In `@src/tau2_voice/providers/asr/whisper_local.py` around lines 94 - 100, The
ImportError handler in the try/except that imports torch and transformers
captures the exception as variable e but doesn't chain it; update the except
block to re-raise the ImportError with the same message using exception chaining
(raise ImportError("transformers or torch not installed. Run: pip install
transformers torch") from e) so the original exception is preserved—modify the
except clause surrounding the imports (torch, AutoModelForSpeechSeq2Seq,
AutoProcessor, pipeline) to use "from e".
| async def initialize(self) -> None: | ||
| """Initialize the provider (load models, etc.). Override if needed.""" | ||
| pass | ||
|
|
||
| async def shutdown(self) -> None: | ||
| """Cleanup resources. Override if needed.""" | ||
| pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ruff B027 경고 가능성(빈 메서드)
옵션 훅을 유지하려면 pass 대신 명시적 반환이나 # noqa: B027로 린트 실패를 피하는 편이 안전합니다.
🛠️ 제안 수정
@@
async def initialize(self) -> None:
"""Initialize the provider (load models, etc.). Override if needed."""
- pass
+ return None
@@
async def shutdown(self) -> None:
"""Cleanup resources. Override if needed."""
- pass
+ return None
@@
async def initialize(self) -> None:
"""Initialize the provider. Override if needed."""
- pass
+ return None
@@
async def shutdown(self) -> None:
"""Cleanup resources. Override if needed."""
- pass
+ return None
@@
async def initialize(self) -> None:
"""Initialize the provider. Override if needed."""
- pass
+ return None
@@
async def shutdown(self) -> None:
"""Cleanup resources. Override if needed."""
- pass
+ return NoneAlso applies to: 117-123, 170-176
🧰 Tools
🪛 Ruff (0.14.13)
58-60: BaseASRProvider.initialize is an empty method in an abstract base class, but has no abstract decorator
(B027)
62-64: BaseASRProvider.shutdown is an empty method in an abstract base class, but has no abstract decorator
(B027)
🤖 Prompt for AI Agents
In `@src/tau2_voice/providers/base.py` around lines 58 - 64, The empty hook
methods (e.g., async def initialize and async def shutdown in
src/tau2_voice/providers/base.py and the other empty hooks around the referenced
ranges) should not use bare `pass` to avoid Ruff B027; replace each `pass` with
an explicit `return None` (or add a trailing `# noqa: B027` comment if you
intentionally want to keep an empty body) so the async hooks remain no-op but
lint-clean; update the bodies of initialize, shutdown and the other empty async
hook methods at the noted locations accordingly.
| def parse_tool_calls(text: str) -> list[ToolCall]: | ||
| """Parse tool calls from model output text.""" | ||
| tool_calls = [] | ||
|
|
||
| # Look for JSON tool call patterns | ||
| patterns = [ | ||
| r'\{"tool_call":\s*\{[^}]+\}\}', # {"tool_call": {...}} | ||
| r'\{"name":\s*"[^"]+",\s*"arguments":\s*\{[^}]*\}\}', # {"name": "...", "arguments": {...}} | ||
| ] | ||
|
|
||
| for pattern in patterns: | ||
| matches = re.findall(pattern, text, re.DOTALL) | ||
| for match in matches: | ||
| try: | ||
| data = json.loads(match) | ||
| if "tool_call" in data: | ||
| tc_data = data["tool_call"] | ||
| else: | ||
| tc_data = data | ||
|
|
||
| if "name" in tc_data: | ||
| tool_calls.append(ToolCall( | ||
| id=f"local_{hash(match) % 10000}", | ||
| name=tc_data["name"], | ||
| arguments=tc_data.get("arguments", {}), | ||
| requestor="assistant", | ||
| )) | ||
| except json.JSONDecodeError: | ||
| continue | ||
|
|
||
| return tool_calls |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
동일 tool call이 중복 파싱될 수 있습니다.
두 패턴이 같은 블록을 중복 매칭하면 동일 tool call이 두 번 추가될 수 있어 호출이 중복될 위험이 있습니다.
🔧 제안 수정
def parse_tool_calls(text: str) -> list[ToolCall]:
- tool_calls = []
+ tool_calls = []
+ seen = set()
@@
- for pattern in patterns:
+ for pattern in patterns:
matches = re.findall(pattern, text, re.DOTALL)
for match in matches:
try:
data = json.loads(match)
if "tool_call" in data:
tc_data = data["tool_call"]
else:
tc_data = data
if "name" in tc_data:
- tool_calls.append(ToolCall(
+ key = json.dumps(tc_data, sort_keys=True)
+ if key in seen:
+ continue
+ seen.add(key)
+ tool_calls.append(ToolCall(
id=f"local_{hash(match) % 10000}",
name=tc_data["name"],
arguments=tc_data.get("arguments", {}),
requestor="assistant",
))
except json.JSONDecodeError:
continue
+ if tool_calls:
+ break🤖 Prompt for AI Agents
In `@src/tau2_voice/providers/llm/local.py` around lines 85 - 115,
parse_tool_calls currently can append the same tool call twice when both regex
patterns match the same JSON block; to fix, in parse_tool_calls maintain a seen
set (e.g., of normalized JSON strings or a tuple (name, json.dumps(arguments,
sort_keys=True))) keyed by the parsed tc_data before creating/appending a
ToolCall and skip duplicates, and ensure the generated id uses a stable source
(like hash of the normalized string) so duplicate inputs won't produce distinct
ToolCall entries; update references to match, tc_data, and ToolCall in that
function accordingly.
| def load_model(): | ||
| # Use official HuggingFace pipeline approach | ||
| model_kwargs = {"torch_dtype": dtype, "device_map": "auto"} | ||
|
|
||
| # Use slow tokenizer to avoid tokenizers library compatibility issues | ||
| tokenizer = transformers.AutoTokenizer.from_pretrained( | ||
| self.model_id, | ||
| ) | ||
| tokenizer.pad_token_id = tokenizer.eos_token_id | ||
|
|
||
| # When thinking is off, use do_sample=False as per official Nemotron docs | ||
| effective_do_sample = self.do_sample if self.thinking == "on" else False | ||
|
|
||
| pipe_kwargs = { | ||
| "max_new_tokens": self.max_new_tokens, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
# First, let's find and examine the file
fd "local.py" -p "*llm*"Repository: channel-io/ch-voice-tau
Length of output: 159
🏁 Script executed:
# Read the specific file to understand the context
wc -l src/tau2_voice/providers/llm/local.pyRepository: channel-io/ch-voice-tau
Length of output: 108
🏁 Script executed:
# Read the code around lines 191-205 and surrounding context
sed -n '170,220p' src/tau2_voice/providers/llm/local.py | cat -nRepository: channel-io/ch-voice-tau
Length of output: 2375
🏁 Script executed:
# Search for self.device initialization and usage in the file
rg "self\.device" src/tau2_voice/providers/llm/local.py -B 2 -A 2Repository: channel-io/ch-voice-tau
Length of output: 361
🏁 Script executed:
# Look for the class definition to understand the __init__ method
rg "class.*:" src/tau2_voice/providers/llm/local.py -A 30 | head -50Repository: channel-io/ch-voice-tau
Length of output: 1202
🏁 Script executed:
# Let me search for more context about device handling in the file
sed -n '1,100p' src/tau2_voice/providers/llm/local.py | cat -nRepository: channel-io/ch-voice-tau
Length of output: 4442
🏁 Script executed:
# Check if there are other uses of device or device_map in the codebase
rg "device_map|self\.device" src/tau2_voice/providers/llm/local.pyRepository: channel-io/ch-voice-tau
Length of output: 243
🏁 Script executed:
# Look for the full load_model function to understand its complete implementation
sed -n '180,250p' src/tau2_voice/providers/llm/local.py | cat -nRepository: channel-io/ch-voice-tau
Length of output: 3074
🏁 Script executed:
# Check if transformers library is used elsewhere with device handling
rg "device_map" -r ".*" --max-count=20Repository: channel-io/ch-voice-tau
Length of output: 168
🏁 Script executed:
# Let's check the transformers documentation/behavior with device_map parameter
# to understand if device_map="auto" truly ignores explicit device settings
rg "device" src/tau2_voice/providers/llm/local.py -B 3 -A 3 | head -40Repository: channel-io/ch-voice-tau
Length of output: 1424
🏁 Script executed:
# Check the full __init__ method to see all parameter initialization
sed -n '100,170p' src/tau2_voice/providers/llm/local.py | cat -nRepository: channel-io/ch-voice-tau
Length of output: 3044
🌐 Web query:
HuggingFace transformers device_map parameter valid values
💡 Result:
In transformers (e.g., AutoModel.from_pretrained(..., device_map=...)), device_map can be:
- A string (automatic strategies):
"auto","balanced","balanced_low_0","sequential"(these are computed by 🤗 Accelerate). [1] - A string device name: e.g.
"cpu","cuda:0"(i.e., anything PyTorch recognizes as a device string). [2] - A non-negative int: treated as “put the whole model on this device index” (internally converted to
{"": device_index}). [2] - A dict (explicit device map): mapping module name prefixes to a device identifier (e.g.
{"model.embed_tokens": 0, "model.layers.0": 0, "model.layers.1": 1, ...}), where device identifiers are typically GPU indices (0,1, …) and/or"cpu"(and can include disk offload when used with Accelerate offloading). [1]
Invalid string values raise an error unless they are a device name or one of the four strategy strings above. [2]
Sources: [1] [2]
device 매개변수가 모델 로딩에 반영되지 않습니다.
self.device는 __init__에서 받아져 180번 줄의 로그에 기록되지만, load_model() 함수 내에서 device_map이 항상 "auto"로 하드코딩되어 있어 무시됩니다. device="cpu"를 명시적으로 지정해도 실제 로딩은 device_map="auto"를 사용하므로 GPU 없는 환경에서 예상과 다르게 동작할 수 있습니다.
🔧 제안 수정
- model_kwargs = {"torch_dtype": dtype, "device_map": "auto"}
+ device_map = "auto" if self.device == "auto" else self.device
+ model_kwargs = {"torch_dtype": dtype, "device_map": device_map}📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def load_model(): | |
| # Use official HuggingFace pipeline approach | |
| model_kwargs = {"torch_dtype": dtype, "device_map": "auto"} | |
| # Use slow tokenizer to avoid tokenizers library compatibility issues | |
| tokenizer = transformers.AutoTokenizer.from_pretrained( | |
| self.model_id, | |
| ) | |
| tokenizer.pad_token_id = tokenizer.eos_token_id | |
| # When thinking is off, use do_sample=False as per official Nemotron docs | |
| effective_do_sample = self.do_sample if self.thinking == "on" else False | |
| pipe_kwargs = { | |
| "max_new_tokens": self.max_new_tokens, | |
| def load_model(): | |
| # Use official HuggingFace pipeline approach | |
| device_map = "auto" if self.device == "auto" else self.device | |
| model_kwargs = {"torch_dtype": dtype, "device_map": device_map} | |
| # Use slow tokenizer to avoid tokenizers library compatibility issues | |
| tokenizer = transformers.AutoTokenizer.from_pretrained( | |
| self.model_id, | |
| ) | |
| tokenizer.pad_token_id = tokenizer.eos_token_id | |
| # When thinking is off, use do_sample=False as per official Nemotron docs | |
| effective_do_sample = self.do_sample if self.thinking == "on" else False | |
| pipe_kwargs = { | |
| "max_new_tokens": self.max_new_tokens, |
🤖 Prompt for AI Agents
In `@src/tau2_voice/providers/llm/local.py` around lines 191 - 205, The
load_model() function currently hardcodes device_map="auto" which ignores
self.device; change the logic in load_model to honor self.device by
conditionally setting model loading args: if self.device == "cpu" set
model_kwargs["device_map"] = None and model_kwargs["device"] = "cpu" (or pass
device="cpu" into the pipeline call) else keep device_map="auto" for automatic
GPU placement; update the place where model_kwargs or pipe call is constructed
(referencing load_model, model_kwargs, and pipe_kwargs) so the transformers
pipeline receives the correct device configuration instead of always using
"auto".
| # Log the system prompt for debugging | ||
| if chat_messages and chat_messages[0]["role"] == "system": | ||
| system_content = chat_messages[0]["content"] | ||
| logger.info(f"[LLM SYSTEM PROMPT] (first 500 chars):\n{system_content[:500]}...") | ||
|
|
||
| # Run generation in thread pool | ||
| loop = asyncio.get_event_loop() | ||
|
|
||
| def generate(): | ||
| result = self._pipe(chat_messages) | ||
| # Pipeline returns list of dicts with 'generated_text' | ||
| if result and len(result) > 0: | ||
| generated = result[0].get("generated_text", []) | ||
| # Get the assistant's response (last message) | ||
| if isinstance(generated, list) and len(generated) > 0: | ||
| last_msg = generated[-1] | ||
| if isinstance(last_msg, dict) and last_msg.get("role") == "assistant": | ||
| return last_msg.get("content", "") | ||
| elif isinstance(generated, str): | ||
| return generated | ||
| return "" | ||
|
|
||
| text = await loop.run_in_executor(None, generate) | ||
|
|
||
| # Log raw LLM output | ||
| logger.info(f"[LLM RAW OUTPUT]\n{text}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
프롬프트/출력 로그가 INFO라 민감정보 노출 우려
사용자 데이터가 포함될 수 있으니 기본 로그 레벨을 낮추거나 최소화하는 편이 안전합니다.
🔧 제안 수정
- logger.info(f"[LLM SYSTEM PROMPT] (first 500 chars):\n{system_content[:500]}...")
+ logger.debug(f"[LLM SYSTEM PROMPT] (first 500 chars):\n{system_content[:500]}...")
@@
- logger.info(f"[LLM RAW OUTPUT]\n{text}")
+ logger.debug(f"[LLM RAW OUTPUT] (first 500 chars):\n{text[:500]}...")📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| # Log the system prompt for debugging | |
| if chat_messages and chat_messages[0]["role"] == "system": | |
| system_content = chat_messages[0]["content"] | |
| logger.info(f"[LLM SYSTEM PROMPT] (first 500 chars):\n{system_content[:500]}...") | |
| # Run generation in thread pool | |
| loop = asyncio.get_event_loop() | |
| def generate(): | |
| result = self._pipe(chat_messages) | |
| # Pipeline returns list of dicts with 'generated_text' | |
| if result and len(result) > 0: | |
| generated = result[0].get("generated_text", []) | |
| # Get the assistant's response (last message) | |
| if isinstance(generated, list) and len(generated) > 0: | |
| last_msg = generated[-1] | |
| if isinstance(last_msg, dict) and last_msg.get("role") == "assistant": | |
| return last_msg.get("content", "") | |
| elif isinstance(generated, str): | |
| return generated | |
| return "" | |
| text = await loop.run_in_executor(None, generate) | |
| # Log raw LLM output | |
| logger.info(f"[LLM RAW OUTPUT]\n{text}") | |
| # Log the system prompt for debugging | |
| if chat_messages and chat_messages[0]["role"] == "system": | |
| system_content = chat_messages[0]["content"] | |
| logger.debug(f"[LLM SYSTEM PROMPT] (first 500 chars):\n{system_content[:500]}...") | |
| # Run generation in thread pool | |
| loop = asyncio.get_event_loop() | |
| def generate(): | |
| result = self._pipe(chat_messages) | |
| # Pipeline returns list of dicts with 'generated_text' | |
| if result and len(result) > 0: | |
| generated = result[0].get("generated_text", []) | |
| # Get the assistant's response (last message) | |
| if isinstance(generated, list) and len(generated) > 0: | |
| last_msg = generated[-1] | |
| if isinstance(last_msg, dict) and last_msg.get("role") == "assistant": | |
| return last_msg.get("content", "") | |
| elif isinstance(generated, str): | |
| return generated | |
| return "" | |
| text = await loop.run_in_executor(None, generate) | |
| # Log raw LLM output | |
| logger.debug(f"[LLM RAW OUTPUT] (first 500 chars):\n{text[:500]}...") |
🤖 Prompt for AI Agents
In `@src/tau2_voice/providers/llm/local.py` around lines 280 - 305, The code
currently logs potentially sensitive prompts and LLM outputs via logger.info in
the block around chat_messages handling and after generation (see the
system_content extraction from chat_messages[0]["content"], the generate()
function and the final logger.info of text); change these logger.info calls to a
lower-verbosity level (logger.debug or logger.trace) or remove full content
logging and instead log redacted/summary metadata (e.g., length, hash, or masked
preview) to avoid exposing user data, and ensure any debugging-only logs are
gated by a config flag so production defaults do not emit raw prompt or full LLM
output.
| if choice.finish_reason == "tool_calls" and tool_calls_in_progress: | ||
| for tc_data in tool_calls_in_progress.values(): | ||
| try: | ||
| arguments = json.loads(tc_data["arguments"]) | ||
| except json.JSONDecodeError: | ||
| logger.warning(f"Failed to parse tool arguments: {tc_data['arguments']}") | ||
| arguments = {} | ||
|
|
||
| response.tool_calls.append(ToolCall( | ||
| id=tc_data["id"], | ||
| name=tc_data["name"], | ||
| arguments=arguments, | ||
| requestor="assistant", | ||
| )) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
도구 인자 파싱 실패 시 빈 dict 대체로 인자 유실 가능
JSON 파싱 실패를 {}로 대체하면 실제 인자가 사라져 툴 실행이 잘못될 수 있습니다. 최소한 원문을 보존해 후속 처리/디버깅이 가능하도록 하는 게 안전합니다.
🐛 제안 수정
- except json.JSONDecodeError:
- logger.warning(f"Failed to parse tool arguments: {tc_data['arguments']}")
- arguments = {}
+ except json.JSONDecodeError:
+ logger.warning(f"Failed to parse tool arguments: {tc_data['arguments']}")
+ arguments = {"_raw": tc_data["arguments"], "_parse_error": True}- except json.JSONDecodeError:
- logger.warning(f"Failed to parse tool arguments: {tc.function.arguments}")
- arguments = {}
+ except json.JSONDecodeError:
+ logger.warning(f"Failed to parse tool arguments: {tc.function.arguments}")
+ arguments = {"_raw": tc.function.arguments, "_parse_error": True}Also applies to: 285-297
🤖 Prompt for AI Agents
In `@src/tau2_voice/providers/llm/openai.py` around lines 216 - 229, When JSON
parsing of tool arguments fails in the block that checks choice.finish_reason ==
"tool_calls" (and the similar block around lines 285-297), don’t replace with an
empty dict; preserve the original raw string so downstream code/debugging can
see it. On json.JSONDecodeError, log the warning and set arguments to a dict
that contains the original raw argument string (e.g. a single key like "_raw" or
"__raw_arguments" mapped to tc_data["arguments"]) before creating the ToolCall
(the code that builds response.tool_calls.append(ToolCall(...)) using
tc_data["id"], tc_data["name"], arguments, requestor="assistant"). Apply the
same change to both occurrences that parse tc_data["arguments"].
| # try: | ||
| logger.info(f"[{task_idx+1}/{total}] Starting task: {task.id}") | ||
| simulation_run = await run_task( | ||
| domain, | ||
| task, | ||
| assistant_model=assistant_model, | ||
| user_model=user_model, | ||
| asr_model=asr_model, | ||
| llm_model=llm_model, | ||
| tts_model=tts_model, | ||
| ) | ||
|
|
||
| reward = simulation_run.reward_info.reward if simulation_run.reward_info else 0.0 | ||
| is_success = reward >= 1.0 | ||
|
|
||
| logger.info(f"[{task_idx+1}/{total}] Task {task.id} completed: {'✓' if is_success else '✗'} (reward={reward:.3f})") | ||
|
|
||
| return { | ||
| 'task_id': task.id, | ||
| 'simulation_id': simulation_run.id, | ||
| 'reward': reward, | ||
| 'success': is_success, | ||
| 'duration': simulation_run.duration, | ||
| 'index': task_idx, | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
배치 실행 중 단일 실패가 전체 중단됨
run_task_with_index()에서 예외를 잡지 않으면 한 작업 실패로 전체 배치가 중단됩니다. 기존 주석 처리된 처리 블록을 복구하는 게 안정적입니다.
🐛 제안 수정
- # try:
- logger.info(f"[{task_idx+1}/{total}] Starting task: {task.id}")
- simulation_run = await run_task(
- domain,
- task,
- assistant_model=assistant_model,
- user_model=user_model,
- asr_model=asr_model,
- llm_model=llm_model,
- tts_model=tts_model,
- )
-
- reward = simulation_run.reward_info.reward if simulation_run.reward_info else 0.0
- is_success = reward >= 1.0
-
- logger.info(f"[{task_idx+1}/{total}] Task {task.id} completed: {'✓' if is_success else '✗'} (reward={reward:.3f})")
-
- return {
- 'task_id': task.id,
- 'simulation_id': simulation_run.id,
- 'reward': reward,
- 'success': is_success,
- 'duration': simulation_run.duration,
- 'index': task_idx,
- }
- # except Exception as e:
- # logger.error(f"[{task_idx+1}/{total}] Error running task {task.id}: {e}")
- # return {
- # 'task_id': task.id,
- # 'simulation_id': None,
- # 'reward': 0.0,
- # 'success': False,
- # 'duration': 0.0,
- # 'error': str(e),
- # 'index': task_idx,
- # }
+ try:
+ logger.info(f"[{task_idx+1}/{total}] Starting task: {task.id}")
+ simulation_run = await run_task(
+ domain,
+ task,
+ assistant_model=assistant_model,
+ user_model=user_model,
+ asr_model=asr_model,
+ llm_model=llm_model,
+ tts_model=tts_model,
+ )
+
+ reward = simulation_run.reward_info.reward if simulation_run.reward_info else 0.0
+ is_success = reward >= 1.0
+
+ logger.info(f"[{task_idx+1}/{total}] Task {task.id} completed: {'✓' if is_success else '✗'} (reward={reward:.3f})")
+
+ return {
+ 'task_id': task.id,
+ 'simulation_id': simulation_run.id,
+ 'reward': reward,
+ 'success': is_success,
+ 'duration': simulation_run.duration,
+ 'index': task_idx,
+ }
+ except Exception as e:
+ logger.error(f"[{task_idx+1}/{total}] Error running task {task.id}: {e}")
+ return {
+ 'task_id': task.id,
+ 'simulation_id': None,
+ 'reward': 0.0,
+ 'success': False,
+ 'duration': 0.0,
+ 'error': str(e),
+ 'index': task_idx,
+ }🤖 Prompt for AI Agents
In `@src/tau2_voice/run.py` around lines 247 - 271, run_task_with_index currently
lets any exception from run_task propagate and abort the whole batch; restore
the try/except around the run_task invocation in run_task_with_index to catch
exceptions, log the error (including task.id and the exception), and return a
failure result dict (e.g., simulation_id None, reward 0.0, success False,
duration 0, index task_idx) so the batch continues. Ensure you re-raise critical
control exceptions (like asyncio.CancelledError and KeyboardInterrupt) instead
of swallowing them. Reference run_task_with_index and run_task when applying the
change.
| from tau2_voice.agent.cascade import CascadeAgent | ||
| from tau2_voice.providers.asr import WhisperLocalProvider | ||
| from tau2_voice.providers.llm import OpenAILLMProvider, LocalLLMProvider | ||
| from tau2_voice.providers.tts import OpenAITTSProvider, ChatterboxTTSProvider | ||
| from tau2_voice.models.events import TranscriptUpdateEvent | ||
|
|
||
| # Create providers | ||
| asr = WhisperLocalProvider( | ||
| model_id="openai/whisper-base", | ||
| device="auto", | ||
| language="en", | ||
| ) | ||
|
|
||
| if use_local_llm: | ||
| llm = LocalLLMProvider( | ||
| model_id="nvidia/Llama-3.1-Nemotron-Nano-4B-v1.1", | ||
| device="auto", | ||
| max_new_tokens=200, | ||
| ) | ||
| else: | ||
| llm = OpenAILLMProvider(model="gpt-4o-mini") | ||
|
|
||
| if use_local_tts: | ||
| tts = ChatterboxTTSProvider(device="auto") | ||
| else: | ||
| tts = OpenAITTSProvider(model="gpt-4o-mini-tts", voice="alloy") | ||
|
|
||
| # Create cascade agent | ||
| agent = CascadeAgent( | ||
| tools=None, | ||
| domain_policy="You are a helpful assistant. Keep responses brief and friendly.", | ||
| asr_provider=asr, | ||
| llm_provider=llm, | ||
| tts_provider=tts, | ||
| role="assistant", | ||
| ) | ||
|
|
||
| await agent.connect() | ||
|
|
||
| # Instead of audio, send a transcript directly (simulates ASR output) | ||
| logger.info("Sending test transcript...") | ||
| test_transcript = TranscriptUpdateEvent( | ||
| role="user", | ||
| message_id="test_1", | ||
| transcript="Hello! Can you tell me a short joke?", | ||
| ) | ||
| await agent.publish(test_transcript) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CascadeAgent가 transcript.update를 무시해 파이프라인이 실행되지 않습니다
CascadeAgent.publish는 transcript.update를 무시하므로 현재 방식으로는 LLM/TTS가 동작하지 않습니다. speak.request 이벤트(예: SpeakRequestEvent)로 지시문을 보내거나 audio.chunk/audio.done 시퀀스로 테스트를 구성하세요. 필드명은 실제 이벤트 정의에 맞게 조정해 주세요.
✅ 예시 수정 (speak.request 사용)
- from tau2_voice.models.events import TranscriptUpdateEvent
+ from tau2_voice.models.events import SpeakRequestEvent
@@
- test_transcript = TranscriptUpdateEvent(
- role="user",
- message_id="test_1",
- transcript="Hello! Can you tell me a short joke?",
- )
- await agent.publish(test_transcript)
+ test_request = SpeakRequestEvent(
+ role="user",
+ message_id="test_1",
+ instructions="Hello! Can you tell me a short joke?",
+ )
+ await agent.publish(test_request)🤖 Prompt for AI Agents
In `@src/tau2_voice/test_cascade.py` around lines 162 - 209, CascadeAgent.publish
currently ignores TranscriptUpdateEvent so your test sending a
TranscriptUpdateEvent won't trigger the LLM/TTS pipeline; replace the test event
with a speak.request-style event (e.g., SpeakRequestEvent) or simulate the ASR
audio sequence using audio.chunk followed by audio.done to drive the pipeline
instead, and ensure you call agent.publish with the correct event class and
field names matching the actual event definitions; look for CascadeAgent.publish
and the test_transcript creation in test_cascade.py to update the event type and
payload accordingly.
기본 파라미터로 open API (asr/llm/tts)를 사용해서 cascade 평가가 가능합니다.
python src/tau2_voice/run.py --domain airline --assistant-model cascade --task-ids 0Summary by CodeRabbit
Release Notes
New Features
Refactor
Tests
✏️ Tip: You can customize this high-level summary in your review settings.