diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java index 5e080e0c369..54494979766 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java @@ -8,8 +8,6 @@ package org.pytorch.executorch.extension.llm; -import com.facebook.jni.HybridData; -import com.facebook.jni.annotations.DoNotStrip; import java.io.File; import java.util.List; import org.pytorch.executorch.ExecuTorchRuntime; @@ -28,18 +26,19 @@ public class LlmModule { public static final int MODEL_TYPE_TEXT_VISION = 2; public static final int MODEL_TYPE_MULTIMODAL = 2; - private final HybridData mHybridData; + private long mNativeHandle; private static final int DEFAULT_SEQ_LEN = 128; private static final boolean DEFAULT_ECHO = true; - @DoNotStrip - private static native HybridData initHybrid( + private static native long nativeCreate( int modelType, String modulePath, String tokenizerPath, float temperature, List dataFiles); + private static native void nativeDestroy(long nativeHandle); + /** * Constructs a LLM Module for a model with given type, model path, tokenizer, temperature, and * dataFiles. @@ -61,7 +60,7 @@ public LlmModule( throw new RuntimeException("Cannot load tokenizer path " + tokenizerPath); } - mHybridData = initHybrid(modelType, modulePath, tokenizerPath, temperature, dataFiles); + mNativeHandle = nativeCreate(modelType, modulePath, tokenizerPath, temperature, dataFiles); } /** @@ -107,7 +106,16 @@ public LlmModule(LlmModuleConfig config) { } public void resetNative() { - mHybridData.resetNative(); + if (mNativeHandle != 0) { + nativeDestroy(mNativeHandle); + mNativeHandle = 0; + } + } + + @Override + protected void finalize() throws Throwable { + resetNative(); + super.finalize(); } /** @@ -150,7 +158,12 @@ public int generate(String prompt, LlmCallback llmCallback, boolean echo) { * @param llmCallback callback object to receive results * @param echo indicate whether to echo the input prompt or not (text completion vs chat) */ - public native int generate(String prompt, int seqLen, LlmCallback llmCallback, boolean echo); + public int generate(String prompt, int seqLen, LlmCallback llmCallback, boolean echo) { + return nativeGenerate(mNativeHandle, prompt, seqLen, llmCallback, echo); + } + + private static native int nativeGenerate( + long nativeHandle, String prompt, int seqLen, LlmCallback llmCallback, boolean echo); /** * Start generating tokens from the module. @@ -206,14 +219,15 @@ public int generate( */ @Experimental public long prefillImages(int[] image, int width, int height, int channels) { - int nativeResult = appendImagesInput(image, width, height, channels); + int nativeResult = nativeAppendImagesInput(mNativeHandle, image, width, height, channels); if (nativeResult != 0) { throw new RuntimeException("Prefill failed with error code: " + nativeResult); } return 0; } - private native int appendImagesInput(int[] image, int width, int height, int channels); + private static native int nativeAppendImagesInput( + long nativeHandle, int[] image, int width, int height, int channels); /** * Prefill a multimodal Module with the given images input. @@ -228,15 +242,16 @@ public long prefillImages(int[] image, int width, int height, int channels) { */ @Experimental public long prefillImages(float[] image, int width, int height, int channels) { - int nativeResult = appendNormalizedImagesInput(image, width, height, channels); + int nativeResult = + nativeAppendNormalizedImagesInput(mNativeHandle, image, width, height, channels); if (nativeResult != 0) { throw new RuntimeException("Prefill failed with error code: " + nativeResult); } return 0; } - private native int appendNormalizedImagesInput( - float[] image, int width, int height, int channels); + private static native int nativeAppendNormalizedImagesInput( + long nativeHandle, float[] image, int width, int height, int channels); /** * Prefill a multimodal Module with the given audio input. @@ -251,14 +266,15 @@ private native int appendNormalizedImagesInput( */ @Experimental public long prefillAudio(byte[] audio, int batch_size, int n_bins, int n_frames) { - int nativeResult = appendAudioInput(audio, batch_size, n_bins, n_frames); + int nativeResult = nativeAppendAudioInput(mNativeHandle, audio, batch_size, n_bins, n_frames); if (nativeResult != 0) { throw new RuntimeException("Prefill failed with error code: " + nativeResult); } return 0; } - private native int appendAudioInput(byte[] audio, int batch_size, int n_bins, int n_frames); + private static native int nativeAppendAudioInput( + long nativeHandle, byte[] audio, int batch_size, int n_bins, int n_frames); /** * Prefill a multimodal Module with the given audio input. @@ -273,14 +289,16 @@ public long prefillAudio(byte[] audio, int batch_size, int n_bins, int n_frames) */ @Experimental public long prefillAudio(float[] audio, int batch_size, int n_bins, int n_frames) { - int nativeResult = appendAudioInputFloat(audio, batch_size, n_bins, n_frames); + int nativeResult = + nativeAppendAudioInputFloat(mNativeHandle, audio, batch_size, n_bins, n_frames); if (nativeResult != 0) { throw new RuntimeException("Prefill failed with error code: " + nativeResult); } return 0; } - private native int appendAudioInputFloat(float[] audio, int batch_size, int n_bins, int n_frames); + private static native int nativeAppendAudioInputFloat( + long nativeHandle, float[] audio, int batch_size, int n_bins, int n_frames); /** * Prefill a multimodal Module with the given raw audio input. @@ -295,15 +313,16 @@ public long prefillAudio(float[] audio, int batch_size, int n_bins, int n_frames */ @Experimental public long prefillRawAudio(byte[] audio, int batch_size, int n_channels, int n_samples) { - int nativeResult = appendRawAudioInput(audio, batch_size, n_channels, n_samples); + int nativeResult = + nativeAppendRawAudioInput(mNativeHandle, audio, batch_size, n_channels, n_samples); if (nativeResult != 0) { throw new RuntimeException("Prefill failed with error code: " + nativeResult); } return 0; } - private native int appendRawAudioInput( - byte[] audio, int batch_size, int n_channels, int n_samples); + private static native int nativeAppendRawAudioInput( + long nativeHandle, byte[] audio, int batch_size, int n_channels, int n_samples); /** * Prefill a multimodal Module with the given text input. @@ -315,7 +334,7 @@ private native int appendRawAudioInput( */ @Experimental public long prefillPrompt(String prompt) { - int nativeResult = appendTextInput(prompt); + int nativeResult = nativeAppendTextInput(mNativeHandle, prompt); if (nativeResult != 0) { throw new RuntimeException("Prefill failed with error code: " + nativeResult); } @@ -323,20 +342,30 @@ public long prefillPrompt(String prompt) { } // returns status - private native int appendTextInput(String prompt); + private static native int nativeAppendTextInput(long nativeHandle, String prompt); /** * Reset the context of the LLM. This will clear the KV cache and reset the state of the LLM. * *

The startPos will be reset to 0. */ - public native void resetContext(); + public void resetContext() { + nativeResetContext(mNativeHandle); + } + + private static native void nativeResetContext(long nativeHandle); /** Stop current generate() before it finishes. */ - @DoNotStrip - public native void stop(); + public void stop() { + nativeStop(mNativeHandle); + } + + private static native void nativeStop(long nativeHandle); /** Force loading the module. Otherwise the model is loaded during first generate(). */ - @DoNotStrip - public native int load(); + public int load() { + return nativeLoad(mNativeHandle); + } + + private static native int nativeLoad(long nativeHandle); } diff --git a/extension/android/jni/jni_helper.cpp b/extension/android/jni/jni_helper.cpp index 6491524c7ac..37f9b271e52 100644 --- a/extension/android/jni/jni_helper.cpp +++ b/extension/android/jni/jni_helper.cpp @@ -10,6 +10,60 @@ namespace executorch::jni_helper { +void throwExecutorchException( + JNIEnv* env, + uint32_t errorCode, + const std::string& details) { + if (!env) { + return; + } + + // Find the exception class + jclass exceptionClass = + env->FindClass("org/pytorch/executorch/ExecutorchRuntimeException"); + if (exceptionClass == nullptr) { + // Class not found, clear the exception and return + env->ExceptionClear(); + return; + } + + // Find the static factory method: makeExecutorchException(int, String) + jmethodID makeExceptionMethod = env->GetStaticMethodID( + exceptionClass, + "makeExecutorchException", + "(ILjava/lang/String;)Ljava/lang/RuntimeException;"); + if (makeExceptionMethod == nullptr) { + env->ExceptionClear(); + env->DeleteLocalRef(exceptionClass); + return; + } + + // Create the details string + jstring jDetails = env->NewStringUTF(details.c_str()); + if (jDetails == nullptr) { + env->ExceptionClear(); + env->DeleteLocalRef(exceptionClass); + return; + } + + // Call the factory method to create the exception object + jobject exception = env->CallStaticObjectMethod( + exceptionClass, + makeExceptionMethod, + static_cast(errorCode), + jDetails); + + env->DeleteLocalRef(jDetails); + + if (exception != nullptr) { + env->Throw(static_cast(exception)); + env->DeleteLocalRef(exception); + } + + env->DeleteLocalRef(exceptionClass); +} + +#if EXECUTORCH_HAS_FBJNI void throwExecutorchException(uint32_t errorCode, const std::string& details) { // Get the current JNI environment auto env = facebook::jni::Environment::current(); @@ -34,5 +88,6 @@ void throwExecutorchException(uint32_t errorCode, const std::string& details) { auto exception = makeExceptionMethod(exceptionClass, errorCode, jDetails); facebook::jni::throwNewJavaException(exception.get()); } +#endif } // namespace executorch::jni_helper diff --git a/extension/android/jni/jni_helper.h b/extension/android/jni/jni_helper.h index 898c1619d9c..683a3cfe447 100644 --- a/extension/android/jni/jni_helper.h +++ b/extension/android/jni/jni_helper.h @@ -8,9 +8,16 @@ #pragma once -#include +#include #include +#if __has_include() +#include +#define EXECUTORCH_HAS_FBJNI 1 +#else +#define EXECUTORCH_HAS_FBJNI 0 +#endif + namespace executorch::jni_helper { /** @@ -18,6 +25,25 @@ namespace executorch::jni_helper { * code and details. Uses the Java factory method * ExecutorchRuntimeException.makeExecutorchException(int, String). * + * This version takes JNIEnv* directly and works with pure JNI. + * + * @param env The JNI environment. + * @param errorCode The error code from the C++ Executorch runtime. + * @param details Additional details to include in the exception message. + */ +void throwExecutorchException( + JNIEnv* env, + uint32_t errorCode, + const std::string& details); + +#if EXECUTORCH_HAS_FBJNI +/** + * Throws a Java ExecutorchRuntimeException corresponding to the given error + * code and details. Uses the Java factory method + * ExecutorchRuntimeException.makeExecutorchException(int, String). + * + * This version uses fbjni to get the current JNI environment. + * * @param errorCode The error code from the C++ Executorch runtime. * @param details Additional details to include in the exception message. */ @@ -29,5 +55,6 @@ struct JExecutorchRuntimeException static constexpr auto kJavaDescriptor = "Lorg/pytorch/executorch/ExecutorchRuntimeException;"; }; +#endif } // namespace executorch::jni_helper diff --git a/extension/android/jni/jni_layer.cpp b/extension/android/jni/jni_layer.cpp index 1f8457e00c5..0fbc0f14e54 100644 --- a/extension/android/jni/jni_layer.cpp +++ b/extension/android/jni/jni_layer.cpp @@ -535,10 +535,10 @@ class ExecuTorchJni : public facebook::jni::HybridClass { } // namespace executorch::extension #ifdef EXECUTORCH_BUILD_LLAMA_JNI -extern void register_natives_for_llm(); +extern void register_natives_for_llm(JNIEnv* env); #else // No op if we don't build LLM -void register_natives_for_llm() {} +void register_natives_for_llm(JNIEnv* /* env */) {} #endif extern void register_natives_for_runtime(); @@ -552,7 +552,9 @@ void register_natives_for_training() {} JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void*) { return facebook::jni::initialize(vm, [] { executorch::extension::ExecuTorchJni::registerNatives(); - register_natives_for_llm(); + // Get JNIEnv for pure JNI registration in LLM + JNIEnv* env = facebook::jni::Environment::current(); + register_natives_for_llm(env); register_natives_for_runtime(); register_natives_for_training(); }); diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp index 888e09e7989..c6844552523 100644 --- a/extension/android/jni/jni_layer_llama.cpp +++ b/extension/android/jni/jni_layer_llama.cpp @@ -6,9 +6,12 @@ * LICENSE file in the root directory of this source tree. */ +#include + #include #include #include +#include #include #include #include @@ -30,9 +33,6 @@ #include #endif -#include -#include - #if defined(EXECUTORCH_BUILD_QNN) #include #endif @@ -45,6 +45,10 @@ namespace llm = ::executorch::extension::llm; using ::executorch::runtime::Error; namespace { + +// Global JavaVM pointer for obtaining JNIEnv in callbacks +JavaVM* g_jvm = nullptr; + bool utf8_check_validity(const char* str, size_t length) { for (size_t i = 0; i < length; ++i) { uint8_t byte = static_cast(str[i]); @@ -79,47 +83,70 @@ bool utf8_check_validity(const char* str, size_t length) { } std::string token_buffer; -} // namespace -namespace executorch_jni { +// Helper to convert jstring to std::string +std::string jstring_to_string(JNIEnv* env, jstring jstr) { + if (jstr == nullptr) { + return ""; + } + const char* chars = env->GetStringUTFChars(jstr, nullptr); + if (chars == nullptr) { + return ""; + } + std::string result(chars); + env->ReleaseStringUTFChars(jstr, chars); + return result; +} -class ExecuTorchLlmCallbackJni - : public facebook::jni::JavaClass { - public: - constexpr static const char* kJavaDescriptor = - "Lorg/pytorch/executorch/extension/llm/LlmCallback;"; +// Helper to convert Java List to std::vector +std::vector jlist_to_string_vector(JNIEnv* env, jobject jlist) { + std::vector result; + if (jlist == nullptr) { + return result; + } - void onResult(std::string result) const { - static auto cls = ExecuTorchLlmCallbackJni::javaClassStatic(); - static const auto method = - cls->getMethod)>("onResult"); + jclass list_class = env->FindClass("java/util/List"); + if (list_class == nullptr) { + env->ExceptionClear(); + return result; + } - token_buffer += result; - if (!utf8_check_validity(token_buffer.c_str(), token_buffer.size())) { - ET_LOG( - Info, "Current token buffer is not valid UTF-8. Waiting for more."); - return; - } - result = token_buffer; - token_buffer = ""; - facebook::jni::local_ref s = facebook::jni::make_jstring(result); - method(self(), s); + jmethodID size_method = env->GetMethodID(list_class, "size", "()I"); + jmethodID get_method = + env->GetMethodID(list_class, "get", "(I)Ljava/lang/Object;"); + + if (size_method == nullptr || get_method == nullptr) { + env->ExceptionClear(); + env->DeleteLocalRef(list_class); + return result; } - void onStats(const llm::Stats& result) const { - static auto cls = ExecuTorchLlmCallbackJni::javaClassStatic(); - static const auto on_stats_method = - cls->getMethod)>("onStats"); - on_stats_method( - self(), - facebook::jni::make_jstring( - executorch::extension::llm::stats_to_json_string(result))); + jint size = env->CallIntMethod(jlist, size_method); + for (jint i = 0; i < size; ++i) { + jobject str_obj = env->CallObjectMethod(jlist, get_method, i); + if (str_obj != nullptr) { + result.push_back(jstring_to_string(env, static_cast(str_obj))); + env->DeleteLocalRef(str_obj); + } } -}; -class ExecuTorchLlmJni : public facebook::jni::HybridClass { - private: - friend HybridBase; + env->DeleteLocalRef(list_class); + return result; +} + +} // namespace + +namespace executorch_jni { + +// Model type category constants +constexpr int MODEL_TYPE_CATEGORY_LLM = 1; +constexpr int MODEL_TYPE_CATEGORY_MULTIMODAL = 2; +constexpr int MODEL_TYPE_MEDIATEK_LLAMA = 3; +constexpr int MODEL_TYPE_QNN_LLAMA = 4; + +// Native handle class that holds the runner state +class ExecuTorchLlmNative { + public: float temperature_ = 0.0f; int model_type_category_; std::unique_ptr runner_; @@ -127,37 +154,13 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { multi_modal_runner_; std::vector prefill_inputs_; - public: - constexpr static auto kJavaDescriptor = - "Lorg/pytorch/executorch/extension/llm/LlmModule;"; - - constexpr static int MODEL_TYPE_CATEGORY_LLM = 1; - constexpr static int MODEL_TYPE_CATEGORY_MULTIMODAL = 2; - constexpr static int MODEL_TYPE_MEDIATEK_LLAMA = 3; - constexpr static int MODEL_TYPE_QNN_LLAMA = 4; - - static facebook::jni::local_ref initHybrid( - facebook::jni::alias_ref, + ExecuTorchLlmNative( + JNIEnv* env, jint model_type_category, - facebook::jni::alias_ref model_path, - facebook::jni::alias_ref tokenizer_path, + jstring model_path, + jstring tokenizer_path, jfloat temperature, - facebook::jni::alias_ref::javaobject> - data_files) { - return makeCxxInstance( - model_type_category, - model_path, - tokenizer_path, - temperature, - data_files); - } - - ExecuTorchLlmJni( - jint model_type_category, - facebook::jni::alias_ref model_path, - facebook::jni::alias_ref tokenizer_path, - jfloat temperature, - facebook::jni::alias_ref data_files = nullptr) { + jobject data_files) { temperature_ = temperature; #if defined(ET_USE_THREADPOOL) // Reserve 1 thread for the main thread. @@ -171,44 +174,30 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { #endif model_type_category_ = model_type_category; - std::vector data_files_vector; + std::string model_path_str = jstring_to_string(env, model_path); + std::string tokenizer_path_str = jstring_to_string(env, tokenizer_path); + std::vector data_files_vector = + jlist_to_string_vector(env, data_files); + if (model_type_category == MODEL_TYPE_CATEGORY_MULTIMODAL) { multi_modal_runner_ = llm::create_multimodal_runner( - model_path->toStdString().c_str(), - llm::load_tokenizer(tokenizer_path->toStdString())); + model_path_str.c_str(), llm::load_tokenizer(tokenizer_path_str)); } else if (model_type_category == MODEL_TYPE_CATEGORY_LLM) { - if (data_files != nullptr) { - // Convert Java List to C++ std::vector - auto list_class = facebook::jni::findClassStatic("java/util/List"); - auto size_method = list_class->getMethod("size"); - auto get_method = - list_class->getMethod(jint)>( - "get"); - - jint size = size_method(data_files); - for (jint i = 0; i < size; ++i) { - auto str_obj = get_method(data_files, i); - auto jstr = facebook::jni::static_ref_cast(str_obj); - data_files_vector.push_back(jstr->toStdString()); - } - } runner_ = executorch::extension::llm::create_text_llm_runner( - model_path->toStdString(), - llm::load_tokenizer(tokenizer_path->toStdString()), - data_files_vector); + model_path_str, llm::load_tokenizer(tokenizer_path_str), data_files_vector); #if defined(EXECUTORCH_BUILD_QNN) } else if (model_type_category == MODEL_TYPE_QNN_LLAMA) { std::unique_ptr module = std::make_unique< executorch::extension::Module>( - model_path->toStdString().c_str(), + model_path_str.c_str(), data_files_vector, executorch::extension::Module::LoadMode::MmapUseMlockIgnoreErrors); std::string decoder_model = "llama3"; // use llama3 for now runner_ = std::make_unique>( // QNN runner std::move(module), decoder_model.c_str(), - model_path->toStdString().c_str(), - tokenizer_path->toStdString().c_str(), + model_path_str.c_str(), + tokenizer_path_str.c_str(), "", ""); model_type_category_ = MODEL_TYPE_CATEGORY_LLM; @@ -216,249 +205,528 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { #if defined(EXECUTORCH_BUILD_MEDIATEK) } else if (model_type_category == MODEL_TYPE_MEDIATEK_LLAMA) { runner_ = std::make_unique( - model_path->toStdString().c_str(), - tokenizer_path->toStdString().c_str()); + model_path_str.c_str(), tokenizer_path_str.c_str()); // Interpret the model type as LLM model_type_category_ = MODEL_TYPE_CATEGORY_LLM; #endif } } +}; - jint generate( - facebook::jni::alias_ref prompt, - jint seq_len, - facebook::jni::alias_ref callback, - jboolean echo) { - if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) { - std::vector inputs = prefill_inputs_; - prefill_inputs_.clear(); - if (!prompt->toStdString().empty()) { - inputs.emplace_back(llm::MultimodalInput{prompt->toStdString()}); - } - executorch::extension::llm::GenerationConfig config{ - .echo = static_cast(echo), - .seq_len = seq_len, - .temperature = temperature_, - }; - multi_modal_runner_->generate( - std::move(inputs), - config, - [callback](const std::string& result) { callback->onResult(result); }, - [callback](const llm::Stats& result) { callback->onStats(result); }); - } else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) { - executorch::extension::llm::GenerationConfig config{ - .echo = static_cast(echo), - .seq_len = seq_len, - .temperature = temperature_, - }; - runner_->generate( - prompt->toStdString(), - config, - [callback](std::string result) { callback->onResult(result); }, - [callback](const llm::Stats& result) { callback->onStats(result); }); +// Helper class for callback invocation +class CallbackHelper { + public: + CallbackHelper(JNIEnv* env, jobject callback) + : env_(env), callback_(nullptr), callback_class_(nullptr) { + if (callback != nullptr) { + callback_ = env_->NewGlobalRef(callback); + jclass local_class = env_->GetObjectClass(callback); + callback_class_ = static_cast(env_->NewGlobalRef(local_class)); + env_->DeleteLocalRef(local_class); + on_result_method_ = env_->GetMethodID( + callback_class_, "onResult", "(Ljava/lang/String;)V"); + on_stats_method_ = + env_->GetMethodID(callback_class_, "onStats", "(Ljava/lang/String;)V"); } - return 0; - } - - // Returns status_code - // Contract is valid within an AAR (JNI + corresponding Java code) - jint append_text_input(facebook::jni::alias_ref prompt) { - prefill_inputs_.emplace_back(llm::MultimodalInput{prompt->toStdString()}); - return 0; } - // Returns status_code - jint append_images_input( - facebook::jni::alias_ref image, - jint width, - jint height, - jint channels) { - std::vector images; - if (image == nullptr) { - return static_cast(Error::EndOfMethod); + ~CallbackHelper() { + if (g_jvm == nullptr) { + return; } - auto image_size = image->size(); - if (image_size != 0) { - std::vector image_data_jint(image_size); - std::vector image_data(image_size); - image->getRegion(0, image_size, image_data_jint.data()); - for (int i = 0; i < image_size; i++) { - image_data[i] = image_data_jint[i]; + // Get the current JNIEnv (might be different thread) + JNIEnv* env = nullptr; + int status = g_jvm->GetEnv((void**)&env, JNI_VERSION_1_6); + if (status == JNI_EDETACHED) { + g_jvm->AttachCurrentThread(&env, nullptr); + } + if (env != nullptr) { + if (callback_ != nullptr) { + env->DeleteGlobalRef(callback_); + } + if (callback_class_ != nullptr) { + env->DeleteGlobalRef(callback_class_); } - llm::Image image_runner{std::move(image_data), width, height, channels}; - prefill_inputs_.emplace_back( - llm::MultimodalInput{std::move(image_runner)}); } - - return 0; } - // Returns status_code - jint append_normalized_images_input( - facebook::jni::alias_ref image, - jint width, - jint height, - jint channels) { - std::vector images; - if (image == nullptr) { - return static_cast(Error::EndOfMethod); + void onResult(const std::string& result) { + JNIEnv* env = getEnv(); + if (env == nullptr || callback_ == nullptr || on_result_method_ == nullptr) { + return; } - auto image_size = image->size(); - if (image_size != 0) { - std::vector image_data_jfloat(image_size); - std::vector image_data(image_size); - image->getRegion(0, image_size, image_data_jfloat.data()); - for (int i = 0; i < image_size; i++) { - image_data[i] = image_data_jfloat[i]; - } - llm::Image image_runner{std::move(image_data), width, height, channels}; - prefill_inputs_.emplace_back( - llm::MultimodalInput{std::move(image_runner)}); + + std::string current_result = result; + token_buffer += current_result; + if (!utf8_check_validity(token_buffer.c_str(), token_buffer.size())) { + ET_LOG( + Info, "Current token buffer is not valid UTF-8. Waiting for more."); + return; } + current_result = token_buffer; + token_buffer = ""; - return 0; + jstring jstr = env->NewStringUTF(current_result.c_str()); + if (jstr != nullptr) { + env->CallVoidMethod(callback_, on_result_method_, jstr); + env->DeleteLocalRef(jstr); + } } - // Returns status_code - jint append_audio_input( - facebook::jni::alias_ref data, - jint batch_size, - jint n_bins, - jint n_frames) { - if (data == nullptr) { - return static_cast(Error::EndOfMethod); + void onStats(const llm::Stats& stats) { + JNIEnv* env = getEnv(); + if (env == nullptr || callback_ == nullptr || on_stats_method_ == nullptr) { + return; } - auto data_size = data->size(); - if (data_size != 0) { - std::vector data_jbyte(data_size); - std::vector data_u8(data_size); - data->getRegion(0, data_size, data_jbyte.data()); - for (int i = 0; i < data_size; i++) { - data_u8[i] = data_jbyte[i]; - } - llm::Audio audio{std::move(data_u8), batch_size, n_bins, n_frames}; - prefill_inputs_.emplace_back(llm::MultimodalInput{std::move(audio)}); + + std::string stats_json = + executorch::extension::llm::stats_to_json_string(stats); + jstring jstr = env->NewStringUTF(stats_json.c_str()); + if (jstr != nullptr) { + env->CallVoidMethod(callback_, on_stats_method_, jstr); + env->DeleteLocalRef(jstr); } - return 0; } - // Returns status_code - jint append_audio_input_float( - facebook::jni::alias_ref data, - jint batch_size, - jint n_bins, - jint n_frames) { - if (data == nullptr) { - return static_cast(Error::EndOfMethod); + private: + JNIEnv* getEnv() { + if (g_jvm == nullptr) { + return nullptr; } - auto data_size = data->size(); - if (data_size != 0) { - std::vector data_jfloat(data_size); - std::vector data_f(data_size); - data->getRegion(0, data_size, data_jfloat.data()); - for (int i = 0; i < data_size; i++) { - data_f[i] = data_jfloat[i]; - } - llm::Audio audio{std::move(data_f), batch_size, n_bins, n_frames}; - prefill_inputs_.emplace_back(llm::MultimodalInput{std::move(audio)}); + JNIEnv* env = nullptr; + int status = g_jvm->GetEnv((void**)&env, JNI_VERSION_1_6); + if (status == JNI_EDETACHED) { + g_jvm->AttachCurrentThread(&env, nullptr); } - return 0; + return env; } - // Returns status_code - jint append_raw_audio_input( - facebook::jni::alias_ref data, - jint batch_size, - jint n_channels, - jint n_samples) { - if (data == nullptr) { - return static_cast(Error::EndOfMethod); + JNIEnv* env_; + jobject callback_; + jclass callback_class_ = nullptr; + jmethodID on_result_method_ = nullptr; + jmethodID on_stats_method_ = nullptr; +}; + +} // namespace executorch_jni + +extern "C" { + +JNIEXPORT jlong JNICALL +Java_org_pytorch_executorch_extension_llm_LlmModule_nativeCreate( + JNIEnv* env, + jobject /* this */, + jint model_type_category, + jstring model_path, + jstring tokenizer_path, + jfloat temperature, + jobject data_files) { + auto* native = new executorch_jni::ExecuTorchLlmNative( + env, model_type_category, model_path, tokenizer_path, temperature, data_files); + return reinterpret_cast(native); +} + +JNIEXPORT void JNICALL +Java_org_pytorch_executorch_extension_llm_LlmModule_nativeDestroy( + JNIEnv* /* env */, + jobject /* this */, + jlong native_handle) { + auto* native = + reinterpret_cast(native_handle); + delete native; +} + +JNIEXPORT jint JNICALL +Java_org_pytorch_executorch_extension_llm_LlmModule_nativeGenerate( + JNIEnv* env, + jobject /* this */, + jlong native_handle, + jstring prompt, + jint seq_len, + jobject callback, + jboolean echo) { + auto* native = + reinterpret_cast(native_handle); + if (native == nullptr) { + return -1; + } + + std::string prompt_str = jstring_to_string(env, prompt); + + // Create a shared callback helper for use in lambdas + auto callback_helper = + std::make_shared(env, callback); + + if (native->model_type_category_ == + executorch_jni::MODEL_TYPE_CATEGORY_MULTIMODAL) { + std::vector inputs = native->prefill_inputs_; + native->prefill_inputs_.clear(); + if (!prompt_str.empty()) { + inputs.emplace_back(llm::MultimodalInput{prompt_str}); } - auto data_size = data->size(); - if (data_size != 0) { - std::vector data_jbyte(data_size); - std::vector data_u8(data_size); - data->getRegion(0, data_size, data_jbyte.data()); - for (int i = 0; i < data_size; i++) { - data_u8[i] = data_jbyte[i]; - } - llm::RawAudio audio{ - std::move(data_u8), batch_size, n_channels, n_samples}; - prefill_inputs_.emplace_back(llm::MultimodalInput{std::move(audio)}); + executorch::extension::llm::GenerationConfig config{ + .echo = static_cast(echo), + .seq_len = seq_len, + .temperature = native->temperature_, + }; + native->multi_modal_runner_->generate( + std::move(inputs), + config, + [callback_helper](const std::string& result) { + callback_helper->onResult(result); + }, + [callback_helper](const llm::Stats& result) { + callback_helper->onStats(result); + }); + } else if ( + native->model_type_category_ == + executorch_jni::MODEL_TYPE_CATEGORY_LLM) { + executorch::extension::llm::GenerationConfig config{ + .echo = static_cast(echo), + .seq_len = seq_len, + .temperature = native->temperature_, + }; + native->runner_->generate( + prompt_str, + config, + [callback_helper](std::string result) { + callback_helper->onResult(result); + }, + [callback_helper](const llm::Stats& result) { + callback_helper->onStats(result); + }); + } + return 0; +} + +JNIEXPORT void JNICALL +Java_org_pytorch_executorch_extension_llm_LlmModule_nativeStop( + JNIEnv* /* env */, + jobject /* this */, + jlong native_handle) { + auto* native = + reinterpret_cast(native_handle); + if (native == nullptr) { + return; + } + + if (native->model_type_category_ == + executorch_jni::MODEL_TYPE_CATEGORY_MULTIMODAL) { + native->multi_modal_runner_->stop(); + } else if ( + native->model_type_category_ == + executorch_jni::MODEL_TYPE_CATEGORY_LLM) { + native->runner_->stop(); + } +} + +JNIEXPORT jint JNICALL +Java_org_pytorch_executorch_extension_llm_LlmModule_nativeLoad( + JNIEnv* env, + jobject /* this */, + jlong native_handle) { + auto* native = + reinterpret_cast(native_handle); + if (native == nullptr) { + return -1; + } + + int result = -1; + std::stringstream ss; + + if (native->model_type_category_ == + executorch_jni::MODEL_TYPE_CATEGORY_MULTIMODAL) { + result = static_cast(native->multi_modal_runner_->load()); + if (result != 0) { + ss << "Failed to load multimodal runner: [" << result << "]"; + } + } else if ( + native->model_type_category_ == + executorch_jni::MODEL_TYPE_CATEGORY_LLM) { + result = static_cast(native->runner_->load()); + if (result != 0) { + ss << "Failed to load llm runner: [" << result << "]"; } - return 0; + } else { + ss << "Invalid model type category: " << native->model_type_category_ + << ". Valid values are: " + << executorch_jni::MODEL_TYPE_CATEGORY_LLM << " or " + << executorch_jni::MODEL_TYPE_CATEGORY_MULTIMODAL; } + if (result != 0) { + executorch::jni_helper::throwExecutorchException( + env, static_cast(Error::InvalidArgument), ss.str().c_str()); + } + return result; // 0 on success to keep backward compatibility +} - void stop() { - if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) { - multi_modal_runner_->stop(); - } else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) { - runner_->stop(); +JNIEXPORT jint JNICALL +Java_org_pytorch_executorch_extension_llm_LlmModule_nativeAppendTextInput( + JNIEnv* env, + jobject /* this */, + jlong native_handle, + jstring prompt) { + auto* native = + reinterpret_cast(native_handle); + if (native == nullptr) { + return -1; + } + + std::string prompt_str = jstring_to_string(env, prompt); + native->prefill_inputs_.emplace_back(llm::MultimodalInput{prompt_str}); + return 0; +} + +JNIEXPORT jint JNICALL +Java_org_pytorch_executorch_extension_llm_LlmModule_nativeAppendImagesInput( + JNIEnv* env, + jobject /* this */, + jlong native_handle, + jintArray image, + jint width, + jint height, + jint channels) { + auto* native = + reinterpret_cast(native_handle); + if (native == nullptr) { + return -1; + } + + if (image == nullptr) { + return static_cast(Error::EndOfMethod); + } + + jsize image_size = env->GetArrayLength(image); + if (image_size != 0) { + std::vector image_data_jint(image_size); + std::vector image_data(image_size); + env->GetIntArrayRegion(image, 0, image_size, image_data_jint.data()); + for (int i = 0; i < image_size; i++) { + image_data[i] = static_cast(image_data_jint[i]); } + llm::Image image_runner{std::move(image_data), width, height, channels}; + native->prefill_inputs_.emplace_back( + llm::MultimodalInput{std::move(image_runner)}); + } + + return 0; +} + +JNIEXPORT jint JNICALL +Java_org_pytorch_executorch_extension_llm_LlmModule_nativeAppendNormalizedImagesInput( + JNIEnv* env, + jobject /* this */, + jlong native_handle, + jfloatArray image, + jint width, + jint height, + jint channels) { + auto* native = + reinterpret_cast(native_handle); + if (native == nullptr) { + return -1; + } + + if (image == nullptr) { + return static_cast(Error::EndOfMethod); } - void reset_context() { - if (runner_ != nullptr) { - runner_->reset(); + jsize image_size = env->GetArrayLength(image); + if (image_size != 0) { + std::vector image_data_jfloat(image_size); + std::vector image_data(image_size); + env->GetFloatArrayRegion(image, 0, image_size, image_data_jfloat.data()); + for (int i = 0; i < image_size; i++) { + image_data[i] = image_data_jfloat[i]; } - if (multi_modal_runner_ != nullptr) { - multi_modal_runner_->reset(); + llm::Image image_runner{std::move(image_data), width, height, channels}; + native->prefill_inputs_.emplace_back( + llm::MultimodalInput{std::move(image_runner)}); + } + + return 0; +} + +JNIEXPORT jint JNICALL +Java_org_pytorch_executorch_extension_llm_LlmModule_nativeAppendAudioInput( + JNIEnv* env, + jobject /* this */, + jlong native_handle, + jbyteArray data, + jint batch_size, + jint n_bins, + jint n_frames) { + auto* native = + reinterpret_cast(native_handle); + if (native == nullptr) { + return -1; + } + + if (data == nullptr) { + return static_cast(Error::EndOfMethod); + } + + jsize data_size = env->GetArrayLength(data); + if (data_size != 0) { + std::vector data_jbyte(data_size); + std::vector data_u8(data_size); + env->GetByteArrayRegion(data, 0, data_size, data_jbyte.data()); + for (int i = 0; i < data_size; i++) { + data_u8[i] = static_cast(data_jbyte[i]); } + llm::Audio audio{std::move(data_u8), batch_size, n_bins, n_frames}; + native->prefill_inputs_.emplace_back(llm::MultimodalInput{std::move(audio)}); + } + return 0; +} + +JNIEXPORT jint JNICALL +Java_org_pytorch_executorch_extension_llm_LlmModule_nativeAppendAudioInputFloat( + JNIEnv* env, + jobject /* this */, + jlong native_handle, + jfloatArray data, + jint batch_size, + jint n_bins, + jint n_frames) { + auto* native = + reinterpret_cast(native_handle); + if (native == nullptr) { + return -1; } - jint load() { - int result = -1; - std::stringstream ss; + if (data == nullptr) { + return static_cast(Error::EndOfMethod); + } - if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) { - result = static_cast(multi_modal_runner_->load()); - if (result != 0) { - ss << "Failed to load multimodal runner: [" << result << "]"; - } - } else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) { - result = static_cast(runner_->load()); - if (result != 0) { - ss << "Failed to load llm runner: [" << result << "]"; - } - } else { - ss << "Invalid model type category: " << model_type_category_ - << ". Valid values are: " << MODEL_TYPE_CATEGORY_LLM << " or " - << MODEL_TYPE_CATEGORY_MULTIMODAL; + jsize data_size = env->GetArrayLength(data); + if (data_size != 0) { + std::vector data_jfloat(data_size); + std::vector data_f(data_size); + env->GetFloatArrayRegion(data, 0, data_size, data_jfloat.data()); + for (int i = 0; i < data_size; i++) { + data_f[i] = data_jfloat[i]; } - if (result != 0) { - executorch::jni_helper::throwExecutorchException( - static_cast(Error::InvalidArgument), ss.str().c_str()); + llm::Audio audio{std::move(data_f), batch_size, n_bins, n_frames}; + native->prefill_inputs_.emplace_back(llm::MultimodalInput{std::move(audio)}); + } + return 0; +} + +JNIEXPORT jint JNICALL +Java_org_pytorch_executorch_extension_llm_LlmModule_nativeAppendRawAudioInput( + JNIEnv* env, + jobject /* this */, + jlong native_handle, + jbyteArray data, + jint batch_size, + jint n_channels, + jint n_samples) { + auto* native = + reinterpret_cast(native_handle); + if (native == nullptr) { + return -1; + } + + if (data == nullptr) { + return static_cast(Error::EndOfMethod); + } + + jsize data_size = env->GetArrayLength(data); + if (data_size != 0) { + std::vector data_jbyte(data_size); + std::vector data_u8(data_size); + env->GetByteArrayRegion(data, 0, data_size, data_jbyte.data()); + for (int i = 0; i < data_size; i++) { + data_u8[i] = static_cast(data_jbyte[i]); } - return result; // 0 on success to keep backward compatibility - } - - static void registerNatives() { - registerHybrid({ - makeNativeMethod("initHybrid", ExecuTorchLlmJni::initHybrid), - makeNativeMethod("generate", ExecuTorchLlmJni::generate), - makeNativeMethod("stop", ExecuTorchLlmJni::stop), - makeNativeMethod("load", ExecuTorchLlmJni::load), - makeNativeMethod( - "appendImagesInput", ExecuTorchLlmJni::append_images_input), - makeNativeMethod( - "appendNormalizedImagesInput", - ExecuTorchLlmJni::append_normalized_images_input), - makeNativeMethod( - "appendAudioInput", ExecuTorchLlmJni::append_audio_input), - makeNativeMethod( - "appendAudioInputFloat", - ExecuTorchLlmJni::append_audio_input_float), - makeNativeMethod( - "appendRawAudioInput", ExecuTorchLlmJni::append_raw_audio_input), - makeNativeMethod( - "appendTextInput", ExecuTorchLlmJni::append_text_input), - makeNativeMethod("resetContext", ExecuTorchLlmJni::reset_context), - }); + llm::RawAudio audio{std::move(data_u8), batch_size, n_channels, n_samples}; + native->prefill_inputs_.emplace_back(llm::MultimodalInput{std::move(audio)}); } -}; + return 0; +} -} // namespace executorch_jni +JNIEXPORT void JNICALL +Java_org_pytorch_executorch_extension_llm_LlmModule_nativeResetContext( + JNIEnv* /* env */, + jobject /* this */, + jlong native_handle) { + auto* native = + reinterpret_cast(native_handle); + if (native == nullptr) { + return; + } + + if (native->runner_ != nullptr) { + native->runner_->reset(); + } + if (native->multi_modal_runner_ != nullptr) { + native->multi_modal_runner_->reset(); + } +} + +} // extern "C" + +void register_natives_for_llm(JNIEnv* env) { + // Store the JavaVM for later use in callbacks + env->GetJavaVM(&g_jvm); + + jclass llm_module_class = + env->FindClass("org/pytorch/executorch/extension/llm/LlmModule"); + if (llm_module_class == nullptr) { + ET_LOG(Error, "Failed to find LlmModule class"); + env->ExceptionClear(); + return; + } + + // clang-format off + static const JNINativeMethod methods[] = { + {"nativeCreate", + "(ILjava/lang/String;Ljava/lang/String;FLjava/util/List;)J", + reinterpret_cast( + Java_org_pytorch_executorch_extension_llm_LlmModule_nativeCreate)}, + {"nativeDestroy", "(J)V", + reinterpret_cast( + Java_org_pytorch_executorch_extension_llm_LlmModule_nativeDestroy)}, + {"nativeGenerate", + "(JLjava/lang/String;ILorg/pytorch/executorch/extension/llm/LlmCallback;Z)I", + reinterpret_cast( + Java_org_pytorch_executorch_extension_llm_LlmModule_nativeGenerate)}, + {"nativeStop", "(J)V", + reinterpret_cast( + Java_org_pytorch_executorch_extension_llm_LlmModule_nativeStop)}, + {"nativeLoad", "(J)I", + reinterpret_cast( + Java_org_pytorch_executorch_extension_llm_LlmModule_nativeLoad)}, + {"nativeAppendTextInput", "(JLjava/lang/String;)I", + reinterpret_cast( + Java_org_pytorch_executorch_extension_llm_LlmModule_nativeAppendTextInput)}, + {"nativeAppendImagesInput", "(J[IIII)I", + reinterpret_cast( + Java_org_pytorch_executorch_extension_llm_LlmModule_nativeAppendImagesInput)}, + {"nativeAppendNormalizedImagesInput", "(J[FIII)I", + reinterpret_cast( + Java_org_pytorch_executorch_extension_llm_LlmModule_nativeAppendNormalizedImagesInput)}, + {"nativeAppendAudioInput", "(J[BIII)I", + reinterpret_cast( + Java_org_pytorch_executorch_extension_llm_LlmModule_nativeAppendAudioInput)}, + {"nativeAppendAudioInputFloat", "(J[FIII)I", + reinterpret_cast( + Java_org_pytorch_executorch_extension_llm_LlmModule_nativeAppendAudioInputFloat)}, + {"nativeAppendRawAudioInput", "(J[BIII)I", + reinterpret_cast( + Java_org_pytorch_executorch_extension_llm_LlmModule_nativeAppendRawAudioInput)}, + {"nativeResetContext", "(J)V", + reinterpret_cast( + Java_org_pytorch_executorch_extension_llm_LlmModule_nativeResetContext)}, + }; + // clang-format on + + int num_methods = sizeof(methods) / sizeof(methods[0]); + int result = env->RegisterNatives(llm_module_class, methods, num_methods); + if (result != JNI_OK) { + ET_LOG(Error, "Failed to register native methods for LlmModule"); + } -void register_natives_for_llm() { - executorch_jni::ExecuTorchLlmJni::registerNatives(); + env->DeleteLocalRef(llm_module_class); }