VisRAG is a novel vision-language model (VLM)-based RAG pipeline. In this pipeline, instead of first parsing the document to obtain text, the document is directly embedded using a VLM as an image and then retrieved to enhance the generation of a VLM.Compared to traditional text-based RAG, VisRAG maximizes the retention and utilization of the data information in the original documents, eliminating the information loss introduced during the parsing process.
VisRAG-Ret is a document embedding model built on MiniCPM-V 2.0, a vision-language model that integrates SigLIP as the vision encoder and MiniCPM-2B as the language model.
In the paper, We use MiniCPM-V 2.0, MiniCPM-V 2.6 and GPT-4o as the generators. Actually you can use any VLMs you like!
conda create --name VisRAG python==3.10.8
conda install nvidia/label/cuda-11.8.0::cuda-toolkit
cd VisRAG
pip install -r requirements.txt
pip install -e .
cd timm_modified
pip install -e .
cd ..
Our training dataset of 362,110 Query-Document (Q-D) Pairs for VisRAG-Ret is comprised of train sets of openly available academic datasets (34%) and a synthetic dataset made up of pages from web-crawled PDF documents and augmented with VLM-generated (GPT-4o) pseudo-queries (66%).
Data coming soon
bash scripts/train_retriever/train.sh 2048 16 8 0.02 1 true false config/deepspeed.json 1e-5 false wmean causal 1 true 2 false <model_path> <repository_name>
<repository_name> can be 'openbmb/VisRAG-Ret-Train-In-domain-data' or 'openbmb/VisRAG-Ret-Train-Synthetic-data'.
The generation part does not use any fine-tuning; we directly use off-the-shelf LLMs/VLMs for generation.
Data coming soon
bash scripts/eval_retriever/eval.sh 512 2048 16 8 wmean causal ArxivQA,ChartQA,MP-DocVQA,InfoVQA,PlotQA,SlideVQA <ckpt_path>
The parameters mentioned above is what we use in our paper, you can use them to reproduce the results in the paper.
Coming soon
Model on Hugging Face: https://huggingface.co/openbmb/VisRAG-Ret
from transformers import AutoModel, AutoTokenizer
import torch
import torch.nn.functional as F
from PIL import Image
import os
def weighted_mean_pooling(hidden, attention_mask):
attention_mask_ = attention_mask * attention_mask.cumsum(dim=1)
s = torch.sum(hidden * attention_mask_.unsqueeze(-1).float(), dim=1)
d = attention_mask_.sum(dim=1, keepdim=True).float()
reps = s / d
return reps
@torch.no_grad()
def encode(text_or_image_list):
if (isinstance(text_or_image_list[0], str)):
inputs = {
"text": text_or_image_list,
'image': [None] * len(text_or_image_list),
'tokenizer': tokenizer
}
else:
inputs = {
"text": [''] * len(text_or_image_list),
'image': text_or_image_list,
'tokenizer': tokenizer
}
outputs = model(**inputs)
attention_mask = outputs.attention_mask
hidden = outputs.last_hidden_state
reps = weighted_mean_pooling(hidden, attention_mask)
embeddings = F.normalize(reps, p=2, dim=1).detach().cpu().numpy()
return embeddings
tokenizer = AutoTokenizer.from_pretrained("openbmb/VisRAG-Ret", trust_remote_code=True)
model = AutoModel.from_pretrained("openbmb/VisRAG-Ret", trust_remote_code=True)
model.eval()
script_dir = os.path.dirname(os.path.realpath(__file__))
queries = ["What does a dog look like?"]
passages = [
Image.open(os.path.join(script_dir, 'test_image/cat.jpeg')).convert('RGB'),
Image.open(os.path.join(script_dir, 'test_image/dog.jpg')).convert('RGB'),
]
INSTRUCTION = "Represent this query for retrieving relevant documents: "
queries = [INSTRUCTION + query for query in queries]
embeddings_query = encode(queries)
embeddings_doc = encode(passages)
scores = (embeddings_query @ embeddings_doc.T)
print(scores.tolist())
- The code in this repo is released under the Apache-2.0 License.
- The usage of VisRAG-Ret model weights must strictly follow MiniCPM Model License.md.
- The models and weights of VisRAG-Ret are completely free for academic research. After filling out a "questionnaire" for registration, VisRAG-Ret weights are also available for free commercial use.
- Shi Yu: [email protected]
- Chaoyue Tang: [email protected]