@@ -114,6 +114,7 @@ class ExecuTorchLlmCallbackJni
114
114
class ExecuTorchLlmJni : public facebook ::jni::HybridClass<ExecuTorchLlmJni> {
115
115
private:
116
116
friend HybridBase;
117
+ float temperature_;
117
118
int model_type_category_;
118
119
std::unique_ptr<llm::IRunner> runner_;
119
120
std::unique_ptr<llm::MultimodalRunner> multi_modal_runner_;
@@ -169,20 +170,17 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
169
170
runner_ = std::make_unique<example::Runner>(
170
171
model_path->toStdString ().c_str (),
171
172
tokenizer_path->toStdString ().c_str (),
172
- temperature,
173
173
data_path->toStdString ().c_str ());
174
174
} else {
175
175
runner_ = std::make_unique<example::Runner>(
176
176
model_path->toStdString ().c_str (),
177
- tokenizer_path->toStdString ().c_str (),
178
- temperature);
177
+ tokenizer_path->toStdString ().c_str ());
179
178
}
180
179
#if defined(EXECUTORCH_BUILD_MEDIATEK)
181
180
} else if (model_type_category == MODEL_TYPE_MEDIATEK_LLAMA) {
182
181
runner_ = std::make_unique<MTKLlamaRunner>(
183
182
model_path->toStdString ().c_str (),
184
- tokenizer_path->toStdString ().c_str (),
185
- temperature);
183
+ tokenizer_path->toStdString ().c_str ());
186
184
// Interpret the model type as LLM
187
185
model_type_category_ = MODEL_TYPE_CATEGORY_LLM;
188
186
#endif
@@ -222,6 +220,7 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
222
220
executorch::extension::llm::GenerationConfig config{
223
221
.echo = static_cast <bool >(echo),
224
222
.seq_len = seq_len,
223
+ .temperature = temperature_,
225
224
};
226
225
runner_->generate (
227
226
prompt->toStdString (),
0 commit comments