Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions src/ml_filter/data_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from enum import Enum
from typing import Dict, Union
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

old type annotations


from pydantic import BaseModel, Field


# Define DecodingStrategy Enum
class DecodingStrategy(str, Enum):
"""Decoding strategies for text generation models"""

GREEDY = "greedy"
BEAM_SEARCH = "beam_search"
TOP_K = "top_k"
TOP_P = "top_p"


# Base class for decoding strategy parameters
class DecodingParameters(BaseModel):
"""Decoding strategy parameters"""

strategy: DecodingStrategy
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the strategy a parameter? This seems a little bit redundant when we then define a separate class for each decoding strategy anyway



# Decoding strategy parameter classes
class GreedyParameters(DecodingParameters):
"""Greedy decoding strategy parameters"""

strategy: DecodingStrategy = Field(default=DecodingStrategy.GREEDY)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't greedy also have a temperature flag?

Temperature of 0 -> argmax (special case)
otherwise we would sample a single token from the probability distribution.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think temperature does not affect the results when decoding greedy. Greedy decoding always selects the most likely token as the next one. The temperature flattens or steepens the distribution of the tokens, but the most likely one will always stay the most likely one, regardless of the chosen temperature

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess what you mean is called temperature sampling



class BeamSearchParameters(DecodingParameters):
"""Beam search decoding strategy parameters"""

strategy: DecodingStrategy = Field(default=DecodingStrategy.BEAM_SEARCH)
num_beams: int = Field(..., gt=0, description="Number of beams must be greater than 0.")
early_stopping: bool
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also add a parameter for the total number of beams that are tracked in parallel, not just for each generated token (=num_beams)



class TopKParameters(DecodingParameters):
"""Top-K decoding strategy parameters"""

strategy: DecodingStrategy = Field(default=DecodingStrategy.TOP_K)
top_k: int = Field(..., gt=0, description="Number of top candidates to consider. Must be greater than 0.")
temperature: float = Field(..., gt=0, description="Sampling temperature. Must be greater than 0.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we really need temperature for top-k decoding?

From my understanding, temperature makes the probability more flattend or peaked.
If we apply top-k to a flattened or peaked probability distribution the result should be the same no?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For top k sampling temperature actually makes sense. The top k tokens you sample from do not change; however, the probabilty with which you then randomly sample one of these k tokens is affected significantly by the temperature



class TopPParameters(DecodingParameters):
"""Top-P decoding strategy parameters"""

strategy: DecodingStrategy = Field(default=DecodingStrategy.TOP_P)
top_p: float = Field(
..., gt=0, le=1, description="Cumulative probability for nucleus sampling. Must be in the range (0, 1]."
)
temperature: float = Field(..., gt=0, description="Sampling temperature. Must be greater than 0.")


# General Information about a document
class DocumentInfo(BaseModel):
"""General information about a document"""

document_id: str
prompt: str
prompt_lang: str
raw_data_path: str
model: str
decoding_parameters: Union[GreedyParameters, BeamSearchParameters, TopKParameters, TopPParameters]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
decoding_parameters: Union[GreedyParameters, BeamSearchParameters, TopKParameters, TopPParameters]
decoding_parameters: GreedyParameters | BeamSearchParameters | TopKParameters | TopPParameters



class CorrelationMetrics(BaseModel):
"""Correlation metrics for performance evaluation"""

correlation: Dict[str, Dict[str, float]] # Correlation per ground truth approach
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do the keys of the dict represent? The model (prompt + prompt_lang + llm) that generated the scores? This would mean that the correlation is always measured compared to the ground truth, correct? If we allow tuples of strings as the keys, we could also measure the correlation between different models



class TTestResults(BaseModel):
"""T-Test results for performance evaluation"""

t_test_p_values: Dict[str, float] # p-values for each ground truth approach


class StatisticReport(BaseModel):
"""Complete statistical report combining various metrics"""

document_info: DocumentInfo
correlation_metrics: CorrelationMetrics
t_test_results: TTestResults
103 changes: 103 additions & 0 deletions tests/test_data_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import pytest
from pydantic import ValidationError

from ml_filter.data_models import (
BeamSearchParameters,
CorrelationMetrics,
DecodingStrategy,
DocumentInfo,
GreedyParameters,
StatisticReport,
TopKParameters,
TopPParameters,
TTestResults,
)


def test_greedy_parameters():
params = GreedyParameters()
assert params.strategy == DecodingStrategy.GREEDY


def test_beam_search_parameters():
params = BeamSearchParameters(num_beams=10, early_stopping=False)
assert params.strategy == DecodingStrategy.BEAM_SEARCH
assert params.num_beams == 10
assert not params.early_stopping


def test_top_k_parameters():
params = TopKParameters(top_k=30, temperature=0.7)
assert params.strategy == DecodingStrategy.TOP_K
assert params.top_k == 30
assert params.temperature == 0.7


def test_top_p_parameters():
params = TopPParameters(top_p=0.85, temperature=0.9)
assert params.strategy == DecodingStrategy.TOP_P
assert params.top_p == 0.85
assert params.temperature == 0.9


def test_invalid_decoding_parameters():
with pytest.raises(ValidationError):
BeamSearchParameters(num_beams=-1, early_stopping=False) # Invalid num_beams
with pytest.raises(ValidationError):
TopKParameters(top_k=-5, temperature=0.7) # Invalid top_k
with pytest.raises(ValidationError):
TopPParameters(top_p=1.5, temperature=0.8) # Invalid top_p


def test_document_info_with_greedy():
doc_info = DocumentInfo(
document_id="doc_001",
prompt="Asses the educational value of the text.",
prompt_lang="en",
raw_data_path="/path/to/raw_data.json",
model="test_model",
decoding_parameters=GreedyParameters(),
)
assert doc_info.document_id == "doc_001"
assert doc_info.decoding_parameters.strategy == DecodingStrategy.GREEDY


def test_document_info_with_top_p():
doc_info = DocumentInfo(
document_id="doc_002",
prompt="Asses, whether the text contains adult content.",
prompt_lang="en",
raw_data_path="/path/to/raw_data.json",
model="test_model",
decoding_parameters=TopPParameters(top_p=0.8, temperature=0.6),
)
assert doc_info.document_id == "doc_002"
assert doc_info.decoding_parameters.top_p == 0.8
assert doc_info.decoding_parameters.temperature == 0.6


def test_statistic_report():
doc_info = DocumentInfo(
document_id="doc_003",
prompt="Asses, whether the text contains chain of thoughts.",
prompt_lang="en",
raw_data_path="/path/to/raw_data.json",
model="test_model",
decoding_parameters=BeamSearchParameters(num_beams=5, early_stopping=True),
)
correlation_metrics = CorrelationMetrics(
correlation={
"average": {"pearson": 0.85, "spearman": 0.82},
"min": {"pearson": 0.75, "spearman": 0.72},
}
)
t_test_results = TTestResults(t_test_p_values={"average": 0.03, "min": 0.05})
report = StatisticReport(
document_info=doc_info,
correlation_metrics=correlation_metrics,
t_test_results=t_test_results,
)

assert report.document_info.document_id == "doc_003"
assert report.correlation_metrics.correlation["average"]["pearson"] == 0.85
assert report.t_test_results.t_test_p_values["average"] == 0.03
Comment on lines +17 to +103
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think all of these tests don't really test any functionality and are a bit redundant.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree

4 changes: 3 additions & 1 deletion tests/test_translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ def test_translate_jsonl_to_multiple_languages(
"""Test the translate_jsonl_to_multiple_languages method."""

class MockTranslationClient:
name: str = "mock_translation_client"

def translate_text(self, text, source_language, target_language):
return mock_translate_text(text, source_language, target_language)

Expand Down Expand Up @@ -81,7 +83,7 @@ def supported_target_languages(self):

# Verify output files
for lang in target_languages:
output_file = output_folder / f"input_{lang}.jsonl"
output_file = output_folder / f"input_{lang}_{mock_client.name}.jsonl"
assert output_file.exists()

with open(output_file, "r", encoding="utf-8") as f:
Expand Down