|
| 1 | +{ |
| 2 | + "cells": [ |
| 3 | + { |
| 4 | + "cell_type": "markdown", |
| 5 | + "id": "e0056726-3359-4a9b-9913-016617525a6d", |
| 6 | + "metadata": {}, |
| 7 | + "source": [ |
| 8 | + "# Scalable late interaction vectors in Elasticsearch: Bit Vectors #\n", |
| 9 | + "\n", |
| 10 | + "In this notebook, we will be looking at how to convert late interaction vectors to bit vectors to \n", |
| 11 | + "1. Save siginificant disk space \n", |
| 12 | + "2. Lower query latency\n", |
| 13 | + " \n", |
| 14 | + "We will also look at how we can use hamming distance to speed our queries up even further. \n", |
| 15 | + "This notebook builds on part 1 where we downloaded the images, created ColPali vectors and saved them to disk. Please execute this notebook before trying the techniques in this notebook. \n", |
| 16 | + " \n", |
| 17 | + "Also check out our accompanying blog post on [Scaling Late Interaction Models](TODO) for more context on this notebook. " |
| 18 | + ] |
| 19 | + }, |
| 20 | + { |
| 21 | + "cell_type": "markdown", |
| 22 | + "id": "49dbcc61-5dab-4cf6-bbc5-7fa898707ce6", |
| 23 | + "metadata": {}, |
| 24 | + "source": [ |
| 25 | + "This is the key part of this notebook. We use the `to_bit_vectors()` function to convert our vectors into bit vectors. \n", |
| 26 | + "The function is simple in essence. Values `> 0` are converted to `1`, values `< 0` are converted to `0`. We then convert our array of `0`s and `1`s to a hex string, that represents our bit vector. \n", |
| 27 | + "So don't be surprised that the values that we will be indexing look like strings and not arrays as before. This is intended! \n", |
| 28 | + "\n", |
| 29 | + "Learn more about [bit vectors and hamming distance in our blog](https://www.elastic.co/search-labs/blog/bit-vectors-in-elasticsearch) about this topic. " |
| 30 | + ] |
| 31 | + }, |
| 32 | + { |
| 33 | + "cell_type": "code", |
| 34 | + "execution_count": 1, |
| 35 | + "id": "be6ffdc5-fbaa-40b5-8b33-5540a3f957ba", |
| 36 | + "metadata": {}, |
| 37 | + "outputs": [], |
| 38 | + "source": [ |
| 39 | + "import numpy as np\n", |
| 40 | + "\n", |
| 41 | + "\n", |
| 42 | + "def to_bit_vectors(embedding: list) -> list:\n", |
| 43 | + " embeddings = []\n", |
| 44 | + " for idx, patch_embedding in enumerate(embedding):\n", |
| 45 | + " patch_embedding = np.array(patch_embedding)\n", |
| 46 | + " binary_vector = (\n", |
| 47 | + " np.packbits(np.where(patch_embedding > 0, 1, 0))\n", |
| 48 | + " .astype(np.int8)\n", |
| 49 | + " .tobytes()\n", |
| 50 | + " .hex()\n", |
| 51 | + " )\n", |
| 52 | + " embeddings.append(binary_vector)\n", |
| 53 | + " return embeddings" |
| 54 | + ] |
| 55 | + }, |
| 56 | + { |
| 57 | + "cell_type": "markdown", |
| 58 | + "id": "52b7449b-8fbf-46b7-90c9-330070f6996a", |
| 59 | + "metadata": {}, |
| 60 | + "source": [ |
| 61 | + "Here we are defining our mapping for our Elasticsearch index. Note how we set the `element_type` parameter to `bit` to inform Elasticsearch that we will be indexing bit vectors in this field. " |
| 62 | + ] |
| 63 | + }, |
| 64 | + { |
| 65 | + "cell_type": "code", |
| 66 | + "execution_count": 2, |
| 67 | + "id": "2de5872d-b372-40fe-85c5-111b9f9fa6c8", |
| 68 | + "metadata": {}, |
| 69 | + "outputs": [ |
| 70 | + { |
| 71 | + "name": "stdout", |
| 72 | + "output_type": "stream", |
| 73 | + "text": [ |
| 74 | + "[INFO] Index 'searchlabs-colpali-hamming' already exists.\n" |
| 75 | + ] |
| 76 | + } |
| 77 | + ], |
| 78 | + "source": [ |
| 79 | + "import os\n", |
| 80 | + "from dotenv import load_dotenv\n", |
| 81 | + "from elasticsearch import Elasticsearch\n", |
| 82 | + "\n", |
| 83 | + "load_dotenv(\"elastic.env\")\n", |
| 84 | + "\n", |
| 85 | + "ELASTIC_API_KEY = os.getenv(\"ELASTIC_API_KEY\")\n", |
| 86 | + "ELASTIC_HOST = os.getenv(\"ELASTIC_HOST\")\n", |
| 87 | + "INDEX_NAME = \"searchlabs-colpali-hamming\"\n", |
| 88 | + "\n", |
| 89 | + "es = Elasticsearch(ELASTIC_HOST, api_key=ELASTIC_API_KEY)\n", |
| 90 | + "\n", |
| 91 | + "mappings = {\n", |
| 92 | + " \"mappings\": {\n", |
| 93 | + " \"properties\": {\n", |
| 94 | + " \"col_pali_vectors\": {\"type\": \"rank_vectors\", \"element_type\": \"bit\"}\n", |
| 95 | + " }\n", |
| 96 | + " }\n", |
| 97 | + "}\n", |
| 98 | + "\n", |
| 99 | + "if not es.indices.exists(index=INDEX_NAME):\n", |
| 100 | + " print(f\"[INFO] Creating index: {INDEX_NAME}\")\n", |
| 101 | + " es.indices.create(index=INDEX_NAME, body=mappings)\n", |
| 102 | + "else:\n", |
| 103 | + " print(f\"[INFO] Index '{INDEX_NAME}' already exists.\")\n", |
| 104 | + "\n", |
| 105 | + "\n", |
| 106 | + "def index_document(es_client, index, doc_id, document, retries=10, initial_backoff=1):\n", |
| 107 | + " for attempt in range(1, retries + 1):\n", |
| 108 | + " try:\n", |
| 109 | + " return es_client.index(index=index, id=doc_id, document=document)\n", |
| 110 | + " except Exception as e:\n", |
| 111 | + " if attempt < retries:\n", |
| 112 | + " wait_time = initial_backoff * (2 ** (attempt - 1))\n", |
| 113 | + " print(f\"[WARN] Failed to index {doc_id} (attempt {attempt}): {e}\")\n", |
| 114 | + " time.sleep(wait_time)\n", |
| 115 | + " else:\n", |
| 116 | + " print(f\"Failed to index {doc_id} after {retries} attempts: {e}\")\n", |
| 117 | + " raise" |
| 118 | + ] |
| 119 | + }, |
| 120 | + { |
| 121 | + "cell_type": "code", |
| 122 | + "execution_count": 3, |
| 123 | + "id": "bdf6ff33-3e22-43c1-9f3e-c3dd663b40e2", |
| 124 | + "metadata": {}, |
| 125 | + "outputs": [ |
| 126 | + { |
| 127 | + "data": { |
| 128 | + "application/vnd.jupyter.widget-view+json": { |
| 129 | + "model_id": "022b4af8891b4a06962e023c7f92d8f4", |
| 130 | + "version_major": 2, |
| 131 | + "version_minor": 0 |
| 132 | + }, |
| 133 | + "text/plain": [ |
| 134 | + "Indexing documents: 0%| | 0/500 [00:00<?, ?it/s]" |
| 135 | + ] |
| 136 | + }, |
| 137 | + "metadata": {}, |
| 138 | + "output_type": "display_data" |
| 139 | + }, |
| 140 | + { |
| 141 | + "name": "stdout", |
| 142 | + "output_type": "stream", |
| 143 | + "text": [ |
| 144 | + "Completed indexing 500 documents\n" |
| 145 | + ] |
| 146 | + } |
| 147 | + ], |
| 148 | + "source": [ |
| 149 | + "from concurrent.futures import ThreadPoolExecutor\n", |
| 150 | + "from tqdm.notebook import tqdm\n", |
| 151 | + "import pickle\n", |
| 152 | + "\n", |
| 153 | + "\n", |
| 154 | + "def process_file(file_name, vectors):\n", |
| 155 | + " if es.exists(index=INDEX_NAME, id=file_name):\n", |
| 156 | + " return\n", |
| 157 | + "\n", |
| 158 | + " bit_vectors = to_bit_vectors(vectors)\n", |
| 159 | + "\n", |
| 160 | + " index_document(\n", |
| 161 | + " es_client=es,\n", |
| 162 | + " index=INDEX_NAME,\n", |
| 163 | + " doc_id=file_name,\n", |
| 164 | + " document={\"col_pali_vectors\": bit_vectors},\n", |
| 165 | + " )\n", |
| 166 | + "\n", |
| 167 | + "\n", |
| 168 | + "with open(\"col_pali_vectors.pkl\", \"rb\") as f:\n", |
| 169 | + " file_to_multi_vectors = pickle.load(f)\n", |
| 170 | + "\n", |
| 171 | + "with ThreadPoolExecutor(max_workers=10) as executor:\n", |
| 172 | + " list(\n", |
| 173 | + " tqdm(\n", |
| 174 | + " executor.map(\n", |
| 175 | + " lambda item: process_file(*item), file_to_multi_vectors.items()\n", |
| 176 | + " ),\n", |
| 177 | + " total=len(file_to_multi_vectors),\n", |
| 178 | + " desc=\"Indexing documents\",\n", |
| 179 | + " )\n", |
| 180 | + " )\n", |
| 181 | + "\n", |
| 182 | + "print(f\"Completed indexing {len(file_to_multi_vectors)} documents\")" |
| 183 | + ] |
| 184 | + }, |
| 185 | + { |
| 186 | + "cell_type": "code", |
| 187 | + "execution_count": 4, |
| 188 | + "id": "1dfc3713-d649-46db-aa81-171d6d92668e", |
| 189 | + "metadata": {}, |
| 190 | + "outputs": [ |
| 191 | + { |
| 192 | + "data": { |
| 193 | + "application/vnd.jupyter.widget-view+json": { |
| 194 | + "model_id": "064e33061bac40e4802138e30599225b", |
| 195 | + "version_major": 2, |
| 196 | + "version_minor": 0 |
| 197 | + }, |
| 198 | + "text/plain": [ |
| 199 | + "Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]" |
| 200 | + ] |
| 201 | + }, |
| 202 | + "metadata": {}, |
| 203 | + "output_type": "display_data" |
| 204 | + } |
| 205 | + ], |
| 206 | + "source": [ |
| 207 | + "import torch\n", |
| 208 | + "from PIL import Image\n", |
| 209 | + "from colpali_engine.models import ColPali, ColPaliProcessor\n", |
| 210 | + "\n", |
| 211 | + "model_name = \"vidore/colpali-v1.3\"\n", |
| 212 | + "model = ColPali.from_pretrained(\n", |
| 213 | + " \"vidore/colpali-v1.3\",\n", |
| 214 | + " torch_dtype=torch.float32,\n", |
| 215 | + " device_map=\"mps\", # \"mps\" for Apple Silicon, \"cuda\" if available, \"cpu\" otherwise\n", |
| 216 | + ").eval()\n", |
| 217 | + "\n", |
| 218 | + "col_pali_processor = ColPaliProcessor.from_pretrained(model_name)\n", |
| 219 | + "\n", |
| 220 | + "\n", |
| 221 | + "def create_col_pali_query_vectors(query: str) -> list:\n", |
| 222 | + " queries = col_pali_processor.process_queries([query]).to(model.device)\n", |
| 223 | + " with torch.no_grad():\n", |
| 224 | + " return model(**queries).tolist()[0]" |
| 225 | + ] |
| 226 | + }, |
| 227 | + { |
| 228 | + "cell_type": "raw", |
| 229 | + "id": "5e86697d-d9dd-4224-85c8-023c71c88548", |
| 230 | + "metadata": {}, |
| 231 | + "source": [ |
| 232 | + "Here we run the search against our index comparing our query vector converted to bit vectors to the bit vectors in our index. \n", |
| 233 | + "Trading of a bit of accuracy, this is allows us to use hamming distance (`maxSimInvHamming(...)`), which is able to leverage optimzations such as bit-masks, SIMD, etc. Again - learn more about [bit vectors and hamming distance in our blog](https://www.elastic.co/search-labs/blog/bit-vectors-in-elasticsearch) about this topic. \n", |
| 234 | + "\n", |
| 235 | + "See the cell below about a different technique to query our bit vectors. " |
| 236 | + ] |
| 237 | + }, |
| 238 | + { |
| 239 | + "cell_type": "code", |
| 240 | + "execution_count": 5, |
| 241 | + "id": "8e322b23-b4bc-409d-9e00-2dab93f6a295", |
| 242 | + "metadata": {}, |
| 243 | + "outputs": [ |
| 244 | + { |
| 245 | + "name": "stdout", |
| 246 | + "output_type": "stream", |
| 247 | + "text": [ |
| 248 | + "{\"_source\": false, \"query\": {\"script_score\": {\"query\": {\"match_all\": {}}, \"script\": {\"source\": \"maxSimInvHamming(params.query_vector, 'col_pali_vectors')\", \"params\": {\"query_vector\": [\"7747bcd9732859c3645aa81036f5c960\", \"729b3c418ba8594a67daa042eca1c961\", \"609e3d8a2ac379c2204aa0cfa8345bdc\", \"30bf378a2ac279da245aa8dfa83c3bdc\", \"64af77ea2acdf9c28c0aa5df863677f4\", \"686f3fce2ac871c26e6aaddf023455ec\", \"383f31a8e8c0f8ca2c4ab54f047c7dec\", \"203b33caaac279da0acaa54f8a3c6bcc\", \"319a63eba8d279ca30dbbccf8f757b8e\", \"203b73ca28d2798a325bb44f8c3c5bce\", \"203bb7caa8d2718a1a4bb14f8a3c5bdc\", \"203bb7caa8d2798a1a6aa14f8a3c5fdc\", \"303b33caa8d2798a0a4aa14f8a3c5bdc\", \"303b33caaad379ca0e4aa14f8a3c5bdc\", \"709b33caaac379ca0c4aa14f8a3c5fdc\", \"708e37eaaac779ca2c4aa1df863c1fdc\", \"648e77ea6acd79caac4ae1df86363ffc\", \"648e77ea6acdf9caac4ae5df06363ffc\", \"608f37ea2ac579ca2c4ea1df063c3ffc\", \"709f37c8aac379ca2c4ea1df863c1fdc\", \"70af31c82ac671ce2c6ab14fc43c1bfc\"]}}}}, \"size\": 5}\n" |
| 249 | + ] |
| 250 | + }, |
| 251 | + { |
| 252 | + "data": { |
| 253 | + "text/html": [ |
| 254 | + "<div style='display: flex; flex-wrap: wrap; align-items: flex-start;'><img src=\"searchlabs-colpali/image_104.jpg\" alt=\"image_104.jpg\" style=\"max-width:300px; height:auto; margin:10px;\"><img src=\"searchlabs-colpali/image_3.jpg\" alt=\"image_3.jpg\" style=\"max-width:300px; height:auto; margin:10px;\"><img src=\"searchlabs-colpali/image_12.jpg\" alt=\"image_12.jpg\" style=\"max-width:300px; height:auto; margin:10px;\"><img src=\"searchlabs-colpali/image_2.jpg\" alt=\"image_2.jpg\" style=\"max-width:300px; height:auto; margin:10px;\"><img src=\"searchlabs-colpali/image_92.jpg\" alt=\"image_92.jpg\" style=\"max-width:300px; height:auto; margin:10px;\"></div>" |
| 255 | + ], |
| 256 | + "text/plain": [ |
| 257 | + "<IPython.core.display.HTML object>" |
| 258 | + ] |
| 259 | + }, |
| 260 | + "metadata": {}, |
| 261 | + "output_type": "display_data" |
| 262 | + } |
| 263 | + ], |
| 264 | + "source": [ |
| 265 | + "from IPython.display import display, HTML\n", |
| 266 | + "import os\n", |
| 267 | + "import json\n", |
| 268 | + "\n", |
| 269 | + "DOCUMENT_DIR = \"searchlabs-colpali\"\n", |
| 270 | + "\n", |
| 271 | + "query = \"What do companies use for recruiting?\"\n", |
| 272 | + "query_vector = to_bit_vectors(create_col_pali_query_vectors(query))\n", |
| 273 | + "es_query = {\n", |
| 274 | + " \"_source\": False,\n", |
| 275 | + " \"query\": {\n", |
| 276 | + " \"script_score\": {\n", |
| 277 | + " \"query\": {\"match_all\": {}},\n", |
| 278 | + " \"script\": {\n", |
| 279 | + " \"source\": \"maxSimInvHamming(params.query_vector, 'col_pali_vectors')\",\n", |
| 280 | + " \"params\": {\"query_vector\": query_vector},\n", |
| 281 | + " },\n", |
| 282 | + " }\n", |
| 283 | + " },\n", |
| 284 | + " \"size\": 5,\n", |
| 285 | + "}\n", |
| 286 | + "print(json.dumps(es_query))\n", |
| 287 | + "\n", |
| 288 | + "results = es.search(index=INDEX_NAME, body=es_query)\n", |
| 289 | + "image_ids = [hit[\"_id\"] for hit in results[\"hits\"][\"hits\"]]\n", |
| 290 | + "\n", |
| 291 | + "html = \"<div style='display: flex; flex-wrap: wrap; align-items: flex-start;'>\"\n", |
| 292 | + "for image_id in image_ids:\n", |
| 293 | + " image_path = os.path.join(DOCUMENT_DIR, image_id)\n", |
| 294 | + " html += f'<img src=\"{image_path}\" alt=\"{image_id}\" style=\"max-width:300px; height:auto; margin:10px;\">'\n", |
| 295 | + "html += \"</div>\"\n", |
| 296 | + "\n", |
| 297 | + "display(HTML(html))" |
| 298 | + ] |
| 299 | + }, |
| 300 | + { |
| 301 | + "cell_type": "markdown", |
| 302 | + "id": "e27b68ac-bec8-4415-919e-8b916bc35816", |
| 303 | + "metadata": {}, |
| 304 | + "source": [ |
| 305 | + "Above we have seen how to query our data using the `maxSimInvHamming(...)` function. \n", |
| 306 | + "We can also just pass the full fidelity col pali vector and use the `maxSimDotProduct(...)` function for [asymmetric similarity](https://www.elastic.co/guide/en/elasticsearch/reference/8.18/rank-vectors.html#rank-vectors-scoring) between the vectors. " |
| 307 | + ] |
| 308 | + }, |
| 309 | + { |
| 310 | + "cell_type": "code", |
| 311 | + "execution_count": 6, |
| 312 | + "id": "32fd9ee4-d7c6-4954-a766-7b06735290ff", |
| 313 | + "metadata": {}, |
| 314 | + "outputs": [ |
| 315 | + { |
| 316 | + "data": { |
| 317 | + "text/html": [ |
| 318 | + "<div style='display: flex; flex-wrap: wrap; align-items: flex-start;'><img src=\"searchlabs-colpali/image_104.jpg\" alt=\"image_104.jpg\" style=\"max-width:300px; height:auto; margin:10px;\"><img src=\"searchlabs-colpali/image_3.jpg\" alt=\"image_3.jpg\" style=\"max-width:300px; height:auto; margin:10px;\"><img src=\"searchlabs-colpali/image_2.jpg\" alt=\"image_2.jpg\" style=\"max-width:300px; height:auto; margin:10px;\"><img src=\"searchlabs-colpali/image_12.jpg\" alt=\"image_12.jpg\" style=\"max-width:300px; height:auto; margin:10px;\"><img src=\"searchlabs-colpali/image_92.jpg\" alt=\"image_92.jpg\" style=\"max-width:300px; height:auto; margin:10px;\"></div>" |
| 319 | + ], |
| 320 | + "text/plain": [ |
| 321 | + "<IPython.core.display.HTML object>" |
| 322 | + ] |
| 323 | + }, |
| 324 | + "metadata": {}, |
| 325 | + "output_type": "display_data" |
| 326 | + } |
| 327 | + ], |
| 328 | + "source": [ |
| 329 | + "query = \"What do companies use for recruiting?\"\n", |
| 330 | + "query_vector = create_col_pali_query_vectors(query)\n", |
| 331 | + "es_query = {\n", |
| 332 | + " \"_source\": False,\n", |
| 333 | + " \"query\": {\n", |
| 334 | + " \"script_score\": {\n", |
| 335 | + " \"query\": {\"match_all\": {}},\n", |
| 336 | + " \"script\": {\n", |
| 337 | + " \"source\": \"maxSimDotProduct(params.query_vector, 'col_pali_vectors')\",\n", |
| 338 | + " \"params\": {\"query_vector\": query_vector},\n", |
| 339 | + " },\n", |
| 340 | + " }\n", |
| 341 | + " },\n", |
| 342 | + " \"size\": 5,\n", |
| 343 | + "}\n", |
| 344 | + "\n", |
| 345 | + "results = es.search(index=INDEX_NAME, body=es_query)\n", |
| 346 | + "image_ids = [hit[\"_id\"] for hit in results[\"hits\"][\"hits\"]]\n", |
| 347 | + "\n", |
| 348 | + "html = \"<div style='display: flex; flex-wrap: wrap; align-items: flex-start;'>\"\n", |
| 349 | + "for image_id in image_ids:\n", |
| 350 | + " image_path = os.path.join(DOCUMENT_DIR, image_id)\n", |
| 351 | + " html += f'<img src=\"{image_path}\" alt=\"{image_id}\" style=\"max-width:300px; height:auto; margin:10px;\">'\n", |
| 352 | + "html += \"</div>\"\n", |
| 353 | + "\n", |
| 354 | + "display(HTML(html))" |
| 355 | + ] |
| 356 | + }, |
| 357 | + { |
| 358 | + "cell_type": "code", |
| 359 | + "execution_count": null, |
| 360 | + "id": "ee8df1e3-af66-4e35-9c26-7257c281536f", |
| 361 | + "metadata": {}, |
| 362 | + "outputs": [], |
| 363 | + "source": [ |
| 364 | + "# We kill the kernel forcefully to free up the memory from the ColPali model.\n", |
| 365 | + "print(\"Shutting down the kernel to free memory...\")\n", |
| 366 | + "import os\n", |
| 367 | + "\n", |
| 368 | + "os._exit(0)" |
| 369 | + ] |
| 370 | + } |
| 371 | + ], |
| 372 | + "metadata": { |
| 373 | + "kernelspec": { |
| 374 | + "display_name": "dependecy-test-colpali-blog", |
| 375 | + "language": "python", |
| 376 | + "name": "dependecy-test-colpali-blog" |
| 377 | + }, |
| 378 | + "language_info": { |
| 379 | + "codemirror_mode": { |
| 380 | + "name": "ipython", |
| 381 | + "version": 3 |
| 382 | + }, |
| 383 | + "file_extension": ".py", |
| 384 | + "mimetype": "text/x-python", |
| 385 | + "name": "python", |
| 386 | + "nbconvert_exporter": "python", |
| 387 | + "pygments_lexer": "ipython3", |
| 388 | + "version": "3.12.6" |
| 389 | + } |
| 390 | + }, |
| 391 | + "nbformat": 4, |
| 392 | + "nbformat_minor": 5 |
| 393 | +} |
0 commit comments