diff --git a/.gitignore b/.gitignore index 41a33f5f..90d64a8c 100644 --- a/.gitignore +++ b/.gitignore @@ -10,4 +10,5 @@ local.properties .kotlin ktlint -ktlint.bat \ No newline at end of file +ktlint.bat +whisper/build/ \ No newline at end of file diff --git a/.gitmodules b/.gitmodules index 81bc6f38..f614413d 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,6 @@ [submodule "llama.cpp"] path = llama.cpp url = https://github.com/ggerganov/llama.cpp +[submodule "whisper/src/main/jni/whisper.cpp"] + path = whisper/src/main/jni/whisper.cpp + url = https://github.com/ggml-org/whisper.cpp.git diff --git a/app/build.gradle.kts b/app/build.gradle.kts index 24bd96f3..96ed2a6e 100644 --- a/app/build.gradle.kts +++ b/app/build.gradle.kts @@ -3,7 +3,7 @@ plugins { alias(libs.plugins.kotlin.android) alias(libs.plugins.kotlin.compose) id("com.google.devtools.ksp") - kotlin("plugin.serialization") version "2.1.0" + kotlin("plugin.serialization") version "2.0.0" } android { @@ -92,12 +92,17 @@ dependencies { implementation(project(":smollm")) implementation(project(":hf-model-hub-api")) + implementation(project(":whisper")) + + // Android Wave Recorder for speech-to-text + implementation("com.github.squti:Android-Wave-Recorder:2.1.0") // Koin: dependency injection implementation(libs.koin.android) implementation(libs.koin.annotations) implementation(libs.koin.androidx.compose) implementation(libs.androidx.ui.text.google.fonts) + implementation(libs.androidx.compose.foundation) ksp(libs.koin.ksp.compiler) // compose-markdown: Markdown rendering in Compose diff --git a/app/src/main/java/io/shubham0204/smollmandroid/data/PreferencesManager.kt b/app/src/main/java/io/shubham0204/smollmandroid/data/PreferencesManager.kt new file mode 100644 index 00000000..0272567a --- /dev/null +++ b/app/src/main/java/io/shubham0204/smollmandroid/data/PreferencesManager.kt @@ -0,0 +1,73 @@ +/* + * Copyright (C) 2024 Shubham Panchal + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.shubham0204.smollmandroid.data + +import android.content.Context +import org.koin.core.annotation.Single + +@Single +class PreferencesManager(context: Context) { + private val prefs = context.getSharedPreferences("smolchat_prefs", Context.MODE_PRIVATE) + + var ttsEnabled: Boolean + get() = prefs.getBoolean("tts_enabled", false) + set(value) = prefs.edit().putBoolean("tts_enabled", value).apply() + + var autoSubmitEnabled: Boolean + get() = prefs.getBoolean("auto_submit_enabled", false) + set(value) = prefs.edit().putBoolean("auto_submit_enabled", value).apply() + + var autoSubmitDelayMs: Long + get() = prefs.getLong("auto_submit_delay_ms", 2000L) + set(value) = prefs.edit().putLong("auto_submit_delay_ms", value).apply() + + var selectedWhisperModel: String + get() = prefs.getString("selected_whisper_model", DEFAULT_WHISPER_MODEL) ?: DEFAULT_WHISPER_MODEL + set(value) = prefs.edit().putString("selected_whisper_model", value).apply() + + var sttLanguage: String + get() = prefs.getString("stt_language", DEFAULT_STT_LANGUAGE) ?: DEFAULT_STT_LANGUAGE + set(value) = prefs.edit().putString("stt_language", value).apply() + + companion object { + const val DEFAULT_WHISPER_MODEL = "ggml-base.en.bin" + const val DEFAULT_STT_LANGUAGE = "en" + + // Whisper supported languages with their display names + val SUPPORTED_LANGUAGES = listOf( + "en" to "English", + "de" to "German", + "fr" to "French", + "es" to "Spanish", + "it" to "Italian", + "pt" to "Portuguese", + "nl" to "Dutch", + "pl" to "Polish", + "ru" to "Russian", + "zh" to "Chinese", + "ja" to "Japanese", + "ko" to "Korean", + "ar" to "Arabic", + "hi" to "Hindi", + "tr" to "Turkish", + "uk" to "Ukrainian", + "cs" to "Czech", + "sv" to "Swedish", + "auto" to "Auto-detect", + ) + } +} diff --git a/app/src/main/java/io/shubham0204/smollmandroid/stt/SpeechToTextManager.kt b/app/src/main/java/io/shubham0204/smollmandroid/stt/SpeechToTextManager.kt new file mode 100644 index 00000000..4fdc4bc9 --- /dev/null +++ b/app/src/main/java/io/shubham0204/smollmandroid/stt/SpeechToTextManager.kt @@ -0,0 +1,768 @@ +package io.shubham0204.smollmandroid.stt + +import android.Manifest +import android.content.Context +import android.content.pm.PackageManager +import android.media.AudioFormat +import android.os.Process +import android.media.AudioRecord +import android.media.MediaRecorder +import android.os.Environment +import android.util.Log +import androidx.core.content.ContextCompat +import com.github.squti.androidwaverecorder.WaveRecorder +import com.whispercpp.whisper.WhisperCallback +import com.whispercpp.whisper.WhisperContext +import io.shubham0204.smollmandroid.data.PreferencesManager +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.Job +import kotlinx.coroutines.SupervisorJob +import kotlinx.coroutines.delay +import kotlinx.coroutines.flow.MutableSharedFlow +import kotlinx.coroutines.flow.MutableStateFlow +import kotlinx.coroutines.flow.SharedFlow +import kotlinx.coroutines.flow.StateFlow +import kotlinx.coroutines.isActive +import kotlinx.coroutines.launch +import kotlinx.coroutines.withContext +import org.koin.core.annotation.Single +import java.io.File +import java.io.FileInputStream +import java.nio.ByteBuffer +import java.nio.ByteOrder + +private const val LOG_TAG = "SpeechToTextManager" + +sealed class STTState { + data object Idle : STTState() + data object Recording : STTState() + data object Transcribing : STTState() + data class Error(val message: String) : STTState() +} + +@Single +class SpeechToTextManager( + private val context: Context, + private val preferencesManager: PreferencesManager +) { + + private val scope = CoroutineScope(Dispatchers.IO + SupervisorJob()) + + private var waveRecorder: WaveRecorder? = null + private var whisperContext: WhisperContext? = null + private var currentRecordingPath: String? = null + + private val _state = MutableStateFlow(STTState.Idle) + val state: StateFlow = _state + + private val _transcribedText = MutableStateFlow("") + val transcribedText: StateFlow = _transcribedText + + // Flow for streaming transcription chunks - emits the FULL transcription each time + // extraBufferCapacity=1 allows tryEmit to work without suspending + private val _streamingTranscription = MutableSharedFlow(extraBufferCapacity = 1) + val streamingTranscription: SharedFlow = _streamingTranscription + + // Flow to signal that silence was detected and auto-submit should happen + // extraBufferCapacity=1 allows tryEmit to work without suspending + private val _silenceDetected = MutableSharedFlow(extraBufferCapacity = 1) + val silenceDetected: SharedFlow = _silenceDetected + + private val modelsPath: File? = context.getExternalFilesDir(Environment.DIRECTORY_DOWNLOADS) + + private var isModelLoaded = false + private var loadedModelName: String? = null + + // For streaming transcription + private var streamingJob: Job? = null + private var silenceDetectionJob: Job? = null + private var audioRecord: AudioRecord? = null + private var isStreamingMode = false + private var audioBuffer = mutableListOf() + private val audioBufferLock = Any() + + // Interval for periodic transcription (in milliseconds) + private val transcriptionIntervalMs = 1500L + + // Audio recording parameters + private val sampleRate = 16000 + + // Regex to filter out Whisper noise/silence markers + // These markers should not be considered as "speech" for auto-submit purposes + private val noiseMarkerRegex = Regex( + """\[.*?]|\(.*?\)|<\|.*?\|>""", + RegexOption.IGNORE_CASE + ) + + /** + * Cleans transcription by removing Whisper noise markers like [empty audio], [BLANK_AUDIO], + * [noise], [music], (silence), etc. Returns only the actual spoken words. + */ + private fun cleanTranscription(text: String): String { + return noiseMarkerRegex.replace(text, "").trim() + } + + /** + * Checks if the transcription contains only noise markers (no real speech). + */ + private fun isOnlyNoiseMarkers(text: String): Boolean { + return cleanTranscription(text).isBlank() + } + + fun hasRecordingPermission(): Boolean { + return ContextCompat.checkSelfPermission( + context, + Manifest.permission.RECORD_AUDIO + ) == PackageManager.PERMISSION_GRANTED + } + + fun isModelAvailable(modelFileName: String? = null): Boolean { + val selectedModel = modelFileName ?: preferencesManager.selectedWhisperModel + val modelFile = File(modelsPath, selectedModel) + return modelFile.exists() + } + + /** + * Returns a list of available Whisper model files in the models directory. + * Whisper models typically have .bin extension and contain "ggml" in the name. + */ + fun getAvailableModels(): List { + return modelsPath?.listFiles() + ?.filter { it.isFile && it.name.endsWith(".bin") && it.name.contains("ggml") } + ?.map { it.name } + ?.sorted() + ?: emptyList() + } + + fun getSelectedModelName(): String = preferencesManager.selectedWhisperModel + + fun setSelectedModel(modelFileName: String) { + // If a different model is being selected, we need to reload + if (loadedModelName != modelFileName && isModelLoaded) { + scope.launch { + whisperContext?.release() + whisperContext = null + isModelLoaded = false + loadedModelName = null + } + } + preferencesManager.selectedWhisperModel = modelFileName + } + + fun loadModel(modelFileName: String? = null, onComplete: (Boolean) -> Unit = {}) { + val selectedModel = modelFileName ?: preferencesManager.selectedWhisperModel + scope.launch { + try { + // If model is already loaded and it's the same model, just return success + if (isModelLoaded && loadedModelName == selectedModel) { + withContext(Dispatchers.Main) { + onComplete(true) + } + return@launch + } + + // If a different model is loaded, release it first + if (isModelLoaded && loadedModelName != selectedModel) { + whisperContext?.release() + whisperContext = null + isModelLoaded = false + loadedModelName = null + } + + val modelFile = File(modelsPath, selectedModel) + if (!modelFile.exists()) { + Log.e(LOG_TAG, "Model file not found: ${modelFile.absolutePath}") + withContext(Dispatchers.Main) { + onComplete(false) + } + return@launch + } + + Log.d(LOG_TAG, "Loading Whisper model from: ${modelFile.absolutePath}") + whisperContext = WhisperContext.createContextFromFile(modelFile.absolutePath) + isModelLoaded = true + loadedModelName = selectedModel + Log.d(LOG_TAG, "Whisper model loaded successfully") + + withContext(Dispatchers.Main) { + onComplete(true) + } + } catch (e: Exception) { + Log.e(LOG_TAG, "Failed to load Whisper model", e) + withContext(Dispatchers.Main) { + onComplete(false) + } + } + } + } + + // Callback for when silence is detected and auto-submit should happen + // This is called directly from the transcription coroutine scope to avoid + // issues with frozen ViewModel coroutines on Samsung devices + // @Volatile ensures visibility across threads (set from Main, read from IO) + @Volatile + private var onSilenceDetectedCallback: ((String) -> Unit)? = null + + /** + * Set a callback that will be called when silence is detected. + * The callback receives the final transcription text. + * This is called directly from the IO dispatcher, so the callback + * should handle any necessary thread dispatching. + */ + fun setOnSilenceDetectedCallback(callback: ((String) -> Unit)?) { + Log.d(LOG_TAG, ">>> setOnSilenceDetectedCallback called, callback is ${if (callback != null) "NOT NULL" else "NULL"}") + onSilenceDetectedCallback = callback + } + + /** + * Start streaming recording with periodic transcription. + * Transcribed text will be emitted via streamingTranscription Flow. + * When transcription stops changing for the configured duration, silenceDetected will emit. + */ + @Suppress("MissingPermission") + fun startStreamingRecording(language: String = "en", autoSubmitDelayMs: Long = 2000L) { + if (!hasRecordingPermission()) { + _state.value = STTState.Error("Recording permission not granted") + return + } + + try { + isStreamingMode = true + synchronized(audioBufferLock) { + audioBuffer.clear() + } + + // Initialize AudioRecord for direct audio capture + val bufferSize = AudioRecord.getMinBufferSize( + sampleRate, + AudioFormat.CHANNEL_IN_MONO, + AudioFormat.ENCODING_PCM_16BIT + ) + + audioRecord = AudioRecord( + MediaRecorder.AudioSource.MIC, + sampleRate, + AudioFormat.CHANNEL_IN_MONO, + AudioFormat.ENCODING_PCM_16BIT, + bufferSize * 2 + ) + + audioRecord?.startRecording() + _state.value = STTState.Recording + Log.d(LOG_TAG, "Streaming recording started with AudioRecord") + + // Start audio capture job + streamingJob = scope.launch { + val readBuffer = ShortArray(bufferSize / 2) + + while (isActive && _state.value == STTState.Recording) { + val readCount = audioRecord?.read(readBuffer, 0, readBuffer.size) ?: 0 + + if (readCount > 0) { + // Add samples to buffer + synchronized(audioBufferLock) { + for (i in 0 until readCount) { + audioBuffer.add(readBuffer[i]) + } + } + } + + delay(10) // Small delay to prevent tight loop + } + } + + // Start periodic transcription job with auto-submit detection + silenceDetectionJob = scope.launch { + var lastCleanTranscription = "" // Only real spoken words (no noise markers) + var lastRawTranscription = "" // Full transcription including markers + var lastSpeechChangeTime = System.currentTimeMillis() + var autoSubmitTriggered = false + + Log.d(LOG_TAG, "Starting periodic transcription loop") + delay(transcriptionIntervalMs) // Initial delay before first transcription + while (isActive && _state.value == STTState.Recording) { + Log.d(LOG_TAG, "Periodic transcription tick, buffer size: ${audioBuffer.size}") + val rawTranscription = transcribeCurrentBuffer(language) + val cleanedTranscription = cleanTranscription(rawTranscription) + + Log.d(LOG_TAG, "Raw: '$rawTranscription' | Clean: '$cleanedTranscription'") + + if (rawTranscription.isNotBlank()) { + // Check if the CLEANED transcription (real speech) has changed + if (cleanedTranscription != lastCleanTranscription && cleanedTranscription.isNotBlank()) { + // Real speech changed - reset timer and emit + lastCleanTranscription = cleanedTranscription + lastRawTranscription = rawTranscription + lastSpeechChangeTime = System.currentTimeMillis() + autoSubmitTriggered = false + Log.d(LOG_TAG, "Speech changed, emitting: $cleanedTranscription") + // Emit the cleaned transcription (without noise markers) + _streamingTranscription.tryEmit(cleanedTranscription) + } else if (lastCleanTranscription.isNotBlank()) { + // Real speech hasn't changed - check if we should auto-submit + // Only auto-submit if we have actual spoken words + val timeSinceLastSpeech = System.currentTimeMillis() - lastSpeechChangeTime + Log.d(LOG_TAG, "Speech unchanged for ${timeSinceLastSpeech}ms (threshold: ${autoSubmitDelayMs}ms)") + if (timeSinceLastSpeech >= autoSubmitDelayMs && !autoSubmitTriggered) { + Log.d(LOG_TAG, ">>> SILENCE DETECTED after speech - triggering auto-submit") + autoSubmitTriggered = true + + // Call the callback directly from this coroutine scope + // This bypasses the ViewModel's coroutine which may be frozen + val callback = onSilenceDetectedCallback + Log.d(LOG_TAG, ">>> onSilenceDetectedCallback is ${if (callback != null) "NOT NULL" else "NULL"}") + if (callback != null) { + Log.d(LOG_TAG, ">>> Calling onSilenceDetectedCallback directly") + // Stop recording and get final transcription (cleaned) + val finalTranscription = lastCleanTranscription + stopStreamingRecordingInternal() + Log.d(LOG_TAG, ">>> Recording stopped, calling callback with: $finalTranscription") + callback(finalTranscription) + Log.d(LOG_TAG, ">>> Callback returned") + return@launch // Exit the loop since we stopped recording + } else { + // Fallback to flow emission for backward compatibility + val emitted = _silenceDetected.tryEmit(Unit) + Log.d(LOG_TAG, ">>> silenceDetected.tryEmit result: $emitted, subscribers: ${_silenceDetected.subscriptionCount.value}") + } + } + } else { + Log.d(LOG_TAG, "Only noise markers detected, waiting for speech...") + } + } else { + Log.d(LOG_TAG, "Transcription was blank") + } + + delay(transcriptionIntervalMs) + } + Log.d(LOG_TAG, "Periodic transcription loop ended, isActive=$isActive, state=${_state.value}") + } + } catch (e: Exception) { + Log.e(LOG_TAG, "Failed to start streaming recording", e) + _state.value = STTState.Error("Failed to start recording: ${e.message}") + } + } + + /** + * Transcribe the current audio buffer without stopping recording. + * Returns the full transcription text. + */ + private suspend fun transcribeCurrentBuffer(language: String): String { + val audioData: FloatArray + synchronized(audioBufferLock) { + if (audioBuffer.size < sampleRate) { // At least 1 second of audio + return "" + } + // Convert Short buffer to Float array + audioData = FloatArray(audioBuffer.size) + for (i in audioBuffer.indices) { + audioData[i] = audioBuffer[i] / 32768.0f + } + } + + Log.d(LOG_TAG, "Periodic transcription of ${audioData.size} samples") + + // Boost thread priority to reduce CPU throttling when screen is locked + val originalPriority = Process.getThreadPriority(Process.myTid()) + try { + Process.setThreadPriority(Process.THREAD_PRIORITY_URGENT_AUDIO) + } catch (e: Exception) { + Log.d(LOG_TAG, "Failed to boost thread priority: ${e.message}") + } + + return try { + val result = whisperContext?.transcribeData( + data = audioData, + language = language, + printTimestamp = false, + callback = object : WhisperCallback { + override fun onNewSegment(startMs: Long, endMs: Long, text: String) { + Log.d(LOG_TAG, "Streaming segment: $text") + } + + override fun onProgress(progress: Int) { + // Ignore progress for streaming + } + + override fun onComplete() { + Log.d(LOG_TAG, "Streaming transcription chunk complete") + } + } + ) ?: "" + + result.trim() + } catch (e: Exception) { + Log.e(LOG_TAG, "Error during periodic transcription", e) + "" + } finally { + // Restore original thread priority + try { + Process.setThreadPriority(originalPriority) + } catch (e: Exception) { + Log.d(LOG_TAG, "Failed to restore thread priority: ${e.message}") + } + } + } + + /** + * Internal method to stop recording without final transcription. + * Used by the silence detection callback. + */ + private fun stopStreamingRecordingInternal() { + streamingJob?.cancel() + streamingJob = null + silenceDetectionJob?.cancel() + silenceDetectionJob = null + isStreamingMode = false + + try { + audioRecord?.stop() + audioRecord?.release() + audioRecord = null + waveRecorder?.stopRecording() + waveRecorder = null + synchronized(audioBufferLock) { + audioBuffer.clear() + } + _state.value = STTState.Idle + Log.d(LOG_TAG, ">>> Recording stopped internally") + } catch (e: Exception) { + Log.e(LOG_TAG, "Failed to stop recording internally", e) + } + } + + /** + * Stop streaming recording and perform final transcription. + */ + fun stopStreamingRecording( + language: String = "en", + onComplete: (String) -> Unit + ) { + streamingJob?.cancel() + streamingJob = null + silenceDetectionJob?.cancel() + silenceDetectionJob = null + isStreamingMode = false + + try { + // Stop AudioRecord + audioRecord?.stop() + audioRecord?.release() + audioRecord = null + + // Also stop WaveRecorder if it was used (for non-streaming mode fallback) + waveRecorder?.stopRecording() + waveRecorder = null + + _state.value = STTState.Transcribing + Log.d(LOG_TAG, "Streaming recording stopped, final transcription") + + scope.launch { + // Get audio data from buffer + val audioData: FloatArray + val bufferEmpty: Boolean + synchronized(audioBufferLock) { + bufferEmpty = audioBuffer.isEmpty() + if (bufferEmpty) { + audioData = FloatArray(0) + } else { + audioData = FloatArray(audioBuffer.size) + for (i in audioBuffer.indices) { + audioData[i] = audioBuffer[i] / 32768.0f + } + audioBuffer.clear() + } + } + + if (bufferEmpty) { + Log.d(LOG_TAG, "Buffer empty, completing with empty string") + _state.value = STTState.Idle + onComplete("") + return@launch + } + + if (audioData.size < sampleRate) { // Less than 1 second + Log.d(LOG_TAG, "Audio too short (${audioData.size} samples), completing with empty string") + _state.value = STTState.Idle + onComplete("") + return@launch + } + + Log.d(LOG_TAG, "Starting final transcription of ${audioData.size} samples") + val result = whisperContext?.transcribeData( + data = audioData, + language = language, + printTimestamp = false, + callback = object : WhisperCallback { + override fun onNewSegment(startMs: Long, endMs: Long, text: String) { + Log.d(LOG_TAG, "Final segment: $text") + } + + override fun onProgress(progress: Int) {} + override fun onComplete() {} + } + ) ?: "" + + val finalText = result.trim() + Log.d(LOG_TAG, "Final transcription complete: $finalText") + + // StateFlow is thread-safe, no need for Main dispatcher + _state.value = STTState.Idle + // Call callback directly - caller handles any needed dispatching + onComplete(finalText) + } + } catch (e: Exception) { + Log.e(LOG_TAG, "Failed to stop streaming recording", e) + _state.value = STTState.Error("Failed to stop recording: ${e.message}") + onComplete("") + } + } + + fun startRecording() { + if (!hasRecordingPermission()) { + _state.value = STTState.Error("Recording permission not granted") + return + } + + try { + val recordingFile = generateRecordingFile() + currentRecordingPath = recordingFile.absolutePath + + waveRecorder = WaveRecorder(recordingFile.absolutePath) + .configureWaveSettings { + sampleRate = 16000 // Whisper expects 16kHz + channels = AudioFormat.CHANNEL_IN_MONO + audioEncoding = AudioFormat.ENCODING_PCM_16BIT + } + + waveRecorder?.startRecording() + _state.value = STTState.Recording + Log.d(LOG_TAG, "Recording started: ${recordingFile.absolutePath}") + } catch (e: Exception) { + Log.e(LOG_TAG, "Failed to start recording", e) + _state.value = STTState.Error("Failed to start recording: ${e.message}") + } + } + + fun stopRecordingAndTranscribe( + language: String = "en", + onTranscriptionComplete: (String) -> Unit + ) { + // If in streaming mode, use the streaming stop method + if (isStreamingMode) { + stopStreamingRecording(language, onTranscriptionComplete) + return + } + + try { + waveRecorder?.stopRecording() + waveRecorder = null + + val recordingPath = currentRecordingPath + if (recordingPath == null) { + _state.value = STTState.Error("No recording file found") + onTranscriptionComplete("") + return + } + + _state.value = STTState.Transcribing + Log.d(LOG_TAG, "Recording stopped, starting transcription") + + scope.launch { + transcribeAudio(recordingPath, language, onTranscriptionComplete) + } + } catch (e: Exception) { + Log.e(LOG_TAG, "Failed to stop recording", e) + _state.value = STTState.Error("Failed to stop recording: ${e.message}") + onTranscriptionComplete("") + } + } + + fun cancelRecording() { + streamingJob?.cancel() + streamingJob = null + silenceDetectionJob?.cancel() + silenceDetectionJob = null + isStreamingMode = false + + try { + // Stop AudioRecord if used + audioRecord?.stop() + audioRecord?.release() + audioRecord = null + + // Stop WaveRecorder if used + waveRecorder?.stopRecording() + waveRecorder = null + + // Clear the audio buffer + synchronized(audioBufferLock) { + audioBuffer.clear() + } + + // Delete the recording file if any + currentRecordingPath?.let { path -> + File(path).delete() + } + currentRecordingPath = null + + _state.value = STTState.Idle + Log.d(LOG_TAG, "Recording cancelled") + } catch (e: Exception) { + Log.e(LOG_TAG, "Failed to cancel recording", e) + } + } + + private suspend fun transcribeAudio( + audioPath: String, + language: String, + onComplete: (String) -> Unit + ) { + try { + if (whisperContext == null) { + Log.e(LOG_TAG, "Whisper context not initialized") + withContext(Dispatchers.Main) { + _state.value = STTState.Error("Whisper model not loaded") + onComplete("") + } + return + } + + val audioData = readWavFileAsFloatArray(audioPath) + if (audioData.isEmpty()) { + withContext(Dispatchers.Main) { + _state.value = STTState.Error("Failed to read audio file") + onComplete("") + } + return + } + + Log.d(LOG_TAG, "Transcribing ${audioData.size} samples") + + val transcriptionBuilder = StringBuilder() + + val result = whisperContext?.transcribeData( + data = audioData, + language = language, + printTimestamp = false, + callback = object : WhisperCallback { + override fun onNewSegment(startMs: Long, endMs: Long, text: String) { + transcriptionBuilder.append(text) + Log.d(LOG_TAG, "Segment: $text") + } + + override fun onProgress(progress: Int) { + Log.d(LOG_TAG, "Transcription progress: $progress%") + } + + override fun onComplete() { + Log.d(LOG_TAG, "Transcription complete") + } + } + ) ?: "" + + // Clean up the recording file + File(audioPath).delete() + currentRecordingPath = null + + val finalText = result.trim() + Log.d(LOG_TAG, "Transcription result: $finalText") + + withContext(Dispatchers.Main) { + _transcribedText.value = finalText + _state.value = STTState.Idle + onComplete(finalText) + } + } catch (e: Exception) { + Log.e(LOG_TAG, "Transcription failed", e) + withContext(Dispatchers.Main) { + _state.value = STTState.Error("Transcription failed: ${e.message}") + onComplete("") + } + } + } + + private fun readWavFileAsFloatArray(filePath: String): FloatArray { + return try { + val file = File(filePath) + if (!file.exists()) { + Log.e(LOG_TAG, "WAV file does not exist: $filePath") + return FloatArray(0) + } + + FileInputStream(file).use { fis -> + val headerBytes = ByteArray(44) + val headerRead = fis.read(headerBytes) + if (headerRead < 44) { + Log.e(LOG_TAG, "WAV header too short") + return FloatArray(0) + } + + // Read the data size from the header (bytes 40-43) + val dataSize = ByteBuffer.wrap(headerBytes, 40, 4) + .order(ByteOrder.LITTLE_ENDIAN) + .int + + // For streaming, the file might still be growing, so read what's available + val availableBytes = file.length().toInt() - 44 + val bytesToRead = minOf(dataSize, availableBytes) + + if (bytesToRead <= 0) { + return FloatArray(0) + } + + val audioBytes = ByteArray(bytesToRead) + fis.read(audioBytes) + + // Convert 16-bit PCM to float array + val samples = bytesToRead / 2 + val floatArray = FloatArray(samples) + val byteBuffer = ByteBuffer.wrap(audioBytes).order(ByteOrder.LITTLE_ENDIAN) + + for (i in 0 until samples) { + val sample = byteBuffer.short.toInt() + floatArray[i] = sample / 32768.0f + } + + floatArray + } + } catch (e: Exception) { + Log.e(LOG_TAG, "Failed to read WAV file", e) + FloatArray(0) + } + } + + private fun generateRecordingFile(): File { + val fileName = "stt_recording_${System.currentTimeMillis()}.wav" + return File(context.cacheDir, fileName) + } + + fun release() { + streamingJob?.cancel() + streamingJob = null + silenceDetectionJob?.cancel() + silenceDetectionJob = null + isStreamingMode = false + + scope.launch { + try { + audioRecord?.stop() + audioRecord?.release() + audioRecord = null + waveRecorder?.stopRecording() + waveRecorder = null + whisperContext?.release() + whisperContext = null + isModelLoaded = false + synchronized(audioBufferLock) { + audioBuffer.clear() + } + currentRecordingPath?.let { File(it).delete() } + currentRecordingPath = null + } catch (e: Exception) { + Log.e(LOG_TAG, "Failed to release resources", e) + } + } + } +} diff --git a/app/src/main/java/io/shubham0204/smollmandroid/tts/TextToSpeechManager.kt b/app/src/main/java/io/shubham0204/smollmandroid/tts/TextToSpeechManager.kt new file mode 100644 index 00000000..70970a44 --- /dev/null +++ b/app/src/main/java/io/shubham0204/smollmandroid/tts/TextToSpeechManager.kt @@ -0,0 +1,273 @@ +/* + * Copyright (C) 2024 Shubham Panchal + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.shubham0204.smollmandroid.tts + +import android.content.Context +import android.speech.tts.TextToSpeech +import android.speech.tts.UtteranceProgressListener +import android.util.Log +import kotlinx.coroutines.flow.MutableSharedFlow +import kotlinx.coroutines.flow.MutableStateFlow +import kotlinx.coroutines.flow.SharedFlow +import kotlinx.coroutines.flow.StateFlow +import org.koin.core.annotation.Single +import java.util.Locale +import java.util.concurrent.ConcurrentLinkedQueue + +private const val LOGTAG = "[TextToSpeechManager-Kt]" +private val LOGD: (String) -> Unit = { Log.d(LOGTAG, it) } + +@Single +class TextToSpeechManager(context: Context) { + + private var tts: TextToSpeech? = null + private var isInitialized = false + private var currentLanguage: String = "en" + + private val sentenceQueue = ConcurrentLinkedQueue() + private var isSpeaking = false + private var utteranceCounter = 0 + + private var previousText = "" + private var pendingBuffer = "" + + private val sentenceEndRegex = Regex("[.!?](?:\\s+|$)") + + // Flow to signal when all speech has finished (queue is empty) + // extraBufferCapacity = 1 ensures tryEmit succeeds even if collector is busy + private val _allSpeechFinished = MutableSharedFlow(extraBufferCapacity = 1) + val allSpeechFinished: SharedFlow = _allSpeechFinished + + // StateFlow to expose whether TTS is currently speaking + private val _isSpeakingFlow = MutableStateFlow(false) + val isSpeakingFlow: StateFlow = _isSpeakingFlow + + // Flag to track if we're in a speech session (from speakChunk to speakRemainingBuffer) + private var isSpeechSessionActive = false + + init { + tts = TextToSpeech(context) { status -> + if (status == TextToSpeech.SUCCESS) { + isInitialized = true + LOGD("TTS initialized successfully") + setupUtteranceListener() + // Set default language + setLanguage(currentLanguage) + } else { + LOGD("TTS initialization failed") + isInitialized = false + } + } + } + + /** + * Set the TTS language using a language code (e.g., "en", "de", "fr"). + * Returns true if the language is supported, false otherwise. + */ + fun setLanguage(languageCode: String): Boolean { + if (!isInitialized) return false + + currentLanguage = languageCode + val locale = when (languageCode) { + "en" -> Locale.US + "de" -> Locale.GERMAN + "fr" -> Locale.FRENCH + "es" -> Locale("es", "ES") + "it" -> Locale.ITALIAN + "pt" -> Locale("pt", "PT") + "nl" -> Locale("nl", "NL") + "pl" -> Locale("pl", "PL") + "ru" -> Locale("ru", "RU") + "zh" -> Locale.CHINESE + "ja" -> Locale.JAPANESE + "ko" -> Locale.KOREAN + "ar" -> Locale("ar", "SA") + "hi" -> Locale("hi", "IN") + "tr" -> Locale("tr", "TR") + "uk" -> Locale("uk", "UA") + "cs" -> Locale("cs", "CZ") + "sv" -> Locale("sv", "SE") + else -> Locale.US + } + + val result = tts?.setLanguage(locale) + val isSupported = result != TextToSpeech.LANG_MISSING_DATA && + result != TextToSpeech.LANG_NOT_SUPPORTED + + if (isSupported) { + LOGD("TTS language set to: $languageCode ($locale)") + } else { + LOGD("TTS language not supported: $languageCode, falling back to default") + tts?.setLanguage(Locale.US) + } + + return isSupported + } + + private fun setupUtteranceListener() { + tts?.setOnUtteranceProgressListener(object : UtteranceProgressListener() { + override fun onStart(utteranceId: String?) { + isSpeaking = true + _isSpeakingFlow.value = true + } + + override fun onDone(utteranceId: String?) { + isSpeaking = false + _isSpeakingFlow.value = sentenceQueue.isNotEmpty() // Still speaking if queue has more + speakNextInQueue() + checkAndEmitSpeechFinished() + } + + @Deprecated("Deprecated in Java") + override fun onError(utteranceId: String?) { + isSpeaking = false + _isSpeakingFlow.value = sentenceQueue.isNotEmpty() + speakNextInQueue() + checkAndEmitSpeechFinished() + } + + override fun onError(utteranceId: String?, errorCode: Int) { + isSpeaking = false + _isSpeakingFlow.value = sentenceQueue.isNotEmpty() + speakNextInQueue() + checkAndEmitSpeechFinished() + } + }) + } + + private fun checkAndEmitSpeechFinished() { + LOGD("checkAndEmitSpeechFinished: queueEmpty=${sentenceQueue.isEmpty()}, isSpeaking=$isSpeaking, sessionActive=$isSpeechSessionActive") + // If the queue is empty and we're not speaking, the session is finished + if (sentenceQueue.isEmpty() && !isSpeaking && isSpeechSessionActive) { + isSpeechSessionActive = false + LOGD("All speech finished, emitting signal") + val emitted = _allSpeechFinished.tryEmit(Unit) + LOGD("tryEmit result: $emitted") + } + } + + fun speakChunk(fullText: String) { + if (!isInitialized) return + + // Mark speech session as active when we start receiving chunks + isSpeechSessionActive = true + + val newContent = if (fullText.startsWith(previousText)) { + fullText.removePrefix(previousText) + } else { + fullText + } + previousText = fullText + + val textToProcess = pendingBuffer + newContent + val sentences = extractCompleteSentences(textToProcess) + + pendingBuffer = sentences.remaining + + sentences.completed.forEach { sentence -> + if (sentence.isNotBlank()) { + queueSentence(sentence.trim()) + } + } + } + + fun speakRemainingBuffer() { + LOGD("speakRemainingBuffer called, isInitialized=$isInitialized, pendingBuffer='$pendingBuffer', isSpeaking=$isSpeaking") + if (!isInitialized) return + + // Re-activate the speech session in case it was prematurely marked as finished + // (e.g., when TTS finished speaking queued sentences while generation was still ongoing) + isSpeechSessionActive = true + + if (pendingBuffer.isNotBlank()) { + LOGD("Queueing remaining buffer: '$pendingBuffer'") + queueSentence(pendingBuffer.trim()) + pendingBuffer = "" + } else { + // If there's no pending buffer and no ongoing speech, the session is done + LOGD("No pending buffer, checking if speech finished") + checkAndEmitSpeechFinished() + } + } + + private fun queueSentence(sentence: String) { + sentenceQueue.add(sentence) + if (!isSpeaking) { + speakNextInQueue() + } + } + + private fun speakNextInQueue() { + val nextSentence = sentenceQueue.poll() ?: return + + utteranceCounter++ + val utteranceId = "tts_utterance_$utteranceCounter" + + // Set isSpeaking before speak() to avoid race condition where + // checkAndEmitSpeechFinished() is called before onStart callback + isSpeaking = true + _isSpeakingFlow.value = true + tts?.speak( + nextSentence, + TextToSpeech.QUEUE_FLUSH, + null, + utteranceId + ) + } + + private data class SentenceExtraction( + val completed: List, + val remaining: String + ) + + private fun extractCompleteSentences(text: String): SentenceExtraction { + val completed = mutableListOf() + var remaining = text + + var match = sentenceEndRegex.find(remaining) + while (match != null) { + val endIndex = match.range.last + 1 + val sentence = remaining.substring(0, endIndex) + completed.add(sentence) + remaining = remaining.substring(endIndex) + match = sentenceEndRegex.find(remaining) + } + + return SentenceExtraction(completed, remaining) + } + + fun stop() { + sentenceQueue.clear() + pendingBuffer = "" + previousText = "" + isSpeaking = false + _isSpeakingFlow.value = false + isSpeechSessionActive = false + tts?.stop() + } + + fun resetState() { + stop() + } + + fun shutdown() { + stop() + tts?.shutdown() + tts = null + isInitialized = false + } +} diff --git a/app/src/main/java/io/shubham0204/smollmandroid/ui/screens/whisper_download/DownloadWhisperModelActivity.kt b/app/src/main/java/io/shubham0204/smollmandroid/ui/screens/whisper_download/DownloadWhisperModelActivity.kt new file mode 100644 index 00000000..a8e10135 --- /dev/null +++ b/app/src/main/java/io/shubham0204/smollmandroid/ui/screens/whisper_download/DownloadWhisperModelActivity.kt @@ -0,0 +1,293 @@ +/* + * Copyright (C) 2024 Shubham Panchal + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.shubham0204.smollmandroid.ui.screens.whisper_download + +import android.app.DownloadManager +import android.content.Context +import android.net.Uri +import android.os.Bundle +import android.os.Environment +import android.widget.Toast +import androidx.activity.ComponentActivity +import androidx.activity.compose.setContent +import androidx.activity.enableEdgeToEdge +import androidx.compose.foundation.background +import androidx.compose.foundation.clickable +import androidx.compose.foundation.layout.Arrangement +import androidx.compose.foundation.layout.Box +import androidx.compose.foundation.layout.Column +import androidx.compose.foundation.layout.Row +import androidx.compose.foundation.layout.Spacer +import androidx.compose.foundation.layout.fillMaxSize +import androidx.compose.foundation.layout.fillMaxWidth +import androidx.compose.foundation.layout.height +import androidx.compose.foundation.layout.padding +import androidx.compose.foundation.layout.safeDrawingPadding +import androidx.compose.foundation.layout.width +import androidx.compose.foundation.rememberScrollState +import androidx.compose.foundation.shape.RoundedCornerShape +import androidx.compose.foundation.verticalScroll +import androidx.compose.material.icons.Icons +import androidx.compose.material.icons.automirrored.filled.ArrowBack +import androidx.compose.material3.Button +import androidx.compose.material3.ExperimentalMaterial3Api +import androidx.compose.material3.HorizontalDivider +import androidx.compose.material3.Icon +import androidx.compose.material3.IconButton +import androidx.compose.material3.MaterialTheme +import androidx.compose.material3.Scaffold +import androidx.compose.material3.Surface +import androidx.compose.material3.Text +import androidx.compose.material3.TopAppBar +import androidx.compose.runtime.Composable +import androidx.compose.runtime.getValue +import androidx.compose.runtime.mutableStateOf +import androidx.compose.runtime.remember +import androidx.compose.runtime.setValue +import androidx.compose.ui.Alignment +import androidx.compose.ui.Modifier +import androidx.compose.ui.res.stringResource +import androidx.compose.ui.unit.dp +import compose.icons.FeatherIcons +import compose.icons.feathericons.Check +import io.shubham0204.smollmandroid.R +import io.shubham0204.smollmandroid.stt.SpeechToTextManager +import io.shubham0204.smollmandroid.ui.components.AppBarTitleText +import io.shubham0204.smollmandroid.ui.theme.SmolLMAndroidTheme +import org.koin.android.ext.android.inject + +class DownloadWhisperModelActivity : ComponentActivity() { + + private val sttManager: SpeechToTextManager by inject() + + override fun onCreate(savedInstanceState: Bundle?) { + super.onCreate(savedInstanceState) + enableEdgeToEdge() + setContent { + val downloadedModels = remember { sttManager.getAvailableModels() } + val selectedModel = remember { sttManager.getSelectedModelName() } + + Box(modifier = Modifier.safeDrawingPadding()) { + DownloadWhisperModelScreen( + downloadedModels = downloadedModels, + selectedModel = selectedModel, + onBackClick = { finish() }, + onDownloadModel = { model -> downloadWhisperModel(model) }, + onSelectModel = { modelName -> + sttManager.setSelectedModel(modelName) + Toast.makeText( + this@DownloadWhisperModelActivity, + getString(R.string.whisper_model_selected, modelName), + Toast.LENGTH_SHORT + ).show() + } + ) + } + } + } + + private fun downloadWhisperModel(model: WhisperModel) { + val downloadManager = getSystemService(Context.DOWNLOAD_SERVICE) as DownloadManager + val request = DownloadManager.Request(Uri.parse(model.url)).apply { + setTitle(model.name) + setDescription(getString(R.string.whisper_download_notification_desc)) + setNotificationVisibility(DownloadManager.Request.VISIBILITY_VISIBLE_NOTIFY_COMPLETED) + setDestinationInExternalFilesDir( + this@DownloadWhisperModelActivity, + Environment.DIRECTORY_DOWNLOADS, + model.fileName + ) + setAllowedNetworkTypes( + DownloadManager.Request.NETWORK_WIFI or DownloadManager.Request.NETWORK_MOBILE + ) + } + downloadManager.enqueue(request) + Toast.makeText( + this, + getString(R.string.whisper_download_started, model.name), + Toast.LENGTH_LONG + ).show() + } +} + +@OptIn(ExperimentalMaterial3Api::class) +@Composable +private fun DownloadWhisperModelScreen( + downloadedModels: List, + selectedModel: String, + onBackClick: () -> Unit, + onDownloadModel: (WhisperModel) -> Unit, + onSelectModel: (String) -> Unit, +) { + var selectedModelIndex by remember { mutableStateOf(1) } // Default to base.en model + var currentSelectedModel by remember { mutableStateOf(selectedModel) } + + SmolLMAndroidTheme { + Scaffold( + modifier = Modifier.fillMaxSize(), + topBar = { + TopAppBar( + title = { AppBarTitleText(stringResource(R.string.whisper_download_title)) }, + navigationIcon = { + IconButton(onClick = onBackClick) { + Icon( + Icons.AutoMirrored.Filled.ArrowBack, + contentDescription = stringResource(R.string.button_text_back) + ) + } + } + ) + }, + ) { innerPadding -> + Surface( + modifier = Modifier + .padding(innerPadding) + .verticalScroll(rememberScrollState()) + ) { + Column( + modifier = Modifier + .fillMaxSize() + .padding(16.dp) + ) { + // Downloaded Models Section + if (downloadedModels.isNotEmpty()) { + Text( + text = stringResource(R.string.whisper_downloaded_models), + style = MaterialTheme.typography.titleMedium, + ) + Text( + text = stringResource(R.string.whisper_downloaded_models_desc), + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSurfaceVariant, + ) + + Spacer(modifier = Modifier.height(8.dp)) + + DownloadedModelsList( + models = downloadedModels, + selectedModel = currentSelectedModel, + onModelSelected = { modelName -> + currentSelectedModel = modelName + onSelectModel(modelName) + } + ) + + Spacer(modifier = Modifier.height(24.dp)) + HorizontalDivider() + Spacer(modifier = Modifier.height(24.dp)) + } + + // Download New Models Section + Text( + text = stringResource(R.string.whisper_download_new_model), + style = MaterialTheme.typography.titleMedium, + ) + Text( + text = stringResource(R.string.whisper_download_desc), + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSurfaceVariant, + ) + + Spacer(modifier = Modifier.height(16.dp)) + + Text( + text = stringResource(R.string.whisper_select_model), + style = MaterialTheme.typography.titleSmall, + ) + + Spacer(modifier = Modifier.height(8.dp)) + + PopularWhisperModelsList( + selectedModelIndex = selectedModelIndex, + onModelSelected = { selectedModelIndex = it } + ) + + Spacer(modifier = Modifier.height(24.dp)) + + Button( + onClick = { + selectedModelIndex?.let { index -> + getPopularWhisperModel(index)?.let { model -> + onDownloadModel(model) + } + } + }, + enabled = selectedModelIndex != null, + modifier = Modifier.align(Alignment.CenterHorizontally) + ) { + Text(stringResource(R.string.whisper_download_button)) + } + + Spacer(modifier = Modifier.height(16.dp)) + + Text( + text = stringResource(R.string.whisper_download_location_info), + style = MaterialTheme.typography.bodySmall, + color = MaterialTheme.colorScheme.onSurfaceVariant, + modifier = Modifier.fillMaxWidth() + ) + } + } + } + } +} + +@Composable +private fun DownloadedModelsList( + models: List, + selectedModel: String, + onModelSelected: (String) -> Unit, +) { + Column(verticalArrangement = Arrangement.Center) { + models.forEach { modelName -> + val isSelected = modelName == selectedModel + Row( + Modifier + .clickable { onModelSelected(modelName) } + .fillMaxWidth() + .background( + if (isSelected) { + MaterialTheme.colorScheme.primaryContainer + } else { + MaterialTheme.colorScheme.surface + }, + RoundedCornerShape(8.dp), + ) + .padding(12.dp), + verticalAlignment = Alignment.CenterVertically, + ) { + if (isSelected) { + Icon( + FeatherIcons.Check, + contentDescription = null, + tint = MaterialTheme.colorScheme.onPrimaryContainer, + ) + Spacer(modifier = Modifier.width(8.dp)) + } + Text( + color = if (isSelected) { + MaterialTheme.colorScheme.onPrimaryContainer + } else { + MaterialTheme.colorScheme.onSurface + }, + text = modelName, + style = MaterialTheme.typography.bodyMedium, + ) + } + } + } +} diff --git a/app/src/main/java/io/shubham0204/smollmandroid/ui/screens/whisper_download/PopularWhisperModelsList.kt b/app/src/main/java/io/shubham0204/smollmandroid/ui/screens/whisper_download/PopularWhisperModelsList.kt new file mode 100644 index 00000000..1856dd8b --- /dev/null +++ b/app/src/main/java/io/shubham0204/smollmandroid/ui/screens/whisper_download/PopularWhisperModelsList.kt @@ -0,0 +1,154 @@ +/* + * Copyright (C) 2024 Shubham Panchal + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.shubham0204.smollmandroid.ui.screens.whisper_download + +import androidx.compose.foundation.background +import androidx.compose.foundation.clickable +import androidx.compose.foundation.layout.Arrangement +import androidx.compose.foundation.layout.Column +import androidx.compose.foundation.layout.Row +import androidx.compose.foundation.layout.Spacer +import androidx.compose.foundation.layout.fillMaxWidth +import androidx.compose.foundation.layout.padding +import androidx.compose.foundation.layout.width +import androidx.compose.foundation.shape.RoundedCornerShape +import androidx.compose.material3.Icon +import androidx.compose.material3.MaterialTheme +import androidx.compose.material3.Text +import androidx.compose.runtime.Composable +import androidx.compose.ui.Alignment +import androidx.compose.ui.Modifier +import androidx.compose.ui.tooling.preview.Preview +import androidx.compose.ui.unit.dp +import compose.icons.FeatherIcons +import compose.icons.feathericons.Check + +data class WhisperModel( + val name: String, + val url: String, + val fileName: String, + val sizeDescription: String, +) + +@Preview +@Composable +fun PreviewPopularWhisperModelsList() { + PopularWhisperModelsList(selectedModelIndex = 0, onModelSelected = {}) +} + +@Composable +fun PopularWhisperModelsList(selectedModelIndex: Int?, onModelSelected: (Int) -> Unit) { + Column(verticalArrangement = Arrangement.Center) { + popularWhisperModelsList.forEachIndexed { idx, model -> + Row( + Modifier + .clickable { onModelSelected(idx) } + .fillMaxWidth() + .background( + if (idx == selectedModelIndex) { + MaterialTheme.colorScheme.surfaceContainer + } else { + MaterialTheme.colorScheme.surface + }, + RoundedCornerShape(8.dp), + ) + .padding(8.dp), + verticalAlignment = Alignment.CenterVertically, + horizontalArrangement = Arrangement.SpaceBetween, + ) { + Row( + verticalAlignment = Alignment.CenterVertically, + modifier = Modifier.weight(1f) + ) { + if (idx == selectedModelIndex) { + Icon( + FeatherIcons.Check, + contentDescription = null, + tint = MaterialTheme.colorScheme.onSurface, + ) + Spacer(modifier = Modifier.width(4.dp)) + } + Column { + Text( + color = MaterialTheme.colorScheme.onSurface, + text = model.name, + style = MaterialTheme.typography.bodySmall, + ) + Text( + color = MaterialTheme.colorScheme.onSurfaceVariant, + text = model.sizeDescription, + style = MaterialTheme.typography.labelSmall, + ) + } + } + } + } + } +} + +fun getPopularWhisperModel(index: Int?): WhisperModel? = + if (index != null) popularWhisperModelsList[index] else null + +/** + * A list of Whisper models for speech-to-text functionality. + * Models are from the ggerganov/whisper.cpp repository. + * See: https://huggingface.co/ggerganov/whisper.cpp + */ +val popularWhisperModelsList = listOf( + WhisperModel( + name = "Whisper Tiny (English)", + url = "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-tiny.en.bin", + fileName = "ggml-tiny.en.bin", + sizeDescription = "~75 MB - Fastest, less accurate", + ), + WhisperModel( + name = "Whisper Base (English)", + url = "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-base.en.bin", + fileName = "ggml-base.en.bin", + sizeDescription = "~142 MB - Good balance of speed/accuracy", + ), + WhisperModel( + name = "Whisper Small (English)", + url = "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-small.en.bin", + fileName = "ggml-small.en.bin", + sizeDescription = "~466 MB - Better accuracy, slower", + ), + WhisperModel( + name = "Whisper Tiny (Multilingual)", + url = "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-tiny.bin", + fileName = "ggml-tiny.bin", + sizeDescription = "~75 MB - Fastest, supports multiple languages", + ), + WhisperModel( + name = "Whisper Base (Multilingual)", + url = "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-base.bin", + fileName = "ggml-base.bin", + sizeDescription = "~142 MB - Good balance, multilingual", + ), + WhisperModel( + name = "Whisper Small (Multilingual)", + url = "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-small.bin", + fileName = "ggml-small.bin", + sizeDescription = "~466 MB - Better accuracy, multilingual", + ), + WhisperModel( + name = "Whisper Large v3-turbo (Multilingual)", + url = "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-large-v3-turbo.bin", + fileName = "ggml-large-v3-turbo.bin", + sizeDescription = "~466 MB - Better accuracy, multilingual", + ), +) diff --git a/app/src/main/res/drawable/ic_mic_notification.xml b/app/src/main/res/drawable/ic_mic_notification.xml new file mode 100644 index 00000000..a6472085 --- /dev/null +++ b/app/src/main/res/drawable/ic_mic_notification.xml @@ -0,0 +1,10 @@ + + + diff --git a/app/src/main/res/drawable/ic_stop.xml b/app/src/main/res/drawable/ic_stop.xml new file mode 100644 index 00000000..c4e79ca4 --- /dev/null +++ b/app/src/main/res/drawable/ic_stop.xml @@ -0,0 +1,10 @@ + + + diff --git a/app/src/main/res/values/strings.xml b/app/src/main/res/values/strings.xml index a40662cc..c662a9e6 100644 --- a/app/src/main/res/values/strings.xml +++ b/app/src/main/res/values/strings.xml @@ -112,4 +112,36 @@ Invalid File The selected file is not a valid GGUF file. Benchmark Model + Enable TTS + Disable TTS + Text-to-Speech + Enable text-to-speech to have responses read aloud as they are generated. + Enable Auto-Submit + Disable Auto-Submit + Auto-Submit + Automatically send messages after you stop typing for a configured delay. + Auto-Submit Delay (seconds) + Whisper model not found. Please download ggml-base.en.bin to Downloads folder. + Microphone permission is required for speech-to-text. + Microphone permission was denied. Speech-to-text requires this permission. + Failed to load Whisper model. + Recording... + Transcribing... + Speech-to-Text Model Required + A Whisper model is required for speech-to-text functionality. Would you like to download one now? + Download Model + Download Whisper Model + Whisper models enable speech-to-text transcription. Choose a model based on your needs: smaller models are faster but less accurate, larger models are more accurate but slower. + Select a model: + Download Model + Models are downloaded to the app\'s internal storage and will be ready to use once the download completes. + Download started for %1$s + Downloading Whisper speech-to-text model + Manage STT Models + Downloaded Models + Select which model to use for speech-to-text. The selected model will be loaded when you use the microphone. + Download New Model + Selected: %1$s + Transcription Language + Select the language for speech-to-text transcription. \ No newline at end of file diff --git a/build.gradle.kts b/build.gradle.kts index 0da250c3..2a2930f9 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -6,5 +6,16 @@ plugins { alias(libs.plugins.android.library) apply false id("com.google.devtools.ksp") version "2.0.0-1.0.24" apply false alias(libs.plugins.jetbrains.kotlin.jvm) apply false - kotlin("plugin.serialization") version "2.1.0" apply false + kotlin("plugin.serialization") version "2.0.0" apply false +} + +subprojects { + plugins.withType { + extensions.configure { + jvmToolchain { + languageVersion.set(JavaLanguageVersion.of(17)) + vendor.set(JvmVendorSpec.AZUL) + } + } + } } diff --git a/gradle.properties b/gradle.properties index 20e2a015..4ce248cb 100644 --- a/gradle.properties +++ b/gradle.properties @@ -7,6 +7,8 @@ # Specifies the JVM arguments used for the daemon process. # The setting is particularly useful for tweaking memory settings. org.gradle.jvmargs=-Xmx2048m -Dfile.encoding=UTF-8 +# Enable JVM toolchain auto-provisioning +org.gradle.java.installations.auto-download=true # When configured, Gradle will run in incubating parallel mode. # This option should only be used with decoupled projects. For more details, visit # https://developer.android.com/r/tools/gradle-multi-project-decoupled-projects diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 4a397786..67ed2a9e 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -16,6 +16,7 @@ uiTextGoogleFonts = "1.7.7" composeIcons = "1.1.1" appcompat = "1.6.1" material = "1.10.0" +foundation = "1.10.1" [libraries] androidx-core-ktx = { group = "androidx.core", name = "core-ktx", version.ref = "coreKtx" } @@ -45,6 +46,7 @@ androidx-ui-text-google-fonts = { group = "androidx.compose.ui", name = "ui-text composeIcons-feather = { module = "br.com.devsrsouza.compose.icons:feather", version.ref = "composeIcons" } androidx-appcompat = { group = "androidx.appcompat", name = "appcompat", version.ref = "appcompat" } material = { group = "com.google.android.material", name = "material", version.ref = "material" } +androidx-compose-foundation = { group = "androidx.compose.foundation", name = "foundation", version.ref = "foundation" } [plugins] android-application = { id = "com.android.application", version.ref = "agp" } diff --git a/hf-model-hub-api/build.gradle.kts b/hf-model-hub-api/build.gradle.kts index 6c4a2c05..157ad491 100644 --- a/hf-model-hub-api/build.gradle.kts +++ b/hf-model-hub-api/build.gradle.kts @@ -1,7 +1,7 @@ plugins { id("java-library") alias(libs.plugins.jetbrains.kotlin.jvm) - kotlin("plugin.serialization") version "2.1.0" + kotlin("plugin.serialization") version "2.0.0" } val ktorVersion = "3.0.2" diff --git a/settings.gradle.kts b/settings.gradle.kts index 11d9a43d..9a8c8bf9 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -16,6 +16,10 @@ pluginManagement { gradlePluginPortal() } } + +plugins { + id("org.gradle.toolchains.foojay-resolver-convention") version "0.9.0" +} dependencyResolutionManagement { repositoriesMode.set(RepositoriesMode.FAIL_ON_PROJECT_REPOS) repositories { @@ -30,3 +34,4 @@ rootProject.name = "SmolChat Android" include(":app") include(":smollm") include(":hf-model-hub-api") +include(":whisper") diff --git a/whisper/build.gradle.kts b/whisper/build.gradle.kts new file mode 100644 index 00000000..b5c7dd35 --- /dev/null +++ b/whisper/build.gradle.kts @@ -0,0 +1,58 @@ +plugins { + alias(libs.plugins.android.library) + alias(libs.plugins.kotlin.android) +} + +android { + namespace = "com.whispercpp" + compileSdk = 35 + + defaultConfig { + minSdk = 26 + + ndk { + abiFilters += listOf("arm64-v8a", "armeabi-v7a", "x86", "x86_64") + } + externalNativeBuild { + cmake { + arguments("-DCMAKE_BUILD_TYPE=Release") + cppFlags("-ffile-prefix-map=${projectDir}=.") + cFlags("-ffile-prefix-map=${projectDir}=.") + } + } + } + + buildTypes { + release { + isMinifyEnabled = false + } + } + + compileOptions { + sourceCompatibility = JavaVersion.VERSION_17 + targetCompatibility = JavaVersion.VERSION_17 + } + + kotlinOptions { + jvmTarget = "17" + } + + externalNativeBuild { + cmake { + path = file("src/main/jni/whisper/CMakeLists.txt") + } + } + + packaging { + resources { + excludes += "/META-INF/{AL2.0,LGPL2.1}" + } + } + + ndkVersion = "27.2.12479018" +} + +dependencies { + implementation(libs.androidx.core.ktx) + implementation(libs.androidx.appcompat) +} diff --git a/whisper/src/main/AndroidManifest.xml b/whisper/src/main/AndroidManifest.xml new file mode 100644 index 00000000..8bdb7e14 --- /dev/null +++ b/whisper/src/main/AndroidManifest.xml @@ -0,0 +1,4 @@ + + + + diff --git a/whisper/src/main/java/com/whispercpp/whisper/LibWhisper.kt b/whisper/src/main/java/com/whispercpp/whisper/LibWhisper.kt new file mode 100644 index 00000000..c09aee6e --- /dev/null +++ b/whisper/src/main/java/com/whispercpp/whisper/LibWhisper.kt @@ -0,0 +1,212 @@ +package com.whispercpp.whisper + +import android.content.res.AssetManager +import android.os.Build +import android.util.Log +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.asCoroutineDispatcher +import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.withContext +import java.io.File +import java.io.InputStream +import java.util.concurrent.Executors + +private const val LOG_TAG = "LibWhisper" + + +class WhisperContext private constructor(private var ptr: Long) { + // Meet Whisper C++ constraint: Don't access from more than one thread at a time. + val scope: CoroutineScope = CoroutineScope( + Executors.newSingleThreadExecutor().asCoroutineDispatcher() + ) + fun stopTranscription(){ + WhisperLib.stopTranscription() + } + + + suspend fun transcribeData( + data: FloatArray, + language: String, + printTimestamp: Boolean = true, + callback: WhisperCallback + ): String = withContext(scope.coroutineContext) { + require(ptr != 0L) + + try { + val numThreads = WhisperCpuConfig.preferredThreadCount + Log.d(LOG_TAG, "Selecting $numThreads threads") + + WhisperLib.fullTranscribe(ptr, numThreads, data, language, callback) + + val textCount = WhisperLib.getTextSegmentCount(ptr) + return@withContext buildString { + for (i in 0 until textCount) { + if (printTimestamp) { + val textTimestamp = "[${toTimestamp(WhisperLib.getTextSegmentT0(ptr, i))} --> ${ + toTimestamp(WhisperLib.getTextSegmentT1(ptr, i)) + }]" + val textSegment = WhisperLib.getTextSegment(ptr, i) + append("$textTimestamp: $textSegment\n") + } else { + append(WhisperLib.getTextSegment(ptr, i)) + } + } + } + + } catch (e: Exception) { + Log.e(LOG_TAG, "Error during transcription", e) + return@withContext "" + } + } + + suspend fun benchMemory(nthreads: Int): String = withContext(scope.coroutineContext) { + return@withContext WhisperLib.benchMemcpy(nthreads) + } + + suspend fun benchGgmlMulMat(nthreads: Int): String = withContext(scope.coroutineContext) { + return@withContext WhisperLib.benchGgmlMulMat(nthreads) + } + + suspend fun release() = withContext(scope.coroutineContext) { + if (ptr != 0L) { + WhisperLib.freeContext(ptr) + ptr = 0 + } + } + + protected fun finalize() { + runBlocking { + release() + } + } + + companion object { + fun createContextFromFile(filePath: String): WhisperContext { + val ptr = WhisperLib.initContext(filePath) + if (ptr == 0L) { + throw java.lang.RuntimeException("Couldn't create context with path $filePath") + } + return WhisperContext(ptr) + } + + fun createContextFromInputStream(stream: InputStream): WhisperContext { + val ptr = WhisperLib.initContextFromInputStream(stream) + + if (ptr == 0L) { + throw java.lang.RuntimeException("Couldn't create context from input stream") + } + return WhisperContext(ptr) + } + + fun createContextFromAsset(assetManager: AssetManager, assetPath: String): WhisperContext { + val ptr = WhisperLib.initContextFromAsset(assetManager, assetPath) + + if (ptr == 0L) { + throw java.lang.RuntimeException("Couldn't create context from asset $assetPath") + } + return WhisperContext(ptr) + } + + fun getSystemInfo(): String { + return WhisperLib.getSystemInfo() + } + } +} + +private class WhisperLib { + companion object { + init { + Log.d(LOG_TAG, "Primary ABI: ${Build.SUPPORTED_ABIS[0]}") + var loadVfpv4 = false + var loadV8fp16 = false + if (isArmEabiV7a()) { + // armeabi-v7a needs runtime detection support + val cpuInfo = cpuInfo() + cpuInfo?.let { + Log.d(LOG_TAG, "CPU info: $cpuInfo") + if (cpuInfo.contains("vfpv4")) { + Log.d(LOG_TAG, "CPU supports vfpv4") + loadVfpv4 = true + } + } + } else if (isArmEabiV8a()) { + // ARMv8.2a needs runtime detection support + val cpuInfo = cpuInfo() + cpuInfo?.let { + Log.d(LOG_TAG, "CPU info: $cpuInfo") + if (cpuInfo.contains("fphp")) { + Log.d(LOG_TAG, "CPU supports fp16 arithmetic") + loadV8fp16 = true + } + } + } + + if (loadVfpv4) { + Log.d(LOG_TAG, "Loading libwhisper_vfpv4.so") + System.loadLibrary("whisper_vfpv4") + } else if (loadV8fp16) { + Log.d(LOG_TAG, "Loading libwhisper_v8fp16_va.so") + System.loadLibrary("whisper_v8fp16_va") + } else { + Log.d(LOG_TAG, "Loading libwhisper.so") + System.loadLibrary("whisper") + } + } + + // JNI methods + external fun initContextFromInputStream(inputStream: InputStream): Long + external fun initContextFromAsset(assetManager: AssetManager, assetPath: String): Long + external fun initContext(modelPath: String): Long + external fun freeContext(contextPtr: Long) + external fun stopTranscription() + external fun fullTranscribe( + contextPtr: Long, + numThreads: Int, + audioData: FloatArray, + language: String, + callback: WhisperCallback + ) + + external fun getTextSegmentCount(contextPtr: Long): Int + external fun getTextSegment(contextPtr: Long, index: Int): String + external fun getTextSegmentT0(contextPtr: Long, index: Int): Long + external fun getTextSegmentT1(contextPtr: Long, index: Int): Long + external fun getSystemInfo(): String + external fun benchMemcpy(nthread: Int): String + external fun benchGgmlMulMat(nthread: Int): String + } +} + +// 500 -> 00:05.000 +// 6000 -> 01:00.000 +private fun toTimestamp(t: Long, comma: Boolean = false): String { + var msec = t * 10 + val hr = msec / (1000 * 60 * 60) + msec -= hr * (1000 * 60 * 60) + val min = msec / (1000 * 60) + msec -= min * (1000 * 60) + val sec = msec / 1000 + msec -= sec * 1000 + + val delimiter = if (comma) "," else "." + return String.format("%02d:%02d:%02d%s%03d", hr, min, sec, delimiter, msec) +} + +private fun isArmEabiV7a(): Boolean { + return Build.SUPPORTED_ABIS[0].equals("armeabi-v7a") +} + +private fun isArmEabiV8a(): Boolean { + return Build.SUPPORTED_ABIS[0].equals("arm64-v8a") +} + +private fun cpuInfo(): String? { + return try { + File("/proc/cpuinfo").inputStream().bufferedReader().use { + it.readText() + } + } catch (e: Exception) { + Log.w(LOG_TAG, "Couldn't read /proc/cpuinfo", e) + null + } +} diff --git a/whisper/src/main/java/com/whispercpp/whisper/WhisperCpuConfig.kt b/whisper/src/main/java/com/whispercpp/whisper/WhisperCpuConfig.kt new file mode 100644 index 00000000..45e7370d --- /dev/null +++ b/whisper/src/main/java/com/whispercpp/whisper/WhisperCpuConfig.kt @@ -0,0 +1,73 @@ +package com.whispercpp.whisper + +import android.util.Log +import java.io.BufferedReader +import java.io.FileReader + +object WhisperCpuConfig { + val preferredThreadCount: Int + // Always use at least 2 threads: + get() = CpuInfo.getHighPerfCpuCount().coerceAtLeast(2) +} + +private class CpuInfo(private val lines: List) { + private fun getHighPerfCpuCount(): Int = try { + getHighPerfCpuCountByFrequencies() + } catch (e: Exception) { + Log.d(LOG_TAG, "Couldn't read CPU frequencies", e) + getHighPerfCpuCountByVariant() + } + + private fun getHighPerfCpuCountByFrequencies(): Int = + getCpuValues(property = "processor") { getMaxCpuFrequency(it.toInt()) } + .also { Log.d(LOG_TAG, "Binned cpu frequencies (frequency, count): ${it.binnedValues()}") } + .countDroppingMin() + + private fun getHighPerfCpuCountByVariant(): Int = + getCpuValues(property = "CPU variant") { it.substringAfter("0x").toInt(radix = 16) } + .also { Log.d(LOG_TAG, "Binned cpu variants (variant, count): ${it.binnedValues()}") } + .countKeepingMin() + + private fun List.binnedValues() = groupingBy { it }.eachCount() + + private fun getCpuValues(property: String, mapper: (String) -> Int) = lines + .asSequence() + .filter { it.startsWith(property) } + .map { mapper(it.substringAfter(':').trim()) } + .sorted() + .toList() + + + private fun List.countDroppingMin(): Int { + val min = min() + return count { it > min } + } + + private fun List.countKeepingMin(): Int { + val min = min() + return count { it == min } + } + + companion object { + private const val LOG_TAG = "WhisperCpuConfig" + + fun getHighPerfCpuCount(): Int = try { + readCpuInfo().getHighPerfCpuCount() + } catch (e: Exception) { + Log.d(LOG_TAG, "Couldn't read CPU info", e) + // Our best guess -- just return the # of CPUs minus 4. + (Runtime.getRuntime().availableProcessors() - 4).coerceAtLeast(0) + } + + private fun readCpuInfo() = CpuInfo( + BufferedReader(FileReader("/proc/cpuinfo")) + .useLines { it.toList() } + ) + + private fun getMaxCpuFrequency(cpuIndex: Int): Int { + val path = "/sys/devices/system/cpu/cpu${cpuIndex}/cpufreq/cpuinfo_max_freq" + val maxFreq = BufferedReader(FileReader(path)).use { it.readLine() } + return maxFreq.toInt() + } + } +} diff --git a/whisper/src/main/java/com/whispercpp/whisper/WishperCallBack.kt b/whisper/src/main/java/com/whispercpp/whisper/WishperCallBack.kt new file mode 100644 index 00000000..2f0a98ad --- /dev/null +++ b/whisper/src/main/java/com/whispercpp/whisper/WishperCallBack.kt @@ -0,0 +1,24 @@ +package com.whispercpp.whisper + +import androidx.annotation.Keep + +interface WhisperCallback { + fun onNewSegment(startMs: Long, endMs: Long, text: String) + fun onProgress(progress: Int) + fun onComplete() +} + +@Keep +class WishperCallBack : WhisperCallback { + override fun onNewSegment(startMs: Long, endMs: Long, text: String) { + println(text) + } + + override fun onProgress(progress: Int) { + println(progress) + } + + override fun onComplete() { + println("Completed") + } +} diff --git a/whisper/src/main/jni/whisper.cpp b/whisper/src/main/jni/whisper.cpp new file mode 160000 index 00000000..e990d1b7 --- /dev/null +++ b/whisper/src/main/jni/whisper.cpp @@ -0,0 +1 @@ +Subproject commit e990d1b791e7bda546866e82e094a8969ea86c6d diff --git a/whisper/src/main/jni/whisper/CMakeLists.txt b/whisper/src/main/jni/whisper/CMakeLists.txt new file mode 100644 index 00000000..9a6d488e --- /dev/null +++ b/whisper/src/main/jni/whisper/CMakeLists.txt @@ -0,0 +1,126 @@ +cmake_minimum_required(VERSION 3.10) +add_link_options("LINKER:--build-id=none") + +# 16KB Page Size Support, linker flags for 16KB page size alignment on Android 15+ devices +add_link_options("LINKER:-z,max-page-size=16384") + +project(whisper.cpp) + +set(CMAKE_CXX_STANDARD 17) +set(WHISPER_LIB_DIR ${CMAKE_SOURCE_DIR}/../whisper.cpp) + +# Remove file paths from debug info +set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -ffile-prefix-map=${CMAKE_SOURCE_DIR}=.") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -ffile-prefix-map=${CMAKE_SOURCE_DIR}=.") + +# Set consistent release flags - ADDED +set(CMAKE_C_FLAGS_RELEASE "-O3 -DNDEBUG") +set(CMAKE_CXX_FLAGS_RELEASE "-O3 -DNDEBUG") + +# Disable timestamps - ADDED +add_definitions(-DGGML_NO_TIMESTAMPS=1) + +# 16KB Page Size Support: Add NDK r27 flexible page size compilation flag +add_definitions(-DANDROID_SUPPORT_FLEXIBLE_PAGE_SIZES=ON) + +# Path to external GGML, otherwise uses the copy in whisper.cpp. +option(GGML_HOME "whisper: Path to external GGML source" OFF) + +set( + SOURCE_FILES + ${WHISPER_LIB_DIR}/src/whisper.cpp + ${CMAKE_SOURCE_DIR}/jni.c +) + +# TODO: this needs to be updated to work with the new ggml CMakeLists + +if (NOT GGML_HOME) + set( + SOURCE_FILES + ${SOURCE_FILES} + ${WHISPER_LIB_DIR}/ggml/src/ggml.c + ${WHISPER_LIB_DIR}/ggml/src/ggml-alloc.c + ${WHISPER_LIB_DIR}/ggml/src/ggml-backend.cpp + ${WHISPER_LIB_DIR}/ggml/src/ggml-backend-reg.cpp + ${WHISPER_LIB_DIR}/ggml/src/ggml-quants.c + ${WHISPER_LIB_DIR}/ggml/src/ggml-threading.cpp + ${WHISPER_LIB_DIR}/ggml/src/ggml-cpu/ggml-cpu.c + ${WHISPER_LIB_DIR}/ggml/src/ggml-cpu/ggml-cpu.cpp + ${WHISPER_LIB_DIR}/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp + ${WHISPER_LIB_DIR}/ggml/src/ggml-cpu/ggml-cpu-hbm.cpp + ${WHISPER_LIB_DIR}/ggml/src/ggml-cpu/ggml-cpu-quants.c + ${WHISPER_LIB_DIR}/ggml/src/ggml-cpu/ggml-cpu-traits.cpp + ) +endif() + +find_library(LOG_LIB log) + +function(build_library target_name) + add_library( + ${target_name} + SHARED + ${SOURCE_FILES} + ) + + target_compile_definitions(${target_name} PUBLIC GGML_USE_CPU) + + # Add reproducible build definitions - ADDED + target_compile_definitions(${target_name} PRIVATE GGML_NO_TIMESTAMPS=1) + + if (${target_name} STREQUAL "whisper_v8fp16_va") + target_compile_options(${target_name} PRIVATE -march=armv8.2-a+fp16) + set(GGML_COMPILE_OPTIONS -march=armv8.2-a+fp16) + elseif (${target_name} STREQUAL "whisper_vfpv4") + target_compile_options(${target_name} PRIVATE -mfpu=neon-vfpv4) + set(GGML_COMPILE_OPTIONS -mfpu=neon-vfpv4) + endif () + + if (NOT ${CMAKE_BUILD_TYPE} STREQUAL "Debug") + # Use consistent optimization flags - MODIFIED + target_compile_options(${target_name} PRIVATE -O3 -DNDEBUG) + target_compile_options(${target_name} PRIVATE -fvisibility=hidden -fvisibility-inlines-hidden) + target_compile_options(${target_name} PRIVATE -ffunction-sections -fdata-sections) + + target_link_options(${target_name} PRIVATE -Wl,--gc-sections) + target_link_options(${target_name} PRIVATE -Wl,--exclude-libs,ALL) + target_link_options(${target_name} PRIVATE -flto) + + # 16KB Page Size Support, linker flags for 16KB page size alignment on Android 15+ devices + target_link_options(${target_name} PRIVATE -Wl,-z,max-page-size=16384) + endif () + + if (GGML_HOME) + include(FetchContent) + FetchContent_Declare(ggml SOURCE_DIR ${GGML_HOME}) + FetchContent_MakeAvailable(ggml) + + target_compile_options(ggml PRIVATE ${GGML_COMPILE_OPTIONS}) + # Add reproducible build flags to ggml as well - ADDED + target_compile_definitions(ggml PRIVATE GGML_NO_TIMESTAMPS=1) + target_link_libraries(${target_name} ${LOG_LIB} android ggml) + else() + target_link_libraries(${target_name} ${LOG_LIB} android) + endif() + + # Strip .comment section for reproducible builds + add_custom_command(TARGET ${target_name} POST_BUILD + COMMAND ${CMAKE_OBJCOPY} --remove-section .comment $ + COMMENT "Removing .comment section from ${target_name} for reproducible builds" + ) + +endfunction() + +if (${ANDROID_ABI} STREQUAL "arm64-v8a") + build_library("whisper_v8fp16_va") +elseif (${ANDROID_ABI} STREQUAL "armeabi-v7a") + build_library("whisper_vfpv4") +endif () + +build_library("whisper") # Default target + +include_directories(${WHISPER_LIB_DIR}) +include_directories(${WHISPER_LIB_DIR}/src) +include_directories(${WHISPER_LIB_DIR}/include) +include_directories(${WHISPER_LIB_DIR}/ggml/include) +include_directories(${WHISPER_LIB_DIR}/ggml/src) +include_directories(${WHISPER_LIB_DIR}/ggml/src/ggml-cpu) diff --git a/whisper/src/main/jni/whisper/jni.c b/whisper/src/main/jni/whisper/jni.c new file mode 100644 index 00000000..7883a900 --- /dev/null +++ b/whisper/src/main/jni/whisper/jni.c @@ -0,0 +1,397 @@ +#include +#include +#include +#include +#include +#include +#include +#include "whisper.h" +#include "ggml.h" + +#include +#include + +static bool g_should_abort = false; +static pthread_mutex_t g_abort_mutex = PTHREAD_MUTEX_INITIALIZER; + +#define UNUSED(x) (void)(x) +#define TAG "JNI" + +#define LOGI(...) __android_log_print(ANDROID_LOG_INFO, TAG, __VA_ARGS__) +#define LOGW(...) __android_log_print(ANDROID_LOG_WARN, TAG, __VA_ARGS__) + +// Global references for callback handling +static JavaVM* g_jvm = NULL; +static jobject g_callback = NULL; + +// Method IDs for callback functions +static jmethodID g_onNewSegmentMethod = NULL; +static jmethodID g_onProgressMethod = NULL; +static jmethodID g_onCompleteMethod = NULL; + + + +static inline int min(int a, int b) { + return (a < b) ? a : b; +} + +static inline int max(int a, int b) { + return (a > b) ? a : b; +} + +struct input_stream_context { + size_t offset; + JNIEnv * env; + jobject thiz; + jobject input_stream; + + jmethodID mid_available; + jmethodID mid_read; +}; + +size_t inputStreamRead(void * ctx, void * output, size_t read_size) { + struct input_stream_context* is = (struct input_stream_context*)ctx; + + jint avail_size = (*is->env)->CallIntMethod(is->env, is->input_stream, is->mid_available); + jint size_to_copy = read_size < avail_size ? (jint)read_size : avail_size; + + jbyteArray byte_array = (*is->env)->NewByteArray(is->env, size_to_copy); + + jint n_read = (*is->env)->CallIntMethod(is->env, is->input_stream, is->mid_read, byte_array, 0, size_to_copy); + + if (size_to_copy != read_size || size_to_copy != n_read) { + LOGI("Insufficient Read: Req=%zu, ToCopy=%d, Available=%d", read_size, size_to_copy, n_read); + } + + jbyte* byte_array_elements = (*is->env)->GetByteArrayElements(is->env, byte_array, NULL); + memcpy(output, byte_array_elements, size_to_copy); + (*is->env)->ReleaseByteArrayElements(is->env, byte_array, byte_array_elements, JNI_ABORT); + + (*is->env)->DeleteLocalRef(is->env, byte_array); + + is->offset += size_to_copy; + + return size_to_copy; +} +bool inputStreamEof(void * ctx) { + struct input_stream_context* is = (struct input_stream_context*)ctx; + + jint result = (*is->env)->CallIntMethod(is->env, is->input_stream, is->mid_available); + return result <= 0; +} +void inputStreamClose(void * ctx) { + +} + +JNIEXPORT jlong JNICALL +Java_com_whispercppdemo_whisper_WhisperLib_00024Companion_initContextFromInputStream( + JNIEnv *env, jobject thiz, jobject input_stream) { + UNUSED(thiz); + + struct whisper_context *context = NULL; + struct whisper_model_loader loader = {}; + struct input_stream_context inp_ctx = {}; + + inp_ctx.offset = 0; + inp_ctx.env = env; + inp_ctx.thiz = thiz; + inp_ctx.input_stream = input_stream; + + jclass cls = (*env)->GetObjectClass(env, input_stream); + inp_ctx.mid_available = (*env)->GetMethodID(env, cls, "available", "()I"); + inp_ctx.mid_read = (*env)->GetMethodID(env, cls, "read", "([BII)I"); + + loader.context = &inp_ctx; + loader.read = inputStreamRead; + loader.eof = inputStreamEof; + loader.close = inputStreamClose; + + loader.eof(loader.context); + + context = whisper_init(&loader); + return (jlong) context; +} + +static size_t asset_read(void *ctx, void *output, size_t read_size) { + return AAsset_read((AAsset *) ctx, output, read_size); +} + +static bool asset_is_eof(void *ctx) { + return AAsset_getRemainingLength64((AAsset *) ctx) <= 0; +} + +static void asset_close(void *ctx) { + AAsset_close((AAsset *) ctx); +} + +static struct whisper_context *whisper_init_from_asset( + JNIEnv *env, + jobject assetManager, + const char *asset_path +) { + LOGI("Loading model from asset '%s'\n", asset_path); + AAssetManager *asset_manager = AAssetManager_fromJava(env, assetManager); + AAsset *asset = AAssetManager_open(asset_manager, asset_path, AASSET_MODE_STREAMING); + if (!asset) { + LOGW("Failed to open '%s'\n", asset_path); + return NULL; + } + + whisper_model_loader loader = { + .context = asset, + .read = &asset_read, + .eof = &asset_is_eof, + .close = &asset_close + }; + + return whisper_init_with_params(&loader, whisper_context_default_params()); +} + +JNIEXPORT jlong JNICALL +Java_com_whispercpp_whisper_WhisperLib_00024Companion_initContextFromAsset( + JNIEnv *env, jobject thiz, jobject assetManager, jstring asset_path_str) { + UNUSED(thiz); + struct whisper_context *context = NULL; + const char *asset_path_chars = (*env)->GetStringUTFChars(env, asset_path_str, NULL); + context = whisper_init_from_asset(env, assetManager, asset_path_chars); + (*env)->ReleaseStringUTFChars(env, asset_path_str, asset_path_chars); + return (jlong) context; +} + +JNIEXPORT jlong JNICALL +Java_com_whispercpp_whisper_WhisperLib_00024Companion_initContext( + JNIEnv *env, jobject thiz, jstring model_path_str) { + UNUSED(thiz); + struct whisper_context *context = NULL; + const char *model_path_chars = (*env)->GetStringUTFChars(env, model_path_str, NULL); + context = whisper_init_from_file_with_params(model_path_chars, whisper_context_default_params()); + (*env)->ReleaseStringUTFChars(env, model_path_str, model_path_chars); + return (jlong) context; +} + +JNIEXPORT void JNICALL +Java_com_whispercpp_whisper_WhisperLib_00024Companion_freeContext( + JNIEnv *env, jobject thiz, jlong context_ptr) { + UNUSED(env); + UNUSED(thiz); + struct whisper_context *context = (struct whisper_context *) context_ptr; + whisper_free(context); +} +// Callback for new segments +void new_segment_callback(struct whisper_context * ctx, struct whisper_state * state, int n_new, void * user_data) { + JNIEnv* env; + (*g_jvm)->AttachCurrentThread(g_jvm, &env, NULL); + + for (int i = 0; i < n_new; i++) { + const int segment_id = whisper_full_n_segments(ctx) - n_new + i; + const char* text = whisper_full_get_segment_text(ctx, segment_id); + const int64_t t0 = whisper_full_get_segment_t0(ctx, segment_id); + const int64_t t1 = whisper_full_get_segment_t1(ctx, segment_id); + + jstring jtext = (*env)->NewStringUTF(env, text); + (*env)->CallVoidMethod( + env, + g_callback, + g_onNewSegmentMethod, + (jlong)t0, + (jlong)t1, + jtext + ); + (*env)->DeleteLocalRef(env, jtext); + } +} + +// Progress callback +void progress_callback(struct whisper_context * ctx, struct whisper_state * state, int progress, void * user_data) { + JNIEnv* env; + (*g_jvm)->AttachCurrentThread(g_jvm, &env, NULL); + (*env)->CallVoidMethod(env, g_callback, g_onProgressMethod, (jint)progress); +} + +static bool abort_callback(void* user_data) { + bool should_abort; + pthread_mutex_lock(&g_abort_mutex); + should_abort = g_should_abort; + pthread_mutex_unlock(&g_abort_mutex); + return should_abort; +} + +JNIEXPORT void JNICALL +Java_com_whispercpp_whisper_WhisperLib_00024Companion_resetAbort(JNIEnv* env, jobject thiz) { + pthread_mutex_lock(&g_abort_mutex); + g_should_abort = false; + pthread_mutex_unlock(&g_abort_mutex); +} + +JNIEXPORT void JNICALL +Java_com_whispercpp_whisper_WhisperLib_00024Companion_stopTranscription(JNIEnv* env, jobject thiz) { + pthread_mutex_lock(&g_abort_mutex); + g_should_abort = true; + pthread_mutex_unlock(&g_abort_mutex); +} + + +JNIEXPORT void JNICALL +Java_com_whispercpp_whisper_WhisperLib_00024Companion_fullTranscribe( + JNIEnv *env, jobject thiz, jlong context_ptr, jint num_threads, + jfloatArray audio_data, jstring language, jobject callback) { + UNUSED(thiz); +// Reset abort state + Java_com_whispercpp_whisper_WhisperLib_00024Companion_resetAbort(env, thiz); + // Store JavaVM for later callbacks + if (g_jvm == NULL) { + (*env)->GetJavaVM(env, &g_jvm); + } + + // Clean up previous callback if exists + if (g_callback != NULL) { + (*env)->DeleteGlobalRef(env, g_callback); + } + + // Create new global reference + g_callback = (*env)->NewGlobalRef(env, callback); + + // Get method IDs + jclass callbackClass = (*env)->GetObjectClass(env, g_callback); + + jmethodID toStringMethod = (*env)->GetMethodID(env, callbackClass, "toString", "()Ljava/lang/String;"); + jstring str = (*env)->CallObjectMethod(env, g_callback, toStringMethod); + const char *cStr = (*env)->GetStringUTFChars(env, str, NULL); + __android_log_print(ANDROID_LOG_DEBUG, "JNI", "Callback class: %s", cStr); + (*env)->ReleaseStringUTFChars(env, str, cStr); + g_onNewSegmentMethod = (*env)->GetMethodID( + env, + callbackClass, + "onNewSegment", + "(JJLjava/lang/String;)V" + ); + g_onProgressMethod = (*env)->GetMethodID( + env, + callbackClass, + "onProgress", + "(I)V" + ); + g_onCompleteMethod = (*env)->GetMethodID( + env, + callbackClass, + "onComplete", + "()V" + ); + + struct whisper_context *context = (struct whisper_context *) context_ptr; + jfloat *audio_data_arr = (*env)->GetFloatArrayElements(env, audio_data, NULL); + const jsize audio_data_length = (*env)->GetArrayLength(env, audio_data); + + // Get language parameter (default to "auto" if null) + const char *language_str = "auto"; + if (language != NULL) { + language_str = (*env)->GetStringUTFChars(env, language, NULL); + } + + // Configure whisper parameters with callbacks + struct whisper_full_params params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); + params.print_realtime = false; // We handle callbacks ourselves + params.print_progress = false; + params.print_timestamps = false; + params.print_special = false; + params.translate = false; + params.language = language_str; + params.n_threads = num_threads; + params.offset_ms = 0; + params.no_context = true; + params.single_segment = false; + + // Set our callbacks + params.new_segment_callback = new_segment_callback; + params.new_segment_callback_user_data = NULL; + params.progress_callback = progress_callback; + params.progress_callback_user_data = NULL; + params.abort_callback = abort_callback; + params.abort_callback_user_data = NULL; + + + whisper_reset_timings(context); + + LOGI("About to run whisper_full with callbacks (language: %s)", language_str); + int result = whisper_full(context, params, audio_data_arr, audio_data_length); + + // Cleanup language string if we allocated it + if (language != NULL) { + (*env)->ReleaseStringUTFChars(env, language, language_str); + } + + // Notify completion + if (result == 0) { + (*env)->CallVoidMethod(env, g_callback, g_onCompleteMethod); + } else { + LOGI("Failed to run the model"); + } + + // Cleanup + (*env)->ReleaseFloatArrayElements(env, audio_data, audio_data_arr, JNI_ABORT); + (*env)->DeleteGlobalRef(env, g_callback); + g_callback = NULL; +} + +JNIEXPORT jint JNICALL +Java_com_whispercpp_whisper_WhisperLib_00024Companion_getTextSegmentCount( + JNIEnv *env, jobject thiz, jlong context_ptr) { + UNUSED(env); + UNUSED(thiz); + struct whisper_context *context = (struct whisper_context *) context_ptr; + return whisper_full_n_segments(context); +} + +JNIEXPORT jstring JNICALL +Java_com_whispercpp_whisper_WhisperLib_00024Companion_getTextSegment( + JNIEnv *env, jobject thiz, jlong context_ptr, jint index) { + UNUSED(thiz); + struct whisper_context *context = (struct whisper_context *) context_ptr; + const char *text = whisper_full_get_segment_text(context, index); + jstring string = (*env)->NewStringUTF(env, text); + return string; +} + +JNIEXPORT jlong JNICALL +Java_com_whispercpp_whisper_WhisperLib_00024Companion_getTextSegmentT0( + JNIEnv *env, jobject thiz, jlong context_ptr, jint index) { + UNUSED(thiz); + struct whisper_context *context = (struct whisper_context *) context_ptr; + return whisper_full_get_segment_t0(context, index); +} + +JNIEXPORT jlong JNICALL +Java_com_whispercpp_whisper_WhisperLib_00024Companion_getTextSegmentT1( + JNIEnv *env, jobject thiz, jlong context_ptr, jint index) { + UNUSED(thiz); + struct whisper_context *context = (struct whisper_context *) context_ptr; + return whisper_full_get_segment_t1(context, index); +} + +JNIEXPORT jstring JNICALL +Java_com_whispercpp_whisper_WhisperLib_00024Companion_getSystemInfo( + JNIEnv *env, jobject thiz +) { + UNUSED(thiz); + const char *sysinfo = whisper_print_system_info(); + jstring string = (*env)->NewStringUTF(env, sysinfo); + return string; +} + +JNIEXPORT jstring JNICALL +Java_com_whispercpp_whisper_WhisperLib_00024Companion_benchMemcpy(JNIEnv *env, jobject thiz, + jint n_threads) { + UNUSED(thiz); + const char *bench_ggml_memcpy = whisper_bench_memcpy_str(n_threads); + jstring string = (*env)->NewStringUTF(env, bench_ggml_memcpy); + return string; +} + +JNIEXPORT jstring JNICALL +Java_com_whispercpp_whisper_WhisperLib_00024Companion_benchGgmlMulMat(JNIEnv *env, jobject thiz, + jint n_threads) { + UNUSED(thiz); + const char *bench_ggml_mul_mat = whisper_bench_ggml_mul_mat_str(n_threads); + jstring string = (*env)->NewStringUTF(env, bench_ggml_mul_mat); + return string; +}