Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions configs/score_documents/lorem_ipsum.yaml
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
#Points to an endpoint where the model is running

settings:
model_name: google/gemma-2-9b-it
model_name: google/gemma-3-27b-it
# we need to set this here manually as this is specified only when hosting the model
num_gpus: 1
tokenizer_name_or_path: ${settings.model_name}

paths:
raw_data_file_paths:
- data/test_fineweb2_dump.jsonl
- /home/abbas-khan/ml_filter/data/test_fineweb2_dump.jsonl
output_directory_path: data/output
prompt_template_file_path: data/prompts/fineweb_edu/educational_prompt.yaml
prompt_template_file_path: /raid/s3/opengptx/mehdi-ali/git_repos/ml_filter/data/prompts/reasoning/general_reasoning.yaml
start_indexes:
- 10
- 0

llm_rest_client:
model_name: ${settings.model_name}
max_tokens: 8192 # The maximum total number of tokens supported by the model (input + output)
max_tokens: 8096 # The maximum total number of tokens supported by the model (input + output)
sampling_params:
max_tokens: 500 # The maximum number of tokens to generate
temperature: 0.7
Expand Down Expand Up @@ -44,8 +44,8 @@ prompt_builder:

document_processor:
output_directory_path: ${settings.paths.output_directory_path}
queue_size: 1000
queue_size: 511
num_processes: 1
score_metric_name: educational_score
score_metric_name: reasoning_score
strings_to_remove: []
jq_language_pattern: .metadata.language
jq_language_pattern: .language
511 changes: 511 additions & 0 deletions data/511_en.jsonl

Large diffs are not rendered by default.

18 changes: 18 additions & 0 deletions data/prompts/reasoning/general_reasoning.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
prompt: |
Below is an extract from a document. Evaluate the quality of its reasoning using the additive 5-point scoring system described below.
Points are accumulated based on the depth and sophistication of reasoning demonstrated across any domain (programming, mathematics, ethics, logic, natural sciences, social sciences,
and other disciplines):

- Add 1 point if the extract identifies a reasoning challenge - whether it's a problem to solve, code to understand, conflict to resolve, concept to explain, or question to answer. The domain and basic elements are recognizable, even if presentation is minimal or disorganized. The document may contain some irrelevant or non-reasoning content like advertisements and promotional material.
- Add another point if the extract demonstrates structured thinking - using clear organization, meaningful terminology, logical sequencing, or systematic approaches that reveal understanding of the domain's conventions and best practices. It might mix reasoning content with non-reasoning material.
- Award a third point if the extract provides explanatory reasoning - offering natural language explanations, describing processes or relationships, connecting cause and effect, or making the underlying logic transparent to readers. The document may include minimal amount of extraneous information.
- Grant a fourth point if the extract engages in analytical reasoning - examining components systematically, tracing through steps or implications, considering multiple perspectives, addressing potential objections, or demonstrating critical evaluation of the reasoning process.
- Bestow a fifth point if the extract demonstrates integrative or meta-level reasoning - discussing broader implications, design principles, trade-offs, limitations, connections to other domains, or reflecting on the reasoning process itself to achieve a coherent, well-justified conclusion.

The extract:
{placeholder}
After examining the extract:
- Briefly justify your score, up to 100 words.
- Conclude with the score using the format: 'Reasoning score: <total points>'

prompt_name: reasoning_content_filter
2 changes: 1 addition & 1 deletion src/ml_filter/data_processing/document_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def __init__(
if score_metric_name not in score_metrics:
raise ValueError(f"Invalid score metric name: {score_metric_name}.")

self.score_metric = score_metrics[score_metric_name]
self.score_metric = score_metrics[score_metric_name]()
self.termination_event = multiprocessing.Event()

@staticmethod
Expand Down
62 changes: 27 additions & 35 deletions src/ml_filter/data_processing/llm_score_metrics.py
Original file line number Diff line number Diff line change
@@ -1,55 +1,47 @@
from dataclasses import dataclass
from dataclasses import dataclass, field

# General pattern used for extracting the numeric score
VALUE_PATTERN = r"\s*(\d+(?:\.\d+)?)"

@dataclass
class LLMScoreMetric:
"""A class used to represent a scoring metric for a Language Learning Model (LLM).
"""
Base class for LLM score metrics.

Attributes:
metric_name (str): The name of the metric.
pattern (str): The pattern used for the metric.
metric_name (str): Name of the metric.
prefix (str): Regex prefix to identify the score type (e.g., "Educational score:").
pattern (str): Full regex pattern for extracting the score (auto-generated).
"""

metric_name: str
pattern: str
prefix: str
pattern: str = field(init=False)

def __post_init__(self):
self.pattern = rf"{self.prefix}{VALUE_PATTERN}"


@dataclass
class EducationalScoreMetric(LLMScoreMetric):
"""
A metric class for extracting educational scores from text.

This class inherits from `LLMScoreMetric` and is designed to identify and
process educational scores using a specific regex pattern.

Attributes:
metric_name (str): The name of the metric, set to "educational_score".
pattern (str): The regex pattern used to extract the educational score
from text. The pattern looks for the phrase "Educational score:"
followed by one or more digits.
"""

metric_name: str = "educational_score"
pattern: str = r"Educational score:\s*(\d+(?:\.\d+)?)"
def __init__(self):
super().__init__(metric_name="educational_score", prefix=r"Educational score:")


@dataclass
class AdultScoreMetric(LLMScoreMetric):
"""
A metric class for extracting educational scores from text.
def __init__(self):
super().__init__(metric_name="adult_score", prefix=r"Adult score:")

This class inherits from `LLMScoreMetric` and is designed to identify and
process educational scores using a specific regex pattern.

Attributes:
metric_name (str): The name of the metric, set to "educational_score".
pattern (str): The regex pattern used to extract the educational score
from text. The pattern looks for the phrase "Educational score:"
followed by one or more digits.
"""

metric_name: str = "adult_score"
pattern: str = r"Adult score:\s*(\d+(?:\.\d+)?)"
@dataclass
class ReasoningScoreMetric(LLMScoreMetric):
def __init__(self):
super().__init__(metric_name="reasoning_score", prefix=r"Reasoning score:")


score_metrics = {"educational_score": EducationalScoreMetric, "adult_score": AdultScoreMetric}
# Factory dictionary to retrieve metric classes by name
score_metrics = {
"educational_score": EducationalScoreMetric,
"adult_score": AdultScoreMetric,
"reasoning_score": ReasoningScoreMetric,
}
3 changes: 2 additions & 1 deletion src/ml_filter/models/annotator_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

import torch
from transformers import AutoConfig, PretrainedConfig
from transformers.modeling_utils import ModelOutput, PreTrainedModel
from transformers.modeling_utils import PreTrainedModel
from transformers.modeling_outputs import ModelOutput

from constants import MODEL_CLASS_MAP
from ml_filter.models.annotator_model_head import (
Expand Down