Skip to content

Commit 6dc4fc0

Browse files
committed
Android JNI llama cache temperature in class
1 parent 08c07fa commit 6dc4fc0

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

extension/android/jni/jni_layer_llama.cpp

+4-5
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ class ExecuTorchLlmCallbackJni
114114
class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
115115
private:
116116
friend HybridBase;
117+
float temperature_;
117118
int model_type_category_;
118119
std::unique_ptr<llm::IRunner> runner_;
119120
std::unique_ptr<llm::MultimodalRunner> multi_modal_runner_;
@@ -169,20 +170,17 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
169170
runner_ = std::make_unique<example::Runner>(
170171
model_path->toStdString().c_str(),
171172
tokenizer_path->toStdString().c_str(),
172-
temperature,
173173
data_path->toStdString().c_str());
174174
} else {
175175
runner_ = std::make_unique<example::Runner>(
176176
model_path->toStdString().c_str(),
177-
tokenizer_path->toStdString().c_str(),
178-
temperature);
177+
tokenizer_path->toStdString().c_str());
179178
}
180179
#if defined(EXECUTORCH_BUILD_MEDIATEK)
181180
} else if (model_type_category == MODEL_TYPE_MEDIATEK_LLAMA) {
182181
runner_ = std::make_unique<MTKLlamaRunner>(
183182
model_path->toStdString().c_str(),
184-
tokenizer_path->toStdString().c_str(),
185-
temperature);
183+
tokenizer_path->toStdString().c_str());
186184
// Interpret the model type as LLM
187185
model_type_category_ = MODEL_TYPE_CATEGORY_LLM;
188186
#endif
@@ -222,6 +220,7 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
222220
executorch::extension::llm::GenerationConfig config{
223221
.echo = static_cast<bool>(echo),
224222
.seq_len = seq_len,
223+
.temperature = temperature_,
225224
};
226225
runner_->generate(
227226
prompt->toStdString(),

0 commit comments

Comments
 (0)