Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 58 additions & 32 deletions openless-all/app/src-tauri/src/asr/volcengine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use parking_lot::Mutex as ParkingMutex;
use serde_json::{json, Value};
use tokio::net::TcpStream;
use tokio::runtime::Handle;
use tokio::sync::{oneshot, Mutex as AsyncMutex, Notify};
use tokio::sync::{mpsc, oneshot, Mutex as AsyncMutex, Notify};
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
use tokio_tungstenite::tungstenite::http::header::HeaderValue;
use tokio_tungstenite::tungstenite::Message;
Expand Down Expand Up @@ -88,7 +88,12 @@ pub struct VolcengineStreamingASR {
/// of the lifetime of any particular `&self` borrow.
writer: SharedWriter,
final_rx: ParkingMutex<Option<oneshot::Receiver<Result<RawTranscript, VolcengineASRError>>>>,
/// 在飞的 audio 帧 spawn 数。consume_pcm_chunk +1,spawn 内 send 完成 -1。
/// 单 worker 模式:consume_pcm_chunk 把 (seq, chunk) 入队这个 channel,
/// open_session 里 spawn 出的唯一 worker 串行 recv + send_binary,
/// 保证 seq 顺序严格等于实际发送顺序。session 结束时 take() 掉这个 sender,
/// worker 的 recv() 返回 None 自动退出。
audio_tx: ParkingMutex<Option<mpsc::UnboundedSender<(i32, Vec<u8>)>>>,
/// 队列里 + worker 在飞的 audio 帧总数。consume +N,worker send 完一帧 -1。
/// send_last_frame 必须等它降到 0 才能安全发末帧,否则末帧可能被服务端先收到
/// 而把后续 chunk 当成「stream 已结束」之后的多余数据丢弃 → 尾句丢失。
pending_sends: Arc<AtomicUsize>,
Expand All @@ -103,6 +108,7 @@ impl VolcengineStreamingASR {
state: ParkingMutex::new(SyncState::default()),
writer: Arc::new(AsyncMutex::new(None)),
final_rx: ParkingMutex::new(None),
audio_tx: ParkingMutex::new(None),
pending_sends: Arc::new(AtomicUsize::new(0)),
send_done: Arc::new(Notify::new()),
}
Expand Down Expand Up @@ -170,6 +176,33 @@ impl VolcengineStreamingASR {
*self.final_rx.lock() = Some(rx);
*self.writer.lock().await = Some(write);

// 起一个唯一的 audio worker:consume_pcm_chunk 把 (seq, chunk) 推到 audio_tx,
// worker 这边 FIFO recv 然后串行 send_binary。session 结束后调用方
// (cancel / handle_frame error / fallback_to_partial_or_error) 会 take 掉
// self.audio_tx,channel 关闭,worker 自然退出。
let (audio_tx, mut audio_rx) = mpsc::unbounded_channel::<(i32, Vec<u8>)>();
*self.audio_tx.lock() = Some(audio_tx);
let writer_for_worker = Arc::clone(&self.writer);
let pending_for_worker = Arc::clone(&self.pending_sends);
let notify_for_worker = Arc::clone(&self.send_done);
tokio::spawn(async move {
while let Some((seq, chunk)) = audio_rx.recv().await {
let frame = frame::build(
MessageType::AudioOnlyRequest,
Flags::PositiveSequence,
Serialization::None,
&chunk,
Some(seq),
);
if let Err(e) = send_binary(&writer_for_worker, frame).await {
log::error!("[asr] audio frame seq={} send 失败: {}", seq, e);
}
if pending_for_worker.fetch_sub(1, Ordering::SeqCst) == 1 {
notify_for_worker.notify_waiters();
}
}
});

// Send the first frame: full client request with seq=1.
let payload_json = self.build_first_frame_payload(&connect_id);
let payload_bytes = serde_json::to_vec(&payload_json)
Expand Down Expand Up @@ -318,6 +351,8 @@ impl VolcengineStreamingASR {
st.pending_audio.clear();
st.runtime.clone()
};
// Drop audio sender → worker.recv() 返回 None → worker 退出,不再 hold writer。
*self.audio_tx.lock() = None;
if let Some(runtime) = runtime {
// Close the writer asynchronously so the receive loop sees EOF.
let writer = Arc::clone(&self.writer);
Expand Down Expand Up @@ -384,6 +419,7 @@ impl VolcengineStreamingASR {
code, body
)));
self.state.lock().is_connected = false;
*self.audio_tx.lock() = None;
return false;
}

Expand Down Expand Up @@ -447,6 +483,7 @@ impl VolcengineStreamingASR {
};
self.signal_success(transcript);
self.state.lock().is_connected = false;
*self.audio_tx.lock() = None;
return false;
}
true
Expand Down Expand Up @@ -492,16 +529,16 @@ impl VolcengineStreamingASR {
self.signal_error(err);
}
self.state.lock().is_connected = false;
*self.audio_tx.lock() = None;
}
}

impl AudioConsumer for VolcengineStreamingASR {
fn consume_pcm_chunk(&self, pcm: &[u8]) {
// 一次性把就绪 chunk 全部 drain 出来(同一把 state 锁内分配 seq,保证 seq 单调)。
// 然后 spawn 一个串行 send 的 task —— 不要每块一个 spawn,否则 burst flush 时多
// 个 task 异步竞争 writer 锁,发送顺序和 seq 顺序对不上,服务端会报
// "autoAssignedSequence (N) mismatch sequence in request (N+1)" 直接断连。
let (runtime, chunks) = {
// 单 worker 串行 send 模式:在 state 锁内 drain 并分配 seq(seq 单调),
// 然后把 (seq, chunk) push 进 mpsc。worker 端按入队顺序 send,
// 哪怕跨多个 consume 调用、多个 spawn 也不会再有 writer 锁竞争。
let chunks: Vec<(i32, Vec<u8>)> = {
let mut st = self.state.lock();
if !st.is_connected {
return;
Expand All @@ -517,41 +554,30 @@ impl AudioConsumer for VolcengineStreamingASR {
st.frames_sent += 1;
out.push((seq, chunk));
}
(st.runtime.clone(), out)
out
};

if chunks.is_empty() {
return;
}
let Some(runtime) = runtime else {
let Some(tx) = self.audio_tx.lock().as_ref().cloned() else {
return;
};

// pending_sends + Notify 让 send_last_frame 知道何时所有 chunk 都已发出。
// 单 task 内串行 send,所以一次性 +N、收尾 -N。
let count = chunks.len();
self.pending_sends.fetch_add(count, Ordering::SeqCst);
let writer = Arc::clone(&self.writer);
let pending = Arc::clone(&self.pending_sends);
let notify = Arc::clone(&self.send_done);
runtime.spawn(async move {
for (seq, chunk) in chunks {
let frame = frame::build(
MessageType::AudioOnlyRequest,
Flags::PositiveSequence,
Serialization::None,
&chunk,
Some(seq),
);
if let Err(e) = send_binary(&writer, frame).await {
// 把丢帧错误顶到日志里,定位"为什么服务端只收到 100ms"
log::error!("[asr] audio frame seq={} send 失败: {}", seq, e);
for entry in chunks {
// pending_sends 必须在 tx.send 之前 +1:否则 worker 可能先 recv + 发送 +
// 减 1,把 usize 计数器 underflow。
self.pending_sends.fetch_add(1, Ordering::SeqCst);
if tx.send(entry).is_err() {
// worker 已退出(cancel / 错误路径里 audio_tx 被 take)。
// 撤销刚才的 +1,避免 send_last_frame 的 wait 永远等不到 0。
if self.pending_sends.fetch_sub(1, Ordering::SeqCst) == 1 {
self.send_done.notify_waiters();
}
log::warn!("[asr] audio queue closed; dropping subsequent frames");
return;
}
if pending.fetch_sub(count, Ordering::SeqCst) == count {
notify.notify_waiters();
}
});
}
}
}

Expand Down
118 changes: 112 additions & 6 deletions openless-all/app/src-tauri/src/coordinator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ struct SessionState {
/// 跳过 history.append。issue #52。
cancelled: bool,
focus_target: Option<usize>,
/// 单调递增的 session id。begin_session 自增。
/// recorder error monitor 持有 captured id,处理时若与当前不等说明
/// 是上一 session 的迟到错误,必须 drop,不要 abort 当前 active session。
session_id: u64,
}

impl Default for SessionState {
Expand All @@ -70,6 +74,7 @@ impl Default for SessionState {
pending_stop: false,
cancelled: false,
focus_target: None,
session_id: 0,
}
}
}
Expand All @@ -90,6 +95,11 @@ struct Inner {
hotkey: Mutex<Option<HotkeyMonitor>>,
hotkey_status: Mutex<HotkeyStatus>,
hotkey_trigger_held: AtomicBool,
/// 翻译模式触发标志。每次 begin_session 重置为 false;hotkey 监听器在
/// Listening / Starting 阶段看到 Shift down 边沿时 set true。
/// end_session 在调 polish/translate 前读这个 flag + translation_target_language
/// 决定走哪条管线。详见 issue #4。
translation_modifier_seen: AtomicBool,
}

impl Coordinator {
Expand All @@ -114,6 +124,7 @@ impl Coordinator {
hotkey: Mutex::new(None),
hotkey_status: Mutex::new(HotkeyStatus::default()),
hotkey_trigger_held: AtomicBool::new(false),
translation_modifier_seen: AtomicBool::new(false),
}),
}
}
Expand Down Expand Up @@ -197,7 +208,8 @@ impl Coordinator {

pub async fn repolish(&self, raw_text: String, mode: PolishMode) -> Result<String, String> {
let hotwords = enabled_phrases(&self.inner);
polish_text(&raw_text, mode, &hotwords)
let working_languages = self.inner.prefs.get().working_languages;
polish_text(&raw_text, mode, &hotwords, &working_languages)
.await
.map_err(|e| e.to_string())
}
Expand Down Expand Up @@ -275,6 +287,18 @@ fn hotkey_bridge_loop(inner: Arc<Inner>, rx: mpsc::Receiver<HotkeyEvent>) {
HotkeyEvent::Cancelled => {
cancel_session(&inner_cloned);
}
HotkeyEvent::TranslationModifierPressed => {
// 仅在 Starting / Listening 阶段把 Shift 边沿计入"翻译模式触发"。
// Idle 阶段按 Shift 不应该影响下一段录音;Processing/Inserting 已经过了
// 决定走哪条管线的检查点,再 set 也没意义。
let phase = inner_cloned.state.lock().phase;
if matches!(phase, SessionPhase::Starting | SessionPhase::Listening) {
inner_cloned
.translation_modifier_seen
.store(true, Ordering::SeqCst);
log::info!("[coord] translation modifier seen during {phase:?}");
}
}
}
}
}
Expand Down Expand Up @@ -437,7 +461,12 @@ async fn begin_session(inner: &Arc<Inner>) -> Result<(), String> {
state.pending_stop = false;
state.cancelled = false;
state.focus_target = capture_focus_target();
// 自增 session_id;spawn 出去的 recorder error monitor 会捕获这个值,
// 如果迟到错误到达时 id 已不匹配就 drop,不会误中止后续 session。
state.session_id = state.session_id.wrapping_add(1);
}
// 翻译模式标志重置;hotkey 监听器在 Shift down 时再 set true。
inner.translation_modifier_seen.store(false, Ordering::SeqCst);

#[cfg(any(debug_assertions, test))]
if hotkey_injection_dry_run_enabled() {
Expand Down Expand Up @@ -618,11 +647,24 @@ fn start_recorder_for_starting(
}

fn spawn_recorder_error_monitor(inner: &Arc<Inner>, rx: mpsc::Receiver<RecorderError>) {
// 捕获当前 session_id:err 来时若 id 已经不一致说明是上一 session 的迟到事件,
// 不能去 abort 当前 active 的新 session(它录得好好的)。
let captured_session_id = inner.state.lock().session_id;
let inner = Arc::clone(inner);
std::thread::Builder::new()
.name("openless-recorder-error-monitor".into())
.spawn(move || {
if let Ok(err) = rx.recv() {
let current_session_id = inner.state.lock().session_id;
if captured_session_id != current_session_id {
log::warn!(
"[coord] recorder error from stale session {} dropped (current={}, err={})",
captured_session_id,
current_session_id,
err
);
return;
}
log::error!("[coord] recorder runtime error: {err}");
abort_recording_with_error(&inner, format!("录音中断: {err}"));
}
Expand Down Expand Up @@ -831,7 +873,20 @@ async fn end_session(inner: &Arc<Inner>) -> Result<(), String> {
let prefs = inner.prefs.get();
let mode = prefs.default_mode;
let hotword_strs = enabled_phrases(inner);
let (polished, polish_error) = polish_or_passthrough(&raw, mode, &hotword_strs).await;
let working_languages = prefs.working_languages.clone();
let translation_target = prefs.translation_target_language.trim().to_string();
let translation_active = inner.translation_modifier_seen.load(Ordering::SeqCst)
&& !translation_target.is_empty();
let (polished, polish_error) = if translation_active {
log::info!(
"[coord] translation mode → target=\u{300C}{}\u{300D} working={:?}",
translation_target,
working_languages
);
translate_or_passthrough(&raw, &translation_target, &working_languages).await
} else {
polish_or_passthrough(&raw, mode, &hotword_strs, &working_languages).await
};

// 原子化最后一次 cancel 检查 + 转 Inserting:
// 在同一 lock 内决定「丢弃」还是「进入 Inserting」。一旦设到 Inserting,
Expand All @@ -857,7 +912,8 @@ async fn end_session(inner: &Arc<Inner>) -> Result<(), String> {

let focus_target = inner.state.lock().focus_target;
restore_focus_target_if_possible(focus_target);
let status = inner.inserter.insert(&polished);
let restore_clipboard = inner.prefs.get().restore_clipboard_after_paste;
let status = inner.inserter.insert(&polished, restore_clipboard);
let inserted_chars = polished.chars().count() as u32;

// 累计每条 enabled 词条在最终文本中的命中次数。
Expand Down Expand Up @@ -1041,11 +1097,12 @@ async fn polish_or_passthrough(
raw: &RawTranscript,
mode: PolishMode,
hotwords: &[String],
working_languages: &[String],
) -> (String, Option<String>) {
if mode == PolishMode::Raw {
return (raw.text.clone(), None);
}
match polish_text(&raw.text, mode, hotwords).await {
match polish_text(&raw.text, mode, hotwords, working_languages).await {
Ok(s) => (s, None),
Err(e) => {
let reason = e.to_string();
Expand All @@ -1055,7 +1112,53 @@ async fn polish_or_passthrough(
}
}

async fn polish_text(raw: &str, mode: PolishMode, hotwords: &[String]) -> anyhow::Result<String> {
async fn polish_text(
raw: &str,
mode: PolishMode,
hotwords: &[String],
working_languages: &[String],
) -> anyhow::Result<String> {
let api_key = CredentialsVault::get(CredentialAccount::ArkApiKey)?.unwrap_or_default();
if api_key.is_empty() {
anyhow::bail!("ark api key missing");
}
let model = CredentialsVault::get(CredentialAccount::ArkModelId)?
.filter(|s| !s.is_empty())
.unwrap_or_else(|| "deepseek-v3-2".to_string());
let endpoint = CredentialsVault::get(CredentialAccount::ArkEndpoint)?
.filter(|s| !s.is_empty())
.unwrap_or_else(|| "https://ark.cn-beijing.volces.com/api/v3/chat/completions".to_string());
let base_url = endpoint
.trim_end_matches("/chat/completions")
.trim_end_matches('/')
.to_string();

let config = OpenAICompatibleConfig::new("ark", "Doubao Ark", base_url, api_key, model);
let provider = OpenAICompatibleLLMProvider::new(config);
Ok(provider.polish(raw, mode, hotwords, working_languages).await?)
}

/// 翻译路径——和 polish 一样失败时返回原文 + 失败原因,避免"不丢字"约定被违反(CLAUDE.md)。
async fn translate_or_passthrough(
raw: &RawTranscript,
target_language: &str,
working_languages: &[String],
) -> (String, Option<String>) {
match translate_text(&raw.text, target_language, working_languages).await {
Ok(s) => (s, None),
Err(e) => {
let reason = e.to_string();
log::error!("[coord] translate failed, falling back to raw: {reason}");
(raw.text.clone(), Some(reason))
}
}
}

async fn translate_text(
raw: &str,
target_language: &str,
working_languages: &[String],
) -> anyhow::Result<String> {
let api_key = CredentialsVault::get(CredentialAccount::ArkApiKey)?.unwrap_or_default();
if api_key.is_empty() {
anyhow::bail!("ark api key missing");
Expand All @@ -1073,7 +1176,9 @@ async fn polish_text(raw: &str, mode: PolishMode, hotwords: &[String]) -> anyhow

let config = OpenAICompatibleConfig::new("ark", "Doubao Ark", base_url, api_key, model);
let provider = OpenAICompatibleLLMProvider::new(config);
Ok(provider.polish(raw, mode, hotwords).await?)
Ok(provider
.translate_to(raw, target_language, working_languages)
.await?)
}

fn read_whisper_credentials() -> (String, String, String) {
Expand Down Expand Up @@ -1419,6 +1524,7 @@ fn emit_capsule(
elapsed_ms,
message,
inserted_chars,
translation: inner.translation_modifier_seen.load(Ordering::SeqCst),
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Gate translation badge on enabled translation target

The capsule translation flag is emitted from translation_modifier_seen alone, but actual translation execution also requires a non-empty translation_target_language in end_session. This means when target is disabled (documented as “Shift does nothing”), pressing Shift still shows the blue “Translating” indicator even though the session will run normal polish. Please derive the emitted badge flag from the same predicate as the translation pipeline to avoid misleading status.

Useful? React with 👍 / 👎.

};

let show_capsule = inner.prefs.get().show_capsule;
Expand Down
Loading
Loading