Skip to content

Commit f20a2d2

Browse files
authored
ColPali blog part 2 (#414)
Examples on: * Bit vectors * Hamming distance * Asymmetric similarity comparison * Average vectors * Token pooling
1 parent 93d7ba8 commit f20a2d2

File tree

3 files changed

+1144
-0
lines changed

3 files changed

+1144
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,393 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"id": "e0056726-3359-4a9b-9913-016617525a6d",
6+
"metadata": {},
7+
"source": [
8+
"# Scalable late interaction vectors in Elasticsearch: Bit Vectors #\n",
9+
"\n",
10+
"In this notebook, we will be looking at how to convert late interaction vectors to bit vectors to \n",
11+
"1. Save siginificant disk space \n",
12+
"2. Lower query latency\n",
13+
" \n",
14+
"We will also look at how we can use hamming distance to speed our queries up even further. \n",
15+
"This notebook builds on part 1 where we downloaded the images, created ColPali vectors and saved them to disk. Please execute this notebook before trying the techniques in this notebook. \n",
16+
" \n",
17+
"Also check out our accompanying blog post on [Scaling Late Interaction Models](TODO) for more context on this notebook. "
18+
]
19+
},
20+
{
21+
"cell_type": "markdown",
22+
"id": "49dbcc61-5dab-4cf6-bbc5-7fa898707ce6",
23+
"metadata": {},
24+
"source": [
25+
"This is the key part of this notebook. We use the `to_bit_vectors()` function to convert our vectors into bit vectors. \n",
26+
"The function is simple in essence. Values `> 0` are converted to `1`, values `< 0` are converted to `0`. We then convert our array of `0`s and `1`s to a hex string, that represents our bit vector. \n",
27+
"So don't be surprised that the values that we will be indexing look like strings and not arrays as before. This is intended! \n",
28+
"\n",
29+
"Learn more about [bit vectors and hamming distance in our blog](https://www.elastic.co/search-labs/blog/bit-vectors-in-elasticsearch) about this topic. "
30+
]
31+
},
32+
{
33+
"cell_type": "code",
34+
"execution_count": 1,
35+
"id": "be6ffdc5-fbaa-40b5-8b33-5540a3f957ba",
36+
"metadata": {},
37+
"outputs": [],
38+
"source": [
39+
"import numpy as np\n",
40+
"\n",
41+
"\n",
42+
"def to_bit_vectors(embedding: list) -> list:\n",
43+
" embeddings = []\n",
44+
" for idx, patch_embedding in enumerate(embedding):\n",
45+
" patch_embedding = np.array(patch_embedding)\n",
46+
" binary_vector = (\n",
47+
" np.packbits(np.where(patch_embedding > 0, 1, 0))\n",
48+
" .astype(np.int8)\n",
49+
" .tobytes()\n",
50+
" .hex()\n",
51+
" )\n",
52+
" embeddings.append(binary_vector)\n",
53+
" return embeddings"
54+
]
55+
},
56+
{
57+
"cell_type": "markdown",
58+
"id": "52b7449b-8fbf-46b7-90c9-330070f6996a",
59+
"metadata": {},
60+
"source": [
61+
"Here we are defining our mapping for our Elasticsearch index. Note how we set the `element_type` parameter to `bit` to inform Elasticsearch that we will be indexing bit vectors in this field. "
62+
]
63+
},
64+
{
65+
"cell_type": "code",
66+
"execution_count": 2,
67+
"id": "2de5872d-b372-40fe-85c5-111b9f9fa6c8",
68+
"metadata": {},
69+
"outputs": [
70+
{
71+
"name": "stdout",
72+
"output_type": "stream",
73+
"text": [
74+
"[INFO] Index 'searchlabs-colpali-hamming' already exists.\n"
75+
]
76+
}
77+
],
78+
"source": [
79+
"import os\n",
80+
"from dotenv import load_dotenv\n",
81+
"from elasticsearch import Elasticsearch\n",
82+
"\n",
83+
"load_dotenv(\"elastic.env\")\n",
84+
"\n",
85+
"ELASTIC_API_KEY = os.getenv(\"ELASTIC_API_KEY\")\n",
86+
"ELASTIC_HOST = os.getenv(\"ELASTIC_HOST\")\n",
87+
"INDEX_NAME = \"searchlabs-colpali-hamming\"\n",
88+
"\n",
89+
"es = Elasticsearch(ELASTIC_HOST, api_key=ELASTIC_API_KEY)\n",
90+
"\n",
91+
"mappings = {\n",
92+
" \"mappings\": {\n",
93+
" \"properties\": {\n",
94+
" \"col_pali_vectors\": {\"type\": \"rank_vectors\", \"element_type\": \"bit\"}\n",
95+
" }\n",
96+
" }\n",
97+
"}\n",
98+
"\n",
99+
"if not es.indices.exists(index=INDEX_NAME):\n",
100+
" print(f\"[INFO] Creating index: {INDEX_NAME}\")\n",
101+
" es.indices.create(index=INDEX_NAME, body=mappings)\n",
102+
"else:\n",
103+
" print(f\"[INFO] Index '{INDEX_NAME}' already exists.\")\n",
104+
"\n",
105+
"\n",
106+
"def index_document(es_client, index, doc_id, document, retries=10, initial_backoff=1):\n",
107+
" for attempt in range(1, retries + 1):\n",
108+
" try:\n",
109+
" return es_client.index(index=index, id=doc_id, document=document)\n",
110+
" except Exception as e:\n",
111+
" if attempt < retries:\n",
112+
" wait_time = initial_backoff * (2 ** (attempt - 1))\n",
113+
" print(f\"[WARN] Failed to index {doc_id} (attempt {attempt}): {e}\")\n",
114+
" time.sleep(wait_time)\n",
115+
" else:\n",
116+
" print(f\"Failed to index {doc_id} after {retries} attempts: {e}\")\n",
117+
" raise"
118+
]
119+
},
120+
{
121+
"cell_type": "code",
122+
"execution_count": 3,
123+
"id": "bdf6ff33-3e22-43c1-9f3e-c3dd663b40e2",
124+
"metadata": {},
125+
"outputs": [
126+
{
127+
"data": {
128+
"application/vnd.jupyter.widget-view+json": {
129+
"model_id": "022b4af8891b4a06962e023c7f92d8f4",
130+
"version_major": 2,
131+
"version_minor": 0
132+
},
133+
"text/plain": [
134+
"Indexing documents: 0%| | 0/500 [00:00<?, ?it/s]"
135+
]
136+
},
137+
"metadata": {},
138+
"output_type": "display_data"
139+
},
140+
{
141+
"name": "stdout",
142+
"output_type": "stream",
143+
"text": [
144+
"Completed indexing 500 documents\n"
145+
]
146+
}
147+
],
148+
"source": [
149+
"from concurrent.futures import ThreadPoolExecutor\n",
150+
"from tqdm.notebook import tqdm\n",
151+
"import pickle\n",
152+
"\n",
153+
"\n",
154+
"def process_file(file_name, vectors):\n",
155+
" if es.exists(index=INDEX_NAME, id=file_name):\n",
156+
" return\n",
157+
"\n",
158+
" bit_vectors = to_bit_vectors(vectors)\n",
159+
"\n",
160+
" index_document(\n",
161+
" es_client=es,\n",
162+
" index=INDEX_NAME,\n",
163+
" doc_id=file_name,\n",
164+
" document={\"col_pali_vectors\": bit_vectors},\n",
165+
" )\n",
166+
"\n",
167+
"\n",
168+
"with open(\"col_pali_vectors.pkl\", \"rb\") as f:\n",
169+
" file_to_multi_vectors = pickle.load(f)\n",
170+
"\n",
171+
"with ThreadPoolExecutor(max_workers=10) as executor:\n",
172+
" list(\n",
173+
" tqdm(\n",
174+
" executor.map(\n",
175+
" lambda item: process_file(*item), file_to_multi_vectors.items()\n",
176+
" ),\n",
177+
" total=len(file_to_multi_vectors),\n",
178+
" desc=\"Indexing documents\",\n",
179+
" )\n",
180+
" )\n",
181+
"\n",
182+
"print(f\"Completed indexing {len(file_to_multi_vectors)} documents\")"
183+
]
184+
},
185+
{
186+
"cell_type": "code",
187+
"execution_count": 4,
188+
"id": "1dfc3713-d649-46db-aa81-171d6d92668e",
189+
"metadata": {},
190+
"outputs": [
191+
{
192+
"data": {
193+
"application/vnd.jupyter.widget-view+json": {
194+
"model_id": "064e33061bac40e4802138e30599225b",
195+
"version_major": 2,
196+
"version_minor": 0
197+
},
198+
"text/plain": [
199+
"Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
200+
]
201+
},
202+
"metadata": {},
203+
"output_type": "display_data"
204+
}
205+
],
206+
"source": [
207+
"import torch\n",
208+
"from PIL import Image\n",
209+
"from colpali_engine.models import ColPali, ColPaliProcessor\n",
210+
"\n",
211+
"model_name = \"vidore/colpali-v1.3\"\n",
212+
"model = ColPali.from_pretrained(\n",
213+
" \"vidore/colpali-v1.3\",\n",
214+
" torch_dtype=torch.float32,\n",
215+
" device_map=\"mps\", # \"mps\" for Apple Silicon, \"cuda\" if available, \"cpu\" otherwise\n",
216+
").eval()\n",
217+
"\n",
218+
"col_pali_processor = ColPaliProcessor.from_pretrained(model_name)\n",
219+
"\n",
220+
"\n",
221+
"def create_col_pali_query_vectors(query: str) -> list:\n",
222+
" queries = col_pali_processor.process_queries([query]).to(model.device)\n",
223+
" with torch.no_grad():\n",
224+
" return model(**queries).tolist()[0]"
225+
]
226+
},
227+
{
228+
"cell_type": "raw",
229+
"id": "5e86697d-d9dd-4224-85c8-023c71c88548",
230+
"metadata": {},
231+
"source": [
232+
"Here we run the search against our index comparing our query vector converted to bit vectors to the bit vectors in our index. \n",
233+
"Trading of a bit of accuracy, this is allows us to use hamming distance (`maxSimInvHamming(...)`), which is able to leverage optimzations such as bit-masks, SIMD, etc. Again - learn more about [bit vectors and hamming distance in our blog](https://www.elastic.co/search-labs/blog/bit-vectors-in-elasticsearch) about this topic. \n",
234+
"\n",
235+
"See the cell below about a different technique to query our bit vectors. "
236+
]
237+
},
238+
{
239+
"cell_type": "code",
240+
"execution_count": 5,
241+
"id": "8e322b23-b4bc-409d-9e00-2dab93f6a295",
242+
"metadata": {},
243+
"outputs": [
244+
{
245+
"name": "stdout",
246+
"output_type": "stream",
247+
"text": [
248+
"{\"_source\": false, \"query\": {\"script_score\": {\"query\": {\"match_all\": {}}, \"script\": {\"source\": \"maxSimInvHamming(params.query_vector, 'col_pali_vectors')\", \"params\": {\"query_vector\": [\"7747bcd9732859c3645aa81036f5c960\", \"729b3c418ba8594a67daa042eca1c961\", \"609e3d8a2ac379c2204aa0cfa8345bdc\", \"30bf378a2ac279da245aa8dfa83c3bdc\", \"64af77ea2acdf9c28c0aa5df863677f4\", \"686f3fce2ac871c26e6aaddf023455ec\", \"383f31a8e8c0f8ca2c4ab54f047c7dec\", \"203b33caaac279da0acaa54f8a3c6bcc\", \"319a63eba8d279ca30dbbccf8f757b8e\", \"203b73ca28d2798a325bb44f8c3c5bce\", \"203bb7caa8d2718a1a4bb14f8a3c5bdc\", \"203bb7caa8d2798a1a6aa14f8a3c5fdc\", \"303b33caa8d2798a0a4aa14f8a3c5bdc\", \"303b33caaad379ca0e4aa14f8a3c5bdc\", \"709b33caaac379ca0c4aa14f8a3c5fdc\", \"708e37eaaac779ca2c4aa1df863c1fdc\", \"648e77ea6acd79caac4ae1df86363ffc\", \"648e77ea6acdf9caac4ae5df06363ffc\", \"608f37ea2ac579ca2c4ea1df063c3ffc\", \"709f37c8aac379ca2c4ea1df863c1fdc\", \"70af31c82ac671ce2c6ab14fc43c1bfc\"]}}}}, \"size\": 5}\n"
249+
]
250+
},
251+
{
252+
"data": {
253+
"text/html": [
254+
"<div style='display: flex; flex-wrap: wrap; align-items: flex-start;'><img src=\"searchlabs-colpali/image_104.jpg\" alt=\"image_104.jpg\" style=\"max-width:300px; height:auto; margin:10px;\"><img src=\"searchlabs-colpali/image_3.jpg\" alt=\"image_3.jpg\" style=\"max-width:300px; height:auto; margin:10px;\"><img src=\"searchlabs-colpali/image_12.jpg\" alt=\"image_12.jpg\" style=\"max-width:300px; height:auto; margin:10px;\"><img src=\"searchlabs-colpali/image_2.jpg\" alt=\"image_2.jpg\" style=\"max-width:300px; height:auto; margin:10px;\"><img src=\"searchlabs-colpali/image_92.jpg\" alt=\"image_92.jpg\" style=\"max-width:300px; height:auto; margin:10px;\"></div>"
255+
],
256+
"text/plain": [
257+
"<IPython.core.display.HTML object>"
258+
]
259+
},
260+
"metadata": {},
261+
"output_type": "display_data"
262+
}
263+
],
264+
"source": [
265+
"from IPython.display import display, HTML\n",
266+
"import os\n",
267+
"import json\n",
268+
"\n",
269+
"DOCUMENT_DIR = \"searchlabs-colpali\"\n",
270+
"\n",
271+
"query = \"What do companies use for recruiting?\"\n",
272+
"query_vector = to_bit_vectors(create_col_pali_query_vectors(query))\n",
273+
"es_query = {\n",
274+
" \"_source\": False,\n",
275+
" \"query\": {\n",
276+
" \"script_score\": {\n",
277+
" \"query\": {\"match_all\": {}},\n",
278+
" \"script\": {\n",
279+
" \"source\": \"maxSimInvHamming(params.query_vector, 'col_pali_vectors')\",\n",
280+
" \"params\": {\"query_vector\": query_vector},\n",
281+
" },\n",
282+
" }\n",
283+
" },\n",
284+
" \"size\": 5,\n",
285+
"}\n",
286+
"print(json.dumps(es_query))\n",
287+
"\n",
288+
"results = es.search(index=INDEX_NAME, body=es_query)\n",
289+
"image_ids = [hit[\"_id\"] for hit in results[\"hits\"][\"hits\"]]\n",
290+
"\n",
291+
"html = \"<div style='display: flex; flex-wrap: wrap; align-items: flex-start;'>\"\n",
292+
"for image_id in image_ids:\n",
293+
" image_path = os.path.join(DOCUMENT_DIR, image_id)\n",
294+
" html += f'<img src=\"{image_path}\" alt=\"{image_id}\" style=\"max-width:300px; height:auto; margin:10px;\">'\n",
295+
"html += \"</div>\"\n",
296+
"\n",
297+
"display(HTML(html))"
298+
]
299+
},
300+
{
301+
"cell_type": "markdown",
302+
"id": "e27b68ac-bec8-4415-919e-8b916bc35816",
303+
"metadata": {},
304+
"source": [
305+
"Above we have seen how to query our data using the `maxSimInvHamming(...)` function. \n",
306+
"We can also just pass the full fidelity col pali vector and use the `maxSimDotProduct(...)` function for [asymmetric similarity](https://www.elastic.co/guide/en/elasticsearch/reference/8.18/rank-vectors.html#rank-vectors-scoring) between the vectors. "
307+
]
308+
},
309+
{
310+
"cell_type": "code",
311+
"execution_count": 6,
312+
"id": "32fd9ee4-d7c6-4954-a766-7b06735290ff",
313+
"metadata": {},
314+
"outputs": [
315+
{
316+
"data": {
317+
"text/html": [
318+
"<div style='display: flex; flex-wrap: wrap; align-items: flex-start;'><img src=\"searchlabs-colpali/image_104.jpg\" alt=\"image_104.jpg\" style=\"max-width:300px; height:auto; margin:10px;\"><img src=\"searchlabs-colpali/image_3.jpg\" alt=\"image_3.jpg\" style=\"max-width:300px; height:auto; margin:10px;\"><img src=\"searchlabs-colpali/image_2.jpg\" alt=\"image_2.jpg\" style=\"max-width:300px; height:auto; margin:10px;\"><img src=\"searchlabs-colpali/image_12.jpg\" alt=\"image_12.jpg\" style=\"max-width:300px; height:auto; margin:10px;\"><img src=\"searchlabs-colpali/image_92.jpg\" alt=\"image_92.jpg\" style=\"max-width:300px; height:auto; margin:10px;\"></div>"
319+
],
320+
"text/plain": [
321+
"<IPython.core.display.HTML object>"
322+
]
323+
},
324+
"metadata": {},
325+
"output_type": "display_data"
326+
}
327+
],
328+
"source": [
329+
"query = \"What do companies use for recruiting?\"\n",
330+
"query_vector = create_col_pali_query_vectors(query)\n",
331+
"es_query = {\n",
332+
" \"_source\": False,\n",
333+
" \"query\": {\n",
334+
" \"script_score\": {\n",
335+
" \"query\": {\"match_all\": {}},\n",
336+
" \"script\": {\n",
337+
" \"source\": \"maxSimDotProduct(params.query_vector, 'col_pali_vectors')\",\n",
338+
" \"params\": {\"query_vector\": query_vector},\n",
339+
" },\n",
340+
" }\n",
341+
" },\n",
342+
" \"size\": 5,\n",
343+
"}\n",
344+
"\n",
345+
"results = es.search(index=INDEX_NAME, body=es_query)\n",
346+
"image_ids = [hit[\"_id\"] for hit in results[\"hits\"][\"hits\"]]\n",
347+
"\n",
348+
"html = \"<div style='display: flex; flex-wrap: wrap; align-items: flex-start;'>\"\n",
349+
"for image_id in image_ids:\n",
350+
" image_path = os.path.join(DOCUMENT_DIR, image_id)\n",
351+
" html += f'<img src=\"{image_path}\" alt=\"{image_id}\" style=\"max-width:300px; height:auto; margin:10px;\">'\n",
352+
"html += \"</div>\"\n",
353+
"\n",
354+
"display(HTML(html))"
355+
]
356+
},
357+
{
358+
"cell_type": "code",
359+
"execution_count": null,
360+
"id": "ee8df1e3-af66-4e35-9c26-7257c281536f",
361+
"metadata": {},
362+
"outputs": [],
363+
"source": [
364+
"# We kill the kernel forcefully to free up the memory from the ColPali model.\n",
365+
"print(\"Shutting down the kernel to free memory...\")\n",
366+
"import os\n",
367+
"\n",
368+
"os._exit(0)"
369+
]
370+
}
371+
],
372+
"metadata": {
373+
"kernelspec": {
374+
"display_name": "dependecy-test-colpali-blog",
375+
"language": "python",
376+
"name": "dependecy-test-colpali-blog"
377+
},
378+
"language_info": {
379+
"codemirror_mode": {
380+
"name": "ipython",
381+
"version": 3
382+
},
383+
"file_extension": ".py",
384+
"mimetype": "text/x-python",
385+
"name": "python",
386+
"nbconvert_exporter": "python",
387+
"pygments_lexer": "ipython3",
388+
"version": "3.12.6"
389+
}
390+
},
391+
"nbformat": 4,
392+
"nbformat_minor": 5
393+
}

0 commit comments

Comments
 (0)