Skip to content

Commit 2b05365

Browse files
Refactored TokenStreamParser and LlamaChatSessionImpl to optimize memory allocation and improve prompt handling.
* **Refactored `TokenStreamParser`**: Replaced `String` buffering and `StreamAction` objects with a `StringBuilder`-based approach to eliminate heap allocations during token processing. * **Renamed API**: Changed `injectPrompt` to `ingestPrompt` across C++, JNI, and Kotlin layers for consistency. * **Optimized `LlamaChatSessionImpl`**: Integrated the new zero-allocation parser and updated the generation loop to emit snapshots only when content or thinking state changes. * **Fixed Formatting**: Removed leading newlines in `ChatML` prompt formatting and updated tests to reflect correct turn boundaries. * **Improved Error Handling**: Added a `catch` block in `ChatViewModel` to handle and display streaming errors in the UI.
1 parent 529f9d6 commit 2b05365

14 files changed

Lines changed: 215 additions & 266 deletions

File tree

app/src/main/java/com/suhel/llamabro/demo/ui/screens/chat/ChatViewModel.kt

Lines changed: 58 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,10 @@ import kotlinx.coroutines.ExperimentalCoroutinesApi
2424
import kotlinx.coroutines.FlowPreview
2525
import kotlinx.coroutines.flow.MutableSharedFlow
2626
import kotlinx.coroutines.flow.SharingStarted
27+
import kotlinx.coroutines.flow.catch
2728
import kotlinx.coroutines.flow.distinctUntilChanged
2829
import kotlinx.coroutines.flow.emitAll
2930
import kotlinx.coroutines.flow.filterNotNull
30-
import kotlinx.coroutines.flow.first
3131
import kotlinx.coroutines.flow.flatMapLatest
3232
import kotlinx.coroutines.flow.flow
3333
import kotlinx.coroutines.flow.flowOf
@@ -78,7 +78,10 @@ class ChatViewModel @Inject constructor(
7878
val history = chatRepository.getMessages(args.conversationId)
7979
.map { chatMessage ->
8080
when (chatMessage.role) {
81-
MessageRole.User -> Message.User(chatMessage.content)
81+
MessageRole.User -> Message.User(
82+
content = chatMessage.content
83+
)
84+
8285
MessageRole.Assistant -> Message.Assistant(
8386
content = chatMessage.content,
8487
thinking = chatMessage.thinking
@@ -116,50 +119,64 @@ class ChatViewModel @Inject constructor(
116119
val incomingMessage = sendMessageTrigger
117120
.distinctUntilChanged()
118121
.flatMapLatest { message ->
119-
if (message != null) {
120-
flow {
121-
emit(
122-
UiChatMessage(
123-
id = "streaming",
124-
role = MessageRole.Assistant,
125-
isProcessing = true,
126-
)
127-
)
122+
if (message == null) {
123+
return@flatMapLatest flowOf(null)
124+
}
128125

129-
chatRepository.addMessage(
130-
conversationId = args.conversationId,
131-
role = MessageRole.User,
132-
content = message
126+
flow<UiChatMessage?> {
127+
emit(
128+
UiChatMessage(
129+
id = "streaming",
130+
role = MessageRole.Assistant,
131+
isProcessing = true,
133132
)
133+
)
134134

135-
val session = chatSessionFlow.filterNotNull().first()
136-
137-
emitAll(
138-
session.completion(message)
139-
.map { chunk ->
140-
if (chunk.isComplete && chunk.contentText != null) {
141-
chatRepository.addMessage(
142-
conversationId = args.conversationId,
143-
role = MessageRole.Assistant,
144-
content = chunk.contentText!!,
145-
thinking = chunk.thinkingText,
146-
tokensPerSecond = chunk.tokensPerSecond
147-
)
135+
chatRepository.addMessage(
136+
conversationId = args.conversationId,
137+
role = MessageRole.User,
138+
content = message
139+
)
148140

149-
null
150-
} else {
151-
UiChatMessage(
152-
id = "streaming",
153-
role = MessageRole.Assistant,
154-
content = chunk.contentText,
155-
thinking = chunk.thinkingText
156-
)
157-
}
141+
emitAll(
142+
chatSessionFlow
143+
.filterNotNull()
144+
.flatMapLatest { chatSession ->
145+
chatSession.completion(message)
146+
}
147+
.onEach { chunk ->
148+
if (chunk.isComplete && chunk.contentText != null) {
149+
chatRepository.addMessage(
150+
conversationId = args.conversationId,
151+
role = MessageRole.Assistant,
152+
content = chunk.contentText!!,
153+
thinking = chunk.thinkingText,
154+
tokensPerSecond = chunk.tokensPerSecond
155+
)
158156
}
159-
)
160-
}
161-
} else {
162-
flowOf(null as UiChatMessage?)
157+
}
158+
.map { chunk ->
159+
if (chunk.isComplete) {
160+
null
161+
} else {
162+
UiChatMessage(
163+
id = "streaming",
164+
role = MessageRole.Assistant,
165+
content = chunk.contentText,
166+
thinking = chunk.thinkingText
167+
)
168+
}
169+
}
170+
.catch { e ->
171+
emit(
172+
UiChatMessage(
173+
id = "streaming",
174+
role = MessageRole.Assistant,
175+
error = e.message
176+
)
177+
)
178+
}
179+
)
163180
}
164181
}
165182
.stateIn(viewModelScope, SharingStarted.Eagerly, null)

sdk/src/main/cpp/jni/llama_session_jni.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ Java_com_suhel_llamabro_sdk_internal_LlamaSessionImpl_00024Jni_setSystemPrompt(J
9191

9292
extern "C"
9393
JNIEXPORT void JNICALL
94-
Java_com_suhel_llamabro_sdk_internal_LlamaSessionImpl_00024Jni_injectPrompt(JNIEnv *env, jclass,
94+
Java_com_suhel_llamabro_sdk_internal_LlamaSessionImpl_00024Jni_ingestPrompt(JNIEnv *env, jclass,
9595
jlong jSessionPtr,
9696
jstring jText,
9797
jboolean jAddSpecial) {
@@ -101,7 +101,7 @@ Java_com_suhel_llamabro_sdk_internal_LlamaSessionImpl_00024Jni_injectPrompt(JNIE
101101
env->ReleaseStringUTFChars(jText, text);
102102

103103
try {
104-
session->injectPrompt(textStr, jAddSpecial);
104+
session->ingestPrompt(textStr, jAddSpecial);
105105
} catch (const LlamaException &ex) {
106106
throwLlamaError(env, ex);
107107
}

sdk/src/main/cpp/session.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -217,8 +217,8 @@ void LlamaSession::setSystemPrompt(const std::string &prompt, bool add_special)
217217
ingest_prompt(prompt, true, add_special);
218218
}
219219

220-
void LlamaSession::injectPrompt(const std::string &user_message, bool add_special) {
221-
ingest_prompt(user_message, false, add_special);
220+
void LlamaSession::ingestPrompt(const std::string &prompt, bool add_special) {
221+
ingest_prompt(prompt, false, add_special);
222222
}
223223

224224
Generation LlamaSession::generate() {

sdk/src/main/cpp/session.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ class LlamaSession {
8484

8585
void setSystemPrompt(const std::string &prompt, bool add_special);
8686

87-
void injectPrompt(const std::string &prompt, bool add_special);
87+
void ingestPrompt(const std::string &prompt, bool add_special);
8888

8989
Generation generate();
9090

sdk/src/main/java/com/suhel/llamabro/sdk/LlamaChatSession.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,10 @@ interface LlamaChatSession {
3535
* If the collector's coroutine is cancelled, the underlying native generation
3636
* is automatically aborted.
3737
*
38-
* @param message The user's input text.
38+
* @param prompt The user's input text.
3939
* @return A flow of [Completion] updates.
4040
*/
41-
fun completion(message: String): Flow<Completion>
41+
fun completion(prompt: String): Flow<Completion>
4242

4343
/**
4444
* Clears the current conversation history while retaining the system prompt.

sdk/src/main/java/com/suhel/llamabro/sdk/LlamaSession.kt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,11 @@ interface LlamaSession : AutoCloseable {
4646
* It is cancellable; if the coroutine is cancelled, the native pre-fill
4747
* loop will be interrupted.
4848
*
49-
* @param text Raw text to add to the context.
49+
* @param prompt Raw text to add to the context.
5050
* @param addSpecial If true, prepends the model's default BOS token.
5151
* @throws LlamaError.ContextOverflow if the context is full and cannot be recovered.
5252
*/
53-
suspend fun prompt(text: String, addSpecial: Boolean = false)
53+
suspend fun ingestPrompt(prompt: String, addSpecial: Boolean = false)
5454

5555
/**
5656
* Samples the next token from the model based on the current context.
@@ -73,7 +73,7 @@ interface LlamaSession : AutoCloseable {
7373
/**
7474
* Asynchronously signals the native engine to stop any active computation.
7575
*
76-
* Use this to immediately halt a long-running [prompt] or [generate] call
76+
* Use this to immediately halt a long-running [ingestPrompt] or [generate] call
7777
* from another thread or UI action.
7878
*/
7979
fun abort()

sdk/src/main/java/com/suhel/llamabro/sdk/internal/LlamaChatSessionImpl.kt

Lines changed: 53 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -14,66 +14,78 @@ import kotlinx.coroutines.flow.onCompletion
1414
import kotlinx.coroutines.isActive
1515
import kotlinx.coroutines.withContext
1616

17-
/**
18-
* High-level implementation of [LlamaChatSession].
19-
*
20-
* This class coordinates between the raw [LlamaSession], the [PromptFormatter],
21-
* and the [TokenStreamParser] to provide a conversational experience.
22-
* It manages the token generation loop and transforms raw tokens into
23-
* structured [Completion] snapshots.
24-
*
25-
* ### Turn lifecycle
26-
* The C++ layer decodes EOG tokens into the KV cache, so every assistant
27-
* turn is automatically closed at the native level. This class does not
28-
* need to track or inject turn-closing tokens.
29-
*/
3017
internal class LlamaChatSessionImpl(
3118
private val session: LlamaSession,
3219
private val systemPrompt: String
3320
) : LlamaChatSession {
34-
21+
private val parser = TokenStreamParser()
3522
private val promptFormatter = PromptFormatter(session.modelConfig.promptFormat)
3623

37-
override fun completion(message: String): Flow<Completion> = flow {
38-
val parser = TokenStreamParser(session.modelConfig.promptFormat.assistantSuffix)
24+
override fun completion(prompt: String): Flow<Completion> = flow {
3925
var completionState = Completion()
4026
var tokenCount = 0
27+
val contentBuilder = StringBuilder()
28+
val thinkingBuilder = StringBuilder()
29+
30+
parser.reset()
31+
session.ingestPrompt(promptFormatter.user(prompt) + promptFormatter.assistantStart())
4132

42-
// Inject user turn + assistant turn prefix
43-
session.prompt(promptFormatter.user(message) + promptFormatter.assistantStart())
4433
val startTime = System.nanoTime()
4534

4635
while (currentCoroutineContext().isActive) {
4736
val generation = try {
4837
session.generate()
4938
} catch (_: LlamaError.Cancelled) {
50-
emit(completionState.finalize(tokenCount, startTime, true))
39+
emit(
40+
completionState.finalize(
41+
tokenCount = tokenCount,
42+
startTime = startTime,
43+
isInterrupted = true,
44+
contentBuilder = contentBuilder,
45+
thinkingBuilder = thinkingBuilder
46+
)
47+
)
5148
return@flow
5249
} catch (e: LlamaError) {
5350
throw e
5451
}
5552

56-
if (generation.isComplete) {
57-
completionState = completionState.applyActions(parser.flush())
58-
emit(completionState.finalize(tokenCount, startTime))
59-
break
60-
}
61-
62-
generation.token?.let { generatedToken ->
53+
generation.token?.let { token ->
6354
tokenCount++
64-
val actions = parser.process(generatedToken)
6555

66-
if (actions.isNotEmpty()) {
67-
completionState = completionState.applyActions(actions)
56+
val contentLenBefore = contentBuilder.length
57+
val thinkingLenBefore = thinkingBuilder.length
58+
val stateBefore = parser.isThinking
59+
60+
// The parser directly modifies the builders. 0 allocations.
61+
parser.process(token, contentBuilder, thinkingBuilder)
62+
63+
// Only emit a new state if the parser actually appended text or flipped state
64+
if (
65+
contentBuilder.length > contentLenBefore ||
66+
thinkingBuilder.length > thinkingLenBefore ||
67+
parser.isThinking != stateBefore
68+
) {
69+
completionState = completionState.copy(
70+
contentText = if (contentBuilder.isEmpty()) null else contentBuilder.toString(),
71+
thinkingText = if (thinkingBuilder.isEmpty()) null else thinkingBuilder.toString()
72+
)
6873
emit(completionState)
6974
}
75+
}
7076

71-
// Stop if the parser intercepted a configured stop sequence
72-
// (e.g. assistant suffix for custom formats where suffix ≠ EOG).
73-
if (actions.any { it is StreamAction.Stop }) {
74-
emit(completionState.finalize(tokenCount, startTime))
75-
break
76-
}
77+
if (generation.isComplete) {
78+
parser.flush(contentBuilder, thinkingBuilder)
79+
emit(
80+
completionState.finalize(
81+
tokenCount = tokenCount,
82+
startTime = startTime,
83+
isInterrupted = false,
84+
contentBuilder = contentBuilder,
85+
thinkingBuilder = thinkingBuilder
86+
)
87+
)
88+
break
7789
}
7890
}
7991
}
@@ -84,42 +96,21 @@ internal class LlamaChatSessionImpl(
8496
}
8597
.flowOn(Dispatchers.IO)
8698

87-
/** Appends parser actions to the current completion snapshot. */
88-
private fun Completion.applyActions(actions: List<StreamAction>): Completion {
89-
var newContent = this.contentText
90-
var newThinking = this.thinkingText
91-
92-
for (action in actions) {
93-
when (action) {
94-
is StreamAction.Content -> {
95-
newContent = (newContent ?: "") + action.text
96-
}
97-
98-
is StreamAction.Thinking -> {
99-
newThinking = (newThinking ?: "") + action.text
100-
}
101-
102-
is StreamAction.Stop -> {
103-
}
104-
}
105-
}
106-
107-
return this.copy(contentText = newContent, thinkingText = newThinking)
108-
}
109-
11099
/** Finalizes completion state with performance metrics and trimming. */
111100
private fun Completion.finalize(
112101
tokenCount: Int,
113102
startTime: Long,
114-
isInterrupted: Boolean = false
103+
isInterrupted: Boolean,
104+
contentBuilder: StringBuilder,
105+
thinkingBuilder: StringBuilder
115106
): Completion {
116107
val endTime = System.nanoTime()
117108
val durationNs = (endTime - startTime).coerceAtLeast(1)
118109
val tps = (tokenCount.toDouble() / durationNs * 1e9).toFloat()
119110

120111
return this.copy(
121-
thinkingText = if (this.thinkingText.isNullOrBlank()) null else this.thinkingText.trim(),
122-
contentText = if (this.contentText.isNullOrBlank()) null else this.contentText.trim(),
112+
thinkingText = thinkingBuilder.ifBlank { null }?.toString()?.trim(),
113+
contentText = contentBuilder.ifBlank { null }?.toString()?.trim(),
123114
tokensPerSecond = tps,
124115
isComplete = true,
125116
isInterrupted = isInterrupted,
@@ -134,7 +125,7 @@ internal class LlamaChatSessionImpl(
134125
override suspend fun loadHistory(messages: List<Message>) =
135126
withContext(Dispatchers.IO) {
136127
messages.forEach { msg ->
137-
session.prompt(promptFormatter.format(msg))
128+
session.ingestPrompt(promptFormatter.format(msg))
138129
}
139130
}
140131

sdk/src/main/java/com/suhel/llamabro/sdk/internal/LlamaSessionImpl.kt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,12 +75,12 @@ internal class LlamaSessionImpl(
7575
}
7676
}
7777

78-
override suspend fun prompt(text: String, addSpecial: Boolean) =
78+
override suspend fun ingestPrompt(prompt: String, addSpecial: Boolean) =
7979
withContext(Dispatchers.IO) {
8080
mutex.withLock {
8181
try {
8282
runInterruptible {
83-
Jni.injectPrompt(ptr, text, addSpecial)
83+
Jni.ingestPrompt(ptr, prompt, addSpecial)
8484
}
8585
} catch (e: RuntimeException) {
8686
throw mapNativeError(e)
@@ -182,7 +182,7 @@ internal class LlamaSessionImpl(
182182
external fun setSystemPrompt(sessionPtr: Long, text: String, addSpecial: Boolean)
183183

184184
@JvmStatic
185-
external fun injectPrompt(sessionPtr: Long, text: String, addSpecial: Boolean)
185+
external fun ingestPrompt(sessionPtr: Long, text: String, addSpecial: Boolean)
186186

187187
@JvmStatic
188188
external fun clear(sessionPtr: Long)

0 commit comments

Comments
 (0)