@@ -120,6 +120,7 @@ class ExecuTorchLlmCallbackJni
120
120
class ExecuTorchLlmJni : public facebook ::jni::HybridClass<ExecuTorchLlmJni> {
121
121
private:
122
122
friend HybridBase;
123
+ float temperature_;
123
124
int model_type_category_;
124
125
std::unique_ptr<llm::IRunner> runner_;
125
126
std::unique_ptr<llm::MultimodalRunner> multi_modal_runner_;
@@ -175,20 +176,17 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
175
176
runner_ = std::make_unique<example::Runner>(
176
177
model_path->toStdString ().c_str (),
177
178
tokenizer_path->toStdString ().c_str (),
178
- temperature,
179
179
data_path->toStdString ().c_str ());
180
180
} else {
181
181
runner_ = std::make_unique<example::Runner>(
182
182
model_path->toStdString ().c_str (),
183
- tokenizer_path->toStdString ().c_str (),
184
- temperature);
183
+ tokenizer_path->toStdString ().c_str ());
185
184
}
186
185
#if defined(EXECUTORCH_BUILD_MEDIATEK)
187
186
} else if (model_type_category == MODEL_TYPE_MEDIATEK_LLAMA) {
188
187
runner_ = std::make_unique<MTKLlamaRunner>(
189
188
model_path->toStdString ().c_str (),
190
- tokenizer_path->toStdString ().c_str (),
191
- temperature);
189
+ tokenizer_path->toStdString ().c_str ());
192
190
// Interpret the model type as LLM
193
191
model_type_category_ = MODEL_TYPE_CATEGORY_LLM;
194
192
#endif
@@ -228,6 +226,7 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
228
226
executorch::extension::llm::GenerationConfig config{
229
227
.echo = static_cast <bool >(echo),
230
228
.seq_len = seq_len,
229
+ .temperature = temperature_,
231
230
};
232
231
runner_->generate (
233
232
prompt->toStdString (),
0 commit comments