-
Notifications
You must be signed in to change notification settings - Fork 0
Add Data Model #157
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Add Data Model #157
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,86 @@ | ||||||
| from enum import Enum | ||||||
| from typing import Dict, Union | ||||||
|
|
||||||
| from pydantic import BaseModel, Field | ||||||
|
|
||||||
|
|
||||||
| # Define DecodingStrategy Enum | ||||||
| class DecodingStrategy(str, Enum): | ||||||
| """Decoding strategies for text generation models""" | ||||||
|
|
||||||
| GREEDY = "greedy" | ||||||
| BEAM_SEARCH = "beam_search" | ||||||
| TOP_K = "top_k" | ||||||
| TOP_P = "top_p" | ||||||
|
|
||||||
|
|
||||||
| # Base class for decoding strategy parameters | ||||||
| class DecodingParameters(BaseModel): | ||||||
| """Decoding strategy parameters""" | ||||||
|
|
||||||
| strategy: DecodingStrategy | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is the strategy a parameter? This seems a little bit redundant when we then define a separate class for each decoding strategy anyway |
||||||
|
|
||||||
|
|
||||||
| # Decoding strategy parameter classes | ||||||
| class GreedyParameters(DecodingParameters): | ||||||
| """Greedy decoding strategy parameters""" | ||||||
|
|
||||||
| strategy: DecodingStrategy = Field(default=DecodingStrategy.GREEDY) | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't greedy also have a temperature flag? Temperature of 0 -> argmax (special case)
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think temperature does not affect the results when decoding greedy. Greedy decoding always selects the most likely token as the next one. The temperature flattens or steepens the distribution of the tokens, but the most likely one will always stay the most likely one, regardless of the chosen temperature
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess what you mean is called temperature sampling |
||||||
|
|
||||||
|
|
||||||
| class BeamSearchParameters(DecodingParameters): | ||||||
| """Beam search decoding strategy parameters""" | ||||||
|
|
||||||
| strategy: DecodingStrategy = Field(default=DecodingStrategy.BEAM_SEARCH) | ||||||
| num_beams: int = Field(..., gt=0, description="Number of beams must be greater than 0.") | ||||||
| early_stopping: bool | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should also add a parameter for the total number of beams that are tracked in parallel, not just for each generated token (=num_beams) |
||||||
|
|
||||||
|
|
||||||
| class TopKParameters(DecodingParameters): | ||||||
| """Top-K decoding strategy parameters""" | ||||||
|
|
||||||
| strategy: DecodingStrategy = Field(default=DecodingStrategy.TOP_K) | ||||||
| top_k: int = Field(..., gt=0, description="Number of top candidates to consider. Must be greater than 0.") | ||||||
| temperature: float = Field(..., gt=0, description="Sampling temperature. Must be greater than 0.") | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we really need temperature for top-k decoding? From my understanding, temperature makes the probability more flattend or peaked.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For top k sampling temperature actually makes sense. The top k tokens you sample from do not change; however, the probabilty with which you then randomly sample one of these k tokens is affected significantly by the temperature |
||||||
|
|
||||||
|
|
||||||
| class TopPParameters(DecodingParameters): | ||||||
| """Top-P decoding strategy parameters""" | ||||||
|
|
||||||
| strategy: DecodingStrategy = Field(default=DecodingStrategy.TOP_P) | ||||||
| top_p: float = Field( | ||||||
| ..., gt=0, le=1, description="Cumulative probability for nucleus sampling. Must be in the range (0, 1]." | ||||||
| ) | ||||||
| temperature: float = Field(..., gt=0, description="Sampling temperature. Must be greater than 0.") | ||||||
|
|
||||||
|
|
||||||
| # General Information about a document | ||||||
| class DocumentInfo(BaseModel): | ||||||
| """General information about a document""" | ||||||
|
|
||||||
| document_id: str | ||||||
| prompt: str | ||||||
| prompt_lang: str | ||||||
| raw_data_path: str | ||||||
| model: str | ||||||
| decoding_parameters: Union[GreedyParameters, BeamSearchParameters, TopKParameters, TopPParameters] | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
|
||||||
|
|
||||||
| class CorrelationMetrics(BaseModel): | ||||||
| """Correlation metrics for performance evaluation""" | ||||||
|
|
||||||
| correlation: Dict[str, Dict[str, float]] # Correlation per ground truth approach | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What do the keys of the dict represent? The model (prompt + prompt_lang + llm) that generated the scores? This would mean that the correlation is always measured compared to the ground truth, correct? If we allow tuples of strings as the keys, we could also measure the correlation between different models |
||||||
|
|
||||||
|
|
||||||
| class TTestResults(BaseModel): | ||||||
| """T-Test results for performance evaluation""" | ||||||
|
|
||||||
| t_test_p_values: Dict[str, float] # p-values for each ground truth approach | ||||||
|
|
||||||
|
|
||||||
| class StatisticReport(BaseModel): | ||||||
| """Complete statistical report combining various metrics""" | ||||||
|
|
||||||
| document_info: DocumentInfo | ||||||
| correlation_metrics: CorrelationMetrics | ||||||
| t_test_results: TTestResults | ||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,103 @@ | ||
| import pytest | ||
| from pydantic import ValidationError | ||
|
|
||
| from ml_filter.data_models import ( | ||
| BeamSearchParameters, | ||
| CorrelationMetrics, | ||
| DecodingStrategy, | ||
| DocumentInfo, | ||
| GreedyParameters, | ||
| StatisticReport, | ||
| TopKParameters, | ||
| TopPParameters, | ||
| TTestResults, | ||
| ) | ||
|
|
||
|
|
||
| def test_greedy_parameters(): | ||
| params = GreedyParameters() | ||
| assert params.strategy == DecodingStrategy.GREEDY | ||
|
|
||
|
|
||
| def test_beam_search_parameters(): | ||
| params = BeamSearchParameters(num_beams=10, early_stopping=False) | ||
| assert params.strategy == DecodingStrategy.BEAM_SEARCH | ||
| assert params.num_beams == 10 | ||
| assert not params.early_stopping | ||
|
|
||
|
|
||
| def test_top_k_parameters(): | ||
| params = TopKParameters(top_k=30, temperature=0.7) | ||
| assert params.strategy == DecodingStrategy.TOP_K | ||
| assert params.top_k == 30 | ||
| assert params.temperature == 0.7 | ||
|
|
||
|
|
||
| def test_top_p_parameters(): | ||
| params = TopPParameters(top_p=0.85, temperature=0.9) | ||
| assert params.strategy == DecodingStrategy.TOP_P | ||
| assert params.top_p == 0.85 | ||
| assert params.temperature == 0.9 | ||
|
|
||
|
|
||
| def test_invalid_decoding_parameters(): | ||
| with pytest.raises(ValidationError): | ||
| BeamSearchParameters(num_beams=-1, early_stopping=False) # Invalid num_beams | ||
| with pytest.raises(ValidationError): | ||
| TopKParameters(top_k=-5, temperature=0.7) # Invalid top_k | ||
| with pytest.raises(ValidationError): | ||
| TopPParameters(top_p=1.5, temperature=0.8) # Invalid top_p | ||
|
|
||
|
|
||
| def test_document_info_with_greedy(): | ||
| doc_info = DocumentInfo( | ||
| document_id="doc_001", | ||
| prompt="Asses the educational value of the text.", | ||
| prompt_lang="en", | ||
| raw_data_path="/path/to/raw_data.json", | ||
| model="test_model", | ||
| decoding_parameters=GreedyParameters(), | ||
| ) | ||
| assert doc_info.document_id == "doc_001" | ||
| assert doc_info.decoding_parameters.strategy == DecodingStrategy.GREEDY | ||
|
|
||
|
|
||
| def test_document_info_with_top_p(): | ||
| doc_info = DocumentInfo( | ||
| document_id="doc_002", | ||
| prompt="Asses, whether the text contains adult content.", | ||
| prompt_lang="en", | ||
| raw_data_path="/path/to/raw_data.json", | ||
| model="test_model", | ||
| decoding_parameters=TopPParameters(top_p=0.8, temperature=0.6), | ||
| ) | ||
| assert doc_info.document_id == "doc_002" | ||
| assert doc_info.decoding_parameters.top_p == 0.8 | ||
| assert doc_info.decoding_parameters.temperature == 0.6 | ||
|
|
||
|
|
||
| def test_statistic_report(): | ||
| doc_info = DocumentInfo( | ||
| document_id="doc_003", | ||
| prompt="Asses, whether the text contains chain of thoughts.", | ||
| prompt_lang="en", | ||
| raw_data_path="/path/to/raw_data.json", | ||
| model="test_model", | ||
| decoding_parameters=BeamSearchParameters(num_beams=5, early_stopping=True), | ||
| ) | ||
| correlation_metrics = CorrelationMetrics( | ||
| correlation={ | ||
| "average": {"pearson": 0.85, "spearman": 0.82}, | ||
| "min": {"pearson": 0.75, "spearman": 0.72}, | ||
| } | ||
| ) | ||
| t_test_results = TTestResults(t_test_p_values={"average": 0.03, "min": 0.05}) | ||
| report = StatisticReport( | ||
| document_info=doc_info, | ||
| correlation_metrics=correlation_metrics, | ||
| t_test_results=t_test_results, | ||
| ) | ||
|
|
||
| assert report.document_info.document_id == "doc_003" | ||
| assert report.correlation_metrics.correlation["average"]["pearson"] == 0.85 | ||
| assert report.t_test_results.t_test_p_values["average"] == 0.03 | ||
|
Comment on lines
+17
to
+103
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think all of these tests don't really test any functionality and are a bit redundant.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree |
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
old type annotations