diff --git a/run.py b/run.py index 6152aa88e..cf51eae4a 100644 --- a/run.py +++ b/run.py @@ -45,6 +45,7 @@ def get_gpu_list(): from vlmeval.inference import infer_data_job from vlmeval.inference_video import infer_data_job_video from vlmeval.inference_mt import infer_data_job_mt +from vlmeval.inference_mixed import infer_data_job_mixed from vlmeval.smp import * from vlmeval.utils.result_transfer import MMMU_result_transfer, MMTBench_result_transfer @@ -336,6 +337,17 @@ def main(): api_nproc=args.api_nproc, ignore_failed=args.ignore, use_vllm=args.use_vllm) + elif dataset.TYPE == 'MixedOutput': + model = infer_data_job_mixed( + model, + work_dir=pred_root, + model_name=model_name, + dataset=dataset, + actual_dataset_name=dataset_name, + verbose=args.verbose, + api_nproc=args.api_nproc, + ignore_failed=args.ignore, + use_vllm=args.use_vllm) else: model = infer_data_job( model, @@ -402,8 +414,8 @@ def main(): judge_kwargs['model'] = 'gpt-4.1' elif listinstr(['MathCanvas'], dataset_name): judge_kwargs['model'] = 'gpt-4.1-2025-04-14' - elif listinstr(['MMReason'], dataset_name): - judge_kwargs['model'] = 'gpt-4.1' + elif dataset.TYPE == 'MixedOutput': + judge_kwargs['model'] = 'qwen-72b' if args.use_verifier: judge_kwargs['use_verifier'] = True diff --git a/vlmeval/dataset/SIBench.py b/vlmeval/dataset/SIBench.py new file mode 100644 index 000000000..d7196cdfa --- /dev/null +++ b/vlmeval/dataset/SIBench.py @@ -0,0 +1,366 @@ +from .image_base import ImageBaseDataset +from .image_mcq import ImageMCQDataset +from .video_base import VideoBaseDataset +from ..smp import * +import os +import decord +import re +import warnings +from .utils import build_judge, DEBUG_MESSAGE + + +class SIBench(ImageMCQDataset, ImageBaseDataset, VideoBaseDataset): + # ------------------------------------------------- + # + # Before running this script, you must download the + # required data from the following Hugging Face repo: + # + # ==> https://huggingface.co/datasets/Two-hot/SIBench + # + # Please download it and place it in the 'data/' directory. + # + # ------------------------------------------------- + MODALITY = 'MixedInput' + TYPE = 'MixedOutput' + + NEED_EXTRA_PROMPT_SOURCE = ['vstibench', 'MMSI-Bench', '3DSRBench', 'OmniSpatial', 'Spatial-MM', 'SpatialMQA', + 'VSI-Bench', 'STI-Bench', 'SpatialEval', 'SITE-Bench', 'SPHERE-VLM', 'SRBench', 'BLINK' + ] + # do not need = SpatialBench, SPAR-Bench, Super-CLEVR-3D, Omni3D-Bench + SETTING = ['relative_distance', 'Reach_Prediction', 'Object_Shape', 'Height', 'Existence', 'Spatial_Compatibility', + 'Coordinate_Conversion', 'Counting', 'Route_Planning', 'Trajectory_Description', 'Geometric_Reasoning', + 'Spatial_Imagination', 'Object_Size_Estimation', 'Spatial_Grid', 'Situational_QA', + 'Velocity_Acceleration','Maze_Navigation', 'Temporal-Appearance_Order', 'Camera_Pose', + 'Occlusion', 'multi-view_reasoning','Object_Localization',"Spatial_Relation", "SIBench", "SIBench-mini" + ] + +# Counting Camera_Pose Coordinate_Conversion multi-view_reasoning Object_Shape Object_Size_Estimation +# Occlusion relative_distance Situational_QA Spatial_Grid Spatial_Relation Trajectory_Description +# Reach_Prediction Height Existence Spatial_Compatibility Route_Planning Geometric_Reasoning +# Velocity_Acceleration Spatial_Imagination Temporal-Appearance_Order Object_Localization + + VIDEO_MODALITY_INCLUDED_SETTING = [''] + + FRAMES_TMPL_SYS = """ +You will receive {} distinct frames that have been uniformly sampled from a video sequence, +arranged in the same temporal order as they appear in the video. +Please analyze these frames and answer the question based on your observations. +""" + FRAMES_TMPL_SYS_4VIDEO_LLM = """ +You will receive several distinct frames that have been uniformly sampled from a video sequence, +arranged in the same temporal order as they appear in the video. +Please analyze these frames and answer the question based on your observations. +""" + + def __init__(self, dataset='MMBench', skip_noimg=True, nframe=30, fps=-1): + super(SIBench, self).__init__(dataset, skip_noimg) + + self.frame_tmpl = 'frame-{}-of-{}.jpg' + self.frame_tmpl_fps = 'frame-{}-of-{}-{}fps.jpg' + + self.nframe = nframe + self.fps = fps + if self.fps > 0 and self.nframe > 0: + raise ValueError('fps and nframe should not be set at the same time') + if self.fps <= 0 and self.nframe <= 0: + raise ValueError('fps and nframe should be set at least one valid value') + + @classmethod + def supported_datasets(cls): + return cls.SETTING + + def add_extra_prompt(self, prompt, answer_type, data_source): + if data_source in self.NEED_EXTRA_PROMPT_SOURCE: + if answer_type == 'MCQ': + prompt += "\nSelect from the given options, answer with letters only." + elif answer_type == 'YN': + prompt += "\nAnswer with 'Yes' or 'No' only." + elif answer_type.startswith('Number'): + prompt += "\nAnswer using a single number and nothing else." + else: + raise NotImplementedError(f"Answer type '{answer_type}' is not supported.") + elif data_source is None: + raise KeyError("Required key 'data_source' is missing.") + return prompt + + def frame_paths(self, video, data_base): + # need self.frame_root & self.frame_tmpl & self.nframe + frame_root = osp.join(data_base, video.split('/')[0], 'frames') + os.makedirs(frame_root, exist_ok=True) + return [osp.join(frame_root, self.frame_tmpl.format(i, self.nframe)) for i in range(1, self.nframe + 1)] + + def save_video_frames(self, line, data_base): + # need self.nframe & self.fps + video = line['video_path'] + vid_path = os.path.normpath(os.path.join(data_base, line['video_path'])) + vid = decord.VideoReader(vid_path) + video_info = { + 'fps': vid.get_avg_fps(), + 'n_frames': len(vid), + } + if self.nframe > 0 and self.fps < 0: + step_size = len(vid) / (self.nframe + 1) + indices = [int(i * step_size) for i in range(1, self.nframe + 1)] + frame_paths = self.frame_paths(video, data_base) + elif self.fps > 0: + # not constrained by num_frames, get frames by fps + total_duration = video_info['n_frames'] / video_info['fps'] + required_frames = int(total_duration * self.fps) + step_size = video_info['fps'] / self.fps + indices = [int(i * step_size) for i in range(required_frames)] + frame_paths = self.frame_paths_fps(video, len(indices)) + + flag = np.all([osp.exists(p) for p in frame_paths]) + + if not flag: + images = [vid[i].asnumpy() for i in indices] + images = [Image.fromarray(arr) for arr in images] + for im, pth in zip(images, frame_paths): + if not osp.exists(pth): + im.save(pth) + + return frame_paths + + def save_video_into_images(self, line, data_base): + frame_paths = self.save_video_frames(line, data_base) + return frame_paths + + def build_prompt_for_video(self, line, video_llm, data_base): + # need video_llm + if isinstance(line, int): + assert line < len(self) + line = self.data.iloc[line] + + video_path = os.path.normpath(os.path.join(data_base, line['video_path'])) + prompt = line['question'] + answer_type = line.get('type') + data_source = line.get('data_source') + prompt = self.add_extra_prompt(prompt, answer_type, data_source) + + if video_llm: + message = [dict(type='text', value=self.FRAMES_TMPL_SYS_4VIDEO_LLM)] + message.append(dict(type='text', value=prompt)) + message.append(dict(type='video', value=video_path)) + else: + img_frame_paths = self.save_video_into_images(line, data_base) + message = [dict(type='text', value=self.FRAMES_TMPL_SYS.format(len(img_frame_paths)))] + message.append(dict(type='text', value=prompt)) + for im in img_frame_paths: + message.append(dict(type='image', value=im)) + return message + + def build_prompt_for_image(self, line, data_base): + msgs = [] + if line.get('image_path'): + tgt_path = toliststr(''.join(line['image_path'].split()).split(',')) + for _ in range(len(tgt_path)): + tgt_path[_] = os.path.join(data_base, tgt_path[_]) + else: + raise KeyError("Required key 'image_path' is missing.") + + if isinstance(tgt_path, list): + msgs.extend([dict(type='image', value=p) for p in tgt_path]) + else: + msgs = [dict(type='image', value=tgt_path)] + + question = line['question'] + prompt = question + answer_type = line.get('type') + data_source = line.get('data_source') + prompt = self.add_extra_prompt(prompt, answer_type, data_source) + msgs.append(dict(type='text', value=prompt)) + return msgs + + def build_prompt(self, line, video_llm=None, data_base='.'): + if isinstance(line, int): + line = self.data.iloc[line] + + if line.get('input_type') in ['image', 'multi-view']: + return self.build_prompt_for_image(line=line, data_base=data_base) + elif line.get('input_type') == 'video': + video_data_base = data_base.replace('/data', '/data_sampled_video') + return self.build_prompt_for_video(line=line, video_llm=video_llm, data_base=video_data_base) + else: + raise NotImplementedError(f"Unrecognized input type: {line.get('input_type')}.") + + def extract_numbers_from_string(self, text, reverse_order): + number_strings = re.findall(r'-?\d{1,3}(?:,\d{3})*(?:\.\d+)?', text) + result = [] + for num_str in number_strings: + cleaned_str = num_str.replace(',', '') + try: + result.append(float(cleaned_str)) + except ValueError: + continue + + if reverse_order: + result.reverse() + + return result + + def compute_mra(self, y_true, y_pred): + C = np.arange(0.5, 1.0, 0.05) + mra_sum = 0 + for theta in C: + relative_error = np.abs(y_pred - y_true) / y_true + if relative_error < (1 - theta): + mra_sum += 1 + mra = mra_sum / len(C) + return mra + + def yn_Extraction(self, pred): + pred = pred.strip().lower() + pred = re.sub(r'[^\w\s]', '', pred) + + if pred == "yes": + return "yes" + elif pred == "no": + return "no" + else: + return pred + + def check_string_format(self, s): + # 1: ("A.", "B:", etc.) + if re.match(r'^[A-F][\.\:]', s): + return True + # 2: ("(A)", " (A)", etc.) + if '(' in s[:3]: + return True + # 3: ("A", "Apple", "A Answer", etc.) + if s[0] in 'ABCDEF': + return True + + return False + + def mcq_check(self, predict): + if isinstance(predict, float): + predict = 'z' + if '(' in predict[:3]: + predict = predict[1] + predict = predict.split('.')[0].split(':')[0] + + return predict + + def build_prompt_mcq(self, reasoning_text): + prompt_template = """You are a multiple-choice answer extractor. + Your sole task is to identify the final answer from a piece of reasoning text + and return *only* the corresponding option letter. + Your response must strictly follow the format: return only the option letter, + enclosed in English double quotes. + Do not include any other text, explanation, or prefixes. + --- + **Example 1:** + **Input:** "Based on the analysis, options A and B are clearly wrong. Option C mentions... + This is correct. Therefore, the final answer is C." + **Output:** "C" + **Example 2:** + **Input:** "Let's go through them one by one. A... B... C... D... After a comprehensive comparison, + option A's description is the most complete and accurate. So, the answer is A." + **Output:** "A" + **Example 3:** + **Input:** "The analysis shows that B is the correct choice because..." + **Output:** "B" + --- + Now, strictly following the format above, extract the answer from the following text: + """ + return prompt_template + reasoning_text + + def llm_process(self, pred, model): + prompt = self.build_prompt_mcq(pred) + logger = get_logger('Evaluation') + retry = 3 + + while retry: + ans = model.generate(prompt).strip(" '\"") + if 'Failed to obtain answer via API' in ans: + logger.warning('GPT API failed to answer. ') + else: + if ans: + return ans # dict(opt=ans, log=ans) + else: + logger.warning( + f'Failed to in infer: prediction is {ans}' + ) + retry -= 1 + + if retry == 0: + return 'z' # dict(opt='z', log='Failed to predict') + + def extract_mcq(self, pred, model): + need_llm = not self.check_string_format(pred) + if need_llm: + pred = self.llm_process(pred, model) + + return self.mcq_check(pred) + + def evaluate(self, eval_file, **judge_kwargs): + from .utils.multiple_choice import extract_characters_regex, report_acc + from .utils.yorn import YOrN_Extraction + assert eval_file.endswith('.xlsx'), 'data file should be an xlsx file' + FAIL_MSG = 'Failed to obtain answer via API.' + tmp_file = eval_file.replace('.xlsx', '_tmp.pkl') + # tgt_file = eval_file.replace('.xlsx', '_rating.json') + score_file = eval_file.replace('.xlsx', '_score.xlsx') + score_file_csv = eval_file.replace('.xlsx', '_score.csv') + + model = build_judge(**judge_kwargs) + if not model.working(): + warnings.warn('OPENAI API is not working properly, will use exact matching for evaluation') + warnings.warn(DEBUG_MESSAGE) + model = None + + if not osp.exists(score_file): + res = {} if not osp.exists(tmp_file) else load(tmp_file) + res = {k: v for k, v in res.items() if FAIL_MSG not in v} + + data = load(eval_file) + cnt_rejected = 0 + data_un = data[~pd.isna(data['prediction'])] + + for idx in data['index']: + ans = data.loc[data['index'] == idx, 'answer'].values[0] + pred = data.loc[data['index'] == idx, 'prediction'].values[0] + output_type = data.loc[data['index'] == idx, 'type'].values[0] + + if output_type == 'MCQ': + extract_pred = self.extract_mcq(pred, model) # extract_characters_regex(pred) + if extract_pred == '': + cnt_rejected += 1 + data.loc[data['index'] == idx, 'hit'] = 0 + else: + data.loc[data['index'] == idx, 'hit'] = int(extract_pred == ans) + elif output_type == 'YN': + extract_pred_yn = self.yn_Extraction(pred[:3]) # YOrN_Extraction(pred) + ans_yn = self.yn_Extraction(ans[:3]) + if ans_yn == 'yes' or ans_yn == 'no': + ans = ans_yn + pred = extract_pred_yn + if pred == 'Unknown': + cnt_rejected += 1 + data.loc[data['index'] == idx, 'hit'] = 0 + else: + data.loc[data['index'] == idx, 'hit'] = int(pred.strip().lower() == ans.strip().lower()) + elif output_type.startswith('Number'): + try: + extract_pred = eval(str(pred.strip())) + except Exception: + extract_pred = -1.0 # pred.strip() # self.extract_numbers_from_string(pred, True) + + ans = eval(str(ans)) + if output_type == 'Number': + data.loc[data['index'] == idx, 'hit'] = self.compute_mra(ans, extract_pred) + elif output_type == 'Number_Int': + data.loc[data['index'] == idx, 'hit'] = int(extract_pred == ans) + else: + NotImplementedError(f'Unsupported output type {output_type}.') + print( + f'Among {len(data)} questions, failed to obtain prediction for {len(data) - len(data_un)} questions, ' + f'failed to obtain the score for another {cnt_rejected} questions. ' + f'Those questions will be counted as 0 score in ALL rating.' + ) + + dump(data, score_file) + data = load(score_file) + acc = report_acc(data) + dump(acc, score_file_csv) + return acc diff --git a/vlmeval/dataset/__init__.py b/vlmeval/dataset/__init__.py index cca809a2c..40a48059c 100644 --- a/vlmeval/dataset/__init__.py +++ b/vlmeval/dataset/__init__.py @@ -20,6 +20,7 @@ from .image_ccocr import CCOCRDataset from .image_shortqa import ImageShortQADataset, PathVQA_VAL, PathVQA_TEST from .text_mcq import CustomTextMCQDataset, TextMCQDataset +from .SIBench import SIBench from .vcr import VCRDataset from .mmlongbench import MMLongBench @@ -247,13 +248,15 @@ def evaluate(self, eval_file, **judge_kwargs): TextMCQDataset ] +MIXED_DATASET = [SIBench] + CUSTOM_DATASET = [ CustomMCQDataset, CustomVQADataset, CustomTextMCQDataset ] DATASET_COLLECTION = [ConcatDataset, ConcatVideoDataset] -DATASET_CLASSES = IMAGE_DATASET + VIDEO_DATASET + TEXT_DATASET + CUSTOM_DATASET + DATASET_COLLECTION # noqa: E501 +DATASET_CLASSES = IMAGE_DATASET + VIDEO_DATASET + TEXT_DATASET + MIXED_DATASET + CUSTOM_DATASET + DATASET_COLLECTION # noqa: E501 SUPPORTED_DATASETS = [] for DATASET_CLS in DATASET_CLASSES: SUPPORTED_DATASETS.extend(DATASET_CLS.supported_datasets()) diff --git a/vlmeval/inference_mixed.py b/vlmeval/inference_mixed.py new file mode 100644 index 000000000..f5b4dbb4d --- /dev/null +++ b/vlmeval/inference_mixed.py @@ -0,0 +1,324 @@ +import torch +import torch.distributed as dist +from vlmeval.config import supported_VLM +from vlmeval.utils import track_progress_rich +from vlmeval.smp import * + +FAIL_MSG = 'Failed to obtain answer via API.' +NOT_USE_SIBENCH_PROMPT = False + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--data', type=str, nargs='+', required=True) + parser.add_argument('--model', type=str, nargs='+', required=True) + parser.add_argument('--nproc', type=int, default=4, required=True) + parser.add_argument('--verbose', action='store_true') + args = parser.parse_args() + return args + + +# Only API model is accepted +def infer_data_api( + model, + work_dir, + model_name, + dataset, + actual_dataset_name, + index_set=None, + api_nproc=4, + ignore_failed=False +): + rank, world_size = get_rank_and_world_size() + assert rank == 0 and world_size == 1 + dataset_name = dataset.dataset_name + data = dataset.data + if index_set is not None: + data = data[data['index'].isin(index_set)] + + model = supported_VLM[model_name]() if isinstance(model, str) else model + assert getattr(model, 'is_api', False) + if hasattr(model, 'set_dump_image'): + model.set_dump_image(dataset.dump_image) + + lt, indices = len(data), list(data['index']) + + structs = [] + for i in range(lt): + item = data.iloc[i] + if hasattr(model, 'use_custom_prompt') and model.use_custom_prompt(dataset_name) and NOT_USE_SIBENCH_PROMPT: + assert hasattr(model, 'build_prompt') + struct = model.build_prompt(item, dataset=dataset_name) + else: + struct = dataset.build_prompt(item) + structs.append(struct) + + out_file = f'{work_dir}/{model_name}_{actual_dataset_name}_supp.pkl' + + # To reuse records in MMBench_V11 + if dataset_name in ['MMBench', 'MMBench_CN']: + v11_pred = f'{work_dir}/{model_name}_{actual_dataset_name}_V11.xlsx' + if osp.exists(v11_pred): + try: + reuse_inds = load('http://opencompass.openxlab.space/utils/mmb_reuse.pkl') + data = load(v11_pred) + ans_map = {x: y for x, y in zip(data['index'], data['prediction']) if x in reuse_inds} + dump(ans_map, out_file) + except Exception as err: + print(type(err), err) + + res = {} + if osp.exists(out_file): + res = load(out_file) + if ignore_failed: + res = {k: v for k, v in res.items() if FAIL_MSG not in v} + + structs = [s for i, s in zip(indices, structs) if i not in res] + indices = [i for i in indices if i not in res] + + gen_func = model.generate + structs = [dict(message=struct, dataset=dataset_name) for struct in structs] + + if len(structs): + track_progress_rich(gen_func, structs, nproc=api_nproc, chunksize=api_nproc, save=out_file, keys=indices) + + res = load(out_file) + if index_set is not None: + res = {k: v for k, v in res.items() if k in index_set} + os.remove(out_file) + return res + + +def infer_data( + model, + model_name, + work_dir, + dataset, + actual_dataset_name, + data_base, + out_file, + verbose=False, + api_nproc=4, + use_vllm=False +): + dataset_name = dataset.dataset_name + prev_file = f'{work_dir}/{model_name}_{actual_dataset_name}_PREV.pkl' + res = load(prev_file) if osp.exists(prev_file) else {} + if osp.exists(out_file): + res.update(load(out_file)) + + rank, world_size = get_rank_and_world_size() + sheet_indices = list(range(rank, len(dataset), world_size)) + lt = len(sheet_indices) + data = dataset.data.iloc[sheet_indices] + data_indices = [i for i in data['index']] + + # If finished, will exit without building the model + all_finished = True + for i in range(lt): + idx = data.iloc[i]['index'] + if idx not in res: + all_finished = False + if all_finished: + res = {k: res[k] for k in data_indices} + dump(res, out_file) + return model + + # Data need to be inferred + data = data[~data['index'].isin(res)] + lt = len(data) + + kwargs = {} + if model_name is not None and ( + 'Llama-4' in model_name + or 'Qwen2-VL' in model_name + or 'Qwen2.5-VL' in model_name + ): + kwargs = {'use_vllm': use_vllm} + + # (25.06.05) In newer version of transformers (after 4.50), with device_map='auto' and torchrun launcher, + # Transformers automatically adopt TP parallelism, which leads to compatibility problems with VLMEvalKit + # (In VLMEvalKit, we use torchrun to launch multiple model instances on a single node). + # To bypass this problem, we unset `WORLD_SIZE` before building the model to not use TP parallel. + ws_bak = os.environ.pop('WORLD_SIZE', None) + model = supported_VLM[model_name](**kwargs) if isinstance(model, str) else model + if ws_bak: + os.environ['WORLD_SIZE'] = ws_bak + + is_api = getattr(model, 'is_api', False) + if is_api: + lt, indices = len(data), list(data['index']) + supp = infer_data_api( + model=model, + work_dir=work_dir, + model_name=model_name, + dataset=dataset, + actual_dataset_name=actual_dataset_name, + index_set=set(indices), + api_nproc=api_nproc) + for idx in indices: + assert idx in supp + res.update(supp) + res = {k: res[k] for k in data_indices} + dump(res, out_file) + return model + else: + model.set_dump_image(dataset.dump_image) + + assert not getattr(dataset, 'pack', False), 'Current model not supported pack mode!' + if 'megabench' in dataset_name.lower() and 'llava_onevision' in model_name: + print( + 'LLaVA-OneVision does not support Megabench dataset as video dataset, ' + 'will set its VIDEO_LLM to False to enable multi-image input for video.' + ) + setattr(model, 'VIDEO_LLM', False) + + for i in tqdm(range(lt), desc=f'Infer {model_name}/{actual_dataset_name}, Rank {rank}/{world_size}'): + idx = data.iloc[i]['index'] + if idx in res: + continue + + if data.iloc[i]['input_type'] in ['image', 'multi-view']: + if hasattr(model, 'use_custom_prompt') and model.use_custom_prompt(dataset_name) and NOT_USE_SIBENCH_PROMPT: + struct = model.build_prompt(data.iloc[i], dataset=dataset_name) + else: + struct = dataset.build_prompt(data.iloc[i], data_base=data_base) + + # If `SKIP_ERR` flag is set, the model will skip the generation if error is encountered + if os.environ.get('SKIP_ERR', False) == '1': + FAIL_MSG = 'Failed to obtain answer' + try: + response = model.generate(message=struct, dataset=dataset_name) + except RuntimeError as err: + torch.cuda.synchronize() + warnings.warn(f'{type(err)} {str(err)}') + response = f'{FAIL_MSG}: {type(err)} {str(err)}' + else: + response = model.generate(message=struct, dataset=dataset_name) + + elif data.iloc[i]['input_type'] == 'video': + if getattr(model, 'nframe', None) is not None and getattr(model, 'nframe', 0) > 0: + if dataset.nframe > 0: + if getattr(model, 'nframe', 0) != dataset.nframe: + print(f'{model_name} is a video-llm model, nframe is set to {dataset.nframe}') + setattr(model, 'nframe', dataset.nframe) + elif getattr(model, 'fps', 0) == 0: + raise ValueError(f'fps is not suitable for {model_name}') + else: + setattr(model, 'nframe', None) + if getattr(model, 'fps', None) is not None and getattr(model, 'fps', 0) > 0: + if dataset.fps > 0: + if getattr(model, 'fps', 0) != dataset.fps: + print(f'{model_name} is a video-llm model, fps is set to {dataset.fps}, not using default') + setattr(model, 'fps', dataset.fps) + elif getattr(model, 'nframe', 0) == 0: + raise ValueError(f'nframe is not suitable for {model_name}') + else: + setattr(model, 'fps', None) + if ( + 'Qwen2-VL' in model_name + or 'Qwen2.5-VL' in model_name + or 'Qwen2.5-Omni' in model_name + ): + if getattr(model, 'nframe', None) is None and dataset.nframe > 0: + print(f'using {model_name} default setting for video, dataset.nframe is ommitted') + if getattr(model, 'fps', None) is None and dataset.fps > 0: + print(f'using {model_name} default setting for video, dataset.fps is ommitted') + + if (hasattr(model, 'use_custom_prompt') + and model.use_custom_prompt(dataset_name) + and NOT_USE_SIBENCH_PROMPT): + if dataset.nframe == 0: + raise ValueError(f'nframe must be set for custom prompt, fps is not suitable for {model_name}') + struct = model.build_prompt( + dataset.data.iloc[i], dataset=dataset, video_llm=getattr(model, 'VIDEO_LLM', False) + ) + else: + struct = dataset.build_prompt( + dataset.data.iloc[i], video_llm=getattr(model, 'VIDEO_LLM', False), data_base=data_base + ) + + # If `SKIP_ERR` flag is set, the model will skip the generation if error is encountered + if os.environ.get('SKIP_ERR', False) == '1': + FAIL_MSG = 'Failed to obtain answer' + try: + response = model.generate(message=struct, dataset=dataset_name) + except RuntimeError as err: + torch.cuda.synchronize() + warnings.error(f'{type(err)} {str(err)}') + response = f'{FAIL_MSG}: {type(err)} {str(err)}' + else: + response = model.generate(message=struct, dataset=dataset_name) + else: + torch.cuda.empty_cache() + raise NotImplementedError + torch.cuda.empty_cache() + + if verbose: + print(response, flush=True) + + res[idx] = response + if (i + 1) % 10 == 0: + dump(res, out_file) + + res = {k: res[k] for k in data_indices} + dump(res, out_file) + return model + + +# A wrapper for infer_data, do the pre & post processing +def infer_data_job_mixed( + model, + work_dir, + model_name, + dataset, + actual_dataset_name, + verbose=False, + api_nproc=4, + ignore_failed=False, + use_vllm=False +): + lmu_path = LMUDataRoot() + data_base = lmu_path + rank, world_size = get_rank_and_world_size() + + result_file = osp.join(work_dir, f'{model_name}_{actual_dataset_name}.xlsx') + + prev_file = f'{work_dir}/{model_name}_{actual_dataset_name}_PREV.pkl' + if osp.exists(result_file): + if rank == 0: + data = load(result_file) + results = {k: v for k, v in zip(data['index'], data['prediction'])} + if not ignore_failed: + results = {k: v for k, v in results.items() if FAIL_MSG not in str(v)} + dump(results, prev_file) + if world_size > 1: + dist.barrier() + + tmpl = osp.join(work_dir, '{}' + f'{world_size}_{actual_dataset_name}.pkl') + out_file = tmpl.format(rank) + + model = infer_data( + model=model, work_dir=work_dir, model_name=model_name, dataset=dataset, actual_dataset_name=actual_dataset_name, + data_base=data_base, out_file=out_file, verbose=verbose, api_nproc=api_nproc, use_vllm=use_vllm) + if world_size > 1: + dist.barrier() + + if rank == 0: + data_all = {} + for i in range(world_size): + data_all.update(load(tmpl.format(i))) + + data = dataset.data + for x in data['index']: + assert x in data_all + data['prediction'] = [str(data_all[x]) for x in data['index']] + if 'image' in data: + data.pop('image') + + dump(data, result_file) + for i in range(world_size): + os.remove(tmpl.format(i)) + if world_size > 1: + dist.barrier() + return model