Skip to content

Commit 529f9d6

Browse files
sdk: update JNI and session generation to return structured result
* Introduced `TokenGenerationResult` (Kotlin) and `Generation` (C++) to return both the generated token and a completion flag from the native layer. * Updated `LlamaSession::generate` to return the new structured result instead of a nullable string. * Implemented JNI reference caching in `llama_session_jni.cpp` for the new result class to improve performance. * Simplified `LlamaEngine` JNI callback logic by removing unnecessary global references during synchronous model loading. * Updated `LlamaChatSessionImpl` and tests to handle the new generation result format. * Refactored JNI parameter naming for consistency (e.g., `kSessionPtr` to `jSessionPtr`).
1 parent 76fc623 commit 529f9d6

9 files changed

Lines changed: 175 additions & 106 deletions

File tree

sdk/src/main/cpp/jni/llama_engine_jni.cpp

Lines changed: 27 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,14 @@
77

88
// ── Shared helper ─────────────────────────────────────────────────────────────
99

10-
static NativeEngineParams readEngineParams(JNIEnv *env, jobject kConfig) {
11-
auto configReader = JniConfigReader(env, kConfig);
10+
namespace jni_refs {
11+
constexpr auto progress_listener_method = "onProgress";
12+
constexpr auto progress_listener_method_sig = "(F)Z";
13+
}
14+
15+
static NativeEngineParams readEngineParams(JNIEnv *env,
16+
jobject jConfig) {
17+
auto configReader = JniConfigReader(env, jConfig);
1218
return NativeEngineParams{
1319
.model_path = configReader.getString("modelPath"),
1420
.threads = configReader.getInt("threads"),
@@ -22,9 +28,10 @@ static NativeEngineParams readEngineParams(JNIEnv *env, jobject kConfig) {
2228
extern "C"
2329
JNIEXPORT jlong JNICALL
2430
Java_com_suhel_llamabro_sdk_internal_LlamaEngineImpl_00024Jni_create(JNIEnv *env, jclass,
25-
jobject kConfig) {
31+
jobject jConfig) {
2632
try {
27-
return reinterpret_cast<jlong>(new LlamaEngine(readEngineParams(env, kConfig)));
33+
auto instance = new LlamaEngine(readEngineParams(env, jConfig));
34+
return reinterpret_cast<jlong>(instance);
2835
} catch (const LlamaException &ex) {
2936
throwLlamaError(env, ex);
3037
return 0L;
@@ -37,39 +44,27 @@ extern "C"
3744
JNIEXPORT jlong JNICALL
3845
Java_com_suhel_llamabro_sdk_internal_LlamaEngineImpl_00024Jni_createWithProgress(JNIEnv *env,
3946
jclass,
40-
jobject kConfig,
41-
jobject kListener) {
42-
auto config = readEngineParams(env, kConfig);
47+
jobject jConfig,
48+
jobject jListener) {
49+
auto config = readEngineParams(env, jConfig);
4350

4451
// Resolve the callback method ID once before entering the blocking load
45-
jclass listenerClass = env->GetObjectClass(kListener);
46-
jmethodID onProgress = env->GetMethodID(listenerClass, "onProgress", "(F)Z");
47-
env->DeleteLocalRef(listenerClass);
48-
49-
// Create a GlobalRef to guarantee the object survives the JNI frame securely
50-
auto globalListener = env->NewGlobalRef(kListener);
51-
52-
// Get the JavaVM so the lambda can attach/detach thread or use the current env
53-
JavaVM* jvm;
54-
env->GetJavaVM(&jvm);
52+
auto jListenerClass = env->GetObjectClass(jListener);
53+
auto jOnProgress = env->GetMethodID(jListenerClass,
54+
jni_refs::progress_listener_method,
55+
jni_refs::progress_listener_method_sig);
56+
env->DeleteLocalRef(jListenerClass);
5557

56-
config.progress_callback = [jvm, globalListener, onProgress](float progress) -> bool {
57-
JNIEnv *currentEnv;
58-
// In this specific codebase, llama_model_load_from_file is synchronous,
59-
// but getting the env from JVM is the correct modern JNI pattern.
60-
auto res = jvm->GetEnv(reinterpret_cast<void**>(&currentEnv), JNI_VERSION_1_6);
61-
if (res == JNI_OK) {
62-
return currentEnv->CallBooleanMethod(globalListener, onProgress, static_cast<jfloat>(progress)) == JNI_TRUE;
63-
}
64-
return false;
58+
// It is safe to pass these refs to the callback because this method is synchronous
59+
config.progress_callback = [env, jListener, jOnProgress](float progress) -> bool {
60+
return env->CallBooleanMethod(jListener, jOnProgress,
61+
static_cast<jfloat>(progress)) == JNI_TRUE;
6562
};
6663

6764
try {
68-
auto *engine = new LlamaEngine(config);
69-
env->DeleteGlobalRef(globalListener); // Safe to delete after synchronous load
70-
return reinterpret_cast<jlong>(engine);
65+
auto instance = new LlamaEngine(config);
66+
return reinterpret_cast<jlong>(instance);
7167
} catch (const LlamaException &ex) {
72-
env->DeleteGlobalRef(globalListener); // Clean up on error
7368
throwLlamaError(env, ex);
7469
return 0L;
7570
}
@@ -78,6 +73,6 @@ Java_com_suhel_llamabro_sdk_internal_LlamaEngineImpl_00024Jni_createWithProgress
7873
extern "C"
7974
JNIEXPORT void JNICALL
8075
Java_com_suhel_llamabro_sdk_internal_LlamaEngineImpl_00024Jni_destroy(JNIEnv *, jclass,
81-
jlong ptr) {
82-
delete reinterpret_cast<LlamaEngine *>(ptr);
76+
jlong jEnginePtr) {
77+
delete reinterpret_cast<LlamaEngine *>(jEnginePtr);
8378
}

sdk/src/main/cpp/jni/llama_session_jni.cpp

Lines changed: 72 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,44 @@
44
#include "utils/llama_exception.h"
55
#include "engine.h"
66

7+
namespace jni_refs {
8+
constexpr auto token_generation_result_class = "com/suhel/llamabro/sdk/internal/LlamaSessionImpl$NativeTokenGenerationResult";
9+
constexpr auto token_generation_result_constructor_sig = "(Ljava/lang/String;Z)V";
10+
}
11+
12+
static jclass jTokenGenerationResultClass = nullptr;
13+
static jmethodID jTokenGenerationResultConstructor = nullptr;
14+
15+
static void cache_refs(JNIEnv *env) {
16+
auto local = env->FindClass(jni_refs::token_generation_result_class);
17+
18+
jTokenGenerationResultClass = reinterpret_cast<jclass>(env->NewGlobalRef(local));
19+
jTokenGenerationResultConstructor = env->GetMethodID(jTokenGenerationResultClass,
20+
"<init>",
21+
jni_refs::token_generation_result_constructor_sig);
22+
env->DeleteLocalRef(local);
23+
}
24+
25+
JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM *vm, void *) {
26+
JNIEnv *env;
27+
if (vm->GetEnv(reinterpret_cast<void **>(&env), JNI_VERSION_1_6) != JNI_OK) {
28+
return JNI_ERR;
29+
}
30+
31+
cache_refs(env);
32+
return JNI_VERSION_1_6;
33+
}
34+
735
// ── create ────────────────────────────────────────────────────────────────────
836

937
extern "C"
1038
JNIEXPORT jlong JNICALL
1139
Java_com_suhel_llamabro_sdk_internal_LlamaSessionImpl_00024Jni_create(JNIEnv *env,
1240
jclass,
13-
jlong kEnginePtr,
14-
jobject kParams) {
15-
auto engine = reinterpret_cast<LlamaEngine *>(kEnginePtr);
16-
auto configReader = JniConfigReader(env, kParams);
41+
jlong jEnginePtr,
42+
jobject jParams) {
43+
auto engine = reinterpret_cast<LlamaEngine *>(jEnginePtr);
44+
auto configReader = JniConfigReader(env, jParams);
1745

1846
auto config = NativeSessionParams{
1947
.context_size = configReader.getInt("contextSize"),
@@ -46,16 +74,16 @@ Java_com_suhel_llamabro_sdk_internal_LlamaSessionImpl_00024Jni_create(JNIEnv *en
4674
extern "C"
4775
JNIEXPORT void JNICALL
4876
Java_com_suhel_llamabro_sdk_internal_LlamaSessionImpl_00024Jni_setSystemPrompt(JNIEnv *env, jclass,
49-
jlong kSessionPtr,
50-
jstring kText,
51-
jboolean kAddSpecial) {
52-
auto session = reinterpret_cast<LlamaSession *>(kSessionPtr);
53-
auto text = env->GetStringUTFChars(kText, nullptr);
54-
std::string textStr(text);
55-
env->ReleaseStringUTFChars(kText, text);
77+
jlong jSessionPtr,
78+
jstring jText,
79+
jboolean jAddSpecial) {
80+
auto session = reinterpret_cast<LlamaSession *>(jSessionPtr);
81+
auto text = env->GetStringUTFChars(jText, nullptr);
82+
auto textStr = std::string(text);
83+
env->ReleaseStringUTFChars(jText, text);
5684

5785
try {
58-
session->setSystemPrompt(textStr, kAddSpecial);
86+
session->setSystemPrompt(textStr, jAddSpecial);
5987
} catch (const LlamaException &ex) {
6088
throwLlamaError(env, ex);
6189
}
@@ -64,16 +92,16 @@ Java_com_suhel_llamabro_sdk_internal_LlamaSessionImpl_00024Jni_setSystemPrompt(J
6492
extern "C"
6593
JNIEXPORT void JNICALL
6694
Java_com_suhel_llamabro_sdk_internal_LlamaSessionImpl_00024Jni_injectPrompt(JNIEnv *env, jclass,
67-
jlong kSessionPtr,
68-
jstring kText,
69-
jboolean kAddSpecial) {
70-
auto session = reinterpret_cast<LlamaSession *>(kSessionPtr);
71-
auto text = env->GetStringUTFChars(kText, nullptr);
72-
std::string textStr(text);
73-
env->ReleaseStringUTFChars(kText, text);
95+
jlong jSessionPtr,
96+
jstring jText,
97+
jboolean jAddSpecial) {
98+
auto session = reinterpret_cast<LlamaSession *>(jSessionPtr);
99+
auto text = env->GetStringUTFChars(jText, nullptr);
100+
auto textStr = std::string(text);
101+
env->ReleaseStringUTFChars(jText, text);
74102

75103
try {
76-
session->injectPrompt(textStr, kAddSpecial);
104+
session->injectPrompt(textStr, jAddSpecial);
77105
} catch (const LlamaException &ex) {
78106
throwLlamaError(env, ex);
79107
}
@@ -84,9 +112,11 @@ Java_com_suhel_llamabro_sdk_internal_LlamaSessionImpl_00024Jni_injectPrompt(JNIE
84112
extern "C"
85113
JNIEXPORT void JNICALL
86114
Java_com_suhel_llamabro_sdk_internal_LlamaSessionImpl_00024Jni_clear(JNIEnv *env, jclass,
87-
jlong kSessionPtr) {
115+
jlong jSessionPtr) {
116+
auto session = reinterpret_cast<LlamaSession *>(jSessionPtr);
117+
88118
try {
89-
reinterpret_cast<LlamaSession *>(kSessionPtr)->clear();
119+
session->clear();
90120
} catch (const LlamaException &ex) {
91121
throwLlamaError(env, ex);
92122
}
@@ -97,25 +127,31 @@ Java_com_suhel_llamabro_sdk_internal_LlamaSessionImpl_00024Jni_clear(JNIEnv *env
97127
extern "C"
98128
JNIEXPORT void JNICALL
99129
Java_com_suhel_llamabro_sdk_internal_LlamaSessionImpl_00024Jni_abort(JNIEnv *, jclass,
100-
jlong kSessionPtr) {
101-
reinterpret_cast<LlamaSession *>(kSessionPtr)->abort();
130+
jlong jSessionPtr) {
131+
auto session = reinterpret_cast<LlamaSession *>(jSessionPtr);
132+
session->abort();
102133
}
103134

104135
// ── generate ─────────────────────────────────────────────────────────────────
105136

106137
extern "C"
107-
JNIEXPORT jstring JNICALL
138+
JNIEXPORT jobject JNICALL
108139
Java_com_suhel_llamabro_sdk_internal_LlamaSessionImpl_00024Jni_generate(JNIEnv *env, jclass,
109-
jlong kSessionPtr) {
140+
jlong jSessionPtr) {
141+
auto session = reinterpret_cast<LlamaSession *>(jSessionPtr);
142+
110143
try {
111-
auto result = reinterpret_cast<LlamaSession *>(kSessionPtr)->generate();
144+
auto gen = session->generate();
145+
auto token = gen.token;
112146

113-
if (result.has_value()) {
114-
const auto &utf16 = result.value();
115-
return env->NewString(reinterpret_cast<const jchar *>(utf16.data()),
116-
static_cast<jsize>(utf16.size()));
117-
}
118-
return nullptr;
147+
auto jToken = token.has_value()
148+
? env->NewString(reinterpret_cast<const jchar *>(token.value().data()),
149+
static_cast<jsize>(token.value().size()))
150+
: nullptr;
151+
auto jIsComplete = static_cast<jboolean>(gen.is_complete);
152+
153+
return env->NewObject(jTokenGenerationResultClass, jTokenGenerationResultConstructor,
154+
jToken, jIsComplete);
119155
} catch (const LlamaException &ex) {
120156
throwLlamaError(env, ex);
121157
return nullptr;
@@ -127,6 +163,7 @@ Java_com_suhel_llamabro_sdk_internal_LlamaSessionImpl_00024Jni_generate(JNIEnv *
127163
extern "C"
128164
JNIEXPORT void JNICALL
129165
Java_com_suhel_llamabro_sdk_internal_LlamaSessionImpl_00024Jni_destroy(JNIEnv *, jclass,
130-
jlong kSessionPtr) {
131-
delete reinterpret_cast<LlamaSession *>(kSessionPtr);
166+
jlong jSessionPtr) {
167+
auto session = reinterpret_cast<LlamaSession *>(jSessionPtr);
168+
delete session;
132169
}

sdk/src/main/cpp/session.cpp

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ bool LlamaSession::is_token_buffer_valid() {
201201
return !token_buffer.empty() && utils::llm_is_valid_utf8(token_buffer);
202202
}
203203

204-
std::u16string LlamaSession::get_token_buffer_as_u16string() {
204+
std::u16string LlamaSession::get_and_clear_token_buffer() {
205205
auto result = utils::llm_utf8_to_utf16_sanitized(token_buffer);
206206
token_buffer.clear();
207207
return result;
@@ -221,7 +221,7 @@ void LlamaSession::injectPrompt(const std::string &user_message, bool add_specia
221221
ingest_prompt(user_message, false, add_special);
222222
}
223223

224-
std::optional<std::u16string> LlamaSession::generate() {
224+
Generation LlamaSession::generate() {
225225
auto ctx = llama_context.get();
226226
auto model = llama_get_model(ctx);
227227
auto vocab = llama_model_get_vocab(model);
@@ -247,20 +247,24 @@ std::optional<std::u16string> LlamaSession::generate() {
247247
}
248248
}
249249

250-
if (is_token_buffer_valid()) {
251-
return get_token_buffer_as_u16string();
252-
}
253-
return std::nullopt;
250+
return Generation{
251+
.token = is_token_buffer_valid()
252+
? std::make_optional(get_and_clear_token_buffer())
253+
: std::nullopt,
254+
.is_complete = true,
255+
};
254256
}
255257

256258
auto piece = utils::token_to_piece(vocab, new_token, true);
257259
token_buffer.append(piece);
258260

259261
if (!roll_kv_cache_if_needed(1)) {
260-
if (is_token_buffer_valid()) {
261-
return get_token_buffer_as_u16string();
262-
}
263-
return std::nullopt;
262+
return Generation{
263+
.token = is_token_buffer_valid()
264+
? std::make_optional(get_and_clear_token_buffer())
265+
: std::nullopt,
266+
.is_complete = true,
267+
};
264268
}
265269

266270
utils::batch_clear(llama_batch);
@@ -275,7 +279,12 @@ std::optional<std::u16string> LlamaSession::generate() {
275279
n_past += 1;
276280

277281
if (is_token_buffer_valid()) {
278-
return get_token_buffer_as_u16string();
282+
return Generation{
283+
.token = is_token_buffer_valid()
284+
? std::make_optional(get_and_clear_token_buffer())
285+
: std::nullopt,
286+
.is_complete = false,
287+
};
279288
}
280289
}
281290
}

sdk/src/main/cpp/session.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,11 @@ struct NativeSessionParams {
3333
int micro_batch_size;
3434
};
3535

36+
struct Generation {
37+
std::optional<std::u16string> token;
38+
bool is_complete;
39+
};
40+
3641
#include <atomic>
3742

3843
class LlamaSession {
@@ -62,7 +67,7 @@ class LlamaSession {
6267

6368
bool is_token_buffer_valid();
6469

65-
std::u16string get_token_buffer_as_u16string();
70+
std::u16string get_and_clear_token_buffer();
6671

6772
public:
6873
LlamaSession(llama_model *model, int threads, const NativeSessionParams &config);
@@ -81,7 +86,7 @@ class LlamaSession {
8186

8287
void injectPrompt(const std::string &prompt, bool add_special);
8388

84-
std::optional<std::u16string> generate();
89+
Generation generate();
8590

8691
void clear();
8792

sdk/src/main/java/com/suhel/llamabro/sdk/LlamaSession.kt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package com.suhel.llamabro.sdk
22

33
import com.suhel.llamabro.sdk.model.ResourceState
44
import com.suhel.llamabro.sdk.model.ModelConfig
5+
import com.suhel.llamabro.sdk.model.TokenGenerationResult
56
import kotlinx.coroutines.flow.Flow
67

78
/**
@@ -60,7 +61,7 @@ interface LlamaSession : AutoCloseable {
6061
* End-of-Generation (EOG) token.
6162
* @throws LlamaError.DecodeFailed if the native sampling loop fails.
6263
*/
63-
suspend fun generate(): String?
64+
suspend fun generate(): TokenGenerationResult
6465

6566
/**
6667
* Clears the conversation history from the KV cache.

0 commit comments

Comments
 (0)