Skip to content

Commit 35afa95

Browse files
author
caozhou
committed
llava ov eval and filter
1 parent 77e104c commit 35afa95

File tree

6 files changed

+135
-25
lines changed

6 files changed

+135
-25
lines changed

flagscale/train/models/llava_onevision/dataloader_provider.py

+19-7
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,16 @@ def datasets_provider(worker_config=None):
6767

6868
def train_valid_test_dataloaders_provider(train_val_test_num_samples):
6969
"""Build multimodal train, validation and test dataloaders."""
70+
args = get_args()
71+
72+
# In llava-ov, set skip_train False to eval each sample.
73+
# Training while evaluating is not supported yet.
74+
if args.skip_train:
75+
args.eval_iters = args.train_iters
76+
7077
if get_tensor_model_parallel_rank() != 0:
7178
return None, None, None
7279

73-
args = get_args()
74-
7580
worker_debug_path = None
7681
worker_log_level = 0
7782

@@ -110,11 +115,18 @@ def train_valid_test_dataloaders_provider(train_val_test_num_samples):
110115
"loading dataloader checkpoint failed. Skipping. " + str(e)
111116
)
112117
if args.training_dataset_only:
113-
return (
114-
EnergonDataloader(train_dataloader),
115-
EnergonDataloader(None),
116-
EnergonDataloader(None),
117-
)
118+
if not args.skip_train:
119+
return (
120+
EnergonDataloader(train_dataloader),
121+
None,
122+
None,
123+
)
124+
else:
125+
return (
126+
None,
127+
EnergonDataloader(train_dataloader),
128+
None,
129+
)
118130
valid_dataloader = [
119131
EnergonDataloader(get_loader(valid_ds, worker_config=worker_config))
120132
for valid_ds in valid_ds1

flagscale/train/models/llava_onevision/dataset_helpers.py

+16-3
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ class AnyResTaskSample:
3636
images: List[torch.Tensor]
3737
image_sizes: List[torch.Tensor]
3838
modalities: List[torch.Tensor]
39+
ids: torch.Tensor
40+
ids_shape: torch.Tensor
3941

4042
# Typing for the resulting batch data after encode_batch()
4143
@dataclass
@@ -50,6 +52,8 @@ class AnyResTaskBatch(Batch):
5052
image_sizes: torch.Tensor
5153
split_image_sizes: torch.Tensor
5254
modalities: torch.Tensor
55+
ids: torch.Tensor
56+
ids_shape: torch.Tensor
5357

5458

5559
class AnyResTaskEncoder(DefaultTaskEncoder[InterleavedSample, InterleavedSample, AnyResTaskBatch, dict]):
@@ -84,6 +88,10 @@ def encode_interleaved(self, sample: InterleavedSample):
8488
else:
8589
assert ValueError("The sequence must have 4 or 5 elements, but got {len(sample.sequence)}.")
8690

91+
id = "".join(sample.__key__.split("/")[1:])
92+
ids_tensor = torch.tensor([ord(c) for c in id], dtype=torch.uint8)
93+
ids_shape = torch.tensor(ids_tensor.shape)
94+
8795
# process modalities to tensor
8896
modalities_list = []
8997
for modality in modalities:
@@ -107,7 +115,9 @@ def encode_interleaved(self, sample: InterleavedSample):
107115
labels_shape=torch.tensor(labels.shape),
108116
images=images,
109117
image_sizes=image_sizes,
110-
modalities=modalities
118+
modalities=modalities,
119+
ids=ids_tensor,
120+
ids_shape=ids_shape
111121
)
112122

113123
def batch(self, samples: List[AnyResTaskSample]) -> AnyResTaskBatch:
@@ -121,7 +131,8 @@ def batch(self, samples: List[AnyResTaskSample]) -> AnyResTaskBatch:
121131
# Adapt video data by decord
122132
image_sizes = torch.stack([image_sizes if len(image_sizes.shape) == 1 else torch.tensor((1, image_sizes.item())) for s in samples for image_sizes in s.image_sizes], dim=0)
123133
modalities = torch.stack([modalities for s in samples for modalities in s.modalities], dim=0)
124-
134+
ids = torch.cat([s.ids.flatten() for s in samples], dim=0)
135+
ids_shape = torch.stack([s.ids_shape for s in samples], dim=0)
125136
batch = AnyResTaskBatch(
126137
__keys__=[s.__key__ for s in samples],
127138
__subflavors__=[s.__subflavors__ for s in samples],
@@ -132,7 +143,9 @@ def batch(self, samples: List[AnyResTaskSample]) -> AnyResTaskBatch:
132143
images=images,
133144
image_sizes=image_sizes,
134145
split_image_sizes=split_image_sizes,
135-
modalities=modalities
146+
modalities=modalities,
147+
ids=ids,
148+
ids_shape=ids_shape,
136149
)
137150

138151
return batch

flagscale/train/train_llava_onevision.py

+30-4
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,10 @@ def get_batch(data_iterator):
191191
labels_shape = tensor_parallel.broadcast_data(["labels_shape"], data, torch.int64)[
192192
"labels_shape"
193193
]
194+
ids = tensor_parallel.broadcast_data(["ids"], data, torch.uint8)["ids"]
195+
ids_shape = tensor_parallel.broadcast_data(["ids_shape"], data, torch.int64)[
196+
"ids_shape"
197+
]
194198
images = tensor_parallel.broadcast_data(["images"], data, torch.float32)["images"]
195199
split_image_sizes = tensor_parallel.broadcast_data(
196200
["split_image_sizes"], data, torch.int64
@@ -229,6 +233,17 @@ def get_batch(data_iterator):
229233
assert start_idx == labels.numel()
230234
labels = labels_list
231235

236+
# ids to list
237+
ids_list = []
238+
start_idx = 0
239+
for shape in ids_shape:
240+
num_elements = torch.prod(shape).item()
241+
sub_tensor = ids[start_idx : start_idx + num_elements].reshape(shape.tolist())
242+
ids_list.append(sub_tensor)
243+
start_idx += num_elements
244+
assert start_idx == ids.numel()
245+
ids = ids_list
246+
232247
# images to list
233248
images_list = []
234249
start_idx = 0
@@ -288,7 +303,7 @@ def get_batch(data_iterator):
288303
attention_mask = input_ids.ne(tokenizer.pad_token_id)
289304
torch.cuda.nvtx.range_pop()
290305

291-
return input_ids, labels, attention_mask, images, image_sizes, modalities
306+
return input_ids, labels, attention_mask, images, image_sizes, modalities, ids
292307

293308

294309
def pad_sequence(input_ids, batch_first, padding_value, tokenizer):
@@ -316,7 +331,13 @@ def get_image_token_count():
316331
return num_image_tokens
317332

318333

319-
def loss_func(labels: torch.Tensor, loss_mask: torch.Tensor, logits: torch.Tensor):
334+
def loss_func(
335+
labels: torch.Tensor,
336+
loss_mask: torch.Tensor,
337+
ids,
338+
logits: torch.Tensor,
339+
):
340+
args = get_args()
320341
labels = labels.transpose(0, 1).contiguous() # [b s] => [s b]
321342
logits = logits.transpose(0, 1).contiguous() # [b s h] => [s b h]
322343

@@ -334,6 +355,11 @@ def loss_func(labels: torch.Tensor, loss_mask: torch.Tensor, logits: torch.Tenso
334355
loss = torch.mean(losses)
335356

336357
# Reduce loss for logging.
358+
if args.skip_train:
359+
assert isinstance(ids, list) and len(ids) == 1
360+
id = "".join([chr(c) for c in ids[0].cpu().numpy()])
361+
print(f"Evaluating id: {id}, loss: {loss.detach().clone().item()}", flush=True)
362+
337363
averaged_loss = average_losses_across_data_parallel_group([loss])
338364
return loss, {"lm loss": averaged_loss[0]}
339365

@@ -354,7 +380,7 @@ def forward_step(data_iterator, model: LLaVAOneVisionModel):
354380

355381
# Get the batch.
356382
timers("batch-generator", log_level=2).start()
357-
input_ids, labels, attention_mask, images, image_sizes, modalities = get_batch(
383+
input_ids, labels, attention_mask, images, image_sizes, modalities, ids = get_batch(
358384
data_iterator
359385
)
360386
if "text" in modalities and ("image" in modalities or "video" in modalities):
@@ -367,7 +393,7 @@ def forward_step(data_iterator, model: LLaVAOneVisionModel):
367393
input_ids, labels, attention_mask, images, image_sizes, modalities
368394
)
369395

370-
return output_tensor, partial(loss_func, labels, loss_mask)
396+
return output_tensor, partial(loss_func, labels, loss_mask, ids)
371397

372398

373399
def add_multimodal_extra_args(parser):
+43
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import os
2+
import re
3+
import json
4+
import argparse
5+
from typing import Dict
6+
7+
8+
def main():
9+
parser = argparse.ArgumentParser(description='Grep id and loss from log files.')
10+
parser.add_argument('--input_dir', type=str, help='Directory to search log files.')
11+
parser.add_argument('--output', type=str, help='Path to save the result.')
12+
args = parser.parse_args()
13+
14+
result_dict: Dict[str, float] = {}
15+
for root, dirs, files in os.walk(args.input_dir):
16+
for file in files:
17+
if file.endswith('.log'):
18+
file_path = os.path.join(root, file)
19+
with open(file_path, 'r') as f:
20+
lines = f.readlines()
21+
for line in lines:
22+
match = re.search(r'Evaluating id: (\d+), loss: ([\d.]+)', line)
23+
if match:
24+
evaluating_id = match.group(1)
25+
loss = float(match.group(2))
26+
if evaluating_id in result_dict:
27+
assert loss == result_dict[evaluating_id]
28+
# Customize filtering rules such as
29+
# if loss < 0.5:
30+
# result_dict[evaluating_id] = loss
31+
32+
# NOTE: No filtering currently, Comment out if Customize
33+
result_dict[evaluating_id] = loss
34+
35+
result = {"ids": list(result_dict.keys())}
36+
assert args.output.endswith(".json")
37+
with open(args.output, 'w') as f:
38+
json.dump(result, f, indent=4)
39+
print("Done")
40+
41+
42+
if __name__ == "__main__":
43+
main()

tools/datasets/llava_onevision/llava_ov_wds.py

+20-11
Original file line numberDiff line numberDiff line change
@@ -1432,6 +1432,8 @@ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
14321432
# batch["images"] = torch.stack(images)
14331433
# else:
14341434
batch["images"] = images
1435+
assert "id" in instances[0]
1436+
batch["ids"] = [torch.tensor([ord(c) for c in instance["id"]], dtype=torch.uint8) for instance in instances]
14351437

14361438
if "prompt" in instances[0]:
14371439
batch["prompts"] = [instance["prompt"] for instance in instances]
@@ -1847,15 +1849,22 @@ def make_inputs_require_grad(module, input, output):
18471849
if not os.path.exists(output):
18481850
os.mkdir(output)
18491851
start_time = time.time()
1852+
filter_ids = []
1853+
filter_json = os.environ.get("FILTER_JSON", "")
1854+
if filter_json:
1855+
with open(filter_json, 'r') as file:
1856+
data = json.load(file)
1857+
filter_ids = data["ids"]
18501858
with wds.ShardWriter(os.path.join(output, f'llava-ov-{dist.get_rank()}-%d.tar'), maxcount=10000) as shard_writer:
18511859
dataloader = trainer.get_train_dataloader()
18521860
print(f"sample num: {len(dataloader)}")
1853-
global_id = 0
18541861
for entry in tqdm(dataloader):
1855-
if global_id == 0:
1856-
for x in entry.keys():
1857-
#print(f"key={x}, type={type(entry[x])}")
1858-
pass
1862+
assert 'ids' in entry
1863+
assert len(entry["ids"]) == 1
1864+
id = "".join([chr(c) for c in entry["ids"][0].cpu().numpy()])
1865+
if filter_ids and id not in filter_ids:
1866+
print(f"The id {id} is filtered out.")
1867+
continue
18591868

18601869
sequence = []
18611870
sequence.append(entry['input_ids'][0].cpu())
@@ -1878,15 +1887,16 @@ def make_inputs_require_grad(module, input, output):
18781887
sequence.append([torch.tensor(entry['image_sizes'][0])])
18791888
sequence.append([entry['modalities'][0]])
18801889
if entry['modalities'][0] == "video":
1881-
print(f"Processing video and image_sizes: {entry['image_sizes'][0]}, {images.shape}")
1890+
print(f"Processing id {id} video and image_sizes: {entry['image_sizes'][0]}, {images.shape}")
18821891
elif entry['modalities'][0] == "text":
1883-
print("Processing text.")
1892+
print(f"Processing id {id} text.")
18841893
elif entry['modalities'][0] == "image":
1885-
print("Processing single image.")
1894+
print(f"Processing id {id} single image.")
18861895
else:
18871896
raise ValueError()
18881897
else:
18891898
# Process images
1899+
print(f"Processing id {id} multi images.")
18901900
images = []
18911901
each_image_shape = None
18921902
for image in entry['images']:
@@ -1912,15 +1922,14 @@ def make_inputs_require_grad(module, input, output):
19121922

19131923
sequence.append(images)
19141924
sequence.append(image_sizes)
1915-
sequence.append(modalities)
1925+
sequence.append(modalities)
19161926

19171927
sample = {
1918-
"__key__": str(global_id),
1928+
"__key__": str(id),
19191929
"sequence.pyd": sequence,
19201930
}
19211931

19221932
shard_writer.write(sample)
1923-
global_id += 1
19241933

19251934
print(f"rank {dist.get_rank()} datasets saved to {training_args.output_dir}")
19261935

tools/datasets/llava_onevision/make_llava_ov_wds.sh

+7
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,12 @@ set -u
2727
HOSTFILE=$3
2828
set +u
2929

30+
if [ $# -ge 4 ]; then
31+
FILTER_JSON=$4
32+
else
33+
FILTER_JSON=""
34+
fi
35+
3036
echo "BASE_RUN_NAME: ${EXPNAME_PATH}"
3137

3238
CKPT_PATH="./checkpoints"
@@ -52,6 +58,7 @@ do
5258
export WANDB_MODE=offline && \
5359
export ACCELERATE_CPU_AFFINITY=1 && \
5460
export PYTHONPATH=$LLaVA_NeXT_HOME:$PYTHONPATH && \
61+
export FILTER_JSON=$FILTER_JSON && \
5562
source /root/miniconda3/bin/activate flagscale && \
5663
torchrun --nproc_per_node=8 --nnodes=${NNodes} --node_rank=${rank} --master_addr=${MASTER_ADDR} --master_port=13888 llava_ov_wds.py \
5764
--model_name_or_path ${CKPT_PATH} \

0 commit comments

Comments
 (0)