Skip to content

WIP: multimodal support #227

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions fast_llm/data/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,60 @@ class TokenizerConfig(Config):
desc="Path to the tokenizer file.",
hint=FieldHint.core,
)


@config_class()
class ImageProcessorConfig(Config):
"""
Configuration for the image processor
"""

# Defaults taken from [pixtral](https://github.com/huggingface/transformers/blob/794fde7b1c3d041519fc28ea3e1461b0cfcad4e7/src/transformers/models/pixtral/image_processing_pixtral.py#L201)
# patch_size: list[int] = Field(
# default_factory=lambda: [16, 16],
# desc="Size for each path extracted from the image. Each patch corresponds to a token for the transformer",
# hint=FieldHint.optional,
# )
# max_height: int = Field(
# default=1024,
# desc="Maximum height of the image. Image will be resized if larger",
# hint=FieldHint.optional,
# )
# max_width: int = Field(
# default=1024,
# desc="Maximum width of the image. Image will be resized if larger",
# hint=FieldHint.optional,
# )
# mean: list[float] = Field(
# default_factory=lambda: [0.48145466, 0.4578275, 0.40821073],
# desc="Mean RGB values for pixel normalization",
# hint=FieldHint.optional,
# )
# std: list[float] = Field(
# default_factory=lambda: [0.26862954, 0.26130258, 0.27577711],
# desc="Standard deviation RGB values for pixel normalization",
# hint=FieldHint.optional,
# )
# rescale_factor: float = Field(
# default=255.0,
# desc="Diminisher factor for pixel normalization",
# hint=FieldHint.optional,
# )


@config_class()
class MultiModalProcessorConfig(Config):
"""
Wrapper config that stores the `ImageProcessorConfig` and `TokenizerConfig`
"""

tokenizer: TokenizerConfig = Field(
default_factory=TokenizerConfig,
desc="Configuration for the tokenizer.",
hint=FieldHint.core,
)
image_processor: ImageProcessorConfig = Field(
default_factory=ImageProcessorConfig,
desc="Configuration for the image processor.",
hint=FieldHint.core,
)
34 changes: 32 additions & 2 deletions fast_llm/data/data/gpt/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,15 @@ class GPTBatch:
token_ids: torch.Tensor
loss_masking_spans: list[torch.Tensor] | None = None
sequence_lengths: list[torch.Tensor] | None = None
images: list[torch.Tensor] | None = None
image_positions: list[torch.Tensor] | None = None


# TODO: collate images
def gpt_data_collate_fn(
batch: list[GPTSample], use_loss_masking_spans: bool, cross_document_attention: bool
batch: list[GPTSample],
use_loss_masking_spans: bool,
cross_document_attention: bool,
) -> GPTBatch:
stacked_ids = np.stack([sample.token_ids for sample in batch])
stacked_spans = None
Expand All @@ -44,8 +49,24 @@ def gpt_data_collate_fn(
stacked_spans = [torch.from_numpy(sample.loss_masking_spans) for sample in batch]
if not cross_document_attention:
sequence_lengths = [torch.tensor(sample.sequence_lengths) for sample in batch]
batch_images = []
for sample in batch:
if sample.images is not None:
batch_images.append([torch.from_numpy(image) for image in sample.images])
else:
batch_images.append(None)
batch_image_positions = []
for sample in batch:
if sample.image_positions is not None:
batch_image_positions.append(torch.from_numpy(sample.image_positions))
else:
batch_image_positions.append(None)
return GPTBatch(
token_ids=torch.from_numpy(stacked_ids), loss_masking_spans=stacked_spans, sequence_lengths=sequence_lengths
token_ids=torch.from_numpy(stacked_ids),
loss_masking_spans=stacked_spans,
sequence_lengths=sequence_lengths,
images=batch_images if any(batch_images) else None,
image_positions=batch_image_positions if any(batch_image_positions) else None,
)


Expand All @@ -67,6 +88,9 @@ def __init__(
vocab_size: int,
max_sequence_length: int,
cross_document_attention: bool = True,
patch_size: list[int] | None = None,
max_image_height: int | None = None,
max_image_width: int | None = None,
):
"""
Create the data and gather some basic information on the dataset(s).
Expand All @@ -76,6 +100,9 @@ def __init__(
self._vocab_size = vocab_size
self._max_sequence_length = max_sequence_length
self._cross_document_attention = cross_document_attention
self._patch_size = patch_size
self._max_image_height = max_image_height
self._max_image_width = max_image_width

def setup(
self,
Expand Down Expand Up @@ -123,6 +150,9 @@ def setup(
tokenizer=self._tokenizer,
truncate_documents=self._config.truncate_documents,
cross_document_attention=self._cross_document_attention,
patch_size=self._patch_size,
image_height=self._max_image_height,
image_width=self._max_image_width,
)
dataset = self._config.datasets[dataset_name].build_and_sample(sampling)
self._datasets[dataset_name] = DatasetMonitor(dataset, self._config.data_sample_warn_time_ms)
Expand Down
12 changes: 11 additions & 1 deletion fast_llm/data/dataset/gpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ class GPTSamplingData(SamplingData):
tokenizer: "Tokenizer"
truncate_documents: bool = True
cross_document_attention: bool = True
patch_size: list[int] | None = None
image_height: int | None = None
image_width: int | None = None


@config_class()
Expand Down Expand Up @@ -178,11 +181,18 @@ class GPTMemmapDatasetConfig(GPTIndexedDatasetConfig):
desc="Expected number of tokens in the dataset.",
hint=FieldHint.optional,
)
num_pixels: int | None = Field(
default=None,
desc="Expected number of pixels in the dataset.",
hint=FieldHint.optional,
)

def build(self) -> "GPTMemmapDataset":
from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset

return GPTMemmapDataset(str(self.path).replace("/", "__"), self.path, self.num_documents, self.num_tokens)
return GPTMemmapDataset(
str(self.path).replace("/", "__"), self.path, self.num_documents, self.num_tokens, self.num_pixels
)


@config_class()
Expand Down
12 changes: 9 additions & 3 deletions fast_llm/data/dataset/gpt/indexed.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@


class GPTIndexedDataset(IndexedDataset):
# TODO Soham: should we change this to include images?
@abc.abstractmethod
def get_document_sizes(self) -> np.ndarray:
"""
Expand Down Expand Up @@ -44,10 +45,15 @@ class GPTDatasetSlice[IndexedDatasetType: GPTIndexedDataset](DatasetSlice[Indexe

def get_document_sizes(self) -> np.ndarray:
# TODO: This can be really big.
return self._dataset.get_document_sizes()[self._begin : self._end]
doc_sizes, im_sizes = self._dataset.get_document_sizes()
return doc_sizes[self._begin : self._end], im_sizes[self._begin : self._end]

def get_document_size(self, index: int) -> int:
return self._dataset.get_document_size(self._begin + index)
def get_document_size(self, index: int, patch_size: list[int]) -> int:
return self._dataset.get_document_size(self._begin + index, patch_size)

@property
def has_images(self) -> bool:
return self._dataset.has_images


class GPTConcatenatedDataset[IndexedDatasetType: GPTIndexedDataset](
Expand Down
Loading