-
Notifications
You must be signed in to change notification settings - Fork 0
Add Data Model #157
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Add Data Model #157
Conversation
le1nux
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left a few minor commands. Otherwise LGTM :)
|
|
||
| strategy: DecodingStrategy = Field(default=DecodingStrategy.TOP_K) | ||
| top_k: int = Field(..., gt=0, description="Number of top candidates to consider. Must be greater than 0.") | ||
| temperature: float = Field(..., gt=0, description="Sampling temperature. Must be greater than 0.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we really need temperature for top-k decoding?
From my understanding, temperature makes the probability more flattend or peaked.
If we apply top-k to a flattened or peaked probability distribution the result should be the same no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For top k sampling temperature actually makes sense. The top k tokens you sample from do not change; however, the probabilty with which you then randomly sample one of these k tokens is affected significantly by the temperature
| class GreedyParameters(DecodingParameters): | ||
| """Greedy decoding strategy parameters""" | ||
|
|
||
| strategy: DecodingStrategy = Field(default=DecodingStrategy.GREEDY) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't greedy also have a temperature flag?
Temperature of 0 -> argmax (special case)
otherwise we would sample a single token from the probability distribution.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think temperature does not affect the results when decoding greedy. Greedy decoding always selects the most likely token as the next one. The temperature flattens or steepens the distribution of the tokens, but the most likely one will always stay the most likely one, regardless of the chosen temperature
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess what you mean is called temperature sampling
| prompt_lang: str | ||
| raw_data_path: str | ||
| model: str | ||
| decoding_parameters: Union[GreedyParameters, BeamSearchParameters, TopKParameters, TopPParameters] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| decoding_parameters: Union[GreedyParameters, BeamSearchParameters, TopKParameters, TopPParameters] | |
| decoding_parameters: GreedyParameters | BeamSearchParameters | TopKParameters | TopPParameters |
| @@ -0,0 +1,86 @@ | |||
| from enum import Enum | |||
| from typing import Dict, Union | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
old type annotations
| def test_greedy_parameters(): | ||
| params = GreedyParameters() | ||
| assert params.strategy == DecodingStrategy.GREEDY | ||
|
|
||
|
|
||
| def test_beam_search_parameters(): | ||
| params = BeamSearchParameters(num_beams=10, early_stopping=False) | ||
| assert params.strategy == DecodingStrategy.BEAM_SEARCH | ||
| assert params.num_beams == 10 | ||
| assert not params.early_stopping | ||
|
|
||
|
|
||
| def test_top_k_parameters(): | ||
| params = TopKParameters(top_k=30, temperature=0.7) | ||
| assert params.strategy == DecodingStrategy.TOP_K | ||
| assert params.top_k == 30 | ||
| assert params.temperature == 0.7 | ||
|
|
||
|
|
||
| def test_top_p_parameters(): | ||
| params = TopPParameters(top_p=0.85, temperature=0.9) | ||
| assert params.strategy == DecodingStrategy.TOP_P | ||
| assert params.top_p == 0.85 | ||
| assert params.temperature == 0.9 | ||
|
|
||
|
|
||
| def test_invalid_decoding_parameters(): | ||
| with pytest.raises(ValidationError): | ||
| BeamSearchParameters(num_beams=-1, early_stopping=False) # Invalid num_beams | ||
| with pytest.raises(ValidationError): | ||
| TopKParameters(top_k=-5, temperature=0.7) # Invalid top_k | ||
| with pytest.raises(ValidationError): | ||
| TopPParameters(top_p=1.5, temperature=0.8) # Invalid top_p | ||
|
|
||
|
|
||
| def test_document_info_with_greedy(): | ||
| doc_info = DocumentInfo( | ||
| document_id="doc_001", | ||
| prompt="Asses the educational value of the text.", | ||
| prompt_lang="en", | ||
| raw_data_path="/path/to/raw_data.json", | ||
| model="test_model", | ||
| decoding_parameters=GreedyParameters(), | ||
| ) | ||
| assert doc_info.document_id == "doc_001" | ||
| assert doc_info.decoding_parameters.strategy == DecodingStrategy.GREEDY | ||
|
|
||
|
|
||
| def test_document_info_with_top_p(): | ||
| doc_info = DocumentInfo( | ||
| document_id="doc_002", | ||
| prompt="Asses, whether the text contains adult content.", | ||
| prompt_lang="en", | ||
| raw_data_path="/path/to/raw_data.json", | ||
| model="test_model", | ||
| decoding_parameters=TopPParameters(top_p=0.8, temperature=0.6), | ||
| ) | ||
| assert doc_info.document_id == "doc_002" | ||
| assert doc_info.decoding_parameters.top_p == 0.8 | ||
| assert doc_info.decoding_parameters.temperature == 0.6 | ||
|
|
||
|
|
||
| def test_statistic_report(): | ||
| doc_info = DocumentInfo( | ||
| document_id="doc_003", | ||
| prompt="Asses, whether the text contains chain of thoughts.", | ||
| prompt_lang="en", | ||
| raw_data_path="/path/to/raw_data.json", | ||
| model="test_model", | ||
| decoding_parameters=BeamSearchParameters(num_beams=5, early_stopping=True), | ||
| ) | ||
| correlation_metrics = CorrelationMetrics( | ||
| correlation={ | ||
| "average": {"pearson": 0.85, "spearman": 0.82}, | ||
| "min": {"pearson": 0.75, "spearman": 0.72}, | ||
| } | ||
| ) | ||
| t_test_results = TTestResults(t_test_p_values={"average": 0.03, "min": 0.05}) | ||
| report = StatisticReport( | ||
| document_info=doc_info, | ||
| correlation_metrics=correlation_metrics, | ||
| t_test_results=t_test_results, | ||
| ) | ||
|
|
||
| assert report.document_info.document_id == "doc_003" | ||
| assert report.correlation_metrics.correlation["average"]["pearson"] == 0.85 | ||
| assert report.t_test_results.t_test_p_values["average"] == 0.03 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think all of these tests don't really test any functionality and are a bit redundant.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree
rrutmann
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The classes in general look good, despite minor things that I mentioned in the comments. However, if I see it correctly, here are only data classes defined, that are not used in the rest of the code, correct? The actual functionality to make use of these configurations is missing currently
| class GreedyParameters(DecodingParameters): | ||
| """Greedy decoding strategy parameters""" | ||
|
|
||
| strategy: DecodingStrategy = Field(default=DecodingStrategy.GREEDY) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think temperature does not affect the results when decoding greedy. Greedy decoding always selects the most likely token as the next one. The temperature flattens or steepens the distribution of the tokens, but the most likely one will always stay the most likely one, regardless of the chosen temperature
|
|
||
| strategy: DecodingStrategy = Field(default=DecodingStrategy.TOP_K) | ||
| top_k: int = Field(..., gt=0, description="Number of top candidates to consider. Must be greater than 0.") | ||
| temperature: float = Field(..., gt=0, description="Sampling temperature. Must be greater than 0.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For top k sampling temperature actually makes sense. The top k tokens you sample from do not change; however, the probabilty with which you then randomly sample one of these k tokens is affected significantly by the temperature
| def test_greedy_parameters(): | ||
| params = GreedyParameters() | ||
| assert params.strategy == DecodingStrategy.GREEDY | ||
|
|
||
|
|
||
| def test_beam_search_parameters(): | ||
| params = BeamSearchParameters(num_beams=10, early_stopping=False) | ||
| assert params.strategy == DecodingStrategy.BEAM_SEARCH | ||
| assert params.num_beams == 10 | ||
| assert not params.early_stopping | ||
|
|
||
|
|
||
| def test_top_k_parameters(): | ||
| params = TopKParameters(top_k=30, temperature=0.7) | ||
| assert params.strategy == DecodingStrategy.TOP_K | ||
| assert params.top_k == 30 | ||
| assert params.temperature == 0.7 | ||
|
|
||
|
|
||
| def test_top_p_parameters(): | ||
| params = TopPParameters(top_p=0.85, temperature=0.9) | ||
| assert params.strategy == DecodingStrategy.TOP_P | ||
| assert params.top_p == 0.85 | ||
| assert params.temperature == 0.9 | ||
|
|
||
|
|
||
| def test_invalid_decoding_parameters(): | ||
| with pytest.raises(ValidationError): | ||
| BeamSearchParameters(num_beams=-1, early_stopping=False) # Invalid num_beams | ||
| with pytest.raises(ValidationError): | ||
| TopKParameters(top_k=-5, temperature=0.7) # Invalid top_k | ||
| with pytest.raises(ValidationError): | ||
| TopPParameters(top_p=1.5, temperature=0.8) # Invalid top_p | ||
|
|
||
|
|
||
| def test_document_info_with_greedy(): | ||
| doc_info = DocumentInfo( | ||
| document_id="doc_001", | ||
| prompt="Asses the educational value of the text.", | ||
| prompt_lang="en", | ||
| raw_data_path="/path/to/raw_data.json", | ||
| model="test_model", | ||
| decoding_parameters=GreedyParameters(), | ||
| ) | ||
| assert doc_info.document_id == "doc_001" | ||
| assert doc_info.decoding_parameters.strategy == DecodingStrategy.GREEDY | ||
|
|
||
|
|
||
| def test_document_info_with_top_p(): | ||
| doc_info = DocumentInfo( | ||
| document_id="doc_002", | ||
| prompt="Asses, whether the text contains adult content.", | ||
| prompt_lang="en", | ||
| raw_data_path="/path/to/raw_data.json", | ||
| model="test_model", | ||
| decoding_parameters=TopPParameters(top_p=0.8, temperature=0.6), | ||
| ) | ||
| assert doc_info.document_id == "doc_002" | ||
| assert doc_info.decoding_parameters.top_p == 0.8 | ||
| assert doc_info.decoding_parameters.temperature == 0.6 | ||
|
|
||
|
|
||
| def test_statistic_report(): | ||
| doc_info = DocumentInfo( | ||
| document_id="doc_003", | ||
| prompt="Asses, whether the text contains chain of thoughts.", | ||
| prompt_lang="en", | ||
| raw_data_path="/path/to/raw_data.json", | ||
| model="test_model", | ||
| decoding_parameters=BeamSearchParameters(num_beams=5, early_stopping=True), | ||
| ) | ||
| correlation_metrics = CorrelationMetrics( | ||
| correlation={ | ||
| "average": {"pearson": 0.85, "spearman": 0.82}, | ||
| "min": {"pearson": 0.75, "spearman": 0.72}, | ||
| } | ||
| ) | ||
| t_test_results = TTestResults(t_test_p_values={"average": 0.03, "min": 0.05}) | ||
| report = StatisticReport( | ||
| document_info=doc_info, | ||
| correlation_metrics=correlation_metrics, | ||
| t_test_results=t_test_results, | ||
| ) | ||
|
|
||
| assert report.document_info.document_id == "doc_003" | ||
| assert report.correlation_metrics.correlation["average"]["pearson"] == 0.85 | ||
| assert report.t_test_results.t_test_p_values["average"] == 0.03 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree
|
|
||
| strategy: DecodingStrategy = Field(default=DecodingStrategy.BEAM_SEARCH) | ||
| num_beams: int = Field(..., gt=0, description="Number of beams must be greater than 0.") | ||
| early_stopping: bool |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should also add a parameter for the total number of beams that are tracked in parallel, not just for each generated token (=num_beams)
| class DecodingParameters(BaseModel): | ||
| """Decoding strategy parameters""" | ||
|
|
||
| strategy: DecodingStrategy |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is the strategy a parameter? This seems a little bit redundant when we then define a separate class for each decoding strategy anyway
| class CorrelationMetrics(BaseModel): | ||
| """Correlation metrics for performance evaluation""" | ||
|
|
||
| correlation: Dict[str, Dict[str, float]] # Correlation per ground truth approach |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do the keys of the dict represent? The model (prompt + prompt_lang + llm) that generated the scores? This would mean that the correlation is always measured compared to the ground truth, correct? If we allow tuples of strings as the keys, we could also measure the correlation between different models
This PR introduces a modular and extensible structure for configuring decoding strategies in a Pydantic-based framework. The changes include dedicated data classes for each decoding strategy (e.g., Greedy, Beam Search, Top-K, Top-P) with configurable parameters. These classes are seamlessly integrated into the existing statistical reporting framework.
Key Changes
Decoding Strategy Parameter Classes:
GreedyParameters: No additional parameters beyond the strategy name.BeamSearchParameters: Configurable fields:num_beams: Number of beams for beam search.early_stopping: Whether to stop early.TopKParameters: Configurable fields:top_k: Number of top candidates to consider.temperature: Sampling temperature.TopPParameters: Configurable fields:top_p: Probability mass for nucleus sampling.temperature: Sampling temperature.Integration into DocumentInfo:
decoding_parametersfield in theDocumentInfoclass now accepts any of the decoding parameter classes (GreedyParameters,BeamSearchParameters,TopKParameters,TopPParameters).Statistical Reporting:
StatisticReportintegratesDocumentInfoand statistical metrics (e.g.,CorrelationMetricsandTTestResults)