-
Notifications
You must be signed in to change notification settings - Fork 169
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Dynamically assert model temperature value in argparser #856
base: main
Are you sure you want to change the base?
Conversation
Hi @DonggeLiu Could you please review it when you get a chance? Many thanks in advance! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @dberardi99.
I left a comment to clarify a bit more about the task.
Please let me know if that makes sense : )
llm_toolkit/models.py
Outdated
@@ -46,7 +46,7 @@ | |||
# Model hyper-parameters. | |||
MAX_TOKENS: int = 2000 | |||
NUM_SAMPLES: int = 1 | |||
TEMPERATURE: float = 0.4 | |||
TEMPERATURE: float = 1.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ops let's keep the default temperature the same for now to avoid causing surprising results in other people's recent experiments.
We can grid search for the best default values for our use case later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perfect, I got it. I'll reset it to its prior value
run_all_experiments.py
Outdated
@@ -279,6 +279,9 @@ def parse_args() -> argparse.Namespace: | |||
|
|||
if args.temperature: | |||
assert 2 >= args.temperature >= 0, '--temperature must be within 0 and 2.' | |||
|
|||
if args.temperature == TEMPERATURE and args.model in models.LLM.all_llm_names(): | |||
args.temperature = run_one_experiment.get_model_temperature(args) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I reckon the issue we solve has 2 tasks:
- Main: Different models have different temperature ranges, we want to assert that if user specified a temperature in args, then it should fall into the corresponding models range.
- Minor: Define a default temperature for each model.
let's solve the main task first:
1.1. Define the temperature rate for each model class in https://github.com/google/oss-fuzz-gen/blob/main/llm_toolkit/models.py. Use inheritance to minimize the changes needed.
1.2. Replace this hardcoded assertion with dynamic assertion based on the model name:
oss-fuzz-gen/run_all_experiments.py
Lines 280 to 281 in 33bddff
if args.temperature: | |
assert 2 >= args.temperature >= 0, '--temperature must be within 0 and 2.' |
Then we can work on the minor task:
- Add default temperatures under each class.
- Set the temperature as the default value here, if user did not specify it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Everything's clear! Just one thing, do you prefer to solve only the different temperature ranges in this PR or to implement both temperature ranges and default temperatures here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @dberardi99, some nits.
@@ -61,6 +61,8 @@ class LLM: | |||
|
|||
_max_attempts = 5 # Maximum number of attempts to get prediction response | |||
|
|||
temperature_range: list[float] = [0.0, 2.0] # Default model temperature range |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's keep this attribute but make the value a constant like TEMPERATURE
defined at the top of the file:
temperature_range: list[float] = TEMPERATURE_RANGE
if (hasattr(subcls, 'temperature_range') and hasattr(subcls, 'name') | ||
and subcls.name != AIBinaryModel.name): | ||
ranges[subcls.name] = subcls.temperature_range | ||
return ranges |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC, you are replicating all_llm_names
.
Would it be more maintainable and extensible to extract and reuse the repeating logic of these two functions? E.g., a function to return all models so that we can reuse it to acquire other attributes in the future.
@classmethod
def _all_llm_models(cls):
"""
Returns a list of LLM model classes that have a `name` attribute
and are not `AIBinaryModel`.
"""
models = []
for subcls in cls.all_llm_subclasses():
# May need a different filter logic here.
if subcls.name != AIBinaryModel.name:
models.append(subcls)
return models
@classmethod
def all_llm_names(cls) -> list[str]:
"""Returns the current model name and all child model names."""
return [m.name for m in cls._all_llm_models()]
@classmethod
def all_llm_temperature_ranges(cls) -> dict[str, list[float, float]]:
"""Returns the current model name and all child model temperature ranges."""
return {
m.name: m.temperature_range
for m in cls._all_llm_models()
if hasattr(m, 'temperature_range')
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Feel free to adjust/simplify these functions as you see fit, particular the filtering logic.
These above are just examples.
ranges = models.LLM.all_llm_temperature_ranges() | ||
assert ranges[args.model][1] >= args.temperature >= ranges[args.model][0], ( | ||
f'--temperature must be within {ranges[args.model][0]} and ' | ||
f'{ranges[args.model][1]}.') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add ... for model args.model
specify the model name we parsed.
A new method
get_model_temperature(args: argparse.Namespace)
has been added to hardcode the model temperature value based on the model name. The chat session models (namely the ones ending with "-chat" or "-azure") have been treated as the corresponding base models. The temperature values can be found in the following data table #366 (comment)The temperature value will be automatically aligned to the one of the model chosen only if no temperature has been set (
args.temperature == TEMPERATURE
). Instead, if a wrong model name is fed (args.model in models.LLM.all_llm_names() == 0
), the above method is skipped and the temperature is left unchanged to its default value.In addition, the default temperature has been changed to 1.0 since it is the one relative to the default model (vertex_ai_gemini-1-5)
Fix #366