Skip to content

Commit f458957

Browse files
committedJul 15, 2024
init
1 parent d607812 commit f458957

31 files changed

+150418
-1
lines changed
 

‎.gitignore

+4
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
# this file
2+
*.png
3+
*.jpg
4+
15
# Byte-compiled / optimized / DLL files
26
__pycache__/
37
*.py[cod]

‎README.md

+103-1
Original file line numberDiff line numberDiff line change
@@ -1 +1,103 @@
1-
# ExoViP
1+
<div align="center">
2+
3+
# ExoViP: Step-by-step Verification and Exploration with Exoskeleton Modules for Compositional Visual Reasoning
4+
5+
[![arXiv](https://img.shields.io/badge/arXiv-<INDEX>-b31b1b.svg)](https://arxiv.org/abs/<INDEX>)
6+
[![Conference](http://img.shields.io/badge/COLM-2024-4b44ce.svg)](https://colmweb.org/)
7+
8+
</div>
9+
10+
Official implementation of our paper: ExoViP: Step-by-step Verification and Exploration with Exoskeleton Modules for Compositional Visual Reasoning
11+
12+
![image](assets/framework.png)
13+
14+
## Introduction
15+
16+
In this work, we devise a "plug-and-play" method, ExoViP, to correct the errors at both the planning and execution stages through introspective verification. We employ verification modules as "exoskeletons" to enhance current vision-language programming schemes. Specifically, our proposed verification module utilizes a mixture of three sub-verifiers to validate predictions after each reasoning step, subsequently calibrating the visual module predictions and refining the reasoning trace planned by LLMs.
17+
18+
## Envirionment
19+
20+
Paste your OPENAI-API-KEY and OPENAPI-API-BASE to `engine/.env` and `tasks/*.ipynb`
21+
22+
```
23+
conda env create -f environment.yaml
24+
conda activate exovip
25+
```
26+
27+
If the Huggingface is not available of your network, you can download all checkpoints under `prev_trained_models` directory
28+
29+
## Highlights
30+
31+
Errors in existing methods could be summarized to two categories:
32+
33+
- Module Error: The visual modules are not able to correctly execute the program
34+
- Planning Error: LLM can not parse the language query into a correct solvable program
35+
36+
![image](assets/error.png)
37+
38+
We conducted a comparative analysis of the statistics derived from a random sample of 100 failure incidents before (left) and after (right) the implementation of our method.
39+
40+
![image](assets/stat.png)
41+
42+
## Start
43+
44+
Our method has been validated on six tasks:
45+
46+
- Compositional Image Question Answering: [GQA](https://cs.stanford.edu/people/dorarad/gqa/about.html)
47+
- Referring Expression Understanding: [RefCOCO/RefCOCO+/RefCOCOg](https://github.com/lichengunc/refer)
48+
- Natural Language for Visual Reasoning: [NLVR](https://github.com/lil-lab/nlvr/tree/master/nlvr2)
49+
- Visual Abstract Reasoning: [KILOGRAM](https://github.com/lil-lab/kilogram)
50+
- Language-guided Image Editing: [MagicBrush](https://github.com/OSU-NLP-Group/MagicBrush)
51+
- Spatial-Temporal Video Reasoning: [AGQA](http://ai.stanford.edu/blog/agqa/)
52+
53+
***NOTE**: All the experiments are applied on subsets of these datasets, please refer to `datasets`*
54+
55+
code demos
56+
57+
```bash
58+
cd tasks
59+
60+
# GQA
61+
gqa.ipynb
62+
63+
# NLVR
64+
nlvr.ipynb
65+
66+
# RefCOCO(+/g)
67+
refcoco.ipynb
68+
69+
# KILOGRAM
70+
kilogram.ipynb
71+
72+
# MagicBrush
73+
magicbrush.ipynb
74+
75+
# AGQA
76+
agqa.ipynb
77+
```
78+
79+
## Available Modules
80+
81+
![image](assets/modules.png)
82+
83+
## Examples
84+
85+
![image](assets/GQA.png)
86+
87+
## Acknowledgement
88+
89+
[visprog](https://github.com/allenai/visprog), a neuro-symbolic system that solves complex and compositional visual tasks given natural language instructions
90+
91+
92+
## Citation
93+
94+
If you find our work helpful, please cite it.
95+
96+
```bibtex
97+
@article{videohallucer,
98+
title={ExoViP: Step-by-step Verification and Exploration with Exoskeleton Modules for Compositional Visual Reasoning},
99+
author={Yuxuan Wang, Alan Yuille, Zhuowan Li, Zilong Zheng},
100+
journal={COLM 2024},
101+
year={2024}
102+
}
103+
```

‎baselines/refcoco_baseline.ipynb

+188
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": [
9+
"import os\n",
10+
"import re\n",
11+
"import sys\n",
12+
"import json\n",
13+
"from pathlib import Path\n",
14+
"module_path = os.path.abspath(os.path.join('..'))\n",
15+
"if module_path not in sys.path:\n",
16+
" sys.path.append(module_path)"
17+
]
18+
},
19+
{
20+
"cell_type": "code",
21+
"execution_count": null,
22+
"metadata": {},
23+
"outputs": [],
24+
"source": [
25+
"from PIL import Image\n",
26+
"from IPython.core.display import HTML\n",
27+
"from functools import partial\n",
28+
"\n",
29+
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
30+
"from transformers.generation import GenerationConfig\n",
31+
"import torch"
32+
]
33+
},
34+
{
35+
"cell_type": "code",
36+
"execution_count": null,
37+
"metadata": {},
38+
"outputs": [],
39+
"source": [
40+
"# Note: The default behavior now has injection attack prevention off.\n",
41+
"tokenizer = AutoTokenizer.from_pretrained(\"Qwen/Qwen-VL-Chat\", trust_remote_code=True)\n",
42+
"\n",
43+
"# use bf16\n",
44+
"# model = AutoModelForCausalLM.from_pretrained(\"Qwen/Qwen-VL-Chat\", device_map=\"auto\", trust_remote_code=True, bf16=True).eval()\n",
45+
"# use fp16\n",
46+
"# model = AutoModelForCausalLM.from_pretrained(\"Qwen/Qwen-VL-Chat\", device_map=\"auto\", trust_remote_code=True, fp16=True).eval()\n",
47+
"# use cpu only\n",
48+
"# model = AutoModelForCausalLM.from_pretrained(\"Qwen/Qwen-VL-Chat\", device_map=\"cpu\", trust_remote_code=True).eval()\n",
49+
"# use cuda device\n",
50+
"model = AutoModelForCausalLM.from_pretrained(\"Qwen/Qwen-VL-Chat\", device_map=\"cuda\", trust_remote_code=True).eval()\n",
51+
"\n",
52+
"# Specify hyperparameters for generation\n",
53+
"model.generation_config = GenerationConfig.from_pretrained(\"Qwen/Qwen-VL-Chat\", trust_remote_code=True)\n"
54+
]
55+
},
56+
{
57+
"cell_type": "code",
58+
"execution_count": null,
59+
"metadata": {},
60+
"outputs": [],
61+
"source": [
62+
"from tqdm import tqdm\n",
63+
"from PIL import ImageDraw\n",
64+
"test_file = os.path.join(Path.home(), 'codes/ExoViP/datasets/refcoco/test.json')\n",
65+
"with open(test_file) as jp:\n",
66+
" test = json.load(jp)\n",
67+
"eval_pred = 0\n",
68+
"eval_cnt = 0\n",
69+
"\n",
70+
"for idx, dct in tqdm(test.items()):\n",
71+
" # eval_cnt += 1\n",
72+
" # if eval_cnt < 5: continue\n",
73+
" \n",
74+
" img_id = dct['img']\n",
75+
" img_path = os.path.join(Path.home(), 'codes/ExoViP/datasets/refcoco/imgs', img_id)\n",
76+
" image = Image.open(img_path)\n",
77+
" h, w = image.height, image.width\n",
78+
" \n",
79+
" instruction = dct['instruction']\n",
80+
" # print(instruction)\n",
81+
" \n",
82+
" query = tokenizer.from_list_format([\n",
83+
" {\"image\": img_path,\n",
84+
" \"text\": instruction}\n",
85+
" ])\n",
86+
" \n",
87+
" response, history = model.chat(tokenizer, query=query, history=None)\n",
88+
" # image = tokenizer.draw_bbox_on_latest_picture(response, history)\n",
89+
" # image.save(str(eval_cnt)+'.jpg')\n",
90+
" # display(image)\n",
91+
" PATTERN = re.compile(r'\\((.*?)\\),\\((.*?)\\)')\n",
92+
" predict_bbox = re.findall(PATTERN, response)\n",
93+
" try:\n",
94+
" if ',' not in predict_bbox[0][0] or ',' not in predict_bbox[0][\n",
95+
" 1]:\n",
96+
" predict_bbox = (0., 0., 0., 0.)\n",
97+
" else:\n",
98+
" x1, y1 = [\n",
99+
" float(tmp) for tmp in predict_bbox[0][0].split(',')\n",
100+
" ]\n",
101+
" x2, y2 = [\n",
102+
" float(tmp) for tmp in predict_bbox[0][1].split(',')\n",
103+
" ]\n",
104+
" \n",
105+
" # x1, y1, x2, y2 = box['box']\n",
106+
" x1, y1, x2, y2 = (int(x1 / 1000 * w), int(y1 / 1000 * h), int(x2 / 1000 * w), int(y2 / 1000 * h))\n",
107+
" predict_bbox = (x1, y1, x2, y2)\n",
108+
" except:\n",
109+
" predict_bbox = (0., 0., 0., 0.)\n",
110+
" box = predict_bbox\n",
111+
" label = dct['box']\n",
112+
" # print(box)\n",
113+
" # print(label)\n",
114+
" # print()\n",
115+
" # draw = ImageDraw.Draw(image)\n",
116+
" # draw.rectangle(box,outline='red',width=4)\n",
117+
" # draw.rectangle(label,outline='green',width=4)\n",
118+
" # image.save(str(eval_cnt)+'.jpg')\n",
119+
" \n",
120+
" # calculate iou\n",
121+
" label_area = (label[2]-label[0]) * (label[3] - label[1])\n",
122+
" box_area = (box[2]-box[0]) * (box[3] - box[1])\n",
123+
" x1 = max(box[0], label[0])\n",
124+
" x2 = min(box[2], label[2])\n",
125+
" y1 = max(box[1], label[1])\n",
126+
" y2 = min(box[3], label[3])\n",
127+
" intersection = max(0, x2-x1) * max(0, y2-y1)\n",
128+
" iou = intersection / (label_area + box_area - intersection)\n",
129+
" # print(iou)\n",
130+
" eval_pred += iou\n",
131+
" eval_cnt += 1\n",
132+
" \n",
133+
" \n",
134+
" # # visualize\n",
135+
" # # W,H=image.size\n",
136+
" # draw = ImageDraw.Draw(result)\n",
137+
" # draw.rectangle(label,outline='red',width=4)\n",
138+
" # result.save(f'{idx}.jpg')\n",
139+
" # print(idx, instruction)\n",
140+
" # if eval_cnt > 5:\n",
141+
" # break\n",
142+
" \n",
143+
" if eval_cnt % 20 == 0:\n",
144+
" print(f'step {eval_cnt} iou: ', round(eval_pred/eval_cnt, 2))\n",
145+
" # break\n",
146+
"\n",
147+
"print('iou: ', eval_pred/len(test.keys()))\n",
148+
"result_file = os.path.join(Path.home(), 'codes/visprog/results/refcoco/qwen.json')\n",
149+
"with open(result_file, 'w') as jp:\n",
150+
" json.dump(test, jp)\n"
151+
]
152+
},
153+
{
154+
"cell_type": "code",
155+
"execution_count": null,
156+
"metadata": {},
157+
"outputs": [],
158+
"source": []
159+
}
160+
],
161+
"metadata": {
162+
"kernelspec": {
163+
"display_name": "Python 3.10.4 ('few-shot-vr')",
164+
"language": "python",
165+
"name": "python3"
166+
},
167+
"language_info": {
168+
"codemirror_mode": {
169+
"name": "ipython",
170+
"version": 3
171+
},
172+
"file_extension": ".py",
173+
"mimetype": "text/x-python",
174+
"name": "python",
175+
"nbconvert_exporter": "python",
176+
"pygments_lexer": "ipython3",
177+
"version": "3.8.13"
178+
},
179+
"orig_nbformat": 4,
180+
"vscode": {
181+
"interpreter": {
182+
"hash": "f6aae81381dc24e2fd0d8778e266667bb8dbd7e1c04425e21584f774a2d20c40"
183+
}
184+
}
185+
},
186+
"nbformat": 4,
187+
"nbformat_minor": 2
188+
}

0 commit comments

Comments
 (0)
Please sign in to comment.