Skip to content

Add data cleaning in fast-llm prepare, concept #210

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 16 additions & 7 deletions fast_llm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,16 @@

import yaml

from fast_llm.utils import Assert, Tag, get_type_name, header, log, pop_nested_dict_value, set_nested_dict_value
from fast_llm.utils import (
Assert,
Tag,
Registry,
get_type_name,
header,
log,
pop_nested_dict_value,
set_nested_dict_value,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -634,17 +643,17 @@ def _serialize_value(cls, value: typing.Any) -> int | float | bool | str | None:
value = str(value)
return value

def to_copy[
T
](self: T, *updates: typing.Union["Config", dict[str | tuple[str, ...], typing.Any]], strict: bool = True,) -> T:
def to_copy[T](
self: T,
*updates: typing.Union["Config", dict[str | tuple[str, ...], typing.Any]],
strict: bool = True,
) -> T:
return self.from_dict(self, *updates, strict=strict)

def to_serialized(self, verbose: int | None = FieldVerboseLevel.core) -> dict[str, typing.Any]:
return self._to_dict(verbose=verbose, format_=_ConfigDictFormat.nested, serializable=True)

def to_logs[
T
](
def to_logs[T](
self,
verbose: int | None = FieldVerboseLevel.core,
log_fn: typing.Callable[[str], T] = logger.info,
Expand Down
48 changes: 14 additions & 34 deletions fast_llm/data/preparator/gpt_memmap/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
import pathlib
import typing

Expand All @@ -8,6 +7,9 @@
from fast_llm.engine.config_utils.data_type import DataType
from fast_llm.utils import Assert

from fast_llm.data.preparator.gpt_memmap.distributed_config import DatasetPreparatorDistributedConfig
from fast_llm.data.preparator.gpt_memmap.hf_processors.configs import HFProcessorConfig, ProcessorsConfig

if typing.TYPE_CHECKING:
from fast_llm.data.preparator.gpt_memmap.prepare import GPTMemmapDatasetPreparator
MEMMAP_DTYPES = {
Expand Down Expand Up @@ -77,39 +79,6 @@ class GPTHuggingfaceDatasetConfig(Config):
)


@config_class
class DatasetPreparatorDistributedConfig(Config):
# TODO: Unify with fast_llm.engine.distributed.config.DistributedConfig

default_world_size: typing.ClassVar[int] = int(os.environ.get("WORLD_SIZE", 1))
default_rank: typing.ClassVar[int] = int(os.environ.get("RANK", 0))
world_size: int = Field(
default=None,
desc="Size of the world group. Typically provided by torchrun or equivalent through the `WORLD_SIZE` environment variable.",
hint=FieldHint.expert,
valid=check_field(Assert.gt, 0),
)
rank: int = Field(
default=None,
desc="Rank of the local process. Typically provided by torchrun or equivalent through the `RANK` environment variable.",
hint=FieldHint.expert,
valid=check_field(Assert.geq, 0),
)
backend: str = Field(
default="gloo",
desc="Distributed backend to use.",
hint=FieldHint.optional,
)

def _validate(self) -> None:
if self.world_size is None:
self.world_size = self.default_world_size
if self.rank is None:
self.rank = self.default_rank
super()._validate()
Assert.in_range(self.rank, 0, self.world_size)


@config_class()
class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig):
preparator_name: typing.ClassVar[str] = "gpt_memmap"
Expand Down Expand Up @@ -165,12 +134,23 @@ class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig):
hint=FieldHint.optional,
)

# TODO: Add desc and hint.
processors: ProcessorsConfig = Field(default=ProcessorsConfig)

def _validate(self) -> None:
assert self.tokenizer.path is not None
if self.dataset.data_type is not None:
Assert.incl(DataType.from_numpy(self.dataset.data_type.numpy), MEMMAP_DTYPES_INV)
super()._validate()

# Propagete datasaet field name and workers count if not set in processors' configs.
for processor_config_field_name in self.processors.get_processor_types_map().keys():
config: HFProcessorConfig = getattr(self.processors, processor_config_field_name)
if config.field is None:
config.field = self.dataset.field
if config.num_proc is None:
config.num_proc = self.tokenize_workers

@classmethod
def get_dataset_preparator_class(cls) -> type["GPTMemmapDatasetPreparator"]:
from fast_llm.data.preparator.gpt_memmap.prepare import GPTMemmapDatasetPreparator
Expand Down
38 changes: 38 additions & 0 deletions fast_llm/data/preparator/gpt_memmap/distributed_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import os
import typing

from fast_llm.config import Config, Field, FieldHint, check_field, config_class
from fast_llm.utils import Assert


@config_class
class DatasetPreparatorDistributedConfig(Config):
# TODO: Unify with fast_llm.engine.distributed.config.DistributedConfig

default_world_size: typing.ClassVar[int] = int(os.environ.get("WORLD_SIZE", 1))
default_rank: typing.ClassVar[int] = int(os.environ.get("RANK", 0))
world_size: int = Field(
default=None,
desc="Size of the world group. Typically provided by torchrun or equivalent through the `WORLD_SIZE` environment variable.",
hint=FieldHint.expert,
valid=check_field(Assert.gt, 0),
)
rank: int = Field(
default=None,
desc="Rank of the local process. Typically provided by torchrun or equivalent through the `RANK` environment variable.",
hint=FieldHint.expert,
valid=check_field(Assert.geq, 0),
)
backend: str = Field(
default="gloo",
desc="Distributed backend to use.",
hint=FieldHint.optional,
)

def _validate(self) -> None:
if self.world_size is None:
self.world_size = self.default_world_size
if self.rank is None:
self.rank = self.default_rank
super()._validate()
Assert.in_range(self.rank, 0, self.world_size)
172 changes: 172 additions & 0 deletions fast_llm/data/preparator/gpt_memmap/hf_processors/configs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
import abc
import datasets
import typing

from fast_llm.config import Config, Configurable, Field, FieldUpdate, config_class
from fast_llm.data.preparator.gpt_memmap.distributed_config import DatasetPreparatorDistributedConfig


# TODO: Add desc and hint to all fields.


@config_class
class HFProcessorConfig(Config):
use_processor: bool = Field(default=True)
human_readable_name: str = Field(default="")
batch_size: int | None = Field(default=None)
num_proc: int | None = Field(default=None)
field: str | None = Field(default=None)


class HFProcessor[ConfigType: HFProcessorConfig](Configurable[ConfigType], abc.ABC):
config_class: typing.ClassVar[type[HFProcessorConfig]] = HFProcessorConfig

def __init__(self, config: ConfigType, distributed_config: DatasetPreparatorDistributedConfig, *args, **kwargs):
super().__init__(config, *args, **kwargs)

self._distributed_config = distributed_config

@abc.abstractmethod
def apply(self, dataset: datasets.Dataset) -> datasets.Dataset:
raise NotImplementedError


@config_class
class DocLengthFilterProcessorConfig(HFProcessorConfig):
human_readable_name: str | None = FieldUpdate(default="Document Length Filter")
min_length_chars: int = Field(default=0)
max_length_chars: int = Field(default=1_000_000)


class DocLengthFilterProcessor[ConfigType: DocLengthFilterProcessorConfig](HFProcessor[ConfigType]):
config_class: typing.ClassVar[type[DocLengthFilterProcessorConfig]] = DocLengthFilterProcessorConfig

def apply(self, dataset: datasets.Dataset) -> datasets.Dataset:
from fast_llm.data.preparator.gpt_memmap.hf_processors.processors import apply_doc_length_filter_processor

return apply_doc_length_filter_processor(self._config, dataset)


@config_class
class NGramRepetitionFilterProcessorConfig(HFProcessorConfig):
human_readable_name: str | None = FieldUpdate(default="N-Gram Repetition Filter")
n: int = Field(default=5)
max_repetitions: int = Field(default=32)


class NGramRepetitionFilterProcessor[ConfigType: NGramRepetitionFilterProcessorConfig](HFProcessor[ConfigType]):
config_class: typing.ClassVar[type[NGramRepetitionFilterProcessorConfig]] = NGramRepetitionFilterProcessorConfig

def apply(self, dataset: datasets.Dataset) -> datasets.Dataset:
from fast_llm.data.preparator.gpt_memmap.hf_processors.processors import (
apply_ngram_repetition_filter_processor,
)

return apply_ngram_repetition_filter_processor(self._config, dataset)


@config_class
class FrequencyBasedFilterProcessorConfig(HFProcessorConfig):
human_readable_name: str | None = FieldUpdate(default="Frequency-Based Filter")
max_single_word_ratio: float = Field(default=0.3)
max_top_two_word_ratio: float = Field(default=0.5)


class FrequencyBasedFilterProcessor[ConfigType: FrequencyBasedFilterProcessorConfig](HFProcessor[ConfigType]):
config_class: typing.ClassVar[type[FrequencyBasedFilterProcessorConfig]] = FrequencyBasedFilterProcessorConfig

def apply(self, dataset: datasets.Dataset) -> datasets.Dataset:
from fast_llm.data.preparator.gpt_memmap.hf_processors.processors import apply_frequency_based_filter_processor

return apply_frequency_based_filter_processor(self._config, dataset)


@config_class
class BinaryContentFilterProcessorConfig(HFProcessorConfig):
human_readable_name: str | None = FieldUpdate(default="Binary Content Filter")
max_bin_ratio: float = Field(default=0.5)


class BinaryContentFilterProcessor[ConfigType: BinaryContentFilterProcessorConfig](HFProcessor[ConfigType]):
config_class: typing.ClassVar[type[BinaryContentFilterProcessorConfig]] = BinaryContentFilterProcessorConfig

def apply(self, dataset: datasets.Dataset) -> datasets.Dataset:
from fast_llm.data.preparator.gpt_memmap.hf_processors.processors import apply_binary_content_filter_processor

return apply_binary_content_filter_processor(self._config, dataset)


@config_class
class NumericalContentFilterProcessorConfig(HFProcessorConfig):
human_readable_name: str | None = FieldUpdate(default="Numerical Content Filter")
max_numeric_token_ratio: float = Field(default=0.5)


class NumericalContentFilterProcessor[ConfigType: NumericalContentFilterProcessorConfig](HFProcessor[ConfigType]):
config_class: typing.ClassVar[type[NumericalContentFilterProcessorConfig]] = NumericalContentFilterProcessorConfig

def apply(self, dataset: datasets.Dataset) -> datasets.Dataset:
from fast_llm.data.preparator.gpt_memmap.hf_processors.processors import (
apply_numerical_content_filter_processor,
)

return apply_numerical_content_filter_processor(self._config, dataset)


@config_class
class PiiRedactionProcessorConfig(HFProcessorConfig):
use_processor: bool = FieldUpdate(default=False)
human_readable_name: str | None = FieldUpdate(default="PII Redaction Processor")
# TODO: make enum
redaction_method: str = Field(default="remove") # Options: 'remove', 'mask'


class PiiRedactionProcessor[ConfigType: PiiRedactionProcessorConfig](HFProcessor[ConfigType]):
config_class: typing.ClassVar[type[PiiRedactionProcessorConfig]] = PiiRedactionProcessorConfig

def apply(self, dataset: datasets.Dataset) -> datasets.Dataset:
from fast_llm.data.preparator.gpt_memmap.hf_processors.processors import apply_pii_redaction_processor

return apply_pii_redaction_processor(self._config, self._distributed_config, dataset)


@config_class
class MalwareRemovalProcessorConfig(HFProcessorConfig):
use_processor: bool = FieldUpdate(default=False)
human_readable_name: str | None = FieldUpdate(default="Malware Removal Processor")


class MalwareRemovalProcessor[ConfigType: MalwareRemovalProcessorConfig](HFProcessor[ConfigType]):
config_class: typing.ClassVar[type[MalwareRemovalProcessorConfig]] = MalwareRemovalProcessorConfig

def apply(self, dataset: datasets.Dataset) -> datasets.Dataset:
from fast_llm.data.preparator.gpt_memmap.hf_processors.processors import apply_malware_removal_processor

return apply_malware_removal_processor(self._config, dataset)


@config_class
class ProcessorsConfig(Config):
doc_length: DocLengthFilterProcessorConfig = Field(default=DocLengthFilterProcessorConfig)
n_gramms: NGramRepetitionFilterProcessorConfig = Field(default=NGramRepetitionFilterProcessorConfig)
frequency: FrequencyBasedFilterProcessorConfig = Field(default=FrequencyBasedFilterProcessorConfig)
binary: BinaryContentFilterProcessorConfig = Field(default=BinaryContentFilterProcessorConfig)
numerical: NumericalContentFilterProcessorConfig = Field(default=NumericalContentFilterProcessorConfig)
pii: PiiRedactionProcessorConfig = Field(default=PiiRedactionProcessorConfig)
malware: MalwareRemovalProcessorConfig = Field(default=MalwareRemovalProcessorConfig)

# TODO: add validation so all steps are actual field names
order: list[str] = Field(
default_factory=lambda: ["doc_length", "n_gramms", "frequency", "binary", "numerical", "pii", "malware"]
)

def get_processor_types_map(self):
return {
"doc_length": DocLengthFilterProcessor,
"n_gramms": NGramRepetitionFilterProcessor,
"frequency": FrequencyBasedFilterProcessor,
"binary": BinaryContentFilterProcessor,
"numerical": NumericalContentFilterProcessor,
"pii": PiiRedactionProcessor,
"malware": MalwareRemovalProcessor,
}
Loading
Loading