Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 81 additions & 120 deletions app/src/main/java/com/urik/keyboard/service/SpellCheckManager.kt
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,14 @@ class SpellCheckManager
private val initScope = CoroutineScope(SupervisorJob() + Dispatchers.Main)

private val spellCheckers = ConcurrentHashMap<String, SpellChecker>()
private val spellCheckerAccessOrder = ConcurrentHashMap<String, Long>()
private var currentLanguage: String = "en"

private val wordFrequencies = ConcurrentHashMap<String, Long>()

private val parsedDictionaryWords = ConcurrentHashMap<String, List<Pair<String, Int>>>()
private val indexedDictionaryWords = ConcurrentHashMap<String, List<CachedWord>>()

private val suggestionCache: ManagedCache<String, List<SpellingSuggestion>> =
cacheMemoryManager.createCache(
name = "spell_suggestions",
Expand All @@ -93,15 +97,6 @@ class SpellCheckManager
val accentStrippedWord: String,
)

@Volatile
private var commonWordsCache = emptyList<Pair<String, Int>>()

@Volatile
private var commonWordsCacheIndexed = emptyList<CachedWord>()

@Volatile
private var commonWordsCacheLanguage = ""

@Volatile
private var isInitialized = false

Expand Down Expand Up @@ -210,7 +205,6 @@ class SpellCheckManager
val spellChecker = createSpellChecker(context, languageCode)
if (spellChecker != null) {
spellCheckers[languageCode] = spellChecker
loadCommonWordsCache(languageCode)
}
}
} catch (e: Exception) {
Expand Down Expand Up @@ -268,18 +262,9 @@ class SpellCheckManager
}
}

private suspend fun loadCommonWordsCache(languageCode: String) {
if (languageCode == commonWordsCacheLanguage && commonWordsCache.isNotEmpty()) {
return
}

try {
val words = getCommonWords(languageCode)
commonWordsCache = words
commonWordsCacheLanguage = languageCode
} catch (_: Exception) {
commonWordsCache = emptyList()
}
private suspend fun ensureDictionaryParsed(languageCode: String) {
if (parsedDictionaryWords.containsKey(languageCode)) return
getSpellCheckerForLanguage(languageCode)
}

private suspend fun createSpellChecker(
Expand All @@ -299,6 +284,7 @@ class SpellCheckManager
val symSpell = SymSpell(settings)
val dictionaryFile = "dictionaries/${languageCode}_symspell.txt"
val inputStream = context.assets.open(dictionaryFile)
val collectedWords = ArrayList<Pair<String, Int>>(INITIAL_WORD_LIST_CAPACITY)

inputStream.bufferedReader().use { reader ->
val lines = reader.readLines()
Expand All @@ -314,16 +300,35 @@ class SpellCheckManager
val frequency = parts[1].toLongOrNull() ?: 1L
wordFrequencies[word.lowercase()] = frequency
symSpell.createDictionaryEntry(word, frequency.toInt())
collectedWords.add(word.lowercase() to frequency.toInt())
} else if (parts.size == 1) {
wordFrequencies[parts[0].lowercase()] = 1L
symSpell.createDictionaryEntry(parts[0], 1)
collectedWords.add(parts[0].lowercase() to 1)
}
}
}

yield()
}

val sorted = collectedWords.sortedByDescending { it.second }
parsedDictionaryWords[languageCode] = sorted
indexedDictionaryWords[languageCode] =
sorted.map { (word, freq) ->
CachedWord(
word = word,
frequency = freq,
strippedWord =
com.urik.keyboard.utils.TextMatchingUtils
.stripWordPunctuation(word),
accentStrippedWord = wordNormalizer.stripDiacritics(word).lowercase(),
)
}

evictExcessSpellCheckers(languageCode)
spellCheckerAccessOrder[languageCode] = System.nanoTime()

symSpell
}
} catch (e: CancellationException) {
Expand Down Expand Up @@ -355,16 +360,33 @@ class SpellCheckManager
}
}

private fun evictExcessSpellCheckers(preserveLanguage: String) {
if (spellCheckers.size < MAX_CACHED_SPELL_CHECKERS) return

val evictionTarget =
spellCheckerAccessOrder.entries
.filter { it.key != preserveLanguage && it.key != currentLanguage }
.minByOrNull { it.value }
?.key ?: return

spellCheckers.remove(evictionTarget)
spellCheckerAccessOrder.remove(evictionTarget)
parsedDictionaryWords.remove(evictionTarget)
indexedDictionaryWords.remove(evictionTarget)
}

private suspend fun getSpellCheckerForLanguage(languageCode: String): SpellChecker? {
spellCheckers[languageCode]?.let { return it }
spellCheckers[languageCode]?.let {
spellCheckerAccessOrder[languageCode] = System.nanoTime()
return it
}

if (!isInitialized || languageCode !in KeyboardSettings.SUPPORTED_LANGUAGES) {
return null
}

return createSpellChecker(context, languageCode)?.also { newChecker ->
spellCheckers.putIfAbsent(languageCode, newChecker)
loadCommonWordsCache(languageCode)
}
}

Expand Down Expand Up @@ -886,19 +908,14 @@ class SpellCheckManager
prefix: String,
languageCode: String,
): List<Pair<String, Int>> {
if (languageCode != commonWordsCacheLanguage || commonWordsCacheIndexed.isEmpty()) {
try {
getCommonWords(languageCode)
} catch (_: Exception) {
return emptyList()
}
}
ensureDictionaryParsed(languageCode)
val indexed = indexedDictionaryWords[languageCode] ?: return emptyList()

val hasApostrophe = prefix.contains('\'')

val apostropheMatches =
if (hasApostrophe) {
commonWordsCacheIndexed
indexed
.filter { cached ->
cached.word.startsWith(prefix, ignoreCase = true) &&
cached.word.length > prefix.length
Expand All @@ -918,7 +935,7 @@ class SpellCheckManager

val apostropheWords = apostropheMatches.map { it.first }.toSet()
val exactPrefixMatches =
commonWordsCacheIndexed
indexed
.filter { cached ->
cached.word !in apostropheWords &&
cached.strippedWord.startsWith(strippedPrefix, ignoreCase = true) &&
Expand All @@ -932,7 +949,7 @@ class SpellCheckManager

val seenWords = combined.map { it.first }.toSet()
val accentFallbackMatches =
commonWordsCacheIndexed
indexed
.filter { cached ->
cached.word !in seenWords &&
cached.accentStrippedWord.startsWith(accentStrippedPrefix) &&
Expand All @@ -959,31 +976,6 @@ class SpellCheckManager
return hasValidChars && codePointCount in 1..MAX_INPUT_CODEPOINTS
}

private fun parseDictionaryLine(line: String): Pair<String, Int>? {
if (line.isBlank()) return null

val parts = line.trim().split(" ", limit = 2)
val word = parts[0].lowercase().trim()
val frequency =
if (parts.size >= 2) {
parts[1].toIntOrNull() ?: 1
} else {
1
}

val isValid =
word.length in COMMON_WORD_MIN_LENGTH..COMMON_WORD_MAX_LENGTH &&
word.all {
Character.isLetter(it.code) ||
Character.getType(it.code) == Character.OTHER_LETTER.toInt() ||
com.urik.keyboard.utils.TextMatchingUtils
.isValidWordPunctuation(it)
} &&
!isWordBlacklisted(word)

return if (isValid) word to frequency else null
}

private fun getCurrentLanguage(): String =
try {
val currentLanguage = languageManager.currentLanguage.value
Expand Down Expand Up @@ -1071,8 +1063,6 @@ class SpellCheckManager
val cacheKey = buildCacheKey(normalizedWord, currentLang)
dictionaryCache.invalidate(cacheKey)
suggestionCache.invalidateAll()
commonWordsCache = emptyList()
commonWordsCacheIndexed = emptyList()
} catch (_: Exception) {
}
}
Expand All @@ -1095,8 +1085,6 @@ class SpellCheckManager
val cacheKey = buildCacheKey(normalizedWord, currentLang)
dictionaryCache.invalidate(cacheKey)
suggestionCache.invalidateAll()
commonWordsCache = emptyList()
commonWordsCacheIndexed = emptyList()
}
} catch (_: Exception) {
}
Expand Down Expand Up @@ -1124,18 +1112,29 @@ class SpellCheckManager
android.content.ComponentCallbacks2.TRIM_MEMORY_COMPLETE,
-> {
wordFrequencies.clear()
commonWordsCache = emptyList()
commonWordsCacheIndexed = emptyList()
commonWordsCacheLanguage = ""
clearCaches()

val keepLanguage = currentLanguage
val toEvict = spellCheckers.keys.filter { it != keepLanguage }
toEvict.forEach { lang ->
spellCheckers.remove(lang)
spellCheckerAccessOrder.remove(lang)
parsedDictionaryWords.remove(lang)
indexedDictionaryWords.remove(lang)
}
}

android.content.ComponentCallbacks2.TRIM_MEMORY_RUNNING_MODERATE,
android.content.ComponentCallbacks2.TRIM_MEMORY_MODERATE,
-> {
commonWordsCache = emptyList()
commonWordsCacheIndexed = emptyList()
commonWordsCacheLanguage = ""
val activeLangs = languageManager.activeLanguages.value.toSet()
val toEvict = parsedDictionaryWords.keys.filter { it !in activeLangs }
toEvict.forEach { lang ->
parsedDictionaryWords.remove(lang)
indexedDictionaryWords.remove(lang)
spellCheckers.remove(lang)
spellCheckerAccessOrder.remove(lang)
}
}
}
}
Expand All @@ -1159,44 +1158,13 @@ class SpellCheckManager
return@withContext emptyList()
}

if (targetLang == commonWordsCacheLanguage && commonWordsCache.isNotEmpty()) {
return@withContext commonWordsCache
}

val dictionaryFile = "dictionaries/${targetLang}_symspell.txt"
val wordFrequencies = mutableListOf<Pair<String, Int>>()

try {
context.assets.open(dictionaryFile).bufferedReader().use { reader ->
reader.forEachLine { line ->
parseDictionaryLine(line)?.let { wordFrequency ->
wordFrequencies.add(wordFrequency)
}
}
}
} catch (_: IOException) {
return@withContext emptyList()
if (parsedDictionaryWords[targetLang] == null) {
ensureDictionaryParsed(targetLang)
}

val sortedWords = wordFrequencies.sortedByDescending { it.second }

val sortedWordsWithStripped =
sortedWords.map { (word, freq) ->
CachedWord(
word = word,
frequency = freq,
strippedWord =
com.urik.keyboard.utils.TextMatchingUtils
.stripWordPunctuation(word),
accentStrippedWord = wordNormalizer.stripDiacritics(word).lowercase(),
)
}

commonWordsCache = sortedWords
commonWordsCacheIndexed = sortedWordsWithStripped
commonWordsCacheLanguage = targetLang

return@withContext sortedWords
return@withContext parsedDictionaryWords[targetLang]
?.filter { !isWordBlacklisted(it.first) }
?: emptyList()
} catch (_: Exception) {
return@withContext emptyList()
}
Expand All @@ -1209,25 +1177,19 @@ class SpellCheckManager
return@withContext emptyMap()
}

val mergedWords = mutableMapOf<String, Int>()
val mergedWords = HashMap<String, Int>(INITIAL_WORD_LIST_CAPACITY)

languages.forEach { lang ->
if (lang !in KeyboardSettings.SUPPORTED_LANGUAGES) {
return@forEach
}

val dictionaryFile = "dictionaries/${lang}_symspell.txt"

try {
context.assets.open(dictionaryFile).bufferedReader().use { reader ->
reader.forEachLine { line ->
parseDictionaryLine(line)?.let { (word, frequency) ->
val currentFreq = mergedWords[word] ?: 0
mergedWords[word] = maxOf(currentFreq, frequency)
}
}
ensureDictionaryParsed(lang)
parsedDictionaryWords[lang]?.forEach { (word, frequency) ->
if (!isWordBlacklisted(word)) {
val currentFreq = mergedWords[word] ?: 0
mergedWords[word] = maxOf(currentFreq, frequency)
}
} catch (_: IOException) {
}
}

Expand Down Expand Up @@ -1338,6 +1300,8 @@ class SpellCheckManager
const val CONTRACTION_GUARANTEED_CONFIDENCE = 0.995

const val DICTIONARY_BATCH_SIZE = 2000
const val INITIAL_WORD_LIST_CAPACITY = 50000
const val MAX_CACHED_SPELL_CHECKERS = 4
const val INITIALIZATION_TIMEOUT_MS = 5000L

const val FREQUENCY_BOOST_MULTIPLIER = 0.02
Expand Down Expand Up @@ -1366,9 +1330,6 @@ class SpellCheckManager
const val MAX_PREFIX_COMPLETION_RESULTS = 10
const val MAX_INPUT_CODEPOINTS = 100

const val COMMON_WORD_MIN_LENGTH = 2
const val COMMON_WORD_MAX_LENGTH = 15

const val HIGH_FREQUENCY_THRESHOLD = 10
const val MEDIUM_FREQUENCY_THRESHOLD = 3
const val HIGH_FREQUENCY_BASE_BOOST = 0.15
Expand Down
Loading