diff --git a/.gitignore b/.gitignore index 41a33f5f..90d64a8c 100644 --- a/.gitignore +++ b/.gitignore @@ -10,4 +10,5 @@ local.properties .kotlin ktlint -ktlint.bat \ No newline at end of file +ktlint.bat +whisper/build/ \ No newline at end of file diff --git a/.gitmodules b/.gitmodules index 81bc6f38..f614413d 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,6 @@ [submodule "llama.cpp"] path = llama.cpp url = https://github.com/ggerganov/llama.cpp +[submodule "whisper/src/main/jni/whisper.cpp"] + path = whisper/src/main/jni/whisper.cpp + url = https://github.com/ggml-org/whisper.cpp.git diff --git a/app/build.gradle.kts b/app/build.gradle.kts index 24bd96f3..96ed2a6e 100644 --- a/app/build.gradle.kts +++ b/app/build.gradle.kts @@ -3,7 +3,7 @@ plugins { alias(libs.plugins.kotlin.android) alias(libs.plugins.kotlin.compose) id("com.google.devtools.ksp") - kotlin("plugin.serialization") version "2.1.0" + kotlin("plugin.serialization") version "2.0.0" } android { @@ -92,12 +92,17 @@ dependencies { implementation(project(":smollm")) implementation(project(":hf-model-hub-api")) + implementation(project(":whisper")) + + // Android Wave Recorder for speech-to-text + implementation("com.github.squti:Android-Wave-Recorder:2.1.0") // Koin: dependency injection implementation(libs.koin.android) implementation(libs.koin.annotations) implementation(libs.koin.androidx.compose) implementation(libs.androidx.ui.text.google.fonts) + implementation(libs.androidx.compose.foundation) ksp(libs.koin.ksp.compiler) // compose-markdown: Markdown rendering in Compose diff --git a/app/proguard-rules.pro b/app/proguard-rules.pro index 481bb434..f4b4cf2d 100644 --- a/app/proguard-rules.pro +++ b/app/proguard-rules.pro @@ -18,4 +18,7 @@ # If you keep the line number information, uncomment this to # hide the original source file name. -#-renamesourcefileattribute SourceFile \ No newline at end of file +#-renamesourcefileattribute SourceFile + +# Keep Whisper native callbacks +-keep class com.whispercpp.whisper.** { *; } \ No newline at end of file diff --git a/app/src/main/AndroidManifest.xml b/app/src/main/AndroidManifest.xml index ea94e633..1ae424c6 100644 --- a/app/src/main/AndroidManifest.xml +++ b/app/src/main/AndroidManifest.xml @@ -3,6 +3,12 @@ xmlns:tools="http://schemas.android.com/tools"> + + + + + + + + + + \ No newline at end of file diff --git a/app/src/main/java/io/shubham0204/smollmandroid/data/PreferencesManager.kt b/app/src/main/java/io/shubham0204/smollmandroid/data/PreferencesManager.kt new file mode 100644 index 00000000..47f5cf11 --- /dev/null +++ b/app/src/main/java/io/shubham0204/smollmandroid/data/PreferencesManager.kt @@ -0,0 +1,77 @@ +/* + * Copyright (C) 2024 Shubham Panchal + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.shubham0204.smollmandroid.data + +import android.content.Context +import org.koin.core.annotation.Single + +@Single +class PreferencesManager(context: Context) { + private val prefs = context.getSharedPreferences("smolchat_prefs", Context.MODE_PRIVATE) + + var ttsEnabled: Boolean + get() = prefs.getBoolean("tts_enabled", false) + set(value) = prefs.edit().putBoolean("tts_enabled", value).apply() + + var autoSubmitEnabled: Boolean + get() = prefs.getBoolean("auto_submit_enabled", false) + set(value) = prefs.edit().putBoolean("auto_submit_enabled", value).apply() + + var autoSubmitDelayMs: Long + get() = prefs.getLong("auto_submit_delay_ms", 2000L) + set(value) = prefs.edit().putLong("auto_submit_delay_ms", value).apply() + + var selectedWhisperModel: String + get() = prefs.getString("selected_whisper_model", DEFAULT_WHISPER_MODEL) ?: DEFAULT_WHISPER_MODEL + set(value) = prefs.edit().putString("selected_whisper_model", value).apply() + + var sttLanguage: String + get() = prefs.getString("stt_language", DEFAULT_STT_LANGUAGE) ?: DEFAULT_STT_LANGUAGE + set(value) = prefs.edit().putString("stt_language", value).apply() + + var autoContextTrimEnabled: Boolean + get() = prefs.getBoolean("auto_context_trim_enabled", false) + set(value) = prefs.edit().putBoolean("auto_context_trim_enabled", value).apply() + + companion object { + const val DEFAULT_WHISPER_MODEL = "ggml-base.en.bin" + const val DEFAULT_STT_LANGUAGE = "en" + + // Whisper supported languages with their display names + val SUPPORTED_LANGUAGES = listOf( + "en" to "English", + "de" to "German", + "fr" to "French", + "es" to "Spanish", + "it" to "Italian", + "pt" to "Portuguese", + "nl" to "Dutch", + "pl" to "Polish", + "ru" to "Russian", + "zh" to "Chinese", + "ja" to "Japanese", + "ko" to "Korean", + "ar" to "Arabic", + "hi" to "Hindi", + "tr" to "Turkish", + "uk" to "Ukrainian", + "cs" to "Czech", + "sv" to "Swedish", + "auto" to "Auto-detect", + ) + } +} diff --git a/app/src/main/java/io/shubham0204/smollmandroid/llm/SmolLMManager.kt b/app/src/main/java/io/shubham0204/smollmandroid/llm/SmolLMManager.kt index 0e1476a5..74d04959 100644 --- a/app/src/main/java/io/shubham0204/smollmandroid/llm/SmolLMManager.kt +++ b/app/src/main/java/io/shubham0204/smollmandroid/llm/SmolLMManager.kt @@ -16,6 +16,7 @@ package io.shubham0204.smollmandroid.llm +import android.os.Process import android.util.Log import io.shubham0204.smollm.SmolLM import io.shubham0204.smollmandroid.data.AppDB @@ -163,49 +164,72 @@ class SmolLMManager(private val appDB: AppDB) { responseGenerationJob?.cancel() responseGenerationJob = CoroutineScope(Dispatchers.Default).launch { + // Boost thread priority to reduce CPU throttling when screen is locked + // THREAD_PRIORITY_URGENT_AUDIO (-19) is the highest priority available + // to regular apps and signals to the system this is time-sensitive work + val originalPriority = Process.getThreadPriority(Process.myTid()) try { + Process.setThreadPriority(Process.THREAD_PRIORITY_URGENT_AUDIO) + LOGD(">>> Thread priority boosted from $originalPriority to URGENT_AUDIO") + } catch (e: Exception) { + LOGD(">>> Failed to boost thread priority: ${e.message}") + } + + try { + LOGD(">>> getResponse coroutine started on thread: ${Thread.currentThread().name}") isInferenceOn = true var response = "" val duration = measureTime { + LOGD(">>> Starting response flow collection...") instance.getResponseAsFlow(query).collect { piece -> response += piece - withContext(Dispatchers.Main) { - onPartialResponseGenerated(response) - } + // Don't use Main dispatcher - callbacks are thread-safe + // Using Main blocks when screen is locked + onPartialResponseGenerated(response) } + LOGD(">>> Response flow collection complete") } response = responseTransform(response) + LOGD(">>> Response transformed, length=${response.length}") // Thread-safe access to chat val currentChat = stateLock.withLock { chat } if (currentChat != null) { // Add response to database + LOGD(">>> Adding assistant message to DB...") appDB.addAssistantMessage(currentChat.id, response) + LOGD(">>> Assistant message added") } - withContext(Dispatchers.Main) { - isInferenceOn = false - onSuccess( - SmolLMResponse( - response = response, - generationSpeed = instance.getResponseGenerationSpeed(), - generationTimeSecs = duration.inWholeSeconds.toInt(), - contextLengthUsed = instance.getContextLengthUsed(), - ) + LOGD(">>> Calling onSuccess callback...") + isInferenceOn = false + onSuccess( + SmolLMResponse( + response = response, + generationSpeed = instance.getResponseGenerationSpeed(), + generationTimeSecs = duration.inWholeSeconds.toInt(), + contextLengthUsed = instance.getContextLengthUsed(), ) - } + ) + LOGD(">>> onSuccess callback returned") } catch (e: CancellationException) { + LOGD(">>> Response generation cancelled") isInferenceOn = false - withContext(Dispatchers.Main) { - onCancelled() - } + onCancelled() } catch (e: Exception) { + LOGD(">>> Response generation error: ${e.message}") isInferenceOn = false - withContext(Dispatchers.Main) { - onError(e) + onError(e) + } finally { + // Restore original thread priority + try { + Process.setThreadPriority(originalPriority) + LOGD(">>> Thread priority restored to $originalPriority") + } catch (e: Exception) { + LOGD(">>> Failed to restore thread priority: ${e.message}") } } } diff --git a/app/src/main/java/io/shubham0204/smollmandroid/service/VoiceChatService.kt b/app/src/main/java/io/shubham0204/smollmandroid/service/VoiceChatService.kt new file mode 100644 index 00000000..a5b5cb84 --- /dev/null +++ b/app/src/main/java/io/shubham0204/smollmandroid/service/VoiceChatService.kt @@ -0,0 +1,200 @@ +/* + * Copyright (C) 2024 Shubham Panchal + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.shubham0204.smollmandroid.service + +import android.app.Notification +import android.app.NotificationChannel +import android.app.NotificationManager +import android.app.PendingIntent +import android.app.Service +import android.content.Context +import android.content.Intent +import android.net.Uri +import android.os.Build +import android.os.IBinder +import android.os.PowerManager +import android.provider.Settings +import android.util.Log +import androidx.core.app.NotificationCompat +import io.shubham0204.smollmandroid.R +import io.shubham0204.smollmandroid.ui.screens.chat.ChatActivity +import org.koin.core.component.KoinComponent +import org.koin.core.component.inject + +private const val LOG_TAG = "VoiceChatService" + +/** + * Foreground service that keeps voice chat active when the screen is locked. + * Shows a persistent notification with a "Stop" action. + * Uses a partial wake lock to keep CPU active for transcription. + */ +class VoiceChatService : Service(), KoinComponent { + + private val voiceChatServiceManager: VoiceChatServiceManager by inject() + private var wakeLock: PowerManager.WakeLock? = null + + companion object { + const val NOTIFICATION_ID = 1001 + const val CHANNEL_ID = "voice_chat_channel" + const val ACTION_STOP = "io.shubham0204.smollmandroid.STOP_VOICE_CHAT" + + fun start(context: Context) { + Log.d(LOG_TAG, ">>> Starting VoiceChatService") + val intent = Intent(context, VoiceChatService::class.java) + if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.O) { + context.startForegroundService(intent) + } else { + context.startService(intent) + } + } + + fun stop(context: Context) { + Log.d(LOG_TAG, ">>> Stopping VoiceChatService") + context.stopService(Intent(context, VoiceChatService::class.java)) + } + + /** + * Check if the app is exempt from battery optimization. + * On Samsung devices, this is required to prevent the app from being frozen. + */ + fun isIgnoringBatteryOptimizations(context: Context): Boolean { + val powerManager = context.getSystemService(Context.POWER_SERVICE) as PowerManager + return powerManager.isIgnoringBatteryOptimizations(context.packageName) + } + + /** + * Request the user to disable battery optimization for this app. + * This is required on Samsung and other OEM devices to prevent aggressive app killing. + */ + fun requestBatteryOptimizationExemption(context: Context) { + if (!isIgnoringBatteryOptimizations(context)) { + Log.d(LOG_TAG, ">>> Requesting battery optimization exemption") + val intent = Intent(Settings.ACTION_REQUEST_IGNORE_BATTERY_OPTIMIZATIONS).apply { + data = Uri.parse("package:${context.packageName}") + addFlags(Intent.FLAG_ACTIVITY_NEW_TASK) + } + context.startActivity(intent) + } + } + } + + override fun onCreate() { + super.onCreate() + Log.d(LOG_TAG, "onCreate") + createNotificationChannel() + } + + override fun onStartCommand(intent: Intent?, flags: Int, startId: Int): Int { + Log.d(LOG_TAG, "onStartCommand, action=${intent?.action}") + + if (intent?.action == ACTION_STOP) { + Log.d(LOG_TAG, "Stop action received") + voiceChatServiceManager.requestStopService() + stopSelf() + return START_NOT_STICKY + } + + val notification = buildNotification() + startForeground(NOTIFICATION_ID, notification) + voiceChatServiceManager.setServiceRunning(true) + + // Acquire partial wake lock to keep CPU active for transcription + acquireWakeLock() + + Log.d(LOG_TAG, "Service started in foreground") + return START_STICKY + } + + override fun onDestroy() { + Log.d(LOG_TAG, "onDestroy") + releaseWakeLock() + voiceChatServiceManager.setServiceRunning(false) + super.onDestroy() + } + + override fun onBind(intent: Intent?): IBinder? = null + + @Suppress("DEPRECATION") + private fun acquireWakeLock() { + if (wakeLock == null) { + val powerManager = getSystemService(Context.POWER_SERVICE) as PowerManager + wakeLock = powerManager.newWakeLock( + PowerManager.PARTIAL_WAKE_LOCK, + "SmolChat:VoiceChatWakeLock" + ).apply { + acquire(60 * 60 * 1000L) // 1 hour max, released when service stops + } + Log.d(LOG_TAG, "Wake lock acquired") + } + } + + private fun releaseWakeLock() { + wakeLock?.let { + if (it.isHeld) { + it.release() + Log.d(LOG_TAG, "Wake lock released") + } + } + wakeLock = null + } + + private fun createNotificationChannel() { + if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.O) { + val channel = NotificationChannel( + CHANNEL_ID, + getString(R.string.voice_chat_channel_name), + NotificationManager.IMPORTANCE_LOW + ).apply { + description = getString(R.string.voice_chat_notification_text) + setShowBadge(false) + } + val manager = getSystemService(NotificationManager::class.java) + manager.createNotificationChannel(channel) + } + } + + private fun buildNotification(): Notification { + val stopIntent = Intent(this, VoiceChatService::class.java).apply { + action = ACTION_STOP + } + val stopPendingIntent = PendingIntent.getService( + this, 0, stopIntent, + PendingIntent.FLAG_UPDATE_CURRENT or PendingIntent.FLAG_IMMUTABLE + ) + + val openIntent = Intent(this, ChatActivity::class.java).apply { + flags = Intent.FLAG_ACTIVITY_SINGLE_TOP + } + val openPendingIntent = PendingIntent.getActivity( + this, 0, openIntent, + PendingIntent.FLAG_UPDATE_CURRENT or PendingIntent.FLAG_IMMUTABLE + ) + + return NotificationCompat.Builder(this, CHANNEL_ID) + .setContentTitle(getString(R.string.voice_chat_notification_title)) + .setContentText(getString(R.string.voice_chat_notification_text)) + .setSmallIcon(R.drawable.ic_mic_notification) + .setOngoing(true) + .setContentIntent(openPendingIntent) + .addAction( + R.drawable.ic_stop, + getString(R.string.voice_chat_stop), + stopPendingIntent + ) + .build() + } +} diff --git a/app/src/main/java/io/shubham0204/smollmandroid/service/VoiceChatServiceManager.kt b/app/src/main/java/io/shubham0204/smollmandroid/service/VoiceChatServiceManager.kt new file mode 100644 index 00000000..c23c6d2e --- /dev/null +++ b/app/src/main/java/io/shubham0204/smollmandroid/service/VoiceChatServiceManager.kt @@ -0,0 +1,44 @@ +/* + * Copyright (C) 2024 Shubham Panchal + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.shubham0204.smollmandroid.service + +import kotlinx.coroutines.flow.MutableSharedFlow +import kotlinx.coroutines.flow.MutableStateFlow +import kotlinx.coroutines.flow.SharedFlow +import kotlinx.coroutines.flow.StateFlow +import org.koin.core.annotation.Single + +/** + * Manages the state of the VoiceChatService and provides a way for the service + * and UI components to communicate. + */ +@Single +class VoiceChatServiceManager { + private val _isServiceRunning = MutableStateFlow(false) + val isServiceRunning: StateFlow = _isServiceRunning + + private val _stopServiceRequest = MutableSharedFlow(extraBufferCapacity = 1) + val stopServiceRequest: SharedFlow = _stopServiceRequest + + fun setServiceRunning(running: Boolean) { + _isServiceRunning.value = running + } + + fun requestStopService() { + _stopServiceRequest.tryEmit(Unit) + } +} diff --git a/app/src/main/java/io/shubham0204/smollmandroid/stt/SpeechToTextManager.kt b/app/src/main/java/io/shubham0204/smollmandroid/stt/SpeechToTextManager.kt new file mode 100644 index 00000000..4fdc4bc9 --- /dev/null +++ b/app/src/main/java/io/shubham0204/smollmandroid/stt/SpeechToTextManager.kt @@ -0,0 +1,768 @@ +package io.shubham0204.smollmandroid.stt + +import android.Manifest +import android.content.Context +import android.content.pm.PackageManager +import android.media.AudioFormat +import android.os.Process +import android.media.AudioRecord +import android.media.MediaRecorder +import android.os.Environment +import android.util.Log +import androidx.core.content.ContextCompat +import com.github.squti.androidwaverecorder.WaveRecorder +import com.whispercpp.whisper.WhisperCallback +import com.whispercpp.whisper.WhisperContext +import io.shubham0204.smollmandroid.data.PreferencesManager +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.Job +import kotlinx.coroutines.SupervisorJob +import kotlinx.coroutines.delay +import kotlinx.coroutines.flow.MutableSharedFlow +import kotlinx.coroutines.flow.MutableStateFlow +import kotlinx.coroutines.flow.SharedFlow +import kotlinx.coroutines.flow.StateFlow +import kotlinx.coroutines.isActive +import kotlinx.coroutines.launch +import kotlinx.coroutines.withContext +import org.koin.core.annotation.Single +import java.io.File +import java.io.FileInputStream +import java.nio.ByteBuffer +import java.nio.ByteOrder + +private const val LOG_TAG = "SpeechToTextManager" + +sealed class STTState { + data object Idle : STTState() + data object Recording : STTState() + data object Transcribing : STTState() + data class Error(val message: String) : STTState() +} + +@Single +class SpeechToTextManager( + private val context: Context, + private val preferencesManager: PreferencesManager +) { + + private val scope = CoroutineScope(Dispatchers.IO + SupervisorJob()) + + private var waveRecorder: WaveRecorder? = null + private var whisperContext: WhisperContext? = null + private var currentRecordingPath: String? = null + + private val _state = MutableStateFlow(STTState.Idle) + val state: StateFlow = _state + + private val _transcribedText = MutableStateFlow("") + val transcribedText: StateFlow = _transcribedText + + // Flow for streaming transcription chunks - emits the FULL transcription each time + // extraBufferCapacity=1 allows tryEmit to work without suspending + private val _streamingTranscription = MutableSharedFlow(extraBufferCapacity = 1) + val streamingTranscription: SharedFlow = _streamingTranscription + + // Flow to signal that silence was detected and auto-submit should happen + // extraBufferCapacity=1 allows tryEmit to work without suspending + private val _silenceDetected = MutableSharedFlow(extraBufferCapacity = 1) + val silenceDetected: SharedFlow = _silenceDetected + + private val modelsPath: File? = context.getExternalFilesDir(Environment.DIRECTORY_DOWNLOADS) + + private var isModelLoaded = false + private var loadedModelName: String? = null + + // For streaming transcription + private var streamingJob: Job? = null + private var silenceDetectionJob: Job? = null + private var audioRecord: AudioRecord? = null + private var isStreamingMode = false + private var audioBuffer = mutableListOf() + private val audioBufferLock = Any() + + // Interval for periodic transcription (in milliseconds) + private val transcriptionIntervalMs = 1500L + + // Audio recording parameters + private val sampleRate = 16000 + + // Regex to filter out Whisper noise/silence markers + // These markers should not be considered as "speech" for auto-submit purposes + private val noiseMarkerRegex = Regex( + """\[.*?]|\(.*?\)|<\|.*?\|>""", + RegexOption.IGNORE_CASE + ) + + /** + * Cleans transcription by removing Whisper noise markers like [empty audio], [BLANK_AUDIO], + * [noise], [music], (silence), etc. Returns only the actual spoken words. + */ + private fun cleanTranscription(text: String): String { + return noiseMarkerRegex.replace(text, "").trim() + } + + /** + * Checks if the transcription contains only noise markers (no real speech). + */ + private fun isOnlyNoiseMarkers(text: String): Boolean { + return cleanTranscription(text).isBlank() + } + + fun hasRecordingPermission(): Boolean { + return ContextCompat.checkSelfPermission( + context, + Manifest.permission.RECORD_AUDIO + ) == PackageManager.PERMISSION_GRANTED + } + + fun isModelAvailable(modelFileName: String? = null): Boolean { + val selectedModel = modelFileName ?: preferencesManager.selectedWhisperModel + val modelFile = File(modelsPath, selectedModel) + return modelFile.exists() + } + + /** + * Returns a list of available Whisper model files in the models directory. + * Whisper models typically have .bin extension and contain "ggml" in the name. + */ + fun getAvailableModels(): List { + return modelsPath?.listFiles() + ?.filter { it.isFile && it.name.endsWith(".bin") && it.name.contains("ggml") } + ?.map { it.name } + ?.sorted() + ?: emptyList() + } + + fun getSelectedModelName(): String = preferencesManager.selectedWhisperModel + + fun setSelectedModel(modelFileName: String) { + // If a different model is being selected, we need to reload + if (loadedModelName != modelFileName && isModelLoaded) { + scope.launch { + whisperContext?.release() + whisperContext = null + isModelLoaded = false + loadedModelName = null + } + } + preferencesManager.selectedWhisperModel = modelFileName + } + + fun loadModel(modelFileName: String? = null, onComplete: (Boolean) -> Unit = {}) { + val selectedModel = modelFileName ?: preferencesManager.selectedWhisperModel + scope.launch { + try { + // If model is already loaded and it's the same model, just return success + if (isModelLoaded && loadedModelName == selectedModel) { + withContext(Dispatchers.Main) { + onComplete(true) + } + return@launch + } + + // If a different model is loaded, release it first + if (isModelLoaded && loadedModelName != selectedModel) { + whisperContext?.release() + whisperContext = null + isModelLoaded = false + loadedModelName = null + } + + val modelFile = File(modelsPath, selectedModel) + if (!modelFile.exists()) { + Log.e(LOG_TAG, "Model file not found: ${modelFile.absolutePath}") + withContext(Dispatchers.Main) { + onComplete(false) + } + return@launch + } + + Log.d(LOG_TAG, "Loading Whisper model from: ${modelFile.absolutePath}") + whisperContext = WhisperContext.createContextFromFile(modelFile.absolutePath) + isModelLoaded = true + loadedModelName = selectedModel + Log.d(LOG_TAG, "Whisper model loaded successfully") + + withContext(Dispatchers.Main) { + onComplete(true) + } + } catch (e: Exception) { + Log.e(LOG_TAG, "Failed to load Whisper model", e) + withContext(Dispatchers.Main) { + onComplete(false) + } + } + } + } + + // Callback for when silence is detected and auto-submit should happen + // This is called directly from the transcription coroutine scope to avoid + // issues with frozen ViewModel coroutines on Samsung devices + // @Volatile ensures visibility across threads (set from Main, read from IO) + @Volatile + private var onSilenceDetectedCallback: ((String) -> Unit)? = null + + /** + * Set a callback that will be called when silence is detected. + * The callback receives the final transcription text. + * This is called directly from the IO dispatcher, so the callback + * should handle any necessary thread dispatching. + */ + fun setOnSilenceDetectedCallback(callback: ((String) -> Unit)?) { + Log.d(LOG_TAG, ">>> setOnSilenceDetectedCallback called, callback is ${if (callback != null) "NOT NULL" else "NULL"}") + onSilenceDetectedCallback = callback + } + + /** + * Start streaming recording with periodic transcription. + * Transcribed text will be emitted via streamingTranscription Flow. + * When transcription stops changing for the configured duration, silenceDetected will emit. + */ + @Suppress("MissingPermission") + fun startStreamingRecording(language: String = "en", autoSubmitDelayMs: Long = 2000L) { + if (!hasRecordingPermission()) { + _state.value = STTState.Error("Recording permission not granted") + return + } + + try { + isStreamingMode = true + synchronized(audioBufferLock) { + audioBuffer.clear() + } + + // Initialize AudioRecord for direct audio capture + val bufferSize = AudioRecord.getMinBufferSize( + sampleRate, + AudioFormat.CHANNEL_IN_MONO, + AudioFormat.ENCODING_PCM_16BIT + ) + + audioRecord = AudioRecord( + MediaRecorder.AudioSource.MIC, + sampleRate, + AudioFormat.CHANNEL_IN_MONO, + AudioFormat.ENCODING_PCM_16BIT, + bufferSize * 2 + ) + + audioRecord?.startRecording() + _state.value = STTState.Recording + Log.d(LOG_TAG, "Streaming recording started with AudioRecord") + + // Start audio capture job + streamingJob = scope.launch { + val readBuffer = ShortArray(bufferSize / 2) + + while (isActive && _state.value == STTState.Recording) { + val readCount = audioRecord?.read(readBuffer, 0, readBuffer.size) ?: 0 + + if (readCount > 0) { + // Add samples to buffer + synchronized(audioBufferLock) { + for (i in 0 until readCount) { + audioBuffer.add(readBuffer[i]) + } + } + } + + delay(10) // Small delay to prevent tight loop + } + } + + // Start periodic transcription job with auto-submit detection + silenceDetectionJob = scope.launch { + var lastCleanTranscription = "" // Only real spoken words (no noise markers) + var lastRawTranscription = "" // Full transcription including markers + var lastSpeechChangeTime = System.currentTimeMillis() + var autoSubmitTriggered = false + + Log.d(LOG_TAG, "Starting periodic transcription loop") + delay(transcriptionIntervalMs) // Initial delay before first transcription + while (isActive && _state.value == STTState.Recording) { + Log.d(LOG_TAG, "Periodic transcription tick, buffer size: ${audioBuffer.size}") + val rawTranscription = transcribeCurrentBuffer(language) + val cleanedTranscription = cleanTranscription(rawTranscription) + + Log.d(LOG_TAG, "Raw: '$rawTranscription' | Clean: '$cleanedTranscription'") + + if (rawTranscription.isNotBlank()) { + // Check if the CLEANED transcription (real speech) has changed + if (cleanedTranscription != lastCleanTranscription && cleanedTranscription.isNotBlank()) { + // Real speech changed - reset timer and emit + lastCleanTranscription = cleanedTranscription + lastRawTranscription = rawTranscription + lastSpeechChangeTime = System.currentTimeMillis() + autoSubmitTriggered = false + Log.d(LOG_TAG, "Speech changed, emitting: $cleanedTranscription") + // Emit the cleaned transcription (without noise markers) + _streamingTranscription.tryEmit(cleanedTranscription) + } else if (lastCleanTranscription.isNotBlank()) { + // Real speech hasn't changed - check if we should auto-submit + // Only auto-submit if we have actual spoken words + val timeSinceLastSpeech = System.currentTimeMillis() - lastSpeechChangeTime + Log.d(LOG_TAG, "Speech unchanged for ${timeSinceLastSpeech}ms (threshold: ${autoSubmitDelayMs}ms)") + if (timeSinceLastSpeech >= autoSubmitDelayMs && !autoSubmitTriggered) { + Log.d(LOG_TAG, ">>> SILENCE DETECTED after speech - triggering auto-submit") + autoSubmitTriggered = true + + // Call the callback directly from this coroutine scope + // This bypasses the ViewModel's coroutine which may be frozen + val callback = onSilenceDetectedCallback + Log.d(LOG_TAG, ">>> onSilenceDetectedCallback is ${if (callback != null) "NOT NULL" else "NULL"}") + if (callback != null) { + Log.d(LOG_TAG, ">>> Calling onSilenceDetectedCallback directly") + // Stop recording and get final transcription (cleaned) + val finalTranscription = lastCleanTranscription + stopStreamingRecordingInternal() + Log.d(LOG_TAG, ">>> Recording stopped, calling callback with: $finalTranscription") + callback(finalTranscription) + Log.d(LOG_TAG, ">>> Callback returned") + return@launch // Exit the loop since we stopped recording + } else { + // Fallback to flow emission for backward compatibility + val emitted = _silenceDetected.tryEmit(Unit) + Log.d(LOG_TAG, ">>> silenceDetected.tryEmit result: $emitted, subscribers: ${_silenceDetected.subscriptionCount.value}") + } + } + } else { + Log.d(LOG_TAG, "Only noise markers detected, waiting for speech...") + } + } else { + Log.d(LOG_TAG, "Transcription was blank") + } + + delay(transcriptionIntervalMs) + } + Log.d(LOG_TAG, "Periodic transcription loop ended, isActive=$isActive, state=${_state.value}") + } + } catch (e: Exception) { + Log.e(LOG_TAG, "Failed to start streaming recording", e) + _state.value = STTState.Error("Failed to start recording: ${e.message}") + } + } + + /** + * Transcribe the current audio buffer without stopping recording. + * Returns the full transcription text. + */ + private suspend fun transcribeCurrentBuffer(language: String): String { + val audioData: FloatArray + synchronized(audioBufferLock) { + if (audioBuffer.size < sampleRate) { // At least 1 second of audio + return "" + } + // Convert Short buffer to Float array + audioData = FloatArray(audioBuffer.size) + for (i in audioBuffer.indices) { + audioData[i] = audioBuffer[i] / 32768.0f + } + } + + Log.d(LOG_TAG, "Periodic transcription of ${audioData.size} samples") + + // Boost thread priority to reduce CPU throttling when screen is locked + val originalPriority = Process.getThreadPriority(Process.myTid()) + try { + Process.setThreadPriority(Process.THREAD_PRIORITY_URGENT_AUDIO) + } catch (e: Exception) { + Log.d(LOG_TAG, "Failed to boost thread priority: ${e.message}") + } + + return try { + val result = whisperContext?.transcribeData( + data = audioData, + language = language, + printTimestamp = false, + callback = object : WhisperCallback { + override fun onNewSegment(startMs: Long, endMs: Long, text: String) { + Log.d(LOG_TAG, "Streaming segment: $text") + } + + override fun onProgress(progress: Int) { + // Ignore progress for streaming + } + + override fun onComplete() { + Log.d(LOG_TAG, "Streaming transcription chunk complete") + } + } + ) ?: "" + + result.trim() + } catch (e: Exception) { + Log.e(LOG_TAG, "Error during periodic transcription", e) + "" + } finally { + // Restore original thread priority + try { + Process.setThreadPriority(originalPriority) + } catch (e: Exception) { + Log.d(LOG_TAG, "Failed to restore thread priority: ${e.message}") + } + } + } + + /** + * Internal method to stop recording without final transcription. + * Used by the silence detection callback. + */ + private fun stopStreamingRecordingInternal() { + streamingJob?.cancel() + streamingJob = null + silenceDetectionJob?.cancel() + silenceDetectionJob = null + isStreamingMode = false + + try { + audioRecord?.stop() + audioRecord?.release() + audioRecord = null + waveRecorder?.stopRecording() + waveRecorder = null + synchronized(audioBufferLock) { + audioBuffer.clear() + } + _state.value = STTState.Idle + Log.d(LOG_TAG, ">>> Recording stopped internally") + } catch (e: Exception) { + Log.e(LOG_TAG, "Failed to stop recording internally", e) + } + } + + /** + * Stop streaming recording and perform final transcription. + */ + fun stopStreamingRecording( + language: String = "en", + onComplete: (String) -> Unit + ) { + streamingJob?.cancel() + streamingJob = null + silenceDetectionJob?.cancel() + silenceDetectionJob = null + isStreamingMode = false + + try { + // Stop AudioRecord + audioRecord?.stop() + audioRecord?.release() + audioRecord = null + + // Also stop WaveRecorder if it was used (for non-streaming mode fallback) + waveRecorder?.stopRecording() + waveRecorder = null + + _state.value = STTState.Transcribing + Log.d(LOG_TAG, "Streaming recording stopped, final transcription") + + scope.launch { + // Get audio data from buffer + val audioData: FloatArray + val bufferEmpty: Boolean + synchronized(audioBufferLock) { + bufferEmpty = audioBuffer.isEmpty() + if (bufferEmpty) { + audioData = FloatArray(0) + } else { + audioData = FloatArray(audioBuffer.size) + for (i in audioBuffer.indices) { + audioData[i] = audioBuffer[i] / 32768.0f + } + audioBuffer.clear() + } + } + + if (bufferEmpty) { + Log.d(LOG_TAG, "Buffer empty, completing with empty string") + _state.value = STTState.Idle + onComplete("") + return@launch + } + + if (audioData.size < sampleRate) { // Less than 1 second + Log.d(LOG_TAG, "Audio too short (${audioData.size} samples), completing with empty string") + _state.value = STTState.Idle + onComplete("") + return@launch + } + + Log.d(LOG_TAG, "Starting final transcription of ${audioData.size} samples") + val result = whisperContext?.transcribeData( + data = audioData, + language = language, + printTimestamp = false, + callback = object : WhisperCallback { + override fun onNewSegment(startMs: Long, endMs: Long, text: String) { + Log.d(LOG_TAG, "Final segment: $text") + } + + override fun onProgress(progress: Int) {} + override fun onComplete() {} + } + ) ?: "" + + val finalText = result.trim() + Log.d(LOG_TAG, "Final transcription complete: $finalText") + + // StateFlow is thread-safe, no need for Main dispatcher + _state.value = STTState.Idle + // Call callback directly - caller handles any needed dispatching + onComplete(finalText) + } + } catch (e: Exception) { + Log.e(LOG_TAG, "Failed to stop streaming recording", e) + _state.value = STTState.Error("Failed to stop recording: ${e.message}") + onComplete("") + } + } + + fun startRecording() { + if (!hasRecordingPermission()) { + _state.value = STTState.Error("Recording permission not granted") + return + } + + try { + val recordingFile = generateRecordingFile() + currentRecordingPath = recordingFile.absolutePath + + waveRecorder = WaveRecorder(recordingFile.absolutePath) + .configureWaveSettings { + sampleRate = 16000 // Whisper expects 16kHz + channels = AudioFormat.CHANNEL_IN_MONO + audioEncoding = AudioFormat.ENCODING_PCM_16BIT + } + + waveRecorder?.startRecording() + _state.value = STTState.Recording + Log.d(LOG_TAG, "Recording started: ${recordingFile.absolutePath}") + } catch (e: Exception) { + Log.e(LOG_TAG, "Failed to start recording", e) + _state.value = STTState.Error("Failed to start recording: ${e.message}") + } + } + + fun stopRecordingAndTranscribe( + language: String = "en", + onTranscriptionComplete: (String) -> Unit + ) { + // If in streaming mode, use the streaming stop method + if (isStreamingMode) { + stopStreamingRecording(language, onTranscriptionComplete) + return + } + + try { + waveRecorder?.stopRecording() + waveRecorder = null + + val recordingPath = currentRecordingPath + if (recordingPath == null) { + _state.value = STTState.Error("No recording file found") + onTranscriptionComplete("") + return + } + + _state.value = STTState.Transcribing + Log.d(LOG_TAG, "Recording stopped, starting transcription") + + scope.launch { + transcribeAudio(recordingPath, language, onTranscriptionComplete) + } + } catch (e: Exception) { + Log.e(LOG_TAG, "Failed to stop recording", e) + _state.value = STTState.Error("Failed to stop recording: ${e.message}") + onTranscriptionComplete("") + } + } + + fun cancelRecording() { + streamingJob?.cancel() + streamingJob = null + silenceDetectionJob?.cancel() + silenceDetectionJob = null + isStreamingMode = false + + try { + // Stop AudioRecord if used + audioRecord?.stop() + audioRecord?.release() + audioRecord = null + + // Stop WaveRecorder if used + waveRecorder?.stopRecording() + waveRecorder = null + + // Clear the audio buffer + synchronized(audioBufferLock) { + audioBuffer.clear() + } + + // Delete the recording file if any + currentRecordingPath?.let { path -> + File(path).delete() + } + currentRecordingPath = null + + _state.value = STTState.Idle + Log.d(LOG_TAG, "Recording cancelled") + } catch (e: Exception) { + Log.e(LOG_TAG, "Failed to cancel recording", e) + } + } + + private suspend fun transcribeAudio( + audioPath: String, + language: String, + onComplete: (String) -> Unit + ) { + try { + if (whisperContext == null) { + Log.e(LOG_TAG, "Whisper context not initialized") + withContext(Dispatchers.Main) { + _state.value = STTState.Error("Whisper model not loaded") + onComplete("") + } + return + } + + val audioData = readWavFileAsFloatArray(audioPath) + if (audioData.isEmpty()) { + withContext(Dispatchers.Main) { + _state.value = STTState.Error("Failed to read audio file") + onComplete("") + } + return + } + + Log.d(LOG_TAG, "Transcribing ${audioData.size} samples") + + val transcriptionBuilder = StringBuilder() + + val result = whisperContext?.transcribeData( + data = audioData, + language = language, + printTimestamp = false, + callback = object : WhisperCallback { + override fun onNewSegment(startMs: Long, endMs: Long, text: String) { + transcriptionBuilder.append(text) + Log.d(LOG_TAG, "Segment: $text") + } + + override fun onProgress(progress: Int) { + Log.d(LOG_TAG, "Transcription progress: $progress%") + } + + override fun onComplete() { + Log.d(LOG_TAG, "Transcription complete") + } + } + ) ?: "" + + // Clean up the recording file + File(audioPath).delete() + currentRecordingPath = null + + val finalText = result.trim() + Log.d(LOG_TAG, "Transcription result: $finalText") + + withContext(Dispatchers.Main) { + _transcribedText.value = finalText + _state.value = STTState.Idle + onComplete(finalText) + } + } catch (e: Exception) { + Log.e(LOG_TAG, "Transcription failed", e) + withContext(Dispatchers.Main) { + _state.value = STTState.Error("Transcription failed: ${e.message}") + onComplete("") + } + } + } + + private fun readWavFileAsFloatArray(filePath: String): FloatArray { + return try { + val file = File(filePath) + if (!file.exists()) { + Log.e(LOG_TAG, "WAV file does not exist: $filePath") + return FloatArray(0) + } + + FileInputStream(file).use { fis -> + val headerBytes = ByteArray(44) + val headerRead = fis.read(headerBytes) + if (headerRead < 44) { + Log.e(LOG_TAG, "WAV header too short") + return FloatArray(0) + } + + // Read the data size from the header (bytes 40-43) + val dataSize = ByteBuffer.wrap(headerBytes, 40, 4) + .order(ByteOrder.LITTLE_ENDIAN) + .int + + // For streaming, the file might still be growing, so read what's available + val availableBytes = file.length().toInt() - 44 + val bytesToRead = minOf(dataSize, availableBytes) + + if (bytesToRead <= 0) { + return FloatArray(0) + } + + val audioBytes = ByteArray(bytesToRead) + fis.read(audioBytes) + + // Convert 16-bit PCM to float array + val samples = bytesToRead / 2 + val floatArray = FloatArray(samples) + val byteBuffer = ByteBuffer.wrap(audioBytes).order(ByteOrder.LITTLE_ENDIAN) + + for (i in 0 until samples) { + val sample = byteBuffer.short.toInt() + floatArray[i] = sample / 32768.0f + } + + floatArray + } + } catch (e: Exception) { + Log.e(LOG_TAG, "Failed to read WAV file", e) + FloatArray(0) + } + } + + private fun generateRecordingFile(): File { + val fileName = "stt_recording_${System.currentTimeMillis()}.wav" + return File(context.cacheDir, fileName) + } + + fun release() { + streamingJob?.cancel() + streamingJob = null + silenceDetectionJob?.cancel() + silenceDetectionJob = null + isStreamingMode = false + + scope.launch { + try { + audioRecord?.stop() + audioRecord?.release() + audioRecord = null + waveRecorder?.stopRecording() + waveRecorder = null + whisperContext?.release() + whisperContext = null + isModelLoaded = false + synchronized(audioBufferLock) { + audioBuffer.clear() + } + currentRecordingPath?.let { File(it).delete() } + currentRecordingPath = null + } catch (e: Exception) { + Log.e(LOG_TAG, "Failed to release resources", e) + } + } + } +} diff --git a/app/src/main/java/io/shubham0204/smollmandroid/tts/TextToSpeechManager.kt b/app/src/main/java/io/shubham0204/smollmandroid/tts/TextToSpeechManager.kt new file mode 100644 index 00000000..70970a44 --- /dev/null +++ b/app/src/main/java/io/shubham0204/smollmandroid/tts/TextToSpeechManager.kt @@ -0,0 +1,273 @@ +/* + * Copyright (C) 2024 Shubham Panchal + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.shubham0204.smollmandroid.tts + +import android.content.Context +import android.speech.tts.TextToSpeech +import android.speech.tts.UtteranceProgressListener +import android.util.Log +import kotlinx.coroutines.flow.MutableSharedFlow +import kotlinx.coroutines.flow.MutableStateFlow +import kotlinx.coroutines.flow.SharedFlow +import kotlinx.coroutines.flow.StateFlow +import org.koin.core.annotation.Single +import java.util.Locale +import java.util.concurrent.ConcurrentLinkedQueue + +private const val LOGTAG = "[TextToSpeechManager-Kt]" +private val LOGD: (String) -> Unit = { Log.d(LOGTAG, it) } + +@Single +class TextToSpeechManager(context: Context) { + + private var tts: TextToSpeech? = null + private var isInitialized = false + private var currentLanguage: String = "en" + + private val sentenceQueue = ConcurrentLinkedQueue() + private var isSpeaking = false + private var utteranceCounter = 0 + + private var previousText = "" + private var pendingBuffer = "" + + private val sentenceEndRegex = Regex("[.!?](?:\\s+|$)") + + // Flow to signal when all speech has finished (queue is empty) + // extraBufferCapacity = 1 ensures tryEmit succeeds even if collector is busy + private val _allSpeechFinished = MutableSharedFlow(extraBufferCapacity = 1) + val allSpeechFinished: SharedFlow = _allSpeechFinished + + // StateFlow to expose whether TTS is currently speaking + private val _isSpeakingFlow = MutableStateFlow(false) + val isSpeakingFlow: StateFlow = _isSpeakingFlow + + // Flag to track if we're in a speech session (from speakChunk to speakRemainingBuffer) + private var isSpeechSessionActive = false + + init { + tts = TextToSpeech(context) { status -> + if (status == TextToSpeech.SUCCESS) { + isInitialized = true + LOGD("TTS initialized successfully") + setupUtteranceListener() + // Set default language + setLanguage(currentLanguage) + } else { + LOGD("TTS initialization failed") + isInitialized = false + } + } + } + + /** + * Set the TTS language using a language code (e.g., "en", "de", "fr"). + * Returns true if the language is supported, false otherwise. + */ + fun setLanguage(languageCode: String): Boolean { + if (!isInitialized) return false + + currentLanguage = languageCode + val locale = when (languageCode) { + "en" -> Locale.US + "de" -> Locale.GERMAN + "fr" -> Locale.FRENCH + "es" -> Locale("es", "ES") + "it" -> Locale.ITALIAN + "pt" -> Locale("pt", "PT") + "nl" -> Locale("nl", "NL") + "pl" -> Locale("pl", "PL") + "ru" -> Locale("ru", "RU") + "zh" -> Locale.CHINESE + "ja" -> Locale.JAPANESE + "ko" -> Locale.KOREAN + "ar" -> Locale("ar", "SA") + "hi" -> Locale("hi", "IN") + "tr" -> Locale("tr", "TR") + "uk" -> Locale("uk", "UA") + "cs" -> Locale("cs", "CZ") + "sv" -> Locale("sv", "SE") + else -> Locale.US + } + + val result = tts?.setLanguage(locale) + val isSupported = result != TextToSpeech.LANG_MISSING_DATA && + result != TextToSpeech.LANG_NOT_SUPPORTED + + if (isSupported) { + LOGD("TTS language set to: $languageCode ($locale)") + } else { + LOGD("TTS language not supported: $languageCode, falling back to default") + tts?.setLanguage(Locale.US) + } + + return isSupported + } + + private fun setupUtteranceListener() { + tts?.setOnUtteranceProgressListener(object : UtteranceProgressListener() { + override fun onStart(utteranceId: String?) { + isSpeaking = true + _isSpeakingFlow.value = true + } + + override fun onDone(utteranceId: String?) { + isSpeaking = false + _isSpeakingFlow.value = sentenceQueue.isNotEmpty() // Still speaking if queue has more + speakNextInQueue() + checkAndEmitSpeechFinished() + } + + @Deprecated("Deprecated in Java") + override fun onError(utteranceId: String?) { + isSpeaking = false + _isSpeakingFlow.value = sentenceQueue.isNotEmpty() + speakNextInQueue() + checkAndEmitSpeechFinished() + } + + override fun onError(utteranceId: String?, errorCode: Int) { + isSpeaking = false + _isSpeakingFlow.value = sentenceQueue.isNotEmpty() + speakNextInQueue() + checkAndEmitSpeechFinished() + } + }) + } + + private fun checkAndEmitSpeechFinished() { + LOGD("checkAndEmitSpeechFinished: queueEmpty=${sentenceQueue.isEmpty()}, isSpeaking=$isSpeaking, sessionActive=$isSpeechSessionActive") + // If the queue is empty and we're not speaking, the session is finished + if (sentenceQueue.isEmpty() && !isSpeaking && isSpeechSessionActive) { + isSpeechSessionActive = false + LOGD("All speech finished, emitting signal") + val emitted = _allSpeechFinished.tryEmit(Unit) + LOGD("tryEmit result: $emitted") + } + } + + fun speakChunk(fullText: String) { + if (!isInitialized) return + + // Mark speech session as active when we start receiving chunks + isSpeechSessionActive = true + + val newContent = if (fullText.startsWith(previousText)) { + fullText.removePrefix(previousText) + } else { + fullText + } + previousText = fullText + + val textToProcess = pendingBuffer + newContent + val sentences = extractCompleteSentences(textToProcess) + + pendingBuffer = sentences.remaining + + sentences.completed.forEach { sentence -> + if (sentence.isNotBlank()) { + queueSentence(sentence.trim()) + } + } + } + + fun speakRemainingBuffer() { + LOGD("speakRemainingBuffer called, isInitialized=$isInitialized, pendingBuffer='$pendingBuffer', isSpeaking=$isSpeaking") + if (!isInitialized) return + + // Re-activate the speech session in case it was prematurely marked as finished + // (e.g., when TTS finished speaking queued sentences while generation was still ongoing) + isSpeechSessionActive = true + + if (pendingBuffer.isNotBlank()) { + LOGD("Queueing remaining buffer: '$pendingBuffer'") + queueSentence(pendingBuffer.trim()) + pendingBuffer = "" + } else { + // If there's no pending buffer and no ongoing speech, the session is done + LOGD("No pending buffer, checking if speech finished") + checkAndEmitSpeechFinished() + } + } + + private fun queueSentence(sentence: String) { + sentenceQueue.add(sentence) + if (!isSpeaking) { + speakNextInQueue() + } + } + + private fun speakNextInQueue() { + val nextSentence = sentenceQueue.poll() ?: return + + utteranceCounter++ + val utteranceId = "tts_utterance_$utteranceCounter" + + // Set isSpeaking before speak() to avoid race condition where + // checkAndEmitSpeechFinished() is called before onStart callback + isSpeaking = true + _isSpeakingFlow.value = true + tts?.speak( + nextSentence, + TextToSpeech.QUEUE_FLUSH, + null, + utteranceId + ) + } + + private data class SentenceExtraction( + val completed: List, + val remaining: String + ) + + private fun extractCompleteSentences(text: String): SentenceExtraction { + val completed = mutableListOf() + var remaining = text + + var match = sentenceEndRegex.find(remaining) + while (match != null) { + val endIndex = match.range.last + 1 + val sentence = remaining.substring(0, endIndex) + completed.add(sentence) + remaining = remaining.substring(endIndex) + match = sentenceEndRegex.find(remaining) + } + + return SentenceExtraction(completed, remaining) + } + + fun stop() { + sentenceQueue.clear() + pendingBuffer = "" + previousText = "" + isSpeaking = false + _isSpeakingFlow.value = false + isSpeechSessionActive = false + tts?.stop() + } + + fun resetState() { + stop() + } + + fun shutdown() { + stop() + tts?.shutdown() + tts = null + isInitialized = false + } +} diff --git a/app/src/main/java/io/shubham0204/smollmandroid/ui/screens/chat/ChatActivity.kt b/app/src/main/java/io/shubham0204/smollmandroid/ui/screens/chat/ChatActivity.kt index 3287a347..55275528 100644 --- a/app/src/main/java/io/shubham0204/smollmandroid/ui/screens/chat/ChatActivity.kt +++ b/app/src/main/java/io/shubham0204/smollmandroid/ui/screens/chat/ChatActivity.kt @@ -17,7 +17,12 @@ package io.shubham0204.smollmandroid.ui.screens.chat import CustomNavTypes +import android.Manifest +import android.content.pm.PackageManager import android.content.ClipData +import android.os.Build +import android.view.WindowManager +import androidx.core.content.ContextCompat import android.content.ClipboardManager import android.content.Context import android.content.Intent @@ -27,14 +32,18 @@ import android.util.Log import android.widget.Toast import androidx.activity.ComponentActivity import androidx.activity.compose.BackHandler +import androidx.activity.compose.rememberLauncherForActivityResult import androidx.activity.compose.setContent import androidx.activity.enableEdgeToEdge +import androidx.activity.result.contract.ActivityResultContracts import androidx.compose.animation.AnimatedVisibility import androidx.compose.animation.fadeIn import androidx.compose.animation.fadeOut import androidx.compose.foundation.ExperimentalFoundationApi import androidx.compose.foundation.background +import androidx.compose.foundation.border import androidx.compose.foundation.clickable +import androidx.compose.foundation.gestures.detectTapGestures import androidx.compose.foundation.layout.Arrangement import androidx.compose.foundation.layout.Box import androidx.compose.foundation.layout.Column @@ -57,8 +66,10 @@ import androidx.compose.foundation.shape.CircleShape import androidx.compose.foundation.shape.RoundedCornerShape import androidx.compose.foundation.text.KeyboardActions import androidx.compose.foundation.text.KeyboardOptions +import androidx.compose.material3.AlertDialog import androidx.compose.material3.CircularProgressIndicator import androidx.compose.material3.DrawerValue +import androidx.compose.material3.LinearProgressIndicator import androidx.compose.material3.ExperimentalMaterial3Api import androidx.compose.material3.Icon import androidx.compose.material3.IconButton @@ -66,6 +77,7 @@ import androidx.compose.material3.MaterialTheme import androidx.compose.material3.ModalBottomSheet import androidx.compose.material3.ModalNavigationDrawer import androidx.compose.material3.OutlinedButton +import androidx.compose.material3.OutlinedCard import androidx.compose.material3.Scaffold import androidx.compose.material3.Text import androidx.compose.material3.TextField @@ -73,12 +85,18 @@ import androidx.compose.material3.TextFieldDefaults import androidx.compose.material3.TopAppBar import androidx.compose.material3.rememberDrawerState import androidx.compose.runtime.Composable +import androidx.compose.runtime.DisposableEffect import androidx.compose.runtime.LaunchedEffect import androidx.compose.runtime.getValue +import androidx.compose.runtime.mutableFloatStateOf import androidx.compose.runtime.mutableStateOf +import androidx.compose.runtime.remember import androidx.compose.runtime.rememberCoroutineScope import androidx.compose.runtime.saveable.rememberSaveable import androidx.compose.runtime.setValue +import androidx.compose.ui.input.pointer.pointerInput +import androidx.compose.ui.zIndex +import kotlinx.coroutines.delay import androidx.compose.ui.Alignment import androidx.compose.ui.Modifier import androidx.compose.ui.draw.clip @@ -101,7 +119,10 @@ import androidx.navigation.compose.composable import androidx.navigation.compose.rememberNavController import androidx.navigation.toRoute import compose.icons.FeatherIcons +import compose.icons.feathericons.Lock import compose.icons.feathericons.Menu +import compose.icons.feathericons.Mic +import compose.icons.feathericons.MicOff import compose.icons.feathericons.MoreVertical import compose.icons.feathericons.Send import compose.icons.feathericons.StopCircle @@ -120,6 +141,7 @@ import io.shubham0204.smollmandroid.ui.preview.dummyFolders import io.shubham0204.smollmandroid.ui.preview.dummyLLMModels import io.shubham0204.smollmandroid.ui.preview.dummyTasksList import io.shubham0204.smollmandroid.ui.screens.chat.ChatScreenViewModel.ModelLoadingState +import io.shubham0204.smollmandroid.stt.STTState import io.shubham0204.smollmandroid.ui.screens.chat.dialogs.ChangeFolderDialogUI import io.shubham0204.smollmandroid.ui.screens.chat.dialogs.ChatMessageOptionsDialog import io.shubham0204.smollmandroid.ui.screens.chat.dialogs.ChatMoreOptionsPopup @@ -144,7 +166,15 @@ private object ChatRoute private object BenchmarkModelRoute @Serializable -private data class EditChatSettingsRoute(val chat: Chat, val modelContextSize: Int) +private data class EditChatSettingsRoute( + val chat: Chat, + val modelContextSize: Int, + val ttsEnabled: Boolean, + val autoSubmitEnabled: Boolean, + val autoSubmitDelayMs: Long, + val sttLanguage: String, + val autoContextTrimEnabled: Boolean +) class ChatActivity : ComponentActivity() { @@ -203,6 +233,11 @@ class ChatActivity : ComponentActivity() { EditChatSettingsScreen( settings, route.modelContextSize, + route.ttsEnabled, + route.autoSubmitEnabled, + route.autoSubmitDelayMs, + route.sttLanguage, + route.autoContextTrimEnabled, onUpdateChat = { editableChatSettings -> viewModel.onEvent( ChatScreenUIEvent.ChatEvents.UpdateChatSettings( @@ -211,6 +246,31 @@ class ChatActivity : ComponentActivity() { ) ) }, + onToggleTTS = { enabled -> + viewModel.onEvent( + ChatScreenUIEvent.TTSEvents.ToggleTTS(enabled) + ) + }, + onToggleAutoSubmit = { enabled -> + viewModel.onEvent( + ChatScreenUIEvent.AutoSubmitEvents.ToggleAutoSubmit(enabled) + ) + }, + onUpdateAutoSubmitDelay = { delayMs -> + viewModel.onEvent( + ChatScreenUIEvent.AutoSubmitEvents.UpdateAutoSubmitDelay(delayMs) + ) + }, + onUpdateSTTLanguage = { language -> + viewModel.onEvent( + ChatScreenUIEvent.STTEvents.UpdateSTTLanguage(language) + ) + }, + onToggleAutoContextTrim = { enabled -> + viewModel.onEvent( + ChatScreenUIEvent.ContextEvents.ToggleAutoContextTrim(enabled) + ) + }, onBackClicked = { navController.navigateUp() }, ) } @@ -223,11 +283,21 @@ class ChatActivity : ComponentActivity() { uiState, onEditChatParamsClick = { chat, modelContextSize -> navController.navigate( - EditChatSettingsRoute(chat, modelContextSize) + EditChatSettingsRoute( + chat, + modelContextSize, + uiState.ttsEnabled, + uiState.autoSubmitEnabled, + uiState.autoSubmitDelayMs, + uiState.sttLanguage, + uiState.autoContextTrimEnabled + ) ) }, onBenchmarkModelClick = { navController.navigate(BenchmarkModelRoute) }, viewModel::onEvent, + viewModel::resetAutoSubmitTrigger, + viewModel::resetClearInputFlag, ) } } @@ -251,8 +321,13 @@ class ChatActivity : ComponentActivity() { override fun onStop() { super.onStop() if (!isChangingConfigurations) { - modelUnloaded = viewModel.unloadModel() - LOGD("onStop() called - model unloaded result: $modelUnloaded") + // Don't unload model if voice chat service is active (screen locked but voice mode on) + if (viewModel.voiceChatServiceManager.isServiceRunning.value) { + LOGD("onStop() called - keeping model loaded for voice chat service") + } else { + modelUnloaded = viewModel.unloadModel() + LOGD("onStop() called - model unloaded result: $modelUnloaded") + } } } } @@ -282,9 +357,77 @@ fun ChatActivityScreenUI( onEditChatParamsClick: (Chat, Int) -> Unit, onBenchmarkModelClick: () -> Unit, onEvent: (ChatScreenUIEvent) -> Unit, + onAutoSubmitHandled: () -> Unit = {}, + onClearInputHandled: () -> Unit = {}, ) { val drawerState = rememberDrawerState(DrawerValue.Closed) val scope = rememberCoroutineScope() + val context = LocalContext.current + + // Permission launcher for RECORD_AUDIO + val permissionLauncher = rememberLauncherForActivityResult( + contract = ActivityResultContracts.RequestPermission() + ) { isGranted -> + onEvent(ChatScreenUIEvent.ChatEvents.RecordingPermissionHandled) + if (isGranted) { + onEvent(ChatScreenUIEvent.ChatEvents.RecordingPermissionGranted) + } else { + Toast.makeText( + context, + context.getString(R.string.stt_permission_denied), + Toast.LENGTH_LONG + ).show() + } + } + + // Permission launcher for POST_NOTIFICATIONS (Android 13+) + val notificationPermissionLauncher = rememberLauncherForActivityResult( + contract = ActivityResultContracts.RequestPermission() + ) { _ -> + // Continue with voice mode regardless of result (notification is optional but nice to have) + onEvent(ChatScreenUIEvent.ChatEvents.NotificationPermissionHandled) + } + + // Request notification permission when the state indicates it's needed + LaunchedEffect(uiState.requestNotificationPermission) { + if (uiState.requestNotificationPermission && Build.VERSION.SDK_INT >= Build.VERSION_CODES.TIRAMISU) { + notificationPermissionLauncher.launch(Manifest.permission.POST_NOTIFICATIONS) + } + } + + // Request permission when the state indicates it's needed + LaunchedEffect(uiState.requestRecordingPermission) { + if (uiState.requestRecordingPermission) { + permissionLauncher.launch(Manifest.permission.RECORD_AUDIO) + } + } + + // Handle pocket mode screen dimming - keeps screen on but at minimum brightness + // This prevents CPU throttling that happens when the screen is locked + // Samsung's "Accidental touch protection" will block touches when phone is in pocket + val activity = context as? ComponentActivity + DisposableEffect(uiState.isPocketModeEnabled) { + if (uiState.isPocketModeEnabled && activity != null) { + Log.d("ChatActivity", "Pocket mode enabled - keeping screen on with min brightness") + activity.window.addFlags(WindowManager.LayoutParams.FLAG_KEEP_SCREEN_ON) + // Set brightness to minimum (0.01f) - almost black but screen stays "on" + // This keeps CPU at full speed while saving battery + val layoutParams = activity.window.attributes + layoutParams.screenBrightness = 0.01f + activity.window.attributes = layoutParams + } + onDispose { + if (activity != null) { + Log.d("ChatActivity", "Pocket mode disabled - restoring normal screen behavior") + activity.window.clearFlags(WindowManager.LayoutParams.FLAG_KEEP_SCREEN_ON) + // Restore automatic brightness (-1f means use system default) + val layoutParams = activity.window.attributes + layoutParams.screenBrightness = WindowManager.LayoutParams.BRIGHTNESS_OVERRIDE_NONE + activity.window.attributes = layoutParams + } + } + } + SmolLMAndroidTheme { ModalNavigationDrawer( drawerState = drawerState, @@ -350,6 +493,8 @@ fun ChatActivityScreenUI( uiState.chat, uiState.showMoreOptionsPopup, uiState.memoryUsage != null, + uiState.ttsEnabled, + uiState.autoSubmitEnabled, onEditChatSettingsClick = { onEditChatParamsClick( uiState.chat, @@ -370,7 +515,7 @@ fun ChatActivityScreenUI( .padding(innerPadding) .background(MaterialTheme.colorScheme.surface) ) { - ScreenUI(uiState, onEvent) + ScreenUI(uiState, onEvent, onAutoSubmitHandled, onClearInputHandled) } } @@ -412,12 +557,273 @@ fun ChatActivityScreenUI( FolderOptionsDialog() TextFieldDialog() ChatMessageOptionsDialog() + + // Pocket mode overlay - covers entire screen when active + if (uiState.isPocketModeEnabled) { + PocketModeOverlay( + onExitPocketMode = { + onEvent(ChatScreenUIEvent.ChatEvents.DisablePocketMode) + } + ) + } + + // Context warning dialog + if (uiState.showContextWarningDialog) { + ContextWarningDialog( + contextUsed = uiState.chat.contextSizeConsumed, + contextMax = uiState.chat.contextSize, + onTrimMessages = { + onEvent(ChatScreenUIEvent.ChatEvents.TrimOldMessages) + }, + onNewChat = { + onEvent(ChatScreenUIEvent.ChatEvents.DismissContextWarning) + onEvent(ChatScreenUIEvent.ChatEvents.NewChat) + }, + onContinue = { + onEvent(ChatScreenUIEvent.ChatEvents.ContinueAnywayContextWarning) + }, + onDismiss = { + onEvent(ChatScreenUIEvent.ChatEvents.DismissContextWarning) + } + ) + } } } } +/** + * Full-screen overlay for pocket mode. Requires long-press (2 seconds) to exit. + */ @Composable -private fun ColumnScope.ScreenUI(uiState: ChatScreenUIState, onEvent: (ChatScreenUIEvent) -> Unit) { +private fun PocketModeOverlay( + onExitPocketMode: () -> Unit +) { + val holdDurationMs = 2000L + var isHolding by remember { mutableStateOf(false) } + var holdProgress by remember { mutableFloatStateOf(0f) } + val scope = rememberCoroutineScope() + + // Progress animation while holding + LaunchedEffect(isHolding) { + if (isHolding) { + holdProgress = 0f + val startTime = System.currentTimeMillis() + while (isHolding && holdProgress < 1f) { + delay(50) + val elapsed = System.currentTimeMillis() - startTime + holdProgress = (elapsed.toFloat() / holdDurationMs).coerceIn(0f, 1f) + if (holdProgress >= 1f) { + onExitPocketMode() + } + } + } else { + holdProgress = 0f + } + } + + Box( + modifier = Modifier + .fillMaxSize() + .background(Color.Black.copy(alpha = 0.95f)) + .zIndex(100f), + contentAlignment = Alignment.Center + ) { + Column( + horizontalAlignment = Alignment.CenterHorizontally, + verticalArrangement = Arrangement.Center, + modifier = Modifier.padding(32.dp) + ) { + // Lock icon + Icon( + imageVector = FeatherIcons.Lock, + contentDescription = null, + tint = Color.White, + modifier = Modifier.size(64.dp) + ) + + Spacer(modifier = Modifier.height(24.dp)) + + Text( + text = stringResource(R.string.pocket_mode_active), + color = Color.White, + style = MaterialTheme.typography.headlineMedium, + textAlign = TextAlign.Center + ) + + Spacer(modifier = Modifier.height(48.dp)) + + // Exit button with long-press requirement + Box( + contentAlignment = Alignment.Center, + modifier = Modifier + .size(120.dp) + .clip(CircleShape) + .background( + if (isHolding) MaterialTheme.colorScheme.error.copy(alpha = 0.8f) + else MaterialTheme.colorScheme.surfaceVariant.copy(alpha = 0.3f) + ) + .border( + width = 3.dp, + color = if (isHolding) MaterialTheme.colorScheme.error else Color.White.copy(alpha = 0.5f), + shape = CircleShape + ) + .pointerInput(Unit) { + detectTapGestures( + onPress = { + isHolding = true + tryAwaitRelease() + isHolding = false + } + ) + } + ) { + Column( + horizontalAlignment = Alignment.CenterHorizontally + ) { + if (isHolding) { + CircularProgressIndicator( + progress = { holdProgress }, + modifier = Modifier.size(60.dp), + color = Color.White, + strokeWidth = 4.dp + ) + } else { + Icon( + imageVector = FeatherIcons.StopCircle, + contentDescription = "Exit", + tint = Color.White, + modifier = Modifier.size(48.dp) + ) + } + } + } + + Spacer(modifier = Modifier.height(16.dp)) + + Text( + text = if (isHolding) { + stringResource(R.string.pocket_mode_exiting) + } else { + stringResource(R.string.pocket_mode_hold_to_exit) + }, + color = Color.White.copy(alpha = 0.7f), + style = MaterialTheme.typography.bodyMedium, + textAlign = TextAlign.Center + ) + } + } +} + +@Composable +private fun ContextWarningDialog( + contextUsed: Int, + contextMax: Int, + onTrimMessages: () -> Unit, + onNewChat: () -> Unit, + onContinue: () -> Unit, + onDismiss: () -> Unit +) { + val usagePercent = if (contextMax > 0) { + ((contextUsed.toFloat() / contextMax.toFloat()) * 100).toInt() + } else 0 + + AlertDialog( + onDismissRequest = onDismiss, + title = { + Text( + text = stringResource(R.string.context_nearly_full_title), + style = MaterialTheme.typography.headlineSmall + ) + }, + text = { + Column { + Text( + text = stringResource(R.string.context_nearly_full_message, usagePercent), + style = MaterialTheme.typography.bodyMedium + ) + Spacer(modifier = Modifier.height(16.dp)) + + // Option 1: Trim old messages + OutlinedCard( + onClick = onTrimMessages, + modifier = Modifier.fillMaxWidth() + ) { + Column(modifier = Modifier.padding(12.dp)) { + Text( + text = stringResource(R.string.context_option_trim), + style = MaterialTheme.typography.titleSmall, + color = MaterialTheme.colorScheme.primary + ) + Text( + text = stringResource(R.string.context_option_trim_desc), + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + } + } + + Spacer(modifier = Modifier.height(8.dp)) + + // Option 2: New chat + OutlinedCard( + onClick = onNewChat, + modifier = Modifier.fillMaxWidth() + ) { + Column(modifier = Modifier.padding(12.dp)) { + Text( + text = stringResource(R.string.context_option_new_chat), + style = MaterialTheme.typography.titleSmall, + color = MaterialTheme.colorScheme.primary + ) + Text( + text = stringResource(R.string.context_option_new_chat_desc), + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + } + } + + Spacer(modifier = Modifier.height(8.dp)) + + // Option 3: Continue anyway + OutlinedCard( + onClick = onContinue, + modifier = Modifier.fillMaxWidth() + ) { + Column(modifier = Modifier.padding(12.dp)) { + Text( + text = stringResource(R.string.context_option_continue), + style = MaterialTheme.typography.titleSmall, + color = MaterialTheme.colorScheme.error + ) + Text( + text = stringResource(R.string.context_option_continue_desc), + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + } + } + } + }, + confirmButton = {}, + dismissButton = {} + ) +} + +@Composable +private fun ColumnScope.ScreenUI( + uiState: ChatScreenUIState, + onEvent: (ChatScreenUIEvent) -> Unit, + onAutoSubmitHandled: () -> Unit, + onClearInputHandled: () -> Unit +) { + // Context usage progress bar + if (uiState.chat.contextSize > 0) { + ContextUsageBar( + contextUsed = uiState.chat.contextSizeConsumed, + contextMax = uiState.chat.contextSize + ) + } if (uiState.memoryUsage != null) { RAMUsageLabel(uiState.memoryUsage) } @@ -431,7 +837,73 @@ private fun ColumnScope.ScreenUI(uiState: ChatScreenUIState, onEvent: (ChatScree uiState.responseGenerationTimeSecs, onEvent, ) - MessageInput(uiState.chat, uiState.modelLoadingState, uiState.isGeneratingResponse, onEvent) + MessageInput( + uiState.chat, + uiState.modelLoadingState, + uiState.isGeneratingResponse, + uiState.autoSubmitEnabled, + uiState.autoSubmitDelayMs, + uiState.sttState, + uiState.pendingTranscribedText, + uiState.triggerAutoSubmit, + uiState.shouldClearInput, + uiState.isTTSSpeaking, + onEvent, + onAutoSubmitHandled, + onClearInputHandled + ) +} + +@Composable +private fun ContextUsageBar( + contextUsed: Int, + contextMax: Int +) { + val progress = if (contextMax > 0) { + (contextUsed.toFloat() / contextMax.toFloat()).coerceIn(0f, 1f) + } else 0f + + val progressPercent = (progress * 100).toInt() + + // Color changes based on usage: green -> yellow -> red + val progressColor = when { + progressPercent < 60 -> MaterialTheme.colorScheme.primary + progressPercent < 85 -> MaterialTheme.colorScheme.tertiary + else -> MaterialTheme.colorScheme.error + } + + Column( + modifier = Modifier + .fillMaxWidth() + .padding(horizontal = 16.dp, vertical = 4.dp) + ) { + Row( + modifier = Modifier.fillMaxWidth(), + horizontalArrangement = Arrangement.SpaceBetween, + verticalAlignment = Alignment.CenterVertically + ) { + Text( + text = stringResource(R.string.context_usage, contextUsed, contextMax), + style = MaterialTheme.typography.labelSmall, + color = MaterialTheme.colorScheme.onSurfaceVariant + ) + Text( + text = "$progressPercent%", + style = MaterialTheme.typography.labelSmall, + color = progressColor + ) + } + Spacer(modifier = Modifier.height(2.dp)) + LinearProgressIndicator( + progress = { progress }, + modifier = Modifier + .fillMaxWidth() + .height(4.dp) + .clip(RoundedCornerShape(2.dp)), + color = progressColor, + trackColor = MaterialTheme.colorScheme.surfaceVariant, + ) + } } @Composable @@ -464,9 +936,11 @@ private fun ColumnScope.MessagesList( listState.animateScrollToItem(messages.size) } } - LazyColumn(state = listState, modifier = Modifier - .fillMaxSize() - .weight(1f)) { + LazyColumn( + state = listState, modifier = Modifier + .fillMaxSize() + .weight(1f) + ) { itemsIndexed(messages) { i, chatMessage -> MessageListItem( chatMessage.renderedMessage, @@ -481,9 +955,9 @@ private fun ColumnScope.MessagesList( val clip = ClipData.newPlainText("Copied message", chatMessage.message) clipboard.setPrimaryClip(clip) Toast.makeText( - context, - context.getString(R.string.chat_message_copied), - Toast.LENGTH_SHORT, + context, + context.getString(R.string.chat_message_copied), + Toast.LENGTH_SHORT, ) .show() }, @@ -569,9 +1043,11 @@ private fun LazyItemScope.MessageListItem( var isEditing by rememberSaveable { mutableStateOf(false) } val context = LocalContext.current if (!isUserMessage) { - Row(modifier = modifier - .fillMaxWidth() - .animateItem()) { + Row( + modifier = modifier + .fillMaxWidth() + .animateItem() + ) { Spacer(modifier = Modifier.width(8.dp)) Column { ChatMessageText( @@ -707,7 +1183,16 @@ private fun MessageInput( currChat: Chat, modelLoadingState: ModelLoadingState, isGeneratingResponse: Boolean, + autoSubmitEnabled: Boolean, + autoSubmitDelayMs: Long, + sttState: STTState, + pendingTranscribedText: String?, + triggerAutoSubmit: Boolean, + shouldClearInput: Boolean, + isTTSSpeaking: Boolean, onEvent: (ChatScreenUIEvent) -> Unit, + onAutoSubmitHandled: () -> Unit, + onClearInputHandled: () -> Unit, defaultQuestion: String? = null, ) { if (currChat.llmModelId == -1L) { @@ -715,6 +1200,38 @@ private fun MessageInput( } else { var questionText by rememberSaveable { mutableStateOf(defaultQuestion ?: "") } val keyboardController = LocalSoftwareKeyboardController.current + + // Update text field with transcribed text when available (full replacement for streaming) + LaunchedEffect(pendingTranscribedText) { + if (pendingTranscribedText != null && pendingTranscribedText.isNotBlank()) { + // Replace with full transcription during streaming + questionText = pendingTranscribedText + } + } + + val isRecording = sttState is STTState.Recording + val isTranscribing = sttState is STTState.Transcribing + + // Handle auto-submit trigger from silence detection + LaunchedEffect(triggerAutoSubmit) { + if (triggerAutoSubmit && questionText.isNotBlank() && !isGeneratingResponse) { + Log.d("MessageInput", "Auto-submitting after silence detection: $questionText") + keyboardController?.hide() + onEvent(ChatScreenUIEvent.ChatEvents.SendUserQuery(questionText, fromVoice = true)) + questionText = "" + onAutoSubmitHandled() + } + } + + // Handle clear input signal from direct callback auto-submit + LaunchedEffect(shouldClearInput) { + if (shouldClearInput) { + Log.d("MessageInput", "Clearing input after auto-submit") + questionText = "" + onClearInputHandled() + } + } + Row( verticalAlignment = Alignment.CenterVertically, horizontalArrangement = Arrangement.Center, @@ -734,13 +1251,70 @@ private fun MessageInput( } AnimatedVisibility(modelLoadingState == ModelLoadingState.SUCCESS) { Row( - modifier = Modifier.fillMaxWidth(), + modifier = Modifier + .fillMaxWidth() + .padding(8.dp), verticalAlignment = Alignment.CenterVertically, ) { - TextField( + // Mic button - toggles recording + IconButton( modifier = Modifier - .fillMaxWidth() - .weight(1f), + .size(40.dp) + .background( + if (isRecording) MaterialTheme.colorScheme.error + else MaterialTheme.colorScheme.primaryContainer, + CircleShape + ), + onClick = { + onEvent(ChatScreenUIEvent.ChatEvents.ToggleMicRecording) + }, + enabled = !isTranscribing && !isTTSSpeaking + ) { + if (isTranscribing || isTTSSpeaking) { + CircularProgressIndicator( + modifier = Modifier.size(20.dp), + strokeWidth = 2.dp, + color = MaterialTheme.colorScheme.onPrimaryContainer + ) + } else { + Icon( + imageVector = if (isRecording) FeatherIcons.MicOff else FeatherIcons.Mic, + contentDescription = if (isRecording) "Stop Recording" else "Start Recording", + tint = if (isRecording) MaterialTheme.colorScheme.onError + else MaterialTheme.colorScheme.onPrimaryContainer, + ) + } + } + + // Pocket mode button - visible when recording + AnimatedVisibility(visible = isRecording) { + Row { + Spacer(modifier = Modifier.width(4.dp)) + IconButton( + modifier = Modifier + .size(40.dp) + .background( + MaterialTheme.colorScheme.secondaryContainer, + CircleShape + ), + onClick = { + onEvent(ChatScreenUIEvent.ChatEvents.EnablePocketMode) + } + ) { + Icon( + imageVector = FeatherIcons.Lock, + contentDescription = stringResource(R.string.pocket_mode), + tint = MaterialTheme.colorScheme.onSecondaryContainer, + ) + } + } + } + + Spacer(modifier = Modifier.width(8.dp)) + + // Text field takes remaining space + TextField( + modifier = Modifier.weight(1f), value = questionText, onValueChange = { questionText = it }, shape = RoundedCornerShape(16.dp), @@ -751,7 +1325,15 @@ private fun MessageInput( unfocusedIndicatorColor = Color.Transparent, disabledIndicatorColor = Color.Transparent, ), - placeholder = { Text(text = stringResource(R.string.chat_ask_question)) }, + placeholder = { + Text( + text = when { + isRecording -> stringResource(R.string.stt_recording) + isTranscribing -> stringResource(R.string.stt_transcribing) + else -> stringResource(R.string.chat_ask_question) + } + ) + }, keyboardOptions = KeyboardOptions.Default.copy( capitalization = KeyboardCapitalization.Sentences, diff --git a/app/src/main/java/io/shubham0204/smollmandroid/ui/screens/chat/ChatListDrawer.kt b/app/src/main/java/io/shubham0204/smollmandroid/ui/screens/chat/ChatListDrawer.kt index 9825b060..7ea9fd77 100644 --- a/app/src/main/java/io/shubham0204/smollmandroid/ui/screens/chat/ChatListDrawer.kt +++ b/app/src/main/java/io/shubham0204/smollmandroid/ui/screens/chat/ChatListDrawer.kt @@ -69,6 +69,7 @@ import compose.icons.feathericons.List import compose.icons.feathericons.MoreVertical import compose.icons.feathericons.Plus import compose.icons.feathericons.PlusSquare +import compose.icons.feathericons.Trash2 import io.shubham0204.smollmandroid.R import io.shubham0204.smollmandroid.data.Chat import io.shubham0204.smollmandroid.data.Folder @@ -93,6 +94,7 @@ private fun PreviewChatsAndFoldersList() { folders = dummyFolders.toImmutableList(), onManageTasksClick = {}, onItemClick = {}, + onDeleteChatClick = {}, onDeleteFolderClick = {}, onDeleteFolderWithChatsClick = {}, onUpdateFolder = {}, @@ -166,6 +168,9 @@ fun DrawerUI( onEvent(ChatScreenUIEvent.ChatEvents.SwitchChat(it)) onCloseDrawer() }, + onDeleteChatClick = { + onEvent(ChatScreenUIEvent.ChatEvents.OnDeleteChat(it)) + }, onDeleteFolderClick = { onEvent(ChatScreenUIEvent.FolderEvents.DeleteFolder(it.id)) }, onDeleteFolderWithChatsClick = { onEvent( @@ -189,6 +194,7 @@ private fun ChatsAndFoldersList( folders: ImmutableList, onManageTasksClick: () -> Unit, onItemClick: (Chat) -> Unit, + onDeleteChatClick: (Chat) -> Unit, onDeleteFolderClick: (Folder) -> Unit, onDeleteFolderWithChatsClick: (Folder) -> Unit, onUpdateFolder: (Folder) -> Unit, @@ -219,18 +225,24 @@ private fun ChatsAndFoldersList( chats, folders, onItemClick, + onDeleteChatClick, onDeleteFolderClick, onDeleteFolderWithChatsClick, onUpdateFolder, onAddFolder, ) Spacer(modifier = Modifier.height(8.dp)) - ChatsList(currentChat, chats, onItemClick) + ChatsList(currentChat, chats, onItemClick, onDeleteChatClick) } } @Composable -private fun ChatsList(currentChat: Chat?, chats: ImmutableList, onItemClick: (Chat) -> Unit) { +private fun ChatsList( + currentChat: Chat?, + chats: ImmutableList, + onItemClick: (Chat) -> Unit, + onDeleteChatClick: (Chat) -> Unit +) { LazyColumn { item { Text( @@ -239,7 +251,14 @@ private fun ChatsList(currentChat: Chat?, chats: ImmutableList, onItemClic ) Spacer(modifier = Modifier.height(8.dp)) } - items(chats) { chat -> ChatListItem(chat, onItemClick, currentChat?.id == chat.id) } + items(chats) { chat -> + ChatListItem( + chat = chat, + onItemClick = onItemClick, + onDeleteClick = onDeleteChatClick, + isCurrentlySelected = currentChat?.id == chat.id + ) + } } } @@ -247,6 +266,7 @@ private fun ChatsList(currentChat: Chat?, chats: ImmutableList, onItemClic private fun LazyItemScope.ChatListItem( chat: Chat, onItemClick: (Chat) -> Unit, + onDeleteClick: (Chat) -> Unit, isCurrentlySelected: Boolean, ) { Row( @@ -259,32 +279,41 @@ private fun LazyItemScope.ChatListItem( .animateItem(), verticalAlignment = Alignment.CenterVertically, ) { - Row(verticalAlignment = Alignment.CenterVertically) { - Column(modifier = Modifier.weight(1f)) { - Text( - if (chat.isTask) { - "[Task] " + chat.name - } else { - chat.name - }, - style = MaterialTheme.typography.bodyLarge, - maxLines = 1, - overflow = TextOverflow.Ellipsis, - ) - Text( - text = DateUtils.getRelativeTimeSpanString(chat.dateUsed.time).toString(), - style = MaterialTheme.typography.labelSmall, - ) - } - if (isCurrentlySelected) { - Box( - modifier = - Modifier - .padding(start = 4.dp) - .background(MaterialTheme.colorScheme.tertiary, CircleShape) - .size(10.dp) - ) {} - } + Column(modifier = Modifier.weight(1f)) { + Text( + if (chat.isTask) { + "[Task] " + chat.name + } else { + chat.name + }, + style = MaterialTheme.typography.bodyLarge, + maxLines = 1, + overflow = TextOverflow.Ellipsis, + ) + Text( + text = DateUtils.getRelativeTimeSpanString(chat.dateUsed.time).toString(), + style = MaterialTheme.typography.labelSmall, + ) + } + if (isCurrentlySelected) { + Box( + modifier = + Modifier + .padding(start = 4.dp) + .background(MaterialTheme.colorScheme.tertiary, CircleShape) + .size(10.dp) + ) {} + } + IconButton( + onClick = { onDeleteClick(chat) }, + modifier = Modifier.size(32.dp) + ) { + Icon( + FeatherIcons.Trash2, + contentDescription = "Delete chat", + tint = MaterialTheme.colorScheme.error, + modifier = Modifier.size(16.dp) + ) } } } @@ -294,6 +323,7 @@ private fun FoldersList( allChats: ImmutableList, folders: ImmutableList, onItemClick: (Chat) -> Unit, + onDeleteChatClick: (Chat) -> Unit, onDeleteFolderClick: (Folder) -> Unit, onDeleteFolderWithChatsClick: (Folder) -> Unit, onUpdateFolder: (Folder) -> Unit, @@ -329,7 +359,8 @@ private fun FoldersList( FolderListItem( folder = folder, chatsInFolder = folder.getChats(allChats).toImmutableList(), - onItemClick, + onChatItemClick = onItemClick, + onDeleteChatClick = onDeleteChatClick, onEditFolderNameClick = { newName -> onUpdateFolder(folder.copy(name = newName)) }, onDeleteFolderClick = { onDeleteFolderClick(folder) }, onDeleteFolderWithChatsClick = { onDeleteFolderWithChatsClick(folder) }, @@ -342,6 +373,7 @@ private fun FolderListItem( folder: Folder, chatsInFolder: ImmutableList, onChatItemClick: (Chat) -> Unit, + onDeleteChatClick: (Chat) -> Unit, onEditFolderNameClick: (String) -> Unit, onDeleteFolderClick: () -> Unit, onDeleteFolderWithChatsClick: () -> Unit, @@ -425,6 +457,7 @@ private fun FolderListItem( ChatListItem( chat = it, onItemClick = onChatItemClick, + onDeleteClick = onDeleteChatClick, isCurrentlySelected = false, ) } diff --git a/app/src/main/java/io/shubham0204/smollmandroid/ui/screens/chat/ChatScreenViewModel.kt b/app/src/main/java/io/shubham0204/smollmandroid/ui/screens/chat/ChatScreenViewModel.kt index 2a655167..8770bcd6 100644 --- a/app/src/main/java/io/shubham0204/smollmandroid/ui/screens/chat/ChatScreenViewModel.kt +++ b/app/src/main/java/io/shubham0204/smollmandroid/ui/screens/chat/ChatScreenViewModel.kt @@ -16,15 +16,21 @@ package io.shubham0204.smollmandroid.ui.screens.chat +import android.Manifest import android.annotation.SuppressLint import android.app.ActivityManager import android.app.ActivityManager.MemoryInfo import android.content.Context +import android.content.Intent +import android.content.pm.PackageManager +import android.os.Build import android.text.Spanned import android.util.Log import android.widget.Toast +import androidx.core.content.ContextCompat import androidx.lifecycle.ViewModel import androidx.lifecycle.viewModelScope +import io.shubham0204.smollmandroid.ui.screens.whisper_download.DownloadWhisperModelActivity import io.shubham0204.smollm.SmolLM import io.shubham0204.smollmandroid.R import io.shubham0204.smollmandroid.data.AppDB @@ -32,12 +38,19 @@ import io.shubham0204.smollmandroid.data.Chat import io.shubham0204.smollmandroid.data.ChatMessage import io.shubham0204.smollmandroid.data.Folder import io.shubham0204.smollmandroid.data.LLMModel +import io.shubham0204.smollmandroid.data.PreferencesManager import io.shubham0204.smollmandroid.data.Task import io.shubham0204.smollmandroid.llm.ModelsRepository import io.shubham0204.smollmandroid.llm.SmolLMManager +import io.shubham0204.smollmandroid.service.VoiceChatService +import io.shubham0204.smollmandroid.service.VoiceChatServiceManager +import io.shubham0204.smollmandroid.stt.SpeechToTextManager +import io.shubham0204.smollmandroid.stt.STTState +import io.shubham0204.smollmandroid.tts.TextToSpeechManager import io.shubham0204.smollmandroid.ui.components.createAlertDialog import kotlinx.collections.immutable.ImmutableList import kotlinx.collections.immutable.toImmutableList +import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.flow.MutableStateFlow import kotlinx.coroutines.flow.StateFlow import kotlinx.coroutines.flow.collectLatest @@ -60,7 +73,26 @@ sealed class ChatScreenUIEvent { data class DeleteModel(val model: LLMModel) : ChatScreenUIEvent() - data class SendUserQuery(val query: String) : ChatScreenUIEvent() + data class SendUserQuery(val query: String, val fromVoice: Boolean = false) : + ChatScreenUIEvent() + + data object ToggleMicRecording : ChatScreenUIEvent() + + data object RecordingPermissionGranted : ChatScreenUIEvent() + + data object RecordingPermissionHandled : ChatScreenUIEvent() + + data object NotificationPermissionHandled : ChatScreenUIEvent() + + data object EnablePocketMode : ChatScreenUIEvent() + + data object DisablePocketMode : ChatScreenUIEvent() + + data object TrimOldMessages : ChatScreenUIEvent() + + data object DismissContextWarning : ChatScreenUIEvent() + + data object ContinueAnywayContextWarning : ChatScreenUIEvent() data object StopGeneration : ChatScreenUIEvent() @@ -112,6 +144,23 @@ sealed class ChatScreenUIEvent { data class ShowContextLengthUsageDialog(val chat: Chat) : ChatScreenUIEvent() } + + sealed class TTSEvents { + data class ToggleTTS(val enabled: Boolean) : ChatScreenUIEvent() + } + + sealed class AutoSubmitEvents { + data class ToggleAutoSubmit(val enabled: Boolean) : ChatScreenUIEvent() + data class UpdateAutoSubmitDelay(val delayMs: Long) : ChatScreenUIEvent() + } + + sealed class ContextEvents { + data class ToggleAutoContextTrim(val enabled: Boolean) : ChatScreenUIEvent() + } + + sealed class STTEvents { + data class UpdateSTTLanguage(val language: String) : ChatScreenUIEvent() + } } data class ChatScreenUIState( @@ -133,6 +182,26 @@ data class ChatScreenUIState( val showSelectModelListDialog: Boolean = false, val showMoreOptionsPopup: Boolean = false, val showTasksBottomSheet: Boolean = false, + val ttsEnabled: Boolean = false, + val autoSubmitEnabled: Boolean = false, + val autoSubmitDelayMs: Long = 2000L, + val sttState: STTState = STTState.Idle, + val pendingTranscribedText: String? = null, + val requestRecordingPermission: Boolean = false, + val requestNotificationPermission: Boolean = false, + val triggerAutoSubmit: Boolean = false, + val sttLanguage: String = "en", + val lastInputWasVoice: Boolean = false, + val isVoiceModeActive: Boolean = false, + val isPocketModeEnabled: Boolean = false, + val shouldClearInput: Boolean = false, + val isTTSSpeaking: Boolean = false, + val showContextWarningDialog: Boolean = false, + val contextWarningShownForThisChat: Boolean = false, + val contextTrimLevel: Int = 0, // 0=none, 1=light, 2=medium, 3=aggressive + val pendingRetryQuery: String? = null, // Query to retry after context trim + val pendingRetryFromVoice: Boolean = false, + val autoContextTrimEnabled: Boolean = false, // Auto-trim context without prompting ) @KoinViewModel @@ -142,6 +211,10 @@ class ChatScreenViewModel( val modelsRepository: ModelsRepository, val smolLMManager: SmolLMManager, val mdRenderer: MDRenderer, + val preferencesManager: PreferencesManager, + val ttsManager: TextToSpeechManager, + val sttManager: SpeechToTextManager, + val voiceChatServiceManager: VoiceChatServiceManager, ) : ViewModel() { enum class ModelLoadingState { NOT_LOADED, // model loading not started @@ -166,6 +239,100 @@ class ChatScreenViewModel( setupCollectors() loadModel() activityManager = context.getSystemService(Context.ACTIVITY_SERVICE) as ActivityManager + _uiState.update { + it.copy( + ttsEnabled = preferencesManager.ttsEnabled, + autoSubmitEnabled = preferencesManager.autoSubmitEnabled, + autoSubmitDelayMs = preferencesManager.autoSubmitDelayMs, + sttLanguage = preferencesManager.sttLanguage, + autoContextTrimEnabled = preferencesManager.autoContextTrimEnabled + ) + } + // Set TTS language from saved preference + ttsManager.setLanguage(preferencesManager.sttLanguage) + // Collect STT state changes + viewModelScope.launch { + sttManager.state.collect { sttState -> + _uiState.update { it.copy(sttState = sttState) } + } + } + // Collect streaming transcription - this now emits the FULL transcription + viewModelScope.launch { + sttManager.streamingTranscription.collect { fullTranscription -> + if (fullTranscription.isNotBlank()) { + _uiState.update { it.copy(pendingTranscribedText = fullTranscription) } + } + } + } + // Set up direct callback for silence detection + // This callback is called directly from SpeechToTextManager's IO coroutine scope + // which bypasses the frozen ViewModel coroutines on Samsung devices + sttManager.setOnSilenceDetectedCallback { finalText -> + LOGD(">>> onSilenceDetectedCallback called on thread: ${Thread.currentThread().name}") + LOGD(">>> finalText='$finalText', autoSubmitEnabled=${_uiState.value.autoSubmitEnabled}") + + if (_uiState.value.autoSubmitEnabled && finalText.isNotBlank()) { + LOGD(">>> Calling sendUserQuery from callback...") + _uiState.update { it.copy(pendingTranscribedText = null, shouldClearInput = true) } + sendUserQuery(finalText, fromVoice = true) + LOGD(">>> sendUserQuery returned from callback") + } else { + LOGD(">>> Not auto-submitting: autoSubmitEnabled=${_uiState.value.autoSubmitEnabled}, finalText.isNotBlank=${finalText.isNotBlank()}") + } + } + + // Keep the flow collector as fallback (for manual stop recording) + viewModelScope.launch(Dispatchers.Default) { + LOGD(">>> Starting silence detection flow collector (fallback)") + sttManager.silenceDetected.collect { + LOGD(">>> silenceDetected flow RECEIVED (fallback path)") + // This is now only used as fallback when callback is not set + } + } + // Collect TTS speaking state to disable mic while speaking + viewModelScope.launch { + ttsManager.isSpeakingFlow.collect { isSpeaking -> + _uiState.update { it.copy(isTTSSpeaking = isSpeaking) } + } + } + // Collect TTS completion events to resume recording if the input was from voice + viewModelScope.launch { + ttsManager.allSpeechFinished.collect { + LOGD("TTS finished. lastInputWasVoice=${_uiState.value.lastInputWasVoice}, ttsEnabled=${_uiState.value.ttsEnabled}, isGeneratingResponse=${_uiState.value.isGeneratingResponse}, isTTSSpeaking=${_uiState.value.isTTSSpeaking}") + + // Check context usage after TTS finishes + checkContextUsage() + + // Resume recording if: + // - Voice mode is active + // - The last input was from voice + // - TTS is enabled (indicating voice conversation mode) + // - We're not currently generating a response + // - Context warning dialog is not showing + // - TTS is not currently speaking (e.g., context warning message) + if (_uiState.value.isVoiceModeActive && + _uiState.value.lastInputWasVoice && + _uiState.value.ttsEnabled && + !_uiState.value.isGeneratingResponse && + !_uiState.value.showContextWarningDialog && + !_uiState.value.isTTSSpeaking + ) { + LOGD("Resuming recording after TTS finished") + _uiState.update { it.copy(lastInputWasVoice = false) } + // Use toggleMicRecording which handles all state checks properly + toggleMicRecording() + } else { + LOGD("NOT resuming recording: isVoiceModeActive=${_uiState.value.isVoiceModeActive}, lastInputWasVoice=${_uiState.value.lastInputWasVoice}, isTTSSpeaking=${_uiState.value.isTTSSpeaking}") + } + } + } + // Collect stop requests from the notification action + viewModelScope.launch { + voiceChatServiceManager.stopServiceRequest.collect { + LOGD("Stop voice mode requested from notification") + stopVoiceMode() + } + } } /** @@ -224,6 +391,19 @@ class ChatScreenViewModel( ) } onComplete(ModelLoadingState.SUCCESS) + + // Check if there's a pending retry query after context trim + val pendingQuery = _uiState.value.pendingRetryQuery + val pendingFromVoice = _uiState.value.pendingRetryFromVoice + if (pendingQuery != null && _uiState.value.isPocketModeEnabled) { + LOGD("Retrying pending query after context trim: '$pendingQuery'") + _uiState.update { it.copy(pendingRetryQuery = null, pendingRetryFromVoice = false) } + // Small delay to let TTS finish announcing the trim + viewModelScope.launch { + kotlinx.coroutines.delay(500) + sendUserQuery(pendingQuery, addMessageToDB = false, fromVoice = pendingFromVoice) + } + } }, ) } @@ -326,7 +506,7 @@ class ChatScreenViewModel( ChatScreenUIEvent.ChatEvents.LoadChatModel -> {} is ChatScreenUIEvent.ChatEvents.SendUserQuery -> { - sendUserQuery(event.query) + sendUserQuery(event.query, fromVoice = event.fromVoice) } ChatScreenUIEvent.ChatEvents.StopGeneration -> { @@ -427,6 +607,64 @@ class ChatScreenViewModel( switchChat(newChat) } + ChatScreenUIEvent.ChatEvents.ToggleMicRecording -> { + toggleMicRecording() + } + + ChatScreenUIEvent.ChatEvents.RecordingPermissionGranted -> { + // Permission was granted, now start recording + startRecordingAfterPermission() + } + + ChatScreenUIEvent.ChatEvents.RecordingPermissionHandled -> { + // Reset the permission request flag + _uiState.update { it.copy(requestRecordingPermission = false) } + } + + ChatScreenUIEvent.ChatEvents.NotificationPermissionHandled -> { + // Reset the permission request flag and continue with voice mode + _uiState.update { it.copy(requestNotificationPermission = false) } + // Now continue with the voice mode start + startVoiceModeAfterPermissions() + } + + ChatScreenUIEvent.ChatEvents.EnablePocketMode -> { + LOGD("Enabling pocket mode") + _uiState.update { it.copy(isPocketModeEnabled = true) } + } + + ChatScreenUIEvent.ChatEvents.DisablePocketMode -> { + LOGD("Disabling pocket mode") + _uiState.update { it.copy(isPocketModeEnabled = false) } + } + + ChatScreenUIEvent.ChatEvents.TrimOldMessages -> { + LOGD("Trimming old messages to free context") + val wasVoiceModeActive = _uiState.value.isVoiceModeActive + trimOldMessages() + // Reset the warning flag so it can show again when context fills up + _uiState.update { it.copy(showContextWarningDialog = false, contextWarningShownForThisChat = false) } + // Resume recording if voice mode was active + if (wasVoiceModeActive && _uiState.value.ttsEnabled) { + LOGD("Resuming recording after context trim") + toggleMicRecording() + } + } + + ChatScreenUIEvent.ChatEvents.DismissContextWarning -> { + _uiState.update { it.copy(showContextWarningDialog = false) } + } + + ChatScreenUIEvent.ChatEvents.ContinueAnywayContextWarning -> { + val wasVoiceModeActive = _uiState.value.isVoiceModeActive + _uiState.update { it.copy(showContextWarningDialog = false, contextWarningShownForThisChat = true) } + // Resume recording if voice mode was active + if (wasVoiceModeActive && _uiState.value.ttsEnabled) { + LOGD("Resuming recording after context warning dismissed") + toggleMicRecording() + } + } + is ChatScreenUIEvent.ChatEvents.SwitchChat -> { switchChat(event.chat) } @@ -443,6 +681,36 @@ class ChatScreenViewModel( event.onResult(result) } } + + is ChatScreenUIEvent.TTSEvents.ToggleTTS -> { + preferencesManager.ttsEnabled = event.enabled + _uiState.update { it.copy(ttsEnabled = event.enabled) } + if (!event.enabled) { + ttsManager.stop() + } + } + + is ChatScreenUIEvent.AutoSubmitEvents.ToggleAutoSubmit -> { + preferencesManager.autoSubmitEnabled = event.enabled + _uiState.update { it.copy(autoSubmitEnabled = event.enabled) } + } + + is ChatScreenUIEvent.AutoSubmitEvents.UpdateAutoSubmitDelay -> { + preferencesManager.autoSubmitDelayMs = event.delayMs + _uiState.update { it.copy(autoSubmitDelayMs = event.delayMs) } + } + + is ChatScreenUIEvent.STTEvents.UpdateSTTLanguage -> { + preferencesManager.sttLanguage = event.language + _uiState.update { it.copy(sttLanguage = event.language) } + // Update TTS language to match + ttsManager.setLanguage(event.language) + } + + is ChatScreenUIEvent.ContextEvents.ToggleAutoContextTrim -> { + preferencesManager.autoContextTrimEnabled = event.enabled + _uiState.update { it.copy(autoContextTrimEnabled = event.enabled) } + } } } @@ -531,12 +799,40 @@ class ChatScreenViewModel( appDB.deleteMessage(messageId) } - private fun sendUserQuery(query: String, addMessageToDB: Boolean = true) { + private fun sendUserQuery( + query: String, + addMessageToDB: Boolean = true, + fromVoice: Boolean = false + ) { + LOGD(">>> sendUserQuery START on thread: ${Thread.currentThread().name}") + LOGD(">>> query='$query', fromVoice=$fromVoice") + val chat = uiState.value.chat + + // Pre-query context check in pocket mode - trim before sending if needed + if (_uiState.value.isPocketModeEnabled && chat.contextSize > 0) { + val usagePercent = (chat.contextSizeConsumed.toFloat() / chat.contextSize.toFloat()) * 100 + LOGD(">>> Pre-query context check: ${usagePercent.toInt()}%") + if (usagePercent >= 70) { + LOGD(">>> Pocket mode: pre-emptive context trim before query") + val message = context.getString(R.string.context_auto_trim_voice) + ttsManager.speakChunk(message) + ttsManager.speakRemainingBuffer() + trimOldMessages() + // Reset context consumed to prevent immediate re-triggering + val updatedChat = _uiState.value.chat.copy(contextSizeConsumed = 0) + _uiState.update { it.copy(chat = updatedChat) } + appDB.updateChat(updatedChat) + LOGD(">>> Context consumed reset to 0 after pre-query trim") + } + } + // Update the 'dateUsed' attribute of the current Chat instance // when a query is sent by the user chat.dateUsed = Date() + LOGD(">>> Updating chat in DB...") appDB.updateChat(chat) + LOGD(">>> Chat updated") if (chat.isTask) { // If the chat is a 'task', delete all existing messages @@ -545,9 +841,26 @@ class ChatScreenViewModel( } if (addMessageToDB) { + LOGD(">>> Adding user message to DB...") appDB.addUserMessage(chat.id, query) + LOGD(">>> User message added") } - _uiState.update { it.copy(isGeneratingResponse = true, renderedPartialResponse = null) } + + // Stop any ongoing TTS before starting new response + LOGD(">>> Resetting TTS state...") + ttsManager.resetState() + LOGD(">>> TTS state reset") + + // Track if this input came from voice for resuming recording after TTS + LOGD(">>> Updating UI state...") + _uiState.update { + it.copy( + isGeneratingResponse = true, + renderedPartialResponse = null, + lastInputWasVoice = fromVoice + ) + } + LOGD(">>> UI state updated, calling smolLMManager.getResponse...") smolLMManager.getResponse( query, responseTransform = { @@ -559,6 +872,10 @@ class ChatScreenViewModel( }, onPartialResponseGenerated = { resp -> _uiState.update { it.copy(renderedPartialResponse = mdRenderer.render(resp)) } + // Speak the response chunk if TTS is enabled + if (_uiState.value.ttsEnabled) { + ttsManager.speakChunk(resp) + } }, onSuccess = { response -> val updatedChat = chat.copy(contextSizeConsumed = response.contextLengthUsed) @@ -572,10 +889,21 @@ class ChatScreenViewModel( getCurrentMemoryUsage() } else { null - } + }, + // Reset trim level on success - context is healthy + contextTrimLevel = 0, + pendingRetryQuery = null, + pendingRetryFromVoice = false, ) } appDB.updateChat(updatedChat) + // Speak any remaining buffered text + if (_uiState.value.ttsEnabled) { + ttsManager.speakRemainingBuffer() + } else { + // If TTS is disabled, check context usage now + checkContextUsage() + } }, onCancelled = { // ignore CancellationException, as it was called because @@ -583,21 +911,68 @@ class ChatScreenViewModel( }, onError = { exception -> _uiState.update { it.copy(isGeneratingResponse = false) } - createAlertDialog( - dialogTitle = "An error occurred", - dialogText = - "The app is unable to process the query. The error message is: ${exception.message}", - dialogPositiveButtonText = "Change model", - onPositiveButtonClick = {}, - dialogNegativeButtonText = "", - onNegativeButtonClick = {}, - ) + + // Check if this is a context overflow error in pocket mode + val isContextError = exception.message?.contains("context", ignoreCase = true) == true + if (isContextError && _uiState.value.isPocketModeEnabled) { + val currentLevel = _uiState.value.contextTrimLevel + val nextLevel = (currentLevel + 1).coerceAtMost(3) + LOGD("Context error in pocket mode. Current trim level: $currentLevel, next: $nextLevel") + + if (nextLevel <= 3 && currentLevel < 3) { + // Store the query for retry after trim + _uiState.update { + it.copy( + contextTrimLevel = nextLevel, + pendingRetryQuery = query, + pendingRetryFromVoice = fromVoice + ) + } + + // Announce the trimming via TTS + val message = if (nextLevel == 3) { + "Context full. Clearing most of the conversation to continue." + } else { + context.getString(R.string.context_auto_trim_voice) + } + ttsManager.speakChunk(message) + ttsManager.speakRemainingBuffer() + + // Trim with the new level - this calls loadModel which will trigger retry + trimOldMessages(nextLevel) + + // Reset context consumed + val updatedChat = _uiState.value.chat.copy(contextSizeConsumed = 0) + _uiState.update { it.copy(chat = updatedChat) } + appDB.updateChat(updatedChat) + + LOGD("Trimmed at level $nextLevel, will retry after model reload") + } else { + // Already at max trim level, give up + LOGD("Already at max trim level, cannot recover") + _uiState.update { it.copy(contextTrimLevel = 0, pendingRetryQuery = null) } + ttsManager.speakChunk("Unable to process. Context too limited.") + ttsManager.speakRemainingBuffer() + } + } else { + // Non-context error or not in pocket mode: show dialog + createAlertDialog( + dialogTitle = "An error occurred", + dialogText = + "The app is unable to process the query. The error message is: ${exception.message}", + dialogPositiveButtonText = "Change model", + onPositiveButtonClick = {}, + dialogNegativeButtonText = "", + onNegativeButtonClick = {}, + ) + } }, ) } private fun stopGeneration() { smolLMManager.stopResponseGeneration() + ttsManager.stop() _uiState.update { it.copy(isGeneratingResponse = false, renderedPartialResponse = null) } } @@ -625,6 +1000,253 @@ class ChatScreenViewModel( _uiState.update { it.copy(chat = newChat) } } + private fun toggleMicRecording() { + when (_uiState.value.sttState) { + is STTState.Idle -> { + // Check if Whisper model is available + if (!sttManager.isModelAvailable()) { + createAlertDialog( + dialogTitle = context.getString(R.string.stt_model_not_found_title), + dialogText = context.getString(R.string.stt_model_not_found_message), + dialogPositiveButtonText = context.getString(R.string.stt_download_model), + onPositiveButtonClick = { + Intent(context, DownloadWhisperModelActivity::class.java).apply { + addFlags(Intent.FLAG_ACTIVITY_NEW_TASK) + context.startActivity(this) + } + }, + dialogNegativeButtonText = context.getString(R.string.dialog_neg_cancel), + onNegativeButtonClick = {}, + ) + return + } + + // Check recording permission - request it if not granted + if (!sttManager.hasRecordingPermission()) { + _uiState.update { it.copy(requestRecordingPermission = true) } + return + } + + // Check notification permission on Android 13+ (required for foreground service notification) + if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.TIRAMISU) { + val hasNotificationPermission = ContextCompat.checkSelfPermission( + context, + Manifest.permission.POST_NOTIFICATIONS + ) == PackageManager.PERMISSION_GRANTED + if (!hasNotificationPermission) { + _uiState.update { it.copy(requestNotificationPermission = true) } + return + } + } + + // All permissions granted, start voice mode + startVoiceModeAfterPermissions() + } + + is STTState.Recording -> { + // Stop streaming recording - final transcription will be emitted via Flow + sttManager.stopStreamingRecording(language = _uiState.value.sttLanguage) { /* final result handled via Flow */ } + } + + is STTState.Transcribing -> { + // Already transcribing, do nothing + } + + is STTState.Error -> { + // Reset to idle state + _uiState.update { it.copy(sttState = STTState.Idle) } + } + } + } + + /** + * Starts voice mode after all permissions have been granted. + */ + private fun startVoiceModeAfterPermissions() { + // Request battery optimization exemption (required for Samsung and other OEMs) + // This shows a dialog to the user on first use + if (!VoiceChatService.isIgnoringBatteryOptimizations(context)) { + LOGD(">>> Requesting battery optimization exemption") + VoiceChatService.requestBatteryOptimizationExemption(context) + } + + // Start foreground service for locked screen support (if not already running) + if (!voiceChatServiceManager.isServiceRunning.value) { + LOGD(">>> Starting VoiceChatService") + VoiceChatService.start(context) + } + _uiState.update { it.copy(isVoiceModeActive = true) } + + // Load model if not loaded, then start streaming recording + sttManager.loadModel { success -> + if (success) { + sttManager.startStreamingRecording( + language = _uiState.value.sttLanguage, + autoSubmitDelayMs = _uiState.value.autoSubmitDelayMs + ) + } else { + Toast.makeText( + context, + context.getString(R.string.stt_model_load_failed), + Toast.LENGTH_LONG + ).show() + } + } + } + + private fun startRecordingAfterPermission() { + // Check again if model is available (user might have navigated away) + if (!sttManager.isModelAvailable()) { + return + } + + // Check notification permission on Android 13+ before starting service + if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.TIRAMISU) { + val hasNotificationPermission = ContextCompat.checkSelfPermission( + context, + Manifest.permission.POST_NOTIFICATIONS + ) == PackageManager.PERMISSION_GRANTED + if (!hasNotificationPermission) { + _uiState.update { it.copy(requestNotificationPermission = true) } + return + } + } + + // All permissions granted, start voice mode + startVoiceModeAfterPermissions() + } + + /** + * Starts recording for voice conversation mode (after TTS finishes). + * This assumes the Whisper model is already loaded and permission is granted. + */ + private fun startRecordingForVoiceConversation() { + if (!sttManager.isModelAvailable()) { + LOGD("Cannot resume recording: Whisper model not available") + return + } + + if (!sttManager.hasRecordingPermission()) { + LOGD("Cannot resume recording: no recording permission") + return + } + + // Model should already be loaded, but load just in case + sttManager.loadModel { success -> + if (success) { + LOGD("Starting streaming recording for voice conversation") + sttManager.startStreamingRecording( + language = _uiState.value.sttLanguage, + autoSubmitDelayMs = _uiState.value.autoSubmitDelayMs + ) + } else { + LOGD("Failed to load Whisper model for voice conversation") + } + } + } + + fun consumePendingTranscribedText(): String? { + val text = _uiState.value.pendingTranscribedText + _uiState.update { it.copy(pendingTranscribedText = null) } + return text + } + + fun resetAutoSubmitTrigger() { + _uiState.update { it.copy(triggerAutoSubmit = false) } + } + + fun resetClearInputFlag() { + _uiState.update { it.copy(shouldClearInput = false) } + } + + /** + * Stops voice mode completely - stops recording, TTS, and the foreground service. + */ + fun stopVoiceMode() { + LOGD("Stopping voice mode") + sttManager.cancelRecording() + ttsManager.stop() + VoiceChatService.stop(context) + _uiState.update { + it.copy( + isVoiceModeActive = false, + lastInputWasVoice = false, + sttState = STTState.Idle + ) + } + } + + /** + * Trim old messages to free up context space. + * Keeps the system prompt and the most recent messages. + * + * @param level Trim aggressiveness: 1=light (remove 2), 2=medium (remove 4), 3=aggressive (keep only 2) + */ + private fun trimOldMessages(level: Int = 1) { + val chatId = _uiState.value.chat.id + val messages = appDB.getMessagesForModel(chatId) + + val messagesToKeep = when (level) { + 1 -> messages.size - 2 // Light: remove 2 oldest messages + 2 -> messages.size - 4 // Medium: remove 4 oldest messages + else -> 2 // Aggressive: keep only last 2 messages (1 exchange) + }.coerceAtLeast(2) // Always keep at least 2 messages + + LOGD("Trimming context at level $level: keeping $messagesToKeep of ${messages.size} messages") + + if (messages.size > messagesToKeep) { + val messagesToDelete = messages.dropLast(messagesToKeep) + messagesToDelete.forEach { message -> + appDB.deleteMessage(message.id) + } + LOGD("Trimmed ${messagesToDelete.size} old messages") + + // Reload model to apply the trimmed context + loadModel() + } + } + + /** + * Check if context usage is high and show warning if needed. + * Called after each response is generated. + * In pocket mode or with auto-trim enabled, automatically trims old messages. + */ + private fun checkContextUsage() { + val chat = _uiState.value.chat + val contextUsed = chat.contextSizeConsumed + val contextMax = chat.contextSize + + if (contextMax > 0) { + val usagePercent = (contextUsed.toFloat() / contextMax.toFloat()) * 100 + LOGD("Context usage: $contextUsed / $contextMax (${usagePercent.toInt()}%)") + + // Check at 75% usage if not already handled for this chat + if (usagePercent >= 75 && !_uiState.value.contextWarningShownForThisChat) { + val shouldAutoTrim = _uiState.value.isPocketModeEnabled || _uiState.value.autoContextTrimEnabled + + if (shouldAutoTrim) { + // Auto-trim mode: trim and notify via voice if TTS enabled + LOGD("Auto-trim: trimming context (pocket=${_uiState.value.isPocketModeEnabled}, autoTrim=${_uiState.value.autoContextTrimEnabled})") + if (_uiState.value.ttsEnabled) { + val message = context.getString(R.string.context_auto_trim_voice) + ttsManager.speakChunk(message) + ttsManager.speakRemainingBuffer() + } + trimOldMessages() + // Reset context consumed to prevent immediate re-triggering + // The actual value will be updated on next response + val updatedChat = _uiState.value.chat.copy(contextSizeConsumed = 0) + _uiState.update { it.copy(chat = updatedChat) } + appDB.updateChat(updatedChat) + LOGD("Context consumed reset to 0 after trim") + } else { + // Normal mode: show dialog + _uiState.update { it.copy(showContextWarningDialog = true) } + } + } + } + } + /** * Get the current memory usage of the device. This method returns the memory consumed (in GBs) * and the total memory available on the device (in GBs) diff --git a/app/src/main/java/io/shubham0204/smollmandroid/ui/screens/chat/EditChatSettingsScreen.kt b/app/src/main/java/io/shubham0204/smollmandroid/ui/screens/chat/EditChatSettingsScreen.kt index 8034d430..f7ecfedd 100644 --- a/app/src/main/java/io/shubham0204/smollmandroid/ui/screens/chat/EditChatSettingsScreen.kt +++ b/app/src/main/java/io/shubham0204/smollmandroid/ui/screens/chat/EditChatSettingsScreen.kt @@ -18,6 +18,7 @@ package io.shubham0204.smollmandroid.ui.screens.chat import android.widget.Toast import androidx.compose.foundation.background +import androidx.compose.foundation.layout.Box import androidx.compose.foundation.layout.Column import androidx.compose.foundation.layout.Row import androidx.compose.foundation.layout.Spacer @@ -29,12 +30,17 @@ import androidx.compose.foundation.rememberScrollState import androidx.compose.foundation.text.KeyboardOptions import androidx.compose.foundation.verticalScroll import androidx.compose.material3.Checkbox +import androidx.compose.material3.DropdownMenu +import androidx.compose.material3.DropdownMenuItem import androidx.compose.material3.ExperimentalMaterial3Api +import androidx.compose.material3.HorizontalDivider import androidx.compose.material3.Icon import androidx.compose.material3.IconButton import androidx.compose.material3.MaterialTheme +import androidx.compose.material3.OutlinedButton import androidx.compose.material3.Scaffold import androidx.compose.material3.Slider +import androidx.compose.material3.Switch import androidx.compose.material3.Text import androidx.compose.material3.TextField import androidx.compose.material3.TopAppBar @@ -56,8 +62,12 @@ import androidx.compose.ui.unit.dp import compose.icons.FeatherIcons import compose.icons.feathericons.ArrowLeft import compose.icons.feathericons.Check +import compose.icons.feathericons.Mic +import compose.icons.feathericons.Scissors +import compose.icons.feathericons.Volume2 import io.shubham0204.smollmandroid.R import io.shubham0204.smollmandroid.data.Chat +import io.shubham0204.smollmandroid.data.PreferencesManager import io.shubham0204.smollmandroid.ui.components.AppBarTitleText import io.shubham0204.smollmandroid.ui.theme.SmolLMAndroidTheme import kotlinx.serialization.Serializable @@ -112,7 +122,17 @@ private fun PreviewEditChatSettingsScreen() { EditChatSettingsScreen( settings = EditableChatSettings.fromChat(Chat()), llmModelContextSize = 2048, + ttsEnabled = false, + autoSubmitEnabled = false, + autoSubmitDelayMs = 2000L, + sttLanguage = "en", + autoContextTrimEnabled = false, onUpdateChat = {}, + onToggleTTS = {}, + onToggleAutoSubmit = {}, + onUpdateAutoSubmitDelay = {}, + onUpdateSTTLanguage = {}, + onToggleAutoContextTrim = {}, onBackClicked = {}, ) } @@ -122,7 +142,17 @@ private fun PreviewEditChatSettingsScreen() { fun EditChatSettingsScreen( settings: EditableChatSettings, llmModelContextSize: Int, + ttsEnabled: Boolean, + autoSubmitEnabled: Boolean, + autoSubmitDelayMs: Long, + sttLanguage: String, + autoContextTrimEnabled: Boolean, onUpdateChat: (EditableChatSettings) -> Unit, + onToggleTTS: (Boolean) -> Unit, + onToggleAutoSubmit: (Boolean) -> Unit, + onUpdateAutoSubmitDelay: (Long) -> Unit, + onUpdateSTTLanguage: (String) -> Unit, + onToggleAutoContextTrim: (Boolean) -> Unit, onBackClicked: () -> Unit, ) { var chatName by remember { mutableStateOf(settings.name) } @@ -373,6 +403,118 @@ fun EditChatSettingsScreen( } } Spacer(modifier = Modifier.height(24.dp)) + + HorizontalDivider() + Spacer(modifier = Modifier.height(24.dp)) + + Row( + modifier = Modifier.fillMaxWidth(), + verticalAlignment = Alignment.CenterVertically, + ) { + Icon( + FeatherIcons.Volume2, + contentDescription = "TTS", + modifier = Modifier.padding(end = 16.dp), + tint = MaterialTheme.colorScheme.primary, + ) + Column(modifier = Modifier.weight(1f)) { + Text( + text = stringResource(R.string.tts_settings_title), + style = MaterialTheme.typography.titleMedium, + ) + Text( + text = stringResource(R.string.tts_settings_desc), + style = MaterialTheme.typography.labelSmall, + ) + } + Switch( + checked = ttsEnabled, + onCheckedChange = { onToggleTTS(it) }, + ) + } + + Spacer(modifier = Modifier.height(24.dp)) + HorizontalDivider() + Spacer(modifier = Modifier.height(24.dp)) + + // STT Language Selection + var languageDropdownExpanded by remember { mutableStateOf(false) } + val selectedLanguageName = PreferencesManager.SUPPORTED_LANGUAGES + .find { it.first == sttLanguage }?.second ?: "English" + + Row( + modifier = Modifier.fillMaxWidth(), + verticalAlignment = Alignment.CenterVertically, + ) { + Icon( + FeatherIcons.Mic, + contentDescription = "STT Language", + modifier = Modifier.padding(end = 16.dp), + tint = MaterialTheme.colorScheme.primary, + ) + Column(modifier = Modifier.weight(1f)) { + Text( + text = stringResource(R.string.stt_language_title), + style = MaterialTheme.typography.titleMedium, + ) + Text( + text = stringResource(R.string.stt_language_desc), + style = MaterialTheme.typography.labelSmall, + ) + } + Box { + OutlinedButton( + onClick = { languageDropdownExpanded = true } + ) { + Text(selectedLanguageName) + } + DropdownMenu( + expanded = languageDropdownExpanded, + onDismissRequest = { languageDropdownExpanded = false } + ) { + PreferencesManager.SUPPORTED_LANGUAGES.forEach { (code, name) -> + DropdownMenuItem( + text = { Text(name) }, + onClick = { + onUpdateSTTLanguage(code) + languageDropdownExpanded = false + } + ) + } + } + } + } + Spacer(modifier = Modifier.height(24.dp)) + HorizontalDivider() + Spacer(modifier = Modifier.height(24.dp)) + + // Auto Context Trim + Row( + modifier = Modifier.fillMaxWidth(), + verticalAlignment = Alignment.CenterVertically, + ) { + Icon( + FeatherIcons.Scissors, + contentDescription = "Auto Context Trim", + modifier = Modifier.padding(end = 16.dp), + tint = MaterialTheme.colorScheme.primary, + ) + Column(modifier = Modifier.weight(1f)) { + Text( + text = stringResource(R.string.auto_context_trim_title), + style = MaterialTheme.typography.titleMedium, + ) + Text( + text = stringResource(R.string.auto_context_trim_desc), + style = MaterialTheme.typography.labelSmall, + ) + } + Switch( + checked = autoContextTrimEnabled, + onCheckedChange = { onToggleAutoContextTrim(it) }, + ) + } + Spacer(modifier = Modifier.height(24.dp)) } } } diff --git a/app/src/main/java/io/shubham0204/smollmandroid/ui/screens/chat/dialogs/ChatMoreOptionsPopup.kt b/app/src/main/java/io/shubham0204/smollmandroid/ui/screens/chat/dialogs/ChatMoreOptionsPopup.kt index 85435c2f..cb9e7682 100644 --- a/app/src/main/java/io/shubham0204/smollmandroid/ui/screens/chat/dialogs/ChatMoreOptionsPopup.kt +++ b/app/src/main/java/io/shubham0204/smollmandroid/ui/screens/chat/dialogs/ChatMoreOptionsPopup.kt @@ -16,6 +16,7 @@ package io.shubham0204.smollmandroid.ui.screens.chat.dialogs +import android.content.Intent import androidx.compose.foundation.layout.Spacer import androidx.compose.foundation.layout.fillMaxWidth import androidx.compose.foundation.layout.height @@ -27,23 +28,29 @@ import androidx.compose.material3.MaterialTheme import androidx.compose.material3.Text import androidx.compose.runtime.Composable import androidx.compose.ui.Modifier +import androidx.compose.ui.platform.LocalContext import androidx.compose.ui.graphics.vector.ImageVector import androidx.compose.ui.res.stringResource import androidx.compose.ui.tooling.preview.Preview import androidx.compose.ui.unit.dp import compose.icons.FeatherIcons +import compose.icons.feathericons.Clock import compose.icons.feathericons.Cpu import compose.icons.feathericons.Delete import compose.icons.feathericons.Folder import compose.icons.feathericons.Layout +import compose.icons.feathericons.Mic import compose.icons.feathericons.Package import compose.icons.feathericons.Settings +import compose.icons.feathericons.Volume2 +import compose.icons.feathericons.VolumeX import compose.icons.feathericons.XCircle import compose.icons.feathericons.Zap import io.shubham0204.smollmandroid.R import io.shubham0204.smollmandroid.data.Chat import io.shubham0204.smollmandroid.ui.preview.dummyChats import io.shubham0204.smollmandroid.ui.screens.chat.ChatScreenUIEvent +import io.shubham0204.smollmandroid.ui.screens.whisper_download.DownloadWhisperModelActivity @Preview @Composable @@ -52,6 +59,8 @@ private fun PreviewChatMoreOptionsPopup() { chat = dummyChats[0], isExpanded = true, showRAMUsageLabel = true, + ttsEnabled = false, + autoSubmitEnabled = false, onEditChatSettingsClick = {}, onBenchmarkModelClick = {}, onEvent = {}, @@ -63,10 +72,13 @@ fun ChatMoreOptionsPopup( chat: Chat, isExpanded: Boolean, showRAMUsageLabel: Boolean, + ttsEnabled: Boolean, + autoSubmitEnabled: Boolean, onEditChatSettingsClick: () -> Unit, onBenchmarkModelClick: () -> Unit, onEvent: (ChatScreenUIEvent) -> Unit, ) { + val context = LocalContext.current DropdownMenu( expanded = isExpanded, onDismissRequest = { @@ -134,6 +146,24 @@ fun ChatMoreOptionsPopup( ) { onEvent(ChatScreenUIEvent.DialogEvents.ToggleRAMUsageLabel) } + PopupMenuItem( + icon = if (ttsEnabled) FeatherIcons.VolumeX else FeatherIcons.Volume2, + text = stringResource(if (ttsEnabled) R.string.tts_disable else R.string.tts_enable), + ) { + onEvent(ChatScreenUIEvent.TTSEvents.ToggleTTS(!ttsEnabled)) + } + PopupMenuItem( + icon = FeatherIcons.Clock, + text = stringResource(if (autoSubmitEnabled) R.string.auto_submit_disable else R.string.auto_submit_enable), + ) { + onEvent(ChatScreenUIEvent.AutoSubmitEvents.ToggleAutoSubmit(!autoSubmitEnabled)) + } + PopupMenuItem( + icon = FeatherIcons.Mic, + text = stringResource(R.string.stt_manage_models), + ) { + context.startActivity(Intent(context, DownloadWhisperModelActivity::class.java)) + } } } } diff --git a/app/src/main/java/io/shubham0204/smollmandroid/ui/screens/whisper_download/DownloadWhisperModelActivity.kt b/app/src/main/java/io/shubham0204/smollmandroid/ui/screens/whisper_download/DownloadWhisperModelActivity.kt new file mode 100644 index 00000000..a8e10135 --- /dev/null +++ b/app/src/main/java/io/shubham0204/smollmandroid/ui/screens/whisper_download/DownloadWhisperModelActivity.kt @@ -0,0 +1,293 @@ +/* + * Copyright (C) 2024 Shubham Panchal + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.shubham0204.smollmandroid.ui.screens.whisper_download + +import android.app.DownloadManager +import android.content.Context +import android.net.Uri +import android.os.Bundle +import android.os.Environment +import android.widget.Toast +import androidx.activity.ComponentActivity +import androidx.activity.compose.setContent +import androidx.activity.enableEdgeToEdge +import androidx.compose.foundation.background +import androidx.compose.foundation.clickable +import androidx.compose.foundation.layout.Arrangement +import androidx.compose.foundation.layout.Box +import androidx.compose.foundation.layout.Column +import androidx.compose.foundation.layout.Row +import androidx.compose.foundation.layout.Spacer +import androidx.compose.foundation.layout.fillMaxSize +import androidx.compose.foundation.layout.fillMaxWidth +import androidx.compose.foundation.layout.height +import androidx.compose.foundation.layout.padding +import androidx.compose.foundation.layout.safeDrawingPadding +import androidx.compose.foundation.layout.width +import androidx.compose.foundation.rememberScrollState +import androidx.compose.foundation.shape.RoundedCornerShape +import androidx.compose.foundation.verticalScroll +import androidx.compose.material.icons.Icons +import androidx.compose.material.icons.automirrored.filled.ArrowBack +import androidx.compose.material3.Button +import androidx.compose.material3.ExperimentalMaterial3Api +import androidx.compose.material3.HorizontalDivider +import androidx.compose.material3.Icon +import androidx.compose.material3.IconButton +import androidx.compose.material3.MaterialTheme +import androidx.compose.material3.Scaffold +import androidx.compose.material3.Surface +import androidx.compose.material3.Text +import androidx.compose.material3.TopAppBar +import androidx.compose.runtime.Composable +import androidx.compose.runtime.getValue +import androidx.compose.runtime.mutableStateOf +import androidx.compose.runtime.remember +import androidx.compose.runtime.setValue +import androidx.compose.ui.Alignment +import androidx.compose.ui.Modifier +import androidx.compose.ui.res.stringResource +import androidx.compose.ui.unit.dp +import compose.icons.FeatherIcons +import compose.icons.feathericons.Check +import io.shubham0204.smollmandroid.R +import io.shubham0204.smollmandroid.stt.SpeechToTextManager +import io.shubham0204.smollmandroid.ui.components.AppBarTitleText +import io.shubham0204.smollmandroid.ui.theme.SmolLMAndroidTheme +import org.koin.android.ext.android.inject + +class DownloadWhisperModelActivity : ComponentActivity() { + + private val sttManager: SpeechToTextManager by inject() + + override fun onCreate(savedInstanceState: Bundle?) { + super.onCreate(savedInstanceState) + enableEdgeToEdge() + setContent { + val downloadedModels = remember { sttManager.getAvailableModels() } + val selectedModel = remember { sttManager.getSelectedModelName() } + + Box(modifier = Modifier.safeDrawingPadding()) { + DownloadWhisperModelScreen( + downloadedModels = downloadedModels, + selectedModel = selectedModel, + onBackClick = { finish() }, + onDownloadModel = { model -> downloadWhisperModel(model) }, + onSelectModel = { modelName -> + sttManager.setSelectedModel(modelName) + Toast.makeText( + this@DownloadWhisperModelActivity, + getString(R.string.whisper_model_selected, modelName), + Toast.LENGTH_SHORT + ).show() + } + ) + } + } + } + + private fun downloadWhisperModel(model: WhisperModel) { + val downloadManager = getSystemService(Context.DOWNLOAD_SERVICE) as DownloadManager + val request = DownloadManager.Request(Uri.parse(model.url)).apply { + setTitle(model.name) + setDescription(getString(R.string.whisper_download_notification_desc)) + setNotificationVisibility(DownloadManager.Request.VISIBILITY_VISIBLE_NOTIFY_COMPLETED) + setDestinationInExternalFilesDir( + this@DownloadWhisperModelActivity, + Environment.DIRECTORY_DOWNLOADS, + model.fileName + ) + setAllowedNetworkTypes( + DownloadManager.Request.NETWORK_WIFI or DownloadManager.Request.NETWORK_MOBILE + ) + } + downloadManager.enqueue(request) + Toast.makeText( + this, + getString(R.string.whisper_download_started, model.name), + Toast.LENGTH_LONG + ).show() + } +} + +@OptIn(ExperimentalMaterial3Api::class) +@Composable +private fun DownloadWhisperModelScreen( + downloadedModels: List, + selectedModel: String, + onBackClick: () -> Unit, + onDownloadModel: (WhisperModel) -> Unit, + onSelectModel: (String) -> Unit, +) { + var selectedModelIndex by remember { mutableStateOf(1) } // Default to base.en model + var currentSelectedModel by remember { mutableStateOf(selectedModel) } + + SmolLMAndroidTheme { + Scaffold( + modifier = Modifier.fillMaxSize(), + topBar = { + TopAppBar( + title = { AppBarTitleText(stringResource(R.string.whisper_download_title)) }, + navigationIcon = { + IconButton(onClick = onBackClick) { + Icon( + Icons.AutoMirrored.Filled.ArrowBack, + contentDescription = stringResource(R.string.button_text_back) + ) + } + } + ) + }, + ) { innerPadding -> + Surface( + modifier = Modifier + .padding(innerPadding) + .verticalScroll(rememberScrollState()) + ) { + Column( + modifier = Modifier + .fillMaxSize() + .padding(16.dp) + ) { + // Downloaded Models Section + if (downloadedModels.isNotEmpty()) { + Text( + text = stringResource(R.string.whisper_downloaded_models), + style = MaterialTheme.typography.titleMedium, + ) + Text( + text = stringResource(R.string.whisper_downloaded_models_desc), + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSurfaceVariant, + ) + + Spacer(modifier = Modifier.height(8.dp)) + + DownloadedModelsList( + models = downloadedModels, + selectedModel = currentSelectedModel, + onModelSelected = { modelName -> + currentSelectedModel = modelName + onSelectModel(modelName) + } + ) + + Spacer(modifier = Modifier.height(24.dp)) + HorizontalDivider() + Spacer(modifier = Modifier.height(24.dp)) + } + + // Download New Models Section + Text( + text = stringResource(R.string.whisper_download_new_model), + style = MaterialTheme.typography.titleMedium, + ) + Text( + text = stringResource(R.string.whisper_download_desc), + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSurfaceVariant, + ) + + Spacer(modifier = Modifier.height(16.dp)) + + Text( + text = stringResource(R.string.whisper_select_model), + style = MaterialTheme.typography.titleSmall, + ) + + Spacer(modifier = Modifier.height(8.dp)) + + PopularWhisperModelsList( + selectedModelIndex = selectedModelIndex, + onModelSelected = { selectedModelIndex = it } + ) + + Spacer(modifier = Modifier.height(24.dp)) + + Button( + onClick = { + selectedModelIndex?.let { index -> + getPopularWhisperModel(index)?.let { model -> + onDownloadModel(model) + } + } + }, + enabled = selectedModelIndex != null, + modifier = Modifier.align(Alignment.CenterHorizontally) + ) { + Text(stringResource(R.string.whisper_download_button)) + } + + Spacer(modifier = Modifier.height(16.dp)) + + Text( + text = stringResource(R.string.whisper_download_location_info), + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSurfaceVariant, + modifier = Modifier.fillMaxWidth() + ) + } + } + } + } +} + +@Composable +private fun DownloadedModelsList( + models: List, + selectedModel: String, + onModelSelected: (String) -> Unit, +) { + Column(verticalArrangement = Arrangement.Center) { + models.forEach { modelName -> + val isSelected = modelName == selectedModel + Row( + Modifier + .clickable { onModelSelected(modelName) } + .fillMaxWidth() + .background( + if (isSelected) { + MaterialTheme.colorScheme.primaryContainer + } else { + MaterialTheme.colorScheme.surface + }, + RoundedCornerShape(8.dp), + ) + .padding(12.dp), + verticalAlignment = Alignment.CenterVertically, + ) { + if (isSelected) { + Icon( + FeatherIcons.Check, + contentDescription = null, + tint = MaterialTheme.colorScheme.onPrimaryContainer, + ) + Spacer(modifier = Modifier.width(8.dp)) + } + Text( + color = if (isSelected) { + MaterialTheme.colorScheme.onPrimaryContainer + } else { + MaterialTheme.colorScheme.onSurface + }, + text = modelName, + style = MaterialTheme.typography.bodyMedium, + ) + } + } + } +} diff --git a/app/src/main/java/io/shubham0204/smollmandroid/ui/screens/whisper_download/PopularWhisperModelsList.kt b/app/src/main/java/io/shubham0204/smollmandroid/ui/screens/whisper_download/PopularWhisperModelsList.kt new file mode 100644 index 00000000..1856dd8b --- /dev/null +++ b/app/src/main/java/io/shubham0204/smollmandroid/ui/screens/whisper_download/PopularWhisperModelsList.kt @@ -0,0 +1,154 @@ +/* + * Copyright (C) 2024 Shubham Panchal + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.shubham0204.smollmandroid.ui.screens.whisper_download + +import androidx.compose.foundation.background +import androidx.compose.foundation.clickable +import androidx.compose.foundation.layout.Arrangement +import androidx.compose.foundation.layout.Column +import androidx.compose.foundation.layout.Row +import androidx.compose.foundation.layout.Spacer +import androidx.compose.foundation.layout.fillMaxWidth +import androidx.compose.foundation.layout.padding +import androidx.compose.foundation.layout.width +import androidx.compose.foundation.shape.RoundedCornerShape +import androidx.compose.material3.Icon +import androidx.compose.material3.MaterialTheme +import androidx.compose.material3.Text +import androidx.compose.runtime.Composable +import androidx.compose.ui.Alignment +import androidx.compose.ui.Modifier +import androidx.compose.ui.tooling.preview.Preview +import androidx.compose.ui.unit.dp +import compose.icons.FeatherIcons +import compose.icons.feathericons.Check + +data class WhisperModel( + val name: String, + val url: String, + val fileName: String, + val sizeDescription: String, +) + +@Preview +@Composable +fun PreviewPopularWhisperModelsList() { + PopularWhisperModelsList(selectedModelIndex = 0, onModelSelected = {}) +} + +@Composable +fun PopularWhisperModelsList(selectedModelIndex: Int?, onModelSelected: (Int) -> Unit) { + Column(verticalArrangement = Arrangement.Center) { + popularWhisperModelsList.forEachIndexed { idx, model -> + Row( + Modifier + .clickable { onModelSelected(idx) } + .fillMaxWidth() + .background( + if (idx == selectedModelIndex) { + MaterialTheme.colorScheme.surfaceContainer + } else { + MaterialTheme.colorScheme.surface + }, + RoundedCornerShape(8.dp), + ) + .padding(8.dp), + verticalAlignment = Alignment.CenterVertically, + horizontalArrangement = Arrangement.SpaceBetween, + ) { + Row( + verticalAlignment = Alignment.CenterVertically, + modifier = Modifier.weight(1f) + ) { + if (idx == selectedModelIndex) { + Icon( + FeatherIcons.Check, + contentDescription = null, + tint = MaterialTheme.colorScheme.onSurface, + ) + Spacer(modifier = Modifier.width(4.dp)) + } + Column { + Text( + color = MaterialTheme.colorScheme.onSurface, + text = model.name, + style = MaterialTheme.typography.bodySmall, + ) + Text( + color = MaterialTheme.colorScheme.onSurfaceVariant, + text = model.sizeDescription, + style = MaterialTheme.typography.labelSmall, + ) + } + } + } + } + } +} + +fun getPopularWhisperModel(index: Int?): WhisperModel? = + if (index != null) popularWhisperModelsList[index] else null + +/** + * A list of Whisper models for speech-to-text functionality. + * Models are from the ggerganov/whisper.cpp repository. + * See: https://huggingface.co/ggerganov/whisper.cpp + */ +val popularWhisperModelsList = listOf( + WhisperModel( + name = "Whisper Tiny (English)", + url = "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-tiny.en.bin", + fileName = "ggml-tiny.en.bin", + sizeDescription = "~75 MB - Fastest, less accurate", + ), + WhisperModel( + name = "Whisper Base (English)", + url = "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-base.en.bin", + fileName = "ggml-base.en.bin", + sizeDescription = "~142 MB - Good balance of speed/accuracy", + ), + WhisperModel( + name = "Whisper Small (English)", + url = "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-small.en.bin", + fileName = "ggml-small.en.bin", + sizeDescription = "~466 MB - Better accuracy, slower", + ), + WhisperModel( + name = "Whisper Tiny (Multilingual)", + url = "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-tiny.bin", + fileName = "ggml-tiny.bin", + sizeDescription = "~75 MB - Fastest, supports multiple languages", + ), + WhisperModel( + name = "Whisper Base (Multilingual)", + url = "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-base.bin", + fileName = "ggml-base.bin", + sizeDescription = "~142 MB - Good balance, multilingual", + ), + WhisperModel( + name = "Whisper Small (Multilingual)", + url = "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-small.bin", + fileName = "ggml-small.bin", + sizeDescription = "~466 MB - Better accuracy, multilingual", + ), + WhisperModel( + name = "Whisper Large v3-turbo (Multilingual)", + url = "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-large-v3-turbo.bin", + fileName = "ggml-large-v3-turbo.bin", + sizeDescription = "~466 MB - Better accuracy, multilingual", + ), +) diff --git a/app/src/main/res/drawable/ic_mic_notification.xml b/app/src/main/res/drawable/ic_mic_notification.xml new file mode 100644 index 00000000..a6472085 --- /dev/null +++ b/app/src/main/res/drawable/ic_mic_notification.xml @@ -0,0 +1,10 @@ + + + diff --git a/app/src/main/res/drawable/ic_stop.xml b/app/src/main/res/drawable/ic_stop.xml new file mode 100644 index 00000000..c4e79ca4 --- /dev/null +++ b/app/src/main/res/drawable/ic_stop.xml @@ -0,0 +1,10 @@ + + + diff --git a/app/src/main/res/values/strings.xml b/app/src/main/res/values/strings.xml index a40662cc..47817485 100644 --- a/app/src/main/res/values/strings.xml +++ b/app/src/main/res/values/strings.xml @@ -112,4 +112,58 @@ Invalid File The selected file is not a valid GGUF file. Benchmark Model + Enable TTS + Disable TTS + Text-to-Speech + Enable text-to-speech to have responses read aloud as they are generated. + Enable Auto-Submit + Disable Auto-Submit + Auto-Submit + Automatically send messages after you stop typing for a configured delay. + Auto-Submit Delay (seconds) + Whisper model not found. Please download ggml-base.en.bin to Downloads folder. + Microphone permission is required for speech-to-text. + Microphone permission was denied. Speech-to-text requires this permission. + Failed to load Whisper model. + Recording... + Transcribing... + Speech-to-Text Model Required + A Whisper model is required for speech-to-text functionality. Would you like to download one now? + Download Model + Download Whisper Model + Whisper models enable speech-to-text transcription. Choose a model based on your needs: smaller models are faster but less accurate, larger models are more accurate but slower. + Select a model: + Download Model + Models are downloaded to the app\'s internal storage and will be ready to use once the download completes. + Download started for %1$s + Downloading Whisper speech-to-text model + Manage STT Models + Downloaded Models + Select which model to use for speech-to-text. The selected model will be loaded when you use the microphone. + Download New Model + Selected: %1$s + Transcription Language + Select the language for speech-to-text transcription. + Voice Chat Active + Tap to return to chat + Stop + Voice Chat + Pocket Mode + Pocket Mode Active + Hold for 2 seconds to exit + Exiting pocket mode... + Context: %1$d / %2$d tokens + Context Almost Full + The conversation is using %1$d%% of the model\'s context window. What would you like to do? + Remove older messages + Keep system prompt and recent messages + Start new chat + Begin a fresh conversation + Continue anyway + May cause errors if context overflows + Context nearly full. Removing older messages to continue. + Enable Auto Context Trim + Disable Auto Context Trim + Auto Context Trim + Automatically remove older messages when context is nearly full, without prompting. \ No newline at end of file diff --git a/build.gradle.kts b/build.gradle.kts index 0da250c3..2a2930f9 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -6,5 +6,16 @@ plugins { alias(libs.plugins.android.library) apply false id("com.google.devtools.ksp") version "2.0.0-1.0.24" apply false alias(libs.plugins.jetbrains.kotlin.jvm) apply false - kotlin("plugin.serialization") version "2.1.0" apply false + kotlin("plugin.serialization") version "2.0.0" apply false +} + +subprojects { + plugins.withType { + extensions.configure { + jvmToolchain { + languageVersion.set(JavaLanguageVersion.of(17)) + vendor.set(JvmVendorSpec.AZUL) + } + } + } } diff --git a/gradle.properties b/gradle.properties index 20e2a015..4ce248cb 100644 --- a/gradle.properties +++ b/gradle.properties @@ -7,6 +7,8 @@ # Specifies the JVM arguments used for the daemon process. # The setting is particularly useful for tweaking memory settings. org.gradle.jvmargs=-Xmx2048m -Dfile.encoding=UTF-8 +# Enable JVM toolchain auto-provisioning +org.gradle.java.installations.auto-download=true # When configured, Gradle will run in incubating parallel mode. # This option should only be used with decoupled projects. For more details, visit # https://developer.android.com/r/tools/gradle-multi-project-decoupled-projects diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 4a397786..67ed2a9e 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -16,6 +16,7 @@ uiTextGoogleFonts = "1.7.7" composeIcons = "1.1.1" appcompat = "1.6.1" material = "1.10.0" +foundation = "1.10.1" [libraries] androidx-core-ktx = { group = "androidx.core", name = "core-ktx", version.ref = "coreKtx" } @@ -45,6 +46,7 @@ androidx-ui-text-google-fonts = { group = "androidx.compose.ui", name = "ui-text composeIcons-feather = { module = "br.com.devsrsouza.compose.icons:feather", version.ref = "composeIcons" } androidx-appcompat = { group = "androidx.appcompat", name = "appcompat", version.ref = "appcompat" } material = { group = "com.google.android.material", name = "material", version.ref = "material" } +androidx-compose-foundation = { group = "androidx.compose.foundation", name = "foundation", version.ref = "foundation" } [plugins] android-application = { id = "com.android.application", version.ref = "agp" } diff --git a/hf-model-hub-api/build.gradle.kts b/hf-model-hub-api/build.gradle.kts index 6c4a2c05..157ad491 100644 --- a/hf-model-hub-api/build.gradle.kts +++ b/hf-model-hub-api/build.gradle.kts @@ -1,7 +1,7 @@ plugins { id("java-library") alias(libs.plugins.jetbrains.kotlin.jvm) - kotlin("plugin.serialization") version "2.1.0" + kotlin("plugin.serialization") version "2.0.0" } val ktorVersion = "3.0.2" diff --git a/settings.gradle.kts b/settings.gradle.kts index 11d9a43d..9a8c8bf9 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -16,6 +16,10 @@ pluginManagement { gradlePluginPortal() } } + +plugins { + id("org.gradle.toolchains.foojay-resolver-convention") version "0.9.0" +} dependencyResolutionManagement { repositoriesMode.set(RepositoriesMode.FAIL_ON_PROJECT_REPOS) repositories { @@ -30,3 +34,4 @@ rootProject.name = "SmolChat Android" include(":app") include(":smollm") include(":hf-model-hub-api") +include(":whisper") diff --git a/whisper/build.gradle.kts b/whisper/build.gradle.kts new file mode 100644 index 00000000..b5c7dd35 --- /dev/null +++ b/whisper/build.gradle.kts @@ -0,0 +1,58 @@ +plugins { + alias(libs.plugins.android.library) + alias(libs.plugins.kotlin.android) +} + +android { + namespace = "com.whispercpp" + compileSdk = 35 + + defaultConfig { + minSdk = 26 + + ndk { + abiFilters += listOf("arm64-v8a", "armeabi-v7a", "x86", "x86_64") + } + externalNativeBuild { + cmake { + arguments("-DCMAKE_BUILD_TYPE=Release") + cppFlags("-ffile-prefix-map=${projectDir}=.") + cFlags("-ffile-prefix-map=${projectDir}=.") + } + } + } + + buildTypes { + release { + isMinifyEnabled = false + } + } + + compileOptions { + sourceCompatibility = JavaVersion.VERSION_17 + targetCompatibility = JavaVersion.VERSION_17 + } + + kotlinOptions { + jvmTarget = "17" + } + + externalNativeBuild { + cmake { + path = file("src/main/jni/whisper/CMakeLists.txt") + } + } + + packaging { + resources { + excludes += "/META-INF/{AL2.0,LGPL2.1}" + } + } + + ndkVersion = "27.2.12479018" +} + +dependencies { + implementation(libs.androidx.core.ktx) + implementation(libs.androidx.appcompat) +} diff --git a/whisper/src/main/AndroidManifest.xml b/whisper/src/main/AndroidManifest.xml new file mode 100644 index 00000000..8bdb7e14 --- /dev/null +++ b/whisper/src/main/AndroidManifest.xml @@ -0,0 +1,4 @@ + + + + diff --git a/whisper/src/main/java/com/whispercpp/whisper/LibWhisper.kt b/whisper/src/main/java/com/whispercpp/whisper/LibWhisper.kt new file mode 100644 index 00000000..c09aee6e --- /dev/null +++ b/whisper/src/main/java/com/whispercpp/whisper/LibWhisper.kt @@ -0,0 +1,212 @@ +package com.whispercpp.whisper + +import android.content.res.AssetManager +import android.os.Build +import android.util.Log +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.asCoroutineDispatcher +import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.withContext +import java.io.File +import java.io.InputStream +import java.util.concurrent.Executors + +private const val LOG_TAG = "LibWhisper" + + +class WhisperContext private constructor(private var ptr: Long) { + // Meet Whisper C++ constraint: Don't access from more than one thread at a time. + val scope: CoroutineScope = CoroutineScope( + Executors.newSingleThreadExecutor().asCoroutineDispatcher() + ) + fun stopTranscription(){ + WhisperLib.stopTranscription() + } + + + suspend fun transcribeData( + data: FloatArray, + language: String, + printTimestamp: Boolean = true, + callback: WhisperCallback + ): String = withContext(scope.coroutineContext) { + require(ptr != 0L) + + try { + val numThreads = WhisperCpuConfig.preferredThreadCount + Log.d(LOG_TAG, "Selecting $numThreads threads") + + WhisperLib.fullTranscribe(ptr, numThreads, data, language, callback) + + val textCount = WhisperLib.getTextSegmentCount(ptr) + return@withContext buildString { + for (i in 0 until textCount) { + if (printTimestamp) { + val textTimestamp = "[${toTimestamp(WhisperLib.getTextSegmentT0(ptr, i))} --> ${ + toTimestamp(WhisperLib.getTextSegmentT1(ptr, i)) + }]" + val textSegment = WhisperLib.getTextSegment(ptr, i) + append("$textTimestamp: $textSegment\n") + } else { + append(WhisperLib.getTextSegment(ptr, i)) + } + } + } + + } catch (e: Exception) { + Log.e(LOG_TAG, "Error during transcription", e) + return@withContext "" + } + } + + suspend fun benchMemory(nthreads: Int): String = withContext(scope.coroutineContext) { + return@withContext WhisperLib.benchMemcpy(nthreads) + } + + suspend fun benchGgmlMulMat(nthreads: Int): String = withContext(scope.coroutineContext) { + return@withContext WhisperLib.benchGgmlMulMat(nthreads) + } + + suspend fun release() = withContext(scope.coroutineContext) { + if (ptr != 0L) { + WhisperLib.freeContext(ptr) + ptr = 0 + } + } + + protected fun finalize() { + runBlocking { + release() + } + } + + companion object { + fun createContextFromFile(filePath: String): WhisperContext { + val ptr = WhisperLib.initContext(filePath) + if (ptr == 0L) { + throw java.lang.RuntimeException("Couldn't create context with path $filePath") + } + return WhisperContext(ptr) + } + + fun createContextFromInputStream(stream: InputStream): WhisperContext { + val ptr = WhisperLib.initContextFromInputStream(stream) + + if (ptr == 0L) { + throw java.lang.RuntimeException("Couldn't create context from input stream") + } + return WhisperContext(ptr) + } + + fun createContextFromAsset(assetManager: AssetManager, assetPath: String): WhisperContext { + val ptr = WhisperLib.initContextFromAsset(assetManager, assetPath) + + if (ptr == 0L) { + throw java.lang.RuntimeException("Couldn't create context from asset $assetPath") + } + return WhisperContext(ptr) + } + + fun getSystemInfo(): String { + return WhisperLib.getSystemInfo() + } + } +} + +private class WhisperLib { + companion object { + init { + Log.d(LOG_TAG, "Primary ABI: ${Build.SUPPORTED_ABIS[0]}") + var loadVfpv4 = false + var loadV8fp16 = false + if (isArmEabiV7a()) { + // armeabi-v7a needs runtime detection support + val cpuInfo = cpuInfo() + cpuInfo?.let { + Log.d(LOG_TAG, "CPU info: $cpuInfo") + if (cpuInfo.contains("vfpv4")) { + Log.d(LOG_TAG, "CPU supports vfpv4") + loadVfpv4 = true + } + } + } else if (isArmEabiV8a()) { + // ARMv8.2a needs runtime detection support + val cpuInfo = cpuInfo() + cpuInfo?.let { + Log.d(LOG_TAG, "CPU info: $cpuInfo") + if (cpuInfo.contains("fphp")) { + Log.d(LOG_TAG, "CPU supports fp16 arithmetic") + loadV8fp16 = true + } + } + } + + if (loadVfpv4) { + Log.d(LOG_TAG, "Loading libwhisper_vfpv4.so") + System.loadLibrary("whisper_vfpv4") + } else if (loadV8fp16) { + Log.d(LOG_TAG, "Loading libwhisper_v8fp16_va.so") + System.loadLibrary("whisper_v8fp16_va") + } else { + Log.d(LOG_TAG, "Loading libwhisper.so") + System.loadLibrary("whisper") + } + } + + // JNI methods + external fun initContextFromInputStream(inputStream: InputStream): Long + external fun initContextFromAsset(assetManager: AssetManager, assetPath: String): Long + external fun initContext(modelPath: String): Long + external fun freeContext(contextPtr: Long) + external fun stopTranscription() + external fun fullTranscribe( + contextPtr: Long, + numThreads: Int, + audioData: FloatArray, + language: String, + callback: WhisperCallback + ) + + external fun getTextSegmentCount(contextPtr: Long): Int + external fun getTextSegment(contextPtr: Long, index: Int): String + external fun getTextSegmentT0(contextPtr: Long, index: Int): Long + external fun getTextSegmentT1(contextPtr: Long, index: Int): Long + external fun getSystemInfo(): String + external fun benchMemcpy(nthread: Int): String + external fun benchGgmlMulMat(nthread: Int): String + } +} + +// 500 -> 00:05.000 +// 6000 -> 01:00.000 +private fun toTimestamp(t: Long, comma: Boolean = false): String { + var msec = t * 10 + val hr = msec / (1000 * 60 * 60) + msec -= hr * (1000 * 60 * 60) + val min = msec / (1000 * 60) + msec -= min * (1000 * 60) + val sec = msec / 1000 + msec -= sec * 1000 + + val delimiter = if (comma) "," else "." + return String.format("%02d:%02d:%02d%s%03d", hr, min, sec, delimiter, msec) +} + +private fun isArmEabiV7a(): Boolean { + return Build.SUPPORTED_ABIS[0].equals("armeabi-v7a") +} + +private fun isArmEabiV8a(): Boolean { + return Build.SUPPORTED_ABIS[0].equals("arm64-v8a") +} + +private fun cpuInfo(): String? { + return try { + File("/proc/cpuinfo").inputStream().bufferedReader().use { + it.readText() + } + } catch (e: Exception) { + Log.w(LOG_TAG, "Couldn't read /proc/cpuinfo", e) + null + } +} diff --git a/whisper/src/main/java/com/whispercpp/whisper/WhisperCpuConfig.kt b/whisper/src/main/java/com/whispercpp/whisper/WhisperCpuConfig.kt new file mode 100644 index 00000000..45e7370d --- /dev/null +++ b/whisper/src/main/java/com/whispercpp/whisper/WhisperCpuConfig.kt @@ -0,0 +1,73 @@ +package com.whispercpp.whisper + +import android.util.Log +import java.io.BufferedReader +import java.io.FileReader + +object WhisperCpuConfig { + val preferredThreadCount: Int + // Always use at least 2 threads: + get() = CpuInfo.getHighPerfCpuCount().coerceAtLeast(2) +} + +private class CpuInfo(private val lines: List) { + private fun getHighPerfCpuCount(): Int = try { + getHighPerfCpuCountByFrequencies() + } catch (e: Exception) { + Log.d(LOG_TAG, "Couldn't read CPU frequencies", e) + getHighPerfCpuCountByVariant() + } + + private fun getHighPerfCpuCountByFrequencies(): Int = + getCpuValues(property = "processor") { getMaxCpuFrequency(it.toInt()) } + .also { Log.d(LOG_TAG, "Binned cpu frequencies (frequency, count): ${it.binnedValues()}") } + .countDroppingMin() + + private fun getHighPerfCpuCountByVariant(): Int = + getCpuValues(property = "CPU variant") { it.substringAfter("0x").toInt(radix = 16) } + .also { Log.d(LOG_TAG, "Binned cpu variants (variant, count): ${it.binnedValues()}") } + .countKeepingMin() + + private fun List.binnedValues() = groupingBy { it }.eachCount() + + private fun getCpuValues(property: String, mapper: (String) -> Int) = lines + .asSequence() + .filter { it.startsWith(property) } + .map { mapper(it.substringAfter(':').trim()) } + .sorted() + .toList() + + + private fun List.countDroppingMin(): Int { + val min = min() + return count { it > min } + } + + private fun List.countKeepingMin(): Int { + val min = min() + return count { it == min } + } + + companion object { + private const val LOG_TAG = "WhisperCpuConfig" + + fun getHighPerfCpuCount(): Int = try { + readCpuInfo().getHighPerfCpuCount() + } catch (e: Exception) { + Log.d(LOG_TAG, "Couldn't read CPU info", e) + // Our best guess -- just return the # of CPUs minus 4. + (Runtime.getRuntime().availableProcessors() - 4).coerceAtLeast(0) + } + + private fun readCpuInfo() = CpuInfo( + BufferedReader(FileReader("/proc/cpuinfo")) + .useLines { it.toList() } + ) + + private fun getMaxCpuFrequency(cpuIndex: Int): Int { + val path = "/sys/devices/system/cpu/cpu${cpuIndex}/cpufreq/cpuinfo_max_freq" + val maxFreq = BufferedReader(FileReader(path)).use { it.readLine() } + return maxFreq.toInt() + } + } +} diff --git a/whisper/src/main/java/com/whispercpp/whisper/WishperCallBack.kt b/whisper/src/main/java/com/whispercpp/whisper/WishperCallBack.kt new file mode 100644 index 00000000..2f0a98ad --- /dev/null +++ b/whisper/src/main/java/com/whispercpp/whisper/WishperCallBack.kt @@ -0,0 +1,24 @@ +package com.whispercpp.whisper + +import androidx.annotation.Keep + +interface WhisperCallback { + fun onNewSegment(startMs: Long, endMs: Long, text: String) + fun onProgress(progress: Int) + fun onComplete() +} + +@Keep +class WishperCallBack : WhisperCallback { + override fun onNewSegment(startMs: Long, endMs: Long, text: String) { + println(text) + } + + override fun onProgress(progress: Int) { + println(progress) + } + + override fun onComplete() { + println("Completed") + } +} diff --git a/whisper/src/main/jni/whisper.cpp b/whisper/src/main/jni/whisper.cpp new file mode 160000 index 00000000..e990d1b7 --- /dev/null +++ b/whisper/src/main/jni/whisper.cpp @@ -0,0 +1 @@ +Subproject commit e990d1b791e7bda546866e82e094a8969ea86c6d diff --git a/whisper/src/main/jni/whisper/CMakeLists.txt b/whisper/src/main/jni/whisper/CMakeLists.txt new file mode 100644 index 00000000..9a6d488e --- /dev/null +++ b/whisper/src/main/jni/whisper/CMakeLists.txt @@ -0,0 +1,126 @@ +cmake_minimum_required(VERSION 3.10) +add_link_options("LINKER:--build-id=none") + +# 16KB Page Size Support, linker flags for 16KB page size alignment on Android 15+ devices +add_link_options("LINKER:-z,max-page-size=16384") + +project(whisper.cpp) + +set(CMAKE_CXX_STANDARD 17) +set(WHISPER_LIB_DIR ${CMAKE_SOURCE_DIR}/../whisper.cpp) + +# Remove file paths from debug info +set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -ffile-prefix-map=${CMAKE_SOURCE_DIR}=.") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -ffile-prefix-map=${CMAKE_SOURCE_DIR}=.") + +# Set consistent release flags - ADDED +set(CMAKE_C_FLAGS_RELEASE "-O3 -DNDEBUG") +set(CMAKE_CXX_FLAGS_RELEASE "-O3 -DNDEBUG") + +# Disable timestamps - ADDED +add_definitions(-DGGML_NO_TIMESTAMPS=1) + +# 16KB Page Size Support: Add NDK r27 flexible page size compilation flag +add_definitions(-DANDROID_SUPPORT_FLEXIBLE_PAGE_SIZES=ON) + +# Path to external GGML, otherwise uses the copy in whisper.cpp. +option(GGML_HOME "whisper: Path to external GGML source" OFF) + +set( + SOURCE_FILES + ${WHISPER_LIB_DIR}/src/whisper.cpp + ${CMAKE_SOURCE_DIR}/jni.c +) + +# TODO: this needs to be updated to work with the new ggml CMakeLists + +if (NOT GGML_HOME) + set( + SOURCE_FILES + ${SOURCE_FILES} + ${WHISPER_LIB_DIR}/ggml/src/ggml.c + ${WHISPER_LIB_DIR}/ggml/src/ggml-alloc.c + ${WHISPER_LIB_DIR}/ggml/src/ggml-backend.cpp + ${WHISPER_LIB_DIR}/ggml/src/ggml-backend-reg.cpp + ${WHISPER_LIB_DIR}/ggml/src/ggml-quants.c + ${WHISPER_LIB_DIR}/ggml/src/ggml-threading.cpp + ${WHISPER_LIB_DIR}/ggml/src/ggml-cpu/ggml-cpu.c + ${WHISPER_LIB_DIR}/ggml/src/ggml-cpu/ggml-cpu.cpp + ${WHISPER_LIB_DIR}/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp + ${WHISPER_LIB_DIR}/ggml/src/ggml-cpu/ggml-cpu-hbm.cpp + ${WHISPER_LIB_DIR}/ggml/src/ggml-cpu/ggml-cpu-quants.c + ${WHISPER_LIB_DIR}/ggml/src/ggml-cpu/ggml-cpu-traits.cpp + ) +endif() + +find_library(LOG_LIB log) + +function(build_library target_name) + add_library( + ${target_name} + SHARED + ${SOURCE_FILES} + ) + + target_compile_definitions(${target_name} PUBLIC GGML_USE_CPU) + + # Add reproducible build definitions - ADDED + target_compile_definitions(${target_name} PRIVATE GGML_NO_TIMESTAMPS=1) + + if (${target_name} STREQUAL "whisper_v8fp16_va") + target_compile_options(${target_name} PRIVATE -march=armv8.2-a+fp16) + set(GGML_COMPILE_OPTIONS -march=armv8.2-a+fp16) + elseif (${target_name} STREQUAL "whisper_vfpv4") + target_compile_options(${target_name} PRIVATE -mfpu=neon-vfpv4) + set(GGML_COMPILE_OPTIONS -mfpu=neon-vfpv4) + endif () + + if (NOT ${CMAKE_BUILD_TYPE} STREQUAL "Debug") + # Use consistent optimization flags - MODIFIED + target_compile_options(${target_name} PRIVATE -O3 -DNDEBUG) + target_compile_options(${target_name} PRIVATE -fvisibility=hidden -fvisibility-inlines-hidden) + target_compile_options(${target_name} PRIVATE -ffunction-sections -fdata-sections) + + target_link_options(${target_name} PRIVATE -Wl,--gc-sections) + target_link_options(${target_name} PRIVATE -Wl,--exclude-libs,ALL) + target_link_options(${target_name} PRIVATE -flto) + + # 16KB Page Size Support, linker flags for 16KB page size alignment on Android 15+ devices + target_link_options(${target_name} PRIVATE -Wl,-z,max-page-size=16384) + endif () + + if (GGML_HOME) + include(FetchContent) + FetchContent_Declare(ggml SOURCE_DIR ${GGML_HOME}) + FetchContent_MakeAvailable(ggml) + + target_compile_options(ggml PRIVATE ${GGML_COMPILE_OPTIONS}) + # Add reproducible build flags to ggml as well - ADDED + target_compile_definitions(ggml PRIVATE GGML_NO_TIMESTAMPS=1) + target_link_libraries(${target_name} ${LOG_LIB} android ggml) + else() + target_link_libraries(${target_name} ${LOG_LIB} android) + endif() + + # Strip .comment section for reproducible builds + add_custom_command(TARGET ${target_name} POST_BUILD + COMMAND ${CMAKE_OBJCOPY} --remove-section .comment $ + COMMENT "Removing .comment section from ${target_name} for reproducible builds" + ) + +endfunction() + +if (${ANDROID_ABI} STREQUAL "arm64-v8a") + build_library("whisper_v8fp16_va") +elseif (${ANDROID_ABI} STREQUAL "armeabi-v7a") + build_library("whisper_vfpv4") +endif () + +build_library("whisper") # Default target + +include_directories(${WHISPER_LIB_DIR}) +include_directories(${WHISPER_LIB_DIR}/src) +include_directories(${WHISPER_LIB_DIR}/include) +include_directories(${WHISPER_LIB_DIR}/ggml/include) +include_directories(${WHISPER_LIB_DIR}/ggml/src) +include_directories(${WHISPER_LIB_DIR}/ggml/src/ggml-cpu) diff --git a/whisper/src/main/jni/whisper/jni.c b/whisper/src/main/jni/whisper/jni.c new file mode 100644 index 00000000..7883a900 --- /dev/null +++ b/whisper/src/main/jni/whisper/jni.c @@ -0,0 +1,397 @@ +#include +#include +#include +#include +#include +#include +#include +#include "whisper.h" +#include "ggml.h" + +#include +#include + +static bool g_should_abort = false; +static pthread_mutex_t g_abort_mutex = PTHREAD_MUTEX_INITIALIZER; + +#define UNUSED(x) (void)(x) +#define TAG "JNI" + +#define LOGI(...) __android_log_print(ANDROID_LOG_INFO, TAG, __VA_ARGS__) +#define LOGW(...) __android_log_print(ANDROID_LOG_WARN, TAG, __VA_ARGS__) + +// Global references for callback handling +static JavaVM* g_jvm = NULL; +static jobject g_callback = NULL; + +// Method IDs for callback functions +static jmethodID g_onNewSegmentMethod = NULL; +static jmethodID g_onProgressMethod = NULL; +static jmethodID g_onCompleteMethod = NULL; + + + +static inline int min(int a, int b) { + return (a < b) ? a : b; +} + +static inline int max(int a, int b) { + return (a > b) ? a : b; +} + +struct input_stream_context { + size_t offset; + JNIEnv * env; + jobject thiz; + jobject input_stream; + + jmethodID mid_available; + jmethodID mid_read; +}; + +size_t inputStreamRead(void * ctx, void * output, size_t read_size) { + struct input_stream_context* is = (struct input_stream_context*)ctx; + + jint avail_size = (*is->env)->CallIntMethod(is->env, is->input_stream, is->mid_available); + jint size_to_copy = read_size < avail_size ? (jint)read_size : avail_size; + + jbyteArray byte_array = (*is->env)->NewByteArray(is->env, size_to_copy); + + jint n_read = (*is->env)->CallIntMethod(is->env, is->input_stream, is->mid_read, byte_array, 0, size_to_copy); + + if (size_to_copy != read_size || size_to_copy != n_read) { + LOGI("Insufficient Read: Req=%zu, ToCopy=%d, Available=%d", read_size, size_to_copy, n_read); + } + + jbyte* byte_array_elements = (*is->env)->GetByteArrayElements(is->env, byte_array, NULL); + memcpy(output, byte_array_elements, size_to_copy); + (*is->env)->ReleaseByteArrayElements(is->env, byte_array, byte_array_elements, JNI_ABORT); + + (*is->env)->DeleteLocalRef(is->env, byte_array); + + is->offset += size_to_copy; + + return size_to_copy; +} +bool inputStreamEof(void * ctx) { + struct input_stream_context* is = (struct input_stream_context*)ctx; + + jint result = (*is->env)->CallIntMethod(is->env, is->input_stream, is->mid_available); + return result <= 0; +} +void inputStreamClose(void * ctx) { + +} + +JNIEXPORT jlong JNICALL +Java_com_whispercppdemo_whisper_WhisperLib_00024Companion_initContextFromInputStream( + JNIEnv *env, jobject thiz, jobject input_stream) { + UNUSED(thiz); + + struct whisper_context *context = NULL; + struct whisper_model_loader loader = {}; + struct input_stream_context inp_ctx = {}; + + inp_ctx.offset = 0; + inp_ctx.env = env; + inp_ctx.thiz = thiz; + inp_ctx.input_stream = input_stream; + + jclass cls = (*env)->GetObjectClass(env, input_stream); + inp_ctx.mid_available = (*env)->GetMethodID(env, cls, "available", "()I"); + inp_ctx.mid_read = (*env)->GetMethodID(env, cls, "read", "([BII)I"); + + loader.context = &inp_ctx; + loader.read = inputStreamRead; + loader.eof = inputStreamEof; + loader.close = inputStreamClose; + + loader.eof(loader.context); + + context = whisper_init(&loader); + return (jlong) context; +} + +static size_t asset_read(void *ctx, void *output, size_t read_size) { + return AAsset_read((AAsset *) ctx, output, read_size); +} + +static bool asset_is_eof(void *ctx) { + return AAsset_getRemainingLength64((AAsset *) ctx) <= 0; +} + +static void asset_close(void *ctx) { + AAsset_close((AAsset *) ctx); +} + +static struct whisper_context *whisper_init_from_asset( + JNIEnv *env, + jobject assetManager, + const char *asset_path +) { + LOGI("Loading model from asset '%s'\n", asset_path); + AAssetManager *asset_manager = AAssetManager_fromJava(env, assetManager); + AAsset *asset = AAssetManager_open(asset_manager, asset_path, AASSET_MODE_STREAMING); + if (!asset) { + LOGW("Failed to open '%s'\n", asset_path); + return NULL; + } + + whisper_model_loader loader = { + .context = asset, + .read = &asset_read, + .eof = &asset_is_eof, + .close = &asset_close + }; + + return whisper_init_with_params(&loader, whisper_context_default_params()); +} + +JNIEXPORT jlong JNICALL +Java_com_whispercpp_whisper_WhisperLib_00024Companion_initContextFromAsset( + JNIEnv *env, jobject thiz, jobject assetManager, jstring asset_path_str) { + UNUSED(thiz); + struct whisper_context *context = NULL; + const char *asset_path_chars = (*env)->GetStringUTFChars(env, asset_path_str, NULL); + context = whisper_init_from_asset(env, assetManager, asset_path_chars); + (*env)->ReleaseStringUTFChars(env, asset_path_str, asset_path_chars); + return (jlong) context; +} + +JNIEXPORT jlong JNICALL +Java_com_whispercpp_whisper_WhisperLib_00024Companion_initContext( + JNIEnv *env, jobject thiz, jstring model_path_str) { + UNUSED(thiz); + struct whisper_context *context = NULL; + const char *model_path_chars = (*env)->GetStringUTFChars(env, model_path_str, NULL); + context = whisper_init_from_file_with_params(model_path_chars, whisper_context_default_params()); + (*env)->ReleaseStringUTFChars(env, model_path_str, model_path_chars); + return (jlong) context; +} + +JNIEXPORT void JNICALL +Java_com_whispercpp_whisper_WhisperLib_00024Companion_freeContext( + JNIEnv *env, jobject thiz, jlong context_ptr) { + UNUSED(env); + UNUSED(thiz); + struct whisper_context *context = (struct whisper_context *) context_ptr; + whisper_free(context); +} +// Callback for new segments +void new_segment_callback(struct whisper_context * ctx, struct whisper_state * state, int n_new, void * user_data) { + JNIEnv* env; + (*g_jvm)->AttachCurrentThread(g_jvm, &env, NULL); + + for (int i = 0; i < n_new; i++) { + const int segment_id = whisper_full_n_segments(ctx) - n_new + i; + const char* text = whisper_full_get_segment_text(ctx, segment_id); + const int64_t t0 = whisper_full_get_segment_t0(ctx, segment_id); + const int64_t t1 = whisper_full_get_segment_t1(ctx, segment_id); + + jstring jtext = (*env)->NewStringUTF(env, text); + (*env)->CallVoidMethod( + env, + g_callback, + g_onNewSegmentMethod, + (jlong)t0, + (jlong)t1, + jtext + ); + (*env)->DeleteLocalRef(env, jtext); + } +} + +// Progress callback +void progress_callback(struct whisper_context * ctx, struct whisper_state * state, int progress, void * user_data) { + JNIEnv* env; + (*g_jvm)->AttachCurrentThread(g_jvm, &env, NULL); + (*env)->CallVoidMethod(env, g_callback, g_onProgressMethod, (jint)progress); +} + +static bool abort_callback(void* user_data) { + bool should_abort; + pthread_mutex_lock(&g_abort_mutex); + should_abort = g_should_abort; + pthread_mutex_unlock(&g_abort_mutex); + return should_abort; +} + +JNIEXPORT void JNICALL +Java_com_whispercpp_whisper_WhisperLib_00024Companion_resetAbort(JNIEnv* env, jobject thiz) { + pthread_mutex_lock(&g_abort_mutex); + g_should_abort = false; + pthread_mutex_unlock(&g_abort_mutex); +} + +JNIEXPORT void JNICALL +Java_com_whispercpp_whisper_WhisperLib_00024Companion_stopTranscription(JNIEnv* env, jobject thiz) { + pthread_mutex_lock(&g_abort_mutex); + g_should_abort = true; + pthread_mutex_unlock(&g_abort_mutex); +} + + +JNIEXPORT void JNICALL +Java_com_whispercpp_whisper_WhisperLib_00024Companion_fullTranscribe( + JNIEnv *env, jobject thiz, jlong context_ptr, jint num_threads, + jfloatArray audio_data, jstring language, jobject callback) { + UNUSED(thiz); +// Reset abort state + Java_com_whispercpp_whisper_WhisperLib_00024Companion_resetAbort(env, thiz); + // Store JavaVM for later callbacks + if (g_jvm == NULL) { + (*env)->GetJavaVM(env, &g_jvm); + } + + // Clean up previous callback if exists + if (g_callback != NULL) { + (*env)->DeleteGlobalRef(env, g_callback); + } + + // Create new global reference + g_callback = (*env)->NewGlobalRef(env, callback); + + // Get method IDs + jclass callbackClass = (*env)->GetObjectClass(env, g_callback); + + jmethodID toStringMethod = (*env)->GetMethodID(env, callbackClass, "toString", "()Ljava/lang/String;"); + jstring str = (*env)->CallObjectMethod(env, g_callback, toStringMethod); + const char *cStr = (*env)->GetStringUTFChars(env, str, NULL); + __android_log_print(ANDROID_LOG_DEBUG, "JNI", "Callback class: %s", cStr); + (*env)->ReleaseStringUTFChars(env, str, cStr); + g_onNewSegmentMethod = (*env)->GetMethodID( + env, + callbackClass, + "onNewSegment", + "(JJLjava/lang/String;)V" + ); + g_onProgressMethod = (*env)->GetMethodID( + env, + callbackClass, + "onProgress", + "(I)V" + ); + g_onCompleteMethod = (*env)->GetMethodID( + env, + callbackClass, + "onComplete", + "()V" + ); + + struct whisper_context *context = (struct whisper_context *) context_ptr; + jfloat *audio_data_arr = (*env)->GetFloatArrayElements(env, audio_data, NULL); + const jsize audio_data_length = (*env)->GetArrayLength(env, audio_data); + + // Get language parameter (default to "auto" if null) + const char *language_str = "auto"; + if (language != NULL) { + language_str = (*env)->GetStringUTFChars(env, language, NULL); + } + + // Configure whisper parameters with callbacks + struct whisper_full_params params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); + params.print_realtime = false; // We handle callbacks ourselves + params.print_progress = false; + params.print_timestamps = false; + params.print_special = false; + params.translate = false; + params.language = language_str; + params.n_threads = num_threads; + params.offset_ms = 0; + params.no_context = true; + params.single_segment = false; + + // Set our callbacks + params.new_segment_callback = new_segment_callback; + params.new_segment_callback_user_data = NULL; + params.progress_callback = progress_callback; + params.progress_callback_user_data = NULL; + params.abort_callback = abort_callback; + params.abort_callback_user_data = NULL; + + + whisper_reset_timings(context); + + LOGI("About to run whisper_full with callbacks (language: %s)", language_str); + int result = whisper_full(context, params, audio_data_arr, audio_data_length); + + // Cleanup language string if we allocated it + if (language != NULL) { + (*env)->ReleaseStringUTFChars(env, language, language_str); + } + + // Notify completion + if (result == 0) { + (*env)->CallVoidMethod(env, g_callback, g_onCompleteMethod); + } else { + LOGI("Failed to run the model"); + } + + // Cleanup + (*env)->ReleaseFloatArrayElements(env, audio_data, audio_data_arr, JNI_ABORT); + (*env)->DeleteGlobalRef(env, g_callback); + g_callback = NULL; +} + +JNIEXPORT jint JNICALL +Java_com_whispercpp_whisper_WhisperLib_00024Companion_getTextSegmentCount( + JNIEnv *env, jobject thiz, jlong context_ptr) { + UNUSED(env); + UNUSED(thiz); + struct whisper_context *context = (struct whisper_context *) context_ptr; + return whisper_full_n_segments(context); +} + +JNIEXPORT jstring JNICALL +Java_com_whispercpp_whisper_WhisperLib_00024Companion_getTextSegment( + JNIEnv *env, jobject thiz, jlong context_ptr, jint index) { + UNUSED(thiz); + struct whisper_context *context = (struct whisper_context *) context_ptr; + const char *text = whisper_full_get_segment_text(context, index); + jstring string = (*env)->NewStringUTF(env, text); + return string; +} + +JNIEXPORT jlong JNICALL +Java_com_whispercpp_whisper_WhisperLib_00024Companion_getTextSegmentT0( + JNIEnv *env, jobject thiz, jlong context_ptr, jint index) { + UNUSED(thiz); + struct whisper_context *context = (struct whisper_context *) context_ptr; + return whisper_full_get_segment_t0(context, index); +} + +JNIEXPORT jlong JNICALL +Java_com_whispercpp_whisper_WhisperLib_00024Companion_getTextSegmentT1( + JNIEnv *env, jobject thiz, jlong context_ptr, jint index) { + UNUSED(thiz); + struct whisper_context *context = (struct whisper_context *) context_ptr; + return whisper_full_get_segment_t1(context, index); +} + +JNIEXPORT jstring JNICALL +Java_com_whispercpp_whisper_WhisperLib_00024Companion_getSystemInfo( + JNIEnv *env, jobject thiz +) { + UNUSED(thiz); + const char *sysinfo = whisper_print_system_info(); + jstring string = (*env)->NewStringUTF(env, sysinfo); + return string; +} + +JNIEXPORT jstring JNICALL +Java_com_whispercpp_whisper_WhisperLib_00024Companion_benchMemcpy(JNIEnv *env, jobject thiz, + jint n_threads) { + UNUSED(thiz); + const char *bench_ggml_memcpy = whisper_bench_memcpy_str(n_threads); + jstring string = (*env)->NewStringUTF(env, bench_ggml_memcpy); + return string; +} + +JNIEXPORT jstring JNICALL +Java_com_whispercpp_whisper_WhisperLib_00024Companion_benchGgmlMulMat(JNIEnv *env, jobject thiz, + jint n_threads) { + UNUSED(thiz); + const char *bench_ggml_mul_mat = whisper_bench_ggml_mul_mat_str(n_threads); + jstring string = (*env)->NewStringUTF(env, bench_ggml_mul_mat); + return string; +}