Skip to content

Commit c11f8dc

Browse files
committed
Android JNI llama cache temperature in class
1 parent d82e852 commit c11f8dc

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

extension/android/jni/jni_layer_llama.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ class ExecuTorchLlmCallbackJni
120120
class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
121121
private:
122122
friend HybridBase;
123+
float temperature_;
123124
int model_type_category_;
124125
std::unique_ptr<llm::IRunner> runner_;
125126
std::unique_ptr<llm::MultimodalRunner> multi_modal_runner_;
@@ -175,20 +176,17 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
175176
runner_ = std::make_unique<example::Runner>(
176177
model_path->toStdString().c_str(),
177178
tokenizer_path->toStdString().c_str(),
178-
temperature,
179179
data_path->toStdString().c_str());
180180
} else {
181181
runner_ = std::make_unique<example::Runner>(
182182
model_path->toStdString().c_str(),
183-
tokenizer_path->toStdString().c_str(),
184-
temperature);
183+
tokenizer_path->toStdString().c_str());
185184
}
186185
#if defined(EXECUTORCH_BUILD_MEDIATEK)
187186
} else if (model_type_category == MODEL_TYPE_MEDIATEK_LLAMA) {
188187
runner_ = std::make_unique<MTKLlamaRunner>(
189188
model_path->toStdString().c_str(),
190-
tokenizer_path->toStdString().c_str(),
191-
temperature);
189+
tokenizer_path->toStdString().c_str());
192190
// Interpret the model type as LLM
193191
model_type_category_ = MODEL_TYPE_CATEGORY_LLM;
194192
#endif
@@ -228,6 +226,7 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
228226
executorch::extension::llm::GenerationConfig config{
229227
.echo = static_cast<bool>(echo),
230228
.seq_len = seq_len,
229+
.temperature = temperature_,
231230
};
232231
runner_->generate(
233232
prompt->toStdString(),

0 commit comments

Comments
 (0)