@@ -191,6 +191,10 @@ def get_batch(data_iterator):
191
191
labels_shape = tensor_parallel .broadcast_data (["labels_shape" ], data , torch .int64 )[
192
192
"labels_shape"
193
193
]
194
+ ids = tensor_parallel .broadcast_data (["ids" ], data , torch .uint8 )["ids" ]
195
+ ids_shape = tensor_parallel .broadcast_data (["ids_shape" ], data , torch .int64 )[
196
+ "ids_shape"
197
+ ]
194
198
images = tensor_parallel .broadcast_data (["images" ], data , torch .float32 )["images" ]
195
199
split_image_sizes = tensor_parallel .broadcast_data (
196
200
["split_image_sizes" ], data , torch .int64
@@ -229,6 +233,17 @@ def get_batch(data_iterator):
229
233
assert start_idx == labels .numel ()
230
234
labels = labels_list
231
235
236
+ # ids to list
237
+ ids_list = []
238
+ start_idx = 0
239
+ for shape in ids_shape :
240
+ num_elements = torch .prod (shape ).item ()
241
+ sub_tensor = ids [start_idx : start_idx + num_elements ].reshape (shape .tolist ())
242
+ ids_list .append (sub_tensor )
243
+ start_idx += num_elements
244
+ assert start_idx == ids .numel ()
245
+ ids = ids_list
246
+
232
247
# images to list
233
248
images_list = []
234
249
start_idx = 0
@@ -288,7 +303,7 @@ def get_batch(data_iterator):
288
303
attention_mask = input_ids .ne (tokenizer .pad_token_id )
289
304
torch .cuda .nvtx .range_pop ()
290
305
291
- return input_ids , labels , attention_mask , images , image_sizes , modalities
306
+ return input_ids , labels , attention_mask , images , image_sizes , modalities , ids
292
307
293
308
294
309
def pad_sequence (input_ids , batch_first , padding_value , tokenizer ):
@@ -316,7 +331,13 @@ def get_image_token_count():
316
331
return num_image_tokens
317
332
318
333
319
- def loss_func (labels : torch .Tensor , loss_mask : torch .Tensor , logits : torch .Tensor ):
334
+ def loss_func (
335
+ labels : torch .Tensor ,
336
+ loss_mask : torch .Tensor ,
337
+ ids ,
338
+ logits : torch .Tensor ,
339
+ ):
340
+ args = get_args ()
320
341
labels = labels .transpose (0 , 1 ).contiguous () # [b s] => [s b]
321
342
logits = logits .transpose (0 , 1 ).contiguous () # [b s h] => [s b h]
322
343
@@ -334,6 +355,11 @@ def loss_func(labels: torch.Tensor, loss_mask: torch.Tensor, logits: torch.Tenso
334
355
loss = torch .mean (losses )
335
356
336
357
# Reduce loss for logging.
358
+ if args .skip_train :
359
+ assert isinstance (ids , list ) and len (ids ) == 1
360
+ id = "" .join ([chr (c ) for c in ids [0 ].cpu ().numpy ()])
361
+ print (f"Evaluating id: { id } , loss: { loss .detach ().clone ().item ()} " , flush = True )
362
+
337
363
averaged_loss = average_losses_across_data_parallel_group ([loss ])
338
364
return loss , {"lm loss" : averaged_loss [0 ]}
339
365
@@ -354,7 +380,7 @@ def forward_step(data_iterator, model: LLaVAOneVisionModel):
354
380
355
381
# Get the batch.
356
382
timers ("batch-generator" , log_level = 2 ).start ()
357
- input_ids , labels , attention_mask , images , image_sizes , modalities = get_batch (
383
+ input_ids , labels , attention_mask , images , image_sizes , modalities , ids = get_batch (
358
384
data_iterator
359
385
)
360
386
if "text" in modalities and ("image" in modalities or "video" in modalities ):
@@ -367,7 +393,7 @@ def forward_step(data_iterator, model: LLaVAOneVisionModel):
367
393
input_ids , labels , attention_mask , images , image_sizes , modalities
368
394
)
369
395
370
- return output_tensor , partial (loss_func , labels , loss_mask )
396
+ return output_tensor , partial (loss_func , labels , loss_mask , ids )
371
397
372
398
373
399
def add_multimodal_extra_args (parser ):
0 commit comments