Skip to content

Commit 33bddff

Browse files
committed
Dynamically assert model temperature value in argparser
1 parent a3094bd commit 33bddff

File tree

3 files changed

+32
-2
lines changed

3 files changed

+32
-2
lines changed

Diff for: llm_toolkit/models.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
# Model hyper-parameters.
4747
MAX_TOKENS: int = 2000
4848
NUM_SAMPLES: int = 1
49-
TEMPERATURE: float = 0.4
49+
TEMPERATURE: float = 1.0
5050

5151

5252
class LLM:

Diff for: run_all_experiments.py

+3
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,9 @@ def parse_args() -> argparse.Namespace:
279279

280280
if args.temperature:
281281
assert 2 >= args.temperature >= 0, '--temperature must be within 0 and 2.'
282+
283+
if args.temperature == TEMPERATURE and args.model in models.LLM.all_llm_names():
284+
args.temperature = run_one_experiment.get_model_temperature(args)
282285

283286
benchmark_yaml = args.benchmark_yaml
284287
if benchmark_yaml:

Diff for: run_one_experiment.py

+28-1
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454
NUM_SAMPLES = 2
5555
MAX_TOKENS: int = 4096
5656
RUN_TIMEOUT: int = 30
57-
TEMPERATURE: float = 0.4
57+
TEMPERATURE: float = 1.0
5858

5959
RESULTS_DIR = './results'
6060

@@ -311,3 +311,30 @@ def run(benchmark: Benchmark, model: models.LLM, args: argparse.Namespace,
311311

312312
return AggregatedResult.from_benchmark_result(
313313
_fuzzing_pipelines(benchmark, model, args, work_dirs))
314+
315+
316+
def get_model_temperature(args: argparse.Namespace) -> float:
317+
"""Retrieves model temperature default value."""
318+
default_temperatures = {models.VertexAICodeBisonModel.name: 0.2,
319+
models.VertexAICodeBison32KModel.name: 0.2,
320+
models.GeminiPro.name: 0.9,
321+
models.GeminiUltra.name: 0.2,
322+
models.GeminiExperimental.name: 1.0,
323+
models.GeminiV1D5.name: 1.0,
324+
models.GeminiV2Flash.name: 1.0,
325+
models.GeminiV2.name: 1.0,
326+
models.GeminiV2Think.name: 0.7,
327+
models.ClaudeHaikuV3.name: 0.5,
328+
models.ClaudeOpusV3.name: 0.5,
329+
models.ClaudeSonnetV3D5.name: 0.5,
330+
models.GPT.name: 1.0,
331+
models.GPT4.name: 1.0,
332+
models.GPT4o.name: 1.0,
333+
models.GPT4oMini.name: 1.0,
334+
models.GPT4Turbo.name: 1.0}
335+
if args.model.endswith('-chat') or args.model.endswith('-azure'):
336+
model_name = '-'.join(args.model.split('-')[:-1])
337+
else:
338+
model_name = args.model
339+
340+
return default_temperatures[model_name]

0 commit comments

Comments
 (0)