Skip to content

Commit 568eccd

Browse files
committed
Apply isort and black reformatting
Signed-off-by: dstnluong <dstnluong@users.noreply.github.com>
1 parent 2c9c357 commit 568eccd

File tree

1 file changed

+17
-11
lines changed

1 file changed

+17
-11
lines changed

nemo/collections/vlm/gemma3vl/data/task_encoder.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
import torch
2121
import torch.nn.functional as F
22-
from megatron.energon import VQASample, InterleavedSample
22+
from megatron.energon import InterleavedSample, VQASample
2323

2424
from nemo.collections.vlm.data.task_encoder import DataBatch, DataSample
2525
from nemo.collections.vlm.data.task_encoder import TaskEncoder as BaseTaskEncoder
@@ -40,7 +40,7 @@ class TaskEncoderConfig(BaseTaskEncoderConfig):
4040
stop_string: Optional[str] = ""
4141
system_prompt: Optional[str] = None
4242
image_token_str: str = "<start_of_image>"
43-
image_token_id: int = 262144 # This is the token id for <image_soft_token>
43+
image_token_id: int = 262144 # This is the token id for <image_soft_token>
4444

4545

4646
@dataclass
@@ -231,7 +231,9 @@ def encode_vqa_sample(self, input_sample: VQASample) -> DataSample:
231231

232232
return sample
233233

234-
def tokenize_interleaved_sample(self, input_sample: InterleavedSample) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
234+
def tokenize_interleaved_sample(
235+
self, input_sample: InterleavedSample
236+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
235237
"""
236238
Tokenize the input sequence and process images in an interleaved sample.
237239
@@ -253,12 +255,14 @@ def tokenize_interleaved_sample(self, input_sample: InterleavedSample) -> Tuple[
253255
texts.append(item)
254256
elif type(item) == torch.Tensor:
255257
images.append(item)
256-
texts.append(self.config.image_token_str) # Append start token to the last text. HF Processor will replace this token with the actual image tokens during processing.
258+
texts.append(
259+
self.config.image_token_str
260+
) # Append start token to the last text. HF Processor will replace this token with the actual image tokens during processing.
257261
else:
258262
raise ValueError(f"Unsupported item type in interleaved sequence: {type(item)}")
259-
263+
260264
outputs = self.hf_processor(
261-
images=[images], # images is a batched to size of one.
265+
images=[images], # images is a batched to size of one.
262266
text=" ".join(texts),
263267
return_tensors="pt",
264268
images_kwargs={"do_rescale": False},
@@ -296,7 +300,9 @@ def compute_labels_interleaved(self, tokens: torch.Tensor) -> torch.Tensor:
296300
labels = labels[1:].contiguous()
297301
return labels
298302

299-
def pad_tokens_and_labels(self, tokens: torch.Tensor, labels: torch.Tensor, seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
303+
def pad_tokens_and_labels(
304+
self, tokens: torch.Tensor, labels: torch.Tensor, seqlen: int
305+
) -> Tuple[torch.Tensor, torch.Tensor]:
300306
"""
301307
Pad tokens and labels to a be a multiple of config.pad_to_multiple_of
302308
@@ -309,10 +315,10 @@ def pad_tokens_and_labels(self, tokens: torch.Tensor, labels: torch.Tensor, seql
309315
Tuple[torch.Tensor, torch.Tensor]: Tokens and labels tensor padded to a multiple of config.pad_to_multiple_of
310316
"""
311317
seqlen_padded = (
312-
(seqlen + self.config.pad_to_multiple_of - 1)
313-
// self.config.pad_to_multiple_of
314-
* self.config.pad_to_multiple_of
315-
)
318+
(seqlen + self.config.pad_to_multiple_of - 1)
319+
// self.config.pad_to_multiple_of
320+
* self.config.pad_to_multiple_of
321+
)
316322
pad_len = seqlen_padded - seqlen
317323

318324
if pad_len > 0:

0 commit comments

Comments
 (0)