Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@

package org.pytorch.executorch.extension.llm;

import com.facebook.jni.HybridData;
import com.facebook.jni.annotations.DoNotStrip;
import java.io.File;
import java.util.List;
import org.pytorch.executorch.ExecuTorchRuntime;
Expand All @@ -28,18 +26,19 @@ public class LlmModule {
public static final int MODEL_TYPE_TEXT_VISION = 2;
public static final int MODEL_TYPE_MULTIMODAL = 2;

private final HybridData mHybridData;
private long mNativeHandle;
private static final int DEFAULT_SEQ_LEN = 128;
private static final boolean DEFAULT_ECHO = true;

@DoNotStrip
private static native HybridData initHybrid(
private static native long nativeCreate(
int modelType,
String modulePath,
String tokenizerPath,
float temperature,
List<String> dataFiles);

private static native void nativeDestroy(long nativeHandle);

/**
* Constructs a LLM Module for a model with given type, model path, tokenizer, temperature, and
* dataFiles.
Expand All @@ -61,7 +60,7 @@ public LlmModule(
throw new RuntimeException("Cannot load tokenizer path " + tokenizerPath);
}

mHybridData = initHybrid(modelType, modulePath, tokenizerPath, temperature, dataFiles);
mNativeHandle = nativeCreate(modelType, modulePath, tokenizerPath, temperature, dataFiles);
}

/**
Expand Down Expand Up @@ -107,7 +106,16 @@ public LlmModule(LlmModuleConfig config) {
}

public void resetNative() {
mHybridData.resetNative();
if (mNativeHandle != 0) {
nativeDestroy(mNativeHandle);
mNativeHandle = 0;
}
}

@Override
protected void finalize() throws Throwable {
resetNative();
super.finalize();
}

/**
Expand Down Expand Up @@ -150,7 +158,12 @@ public int generate(String prompt, LlmCallback llmCallback, boolean echo) {
* @param llmCallback callback object to receive results
* @param echo indicate whether to echo the input prompt or not (text completion vs chat)
*/
public native int generate(String prompt, int seqLen, LlmCallback llmCallback, boolean echo);
public int generate(String prompt, int seqLen, LlmCallback llmCallback, boolean echo) {
return nativeGenerate(mNativeHandle, prompt, seqLen, llmCallback, echo);
}

private static native int nativeGenerate(
long nativeHandle, String prompt, int seqLen, LlmCallback llmCallback, boolean echo);

/**
* Start generating tokens from the module.
Expand Down Expand Up @@ -206,14 +219,15 @@ public int generate(
*/
@Experimental
public long prefillImages(int[] image, int width, int height, int channels) {
int nativeResult = appendImagesInput(image, width, height, channels);
int nativeResult = nativeAppendImagesInput(mNativeHandle, image, width, height, channels);
if (nativeResult != 0) {
throw new RuntimeException("Prefill failed with error code: " + nativeResult);
}
return 0;
}

private native int appendImagesInput(int[] image, int width, int height, int channels);
private static native int nativeAppendImagesInput(
long nativeHandle, int[] image, int width, int height, int channels);

/**
* Prefill a multimodal Module with the given images input.
Expand All @@ -228,15 +242,16 @@ public long prefillImages(int[] image, int width, int height, int channels) {
*/
@Experimental
public long prefillImages(float[] image, int width, int height, int channels) {
int nativeResult = appendNormalizedImagesInput(image, width, height, channels);
int nativeResult =
nativeAppendNormalizedImagesInput(mNativeHandle, image, width, height, channels);
if (nativeResult != 0) {
throw new RuntimeException("Prefill failed with error code: " + nativeResult);
}
return 0;
}

private native int appendNormalizedImagesInput(
float[] image, int width, int height, int channels);
private static native int nativeAppendNormalizedImagesInput(
long nativeHandle, float[] image, int width, int height, int channels);

/**
* Prefill a multimodal Module with the given audio input.
Expand All @@ -251,14 +266,15 @@ private native int appendNormalizedImagesInput(
*/
@Experimental
public long prefillAudio(byte[] audio, int batch_size, int n_bins, int n_frames) {
int nativeResult = appendAudioInput(audio, batch_size, n_bins, n_frames);
int nativeResult = nativeAppendAudioInput(mNativeHandle, audio, batch_size, n_bins, n_frames);
if (nativeResult != 0) {
throw new RuntimeException("Prefill failed with error code: " + nativeResult);
}
return 0;
}

private native int appendAudioInput(byte[] audio, int batch_size, int n_bins, int n_frames);
private static native int nativeAppendAudioInput(
long nativeHandle, byte[] audio, int batch_size, int n_bins, int n_frames);

/**
* Prefill a multimodal Module with the given audio input.
Expand All @@ -273,14 +289,16 @@ public long prefillAudio(byte[] audio, int batch_size, int n_bins, int n_frames)
*/
@Experimental
public long prefillAudio(float[] audio, int batch_size, int n_bins, int n_frames) {
int nativeResult = appendAudioInputFloat(audio, batch_size, n_bins, n_frames);
int nativeResult =
nativeAppendAudioInputFloat(mNativeHandle, audio, batch_size, n_bins, n_frames);
if (nativeResult != 0) {
throw new RuntimeException("Prefill failed with error code: " + nativeResult);
}
return 0;
}

private native int appendAudioInputFloat(float[] audio, int batch_size, int n_bins, int n_frames);
private static native int nativeAppendAudioInputFloat(
long nativeHandle, float[] audio, int batch_size, int n_bins, int n_frames);

/**
* Prefill a multimodal Module with the given raw audio input.
Expand All @@ -295,15 +313,16 @@ public long prefillAudio(float[] audio, int batch_size, int n_bins, int n_frames
*/
@Experimental
public long prefillRawAudio(byte[] audio, int batch_size, int n_channels, int n_samples) {
int nativeResult = appendRawAudioInput(audio, batch_size, n_channels, n_samples);
int nativeResult =
nativeAppendRawAudioInput(mNativeHandle, audio, batch_size, n_channels, n_samples);
if (nativeResult != 0) {
throw new RuntimeException("Prefill failed with error code: " + nativeResult);
}
return 0;
}

private native int appendRawAudioInput(
byte[] audio, int batch_size, int n_channels, int n_samples);
private static native int nativeAppendRawAudioInput(
long nativeHandle, byte[] audio, int batch_size, int n_channels, int n_samples);

/**
* Prefill a multimodal Module with the given text input.
Expand All @@ -315,28 +334,38 @@ private native int appendRawAudioInput(
*/
@Experimental
public long prefillPrompt(String prompt) {
int nativeResult = appendTextInput(prompt);
int nativeResult = nativeAppendTextInput(mNativeHandle, prompt);
if (nativeResult != 0) {
throw new RuntimeException("Prefill failed with error code: " + nativeResult);
}
return 0;
}

// returns status
private native int appendTextInput(String prompt);
private static native int nativeAppendTextInput(long nativeHandle, String prompt);

/**
* Reset the context of the LLM. This will clear the KV cache and reset the state of the LLM.
*
* <p>The startPos will be reset to 0.
*/
public native void resetContext();
public void resetContext() {
nativeResetContext(mNativeHandle);
}

private static native void nativeResetContext(long nativeHandle);

/** Stop current generate() before it finishes. */
@DoNotStrip
public native void stop();
public void stop() {
nativeStop(mNativeHandle);
}

private static native void nativeStop(long nativeHandle);

/** Force loading the module. Otherwise the model is loaded during first generate(). */
@DoNotStrip
public native int load();
public int load() {
return nativeLoad(mNativeHandle);
}

private static native int nativeLoad(long nativeHandle);
}
55 changes: 55 additions & 0 deletions extension/android/jni/jni_helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,60 @@

namespace executorch::jni_helper {

void throwExecutorchException(
JNIEnv* env,
uint32_t errorCode,
const std::string& details) {
if (!env) {
return;
}

// Find the exception class
jclass exceptionClass =
env->FindClass("org/pytorch/executorch/ExecutorchRuntimeException");
if (exceptionClass == nullptr) {
// Class not found, clear the exception and return
env->ExceptionClear();
return;
}

// Find the static factory method: makeExecutorchException(int, String)
jmethodID makeExceptionMethod = env->GetStaticMethodID(
exceptionClass,
"makeExecutorchException",
"(ILjava/lang/String;)Ljava/lang/RuntimeException;");
if (makeExceptionMethod == nullptr) {
env->ExceptionClear();
env->DeleteLocalRef(exceptionClass);
return;
}

// Create the details string
jstring jDetails = env->NewStringUTF(details.c_str());
if (jDetails == nullptr) {
env->ExceptionClear();
env->DeleteLocalRef(exceptionClass);
return;
}

// Call the factory method to create the exception object
jobject exception = env->CallStaticObjectMethod(
exceptionClass,
makeExceptionMethod,
static_cast<jint>(errorCode),
jDetails);

env->DeleteLocalRef(jDetails);

if (exception != nullptr) {
env->Throw(static_cast<jthrowable>(exception));
env->DeleteLocalRef(exception);
}

env->DeleteLocalRef(exceptionClass);
}

#if EXECUTORCH_HAS_FBJNI
void throwExecutorchException(uint32_t errorCode, const std::string& details) {
// Get the current JNI environment
auto env = facebook::jni::Environment::current();
Expand All @@ -34,5 +88,6 @@ void throwExecutorchException(uint32_t errorCode, const std::string& details) {
auto exception = makeExceptionMethod(exceptionClass, errorCode, jDetails);
facebook::jni::throwNewJavaException(exception.get());
}
#endif

} // namespace executorch::jni_helper
29 changes: 28 additions & 1 deletion extension/android/jni/jni_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,42 @@

#pragma once

#include <fbjni/fbjni.h>
#include <jni.h>
#include <string>

#if __has_include(<fbjni/fbjni.h>)
#include <fbjni/fbjni.h>
#define EXECUTORCH_HAS_FBJNI 1
#else
#define EXECUTORCH_HAS_FBJNI 0
#endif

namespace executorch::jni_helper {

/**
* Throws a Java ExecutorchRuntimeException corresponding to the given error
* code and details. Uses the Java factory method
* ExecutorchRuntimeException.makeExecutorchException(int, String).
*
* This version takes JNIEnv* directly and works with pure JNI.
*
* @param env The JNI environment.
* @param errorCode The error code from the C++ Executorch runtime.
* @param details Additional details to include in the exception message.
*/
void throwExecutorchException(
JNIEnv* env,
uint32_t errorCode,
const std::string& details);

#if EXECUTORCH_HAS_FBJNI
/**
* Throws a Java ExecutorchRuntimeException corresponding to the given error
* code and details. Uses the Java factory method
* ExecutorchRuntimeException.makeExecutorchException(int, String).
*
* This version uses fbjni to get the current JNI environment.
*
* @param errorCode The error code from the C++ Executorch runtime.
* @param details Additional details to include in the exception message.
*/
Expand All @@ -29,5 +55,6 @@ struct JExecutorchRuntimeException
static constexpr auto kJavaDescriptor =
"Lorg/pytorch/executorch/ExecutorchRuntimeException;";
};
#endif

} // namespace executorch::jni_helper
8 changes: 5 additions & 3 deletions extension/android/jni/jni_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -535,10 +535,10 @@ class ExecuTorchJni : public facebook::jni::HybridClass<ExecuTorchJni> {
} // namespace executorch::extension

#ifdef EXECUTORCH_BUILD_LLAMA_JNI
extern void register_natives_for_llm();
extern void register_natives_for_llm(JNIEnv* env);
#else
// No op if we don't build LLM
void register_natives_for_llm() {}
void register_natives_for_llm(JNIEnv* /* env */) {}
#endif
extern void register_natives_for_runtime();

Expand All @@ -552,7 +552,9 @@ void register_natives_for_training() {}
JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void*) {
return facebook::jni::initialize(vm, [] {
executorch::extension::ExecuTorchJni::registerNatives();
register_natives_for_llm();
// Get JNIEnv for pure JNI registration in LLM
JNIEnv* env = facebook::jni::Environment::current();
register_natives_for_llm(env);
register_natives_for_runtime();
register_natives_for_training();
});
Expand Down
Loading
Loading