Skip to content

Commit

Permalink
fix token counting bug, minor revamping
Browse files Browse the repository at this point in the history
Signed-off-by: Panos Vagenas <[email protected]>
  • Loading branch information
vagenas committed Nov 12, 2024
1 parent e1cb823 commit 6f30048
Showing 1 changed file with 91 additions and 92 deletions.
183 changes: 91 additions & 92 deletions docs/examples/advanced_chunking_with_merging.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
"from dataclasses import dataclass\n",
"from pathlib import Path\n",
"from tempfile import mkdtemp\n",
"from typing import Any, Iterator, Optional, Union\n",
"from typing import Iterator, Optional, Union\n",
"\n",
"import lancedb\n",
"import semchunk\n",
Expand All @@ -54,7 +54,7 @@
"from docling_core.types import DoclingDocument\n",
"from pydantic import ConfigDict, PositiveInt\n",
"from sentence_transformers import SentenceTransformer\n",
"from transformers import AutoTokenizer\n",
"from transformers import AutoTokenizer, PreTrainedTokenizerBase\n",
"\n",
"from docling.document_converter import DocumentConverter"
]
Expand All @@ -65,11 +65,12 @@
"metadata": {},
"outputs": [],
"source": [
"DOC_SOURCE = \"http://bill.murdocks.org/iccbr2011murdock_web.pdf\"\n",
"EMBED_MODEL_ID = \"sentence-transformers/all-MiniLM-L6-v2\"\n",
"TOKENIZER = AutoTokenizer.from_pretrained(EMBED_MODEL_ID)\n",
"EMBED_MODEL = SentenceTransformer(EMBED_MODEL_ID)\n",
"MAX_TOKENS = 64"
"MAX_TOKENS = 64\n",
"DOC_SOURCE = \"http://bill.murdocks.org/iccbr2011murdock_web.pdf\"\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(EMBED_MODEL_ID)\n",
"embed_model = SentenceTransformer(EMBED_MODEL_ID)"
]
},
{
Expand All @@ -85,14 +86,14 @@
"metadata": {},
"outputs": [],
"source": [
"class HybridChunker(BaseChunker):\n",
"class HybridChunker(BaseChunker): # TODO: improve naming\n",
"\n",
" model_config: ConfigDict = ConfigDict(arbitrary_types_allowed=True)\n",
"\n",
" inner_chunker: BaseChunker = HierarchicalChunker()\n",
" # TODO: improve typing for tokenizer below (ran into issues with `PreTrainedTokenizer`):\n",
" tokenizer: Any\n",
" tokenizer: PreTrainedTokenizerBase\n",
" max_tokens: PositiveInt\n",
" delim: str = \"\\n\"\n",
"\n",
" def _count_tokens(self, text: Optional[Union[str, list[str]]]):\n",
" if text is None:\n",
Expand All @@ -104,9 +105,6 @@
" return total\n",
" return len(self.tokenizer.tokenize(text, max_length=None))\n",
"\n",
" def _make_splitter(self):\n",
" return semchunk.chunkerify(self.tokenizer, self.max_tokens)\n",
"\n",
" @dataclass\n",
" class _ChunkLengthInfo:\n",
" total_len: int\n",
Expand All @@ -116,6 +114,7 @@
" def _doc_chunk_length(self, doc_chunk: DocChunk):\n",
" text_length = self._count_tokens(doc_chunk.text)\n",
" # Note that count_tokens handles None and lists, making this code simpler:\n",
" # TODO check if delim properly considered\n",
" headings_length = self._count_tokens(doc_chunk.meta.headings)\n",
" captions_length = self._count_tokens(doc_chunk.meta.captions)\n",
" total = text_length + headings_length + captions_length\n",
Expand All @@ -137,14 +136,13 @@
" new_chunk = DocChunk(text=window_text, meta=meta)\n",
" return new_chunk\n",
"\n",
" @classmethod\n",
" def _merge_text(cls, t1, t2):\n",
" def _merge_text(self, t1, t2):\n",
" if t1 == \"\":\n",
" return t2\n",
" elif t2 == \"\":\n",
" return t1\n",
" else:\n",
" return t1 + \"\\n\" + t2\n",
" return f\"{t1}{self.delim}{t2}\"\n",
"\n",
" def _split_by_doc_items(self, doc_chunk: DocChunk) -> list[DocChunk]:\n",
" if doc_chunk.meta.doc_items == None or len(doc_chunk.meta.doc_items) <= 1:\n",
Expand Down Expand Up @@ -206,20 +204,23 @@
" def _split_using_plain_text(\n",
" self,\n",
" doc_chunk: DocChunk,\n",
" plain_text_splitter,\n",
" ):\n",
" ) -> list[DocChunk]:\n",
" lengths = self._doc_chunk_length(doc_chunk)\n",
" if lengths.total_len <= self.max_tokens:\n",
" return [doc_chunk]\n",
" else:\n",
"\n",
" # How much room is there for text after subtracting out the headers and captions:\n",
" available_length = self.max_tokens - lengths.other_len\n",
" sem_chunker = semchunk.chunkerify(\n",
" self.tokenizer, chunk_size=available_length\n",
" )\n",
" if available_length <= 0:\n",
" raise ValueError(\n",
" \"Headers and captions for this chunk are longer than the total amount of size for the chunk. This is not supported now.\"\n",
" )\n",
" \"Headers and captions for this chunk are longer than the total amount of size for the chunk. This is not supported now.\"\n",
" ) # TODO switch to warning\n",
" text = doc_chunk.text\n",
" segments = plain_text_splitter.chunk(text)\n",
" segments = sem_chunker.chunk(text)\n",
" chunks = [DocChunk(text=s, meta=doc_chunk.meta) for s in segments]\n",
" return chunks\n",
"\n",
Expand Down Expand Up @@ -283,36 +284,33 @@
" )\n",
" return final_merged_chunks\n",
"\n",
" @classmethod\n",
" def _make_text_for_embedding(cls, chunk: DocChunk):\n",
" def _make_text_for_embedding(self, chunk: DocChunk):\n",
" output = \"\"\n",
" if chunk.meta.headings != None:\n",
" for h in chunk.meta.headings:\n",
" output += h + \"\\n\"\n",
" output += h + self.delim\n",
" if chunk.meta.captions != None:\n",
" for c in chunk.meta.captions:\n",
" output += c + \"\\n\"\n",
" output += c + self.delim\n",
" output += chunk.text\n",
" return output\n",
"\n",
" def _adjust_chunks_for_fixed_size(self, chunks: list[DocChunk], splitter):\n",
" split_by_items = [x for c in chunks for x in self._split_by_doc_items(c)]\n",
" split_recursively = [\n",
" x for c in split_by_items for x in self._split_using_plain_text(c, splitter)\n",
" ]\n",
" merged = self._merge_chunks(split_recursively)\n",
" text_expanded = [\n",
" def _adjust_chunks_for_fixed_size(self, chunks: list[DocChunk]):\n",
" res = chunks\n",
" res = [x for c in res for x in self._split_by_doc_items(c)]\n",
" res = [x for c in res for x in self._split_using_plain_text(c)]\n",
" res = self._merge_chunks(res)\n",
" res = [\n",
" DocChunk.model_validate(\n",
" {**c.model_dump(), \"text\": self._make_text_for_embedding(c)}\n",
" )\n",
" for c in merged\n",
" for c in res\n",
" ]\n",
" return text_expanded\n",
" return res\n",
"\n",
" def chunk(self, dl_doc: DoclingDocument, **kwargs) -> Iterator[BaseChunk]:\n",
" preliminary_chunks = self.inner_chunker.chunk(dl_doc=dl_doc, **kwargs)\n",
" splitter = self._make_splitter()\n",
" output_chunks = self._adjust_chunks_for_fixed_size(preliminary_chunks, splitter)\n",
" output_chunks = self._adjust_chunks_for_fixed_size(preliminary_chunks)\n",
" return iter(output_chunks)"
]
},
Expand All @@ -327,49 +325,50 @@
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"conv_res = DocumentConverter().convert(source=DOC_SOURCE)\n",
"doc = conv_res.document"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using CPU. Note: This module is much faster with a GPU.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"J. William Murdock\n",
"[email protected] IBM T.J. Watson Research Center P.O. Box 704 Yorktown Heights, NY 10598\n",
"39\n",
"J. William Murdock\n",
"Abstract. The Jeopardy! television quiz show asks natural-language questions and requires natural-language answers. One useful source of information for answering Jeopardy! questions is text from written sources such as encyclopedias or news articles. A text passage may partially or fully indicate that some candidate answer is the correct answer to the question. Recognizing\n",
"70\n",
"J. William Murdock\n",
"whether it does requires determining the extent to which what the passage is saying about the candidate answer is similar to what the question is saying about the desired answer. This paper describes how structure mapping [1] (an algorithm originally developed for analogical reasoning) is applied to determine similarity between content in questions and passages. That algorithm\n",
"70\n",
"J. William Murdock\n",
"is one of many used in the Watson question answering system [2]. It contributes a significant amount to Watson's effectiveness.\n",
"32\n",
"1 Introduction\n",
"Watson is a question answering system built on a set of technologies known as DeepQA [2]. Watson has been customized and configured to compete at Jeopardy!, an American television quiz show. Watson takes in a question and produces a ranked list of answers with confidence scores attached to each of these answers.\n",
"62\n"
"chunk.text='J. William Murdock\\[email protected] IBM T.J. Watson Research Center P.O. Box 704 Yorktown Heights, NY 10598'\n",
"num tokens: 39\n",
"\n",
"chunk.text='J. William Murdock\\nAbstract. The Jeopardy! television quiz show asks natural-language questions and requires natural-language answers. One useful source of information for answering Jeopardy! questions is text from written sources such as encyclopedias or news articles. A text passage may partially or fully indicate that some candidate answer is the correct'\n",
"num tokens: 64\n",
"\n",
"chunk.text='J. William Murdock\\nanswer to the question. Recognizing whether it does requires determining the extent to which what the passage is saying about the candidate answer is similar to what the question is saying about the desired answer. This paper describes how structure mapping [1] (an algorithm originally developed for analogical reasoning) is applied'\n",
"num tokens: 64\n",
"\n",
"chunk.text=\"J. William Murdock\\nto determine similarity between content in questions and passages. That algorithm is one of many used in the Watson question answering system [2]. It contributes a significant amount to Watson's effectiveness.\"\n",
"num tokens: 44\n",
"\n",
"chunk.text='1 Introduction\\nWatson is a question answering system built on a set of technologies known as DeepQA [2]. Watson has been customized and configured to compete at Jeopardy!, an American television quiz show. Watson takes in a question and produces a ranked list of answers with confidence scores attached to each of these answers.'\n",
"num tokens: 62\n",
"\n"
]
}
],
"source": [
"conv_res = DocumentConverter().convert(source=DOC_SOURCE)\n",
"doc = conv_res.document\n",
"\n",
"chunker = HybridChunker(\n",
" tokenizer=TOKENIZER,\n",
" tokenizer=tokenizer,\n",
" max_tokens=MAX_TOKENS,\n",
")\n",
"chunks = list(chunker.chunk(dl_doc=doc))\n",
"\n",
"for chunk in chunks[:5]:\n",
" print(chunk.text)\n",
" print(chunker._count_tokens(chunk.text))"
" print(f\"{chunk.text=}\")\n",
" print(f\"num tokens: {len(tokenizer.tokenize(chunk.text, max_length=None))}\")\n",
" print()"
]
},
{
Expand All @@ -381,7 +380,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 7,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -423,64 +422,64 @@
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>[0.034203544, 0.10181023, 0.003722408, 0.00506...</td>\n",
" <td>5 Evaluation and Conclusions\\nconsider to be a...</td>\n",
" <td>[5 Evaluation and Conclusions]</td>\n",
" <td>None</td>\n",
" <td>1.469304</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>[0.04400234, -0.034766007, -0.00025527124, 0.0...</td>\n",
" <td>References\\n4. McCord, M. C. (1990). Slot Gram...</td>\n",
" <td>[References]</td>\n",
" <td>None</td>\n",
" <td>1.525625</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>[0.112926826, -0.010892201, 0.007714559, -0.06...</td>\n",
" <td>3 Syntactic-Semantic Graphs\\nplay , about , Ut...</td>\n",
" <th>2</th>\n",
" <td>[0.10043394, 0.00652478, 0.011601829, -0.06390...</td>\n",
" <td>3 Syntactic-Semantic Graphs\\npassage using sem...</td>\n",
" <td>[3 Syntactic-Semantic Graphs]</td>\n",
" <td>None</td>\n",
" <td>1.540550</td>\n",
" <td>1.569923</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <th>3</th>\n",
" <td>[0.025994677, 0.08402823, 0.03268827, -0.03727...</td>\n",
" <td>4 Algorithm\\nIn using this algorithm, we have ...</td>\n",
" <td>[4 Algorithm]</td>\n",
" <td>None</td>\n",
" <td>1.576838</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>[0.050165094, 0.08015387, 0.035965856, 0.00846...</td>\n",
" <td>5 Evaluation and Conclusions\\nword order) are ...</td>\n",
" <td>[5 Evaluation and Conclusions]</td>\n",
" <td>None</td>\n",
" <td>1.580265</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" vector \\\n",
"0 [-0.025746439, 0.03888134, 0.0033668755, -0.03... \n",
"1 [0.034203544, 0.10181023, 0.003722408, 0.00506... \n",
"2 [0.04400234, -0.034766007, -0.00025527124, 0.0... \n",
"3 [0.112926826, -0.010892201, 0.007714559, -0.06... \n",
"4 [0.025994677, 0.08402823, 0.03268827, -0.03727... \n",
"1 [0.04400234, -0.034766007, -0.00025527124, 0.0... \n",
"2 [0.10043394, 0.00652478, 0.011601829, -0.06390... \n",
"3 [0.025994677, 0.08402823, 0.03268827, -0.03727... \n",
"4 [0.050165094, 0.08015387, 0.035965856, 0.00846... \n",
"\n",
" text \\\n",
"0 References\\n3. Forbus, K. and Oblinger, D. (19... \n",
"1 5 Evaluation and Conclusions\\nconsider to be a... \n",
"2 References\\n4. McCord, M. C. (1990). Slot Gram... \n",
"3 3 Syntactic-Semantic Graphs\\nplay , about , Ut... \n",
"4 4 Algorithm\\nIn using this algorithm, we have ... \n",
"1 References\\n4. McCord, M. C. (1990). Slot Gram... \n",
"2 3 Syntactic-Semantic Graphs\\npassage using sem... \n",
"3 4 Algorithm\\nIn using this algorithm, we have ... \n",
"4 5 Evaluation and Conclusions\\nword order) are ... \n",
"\n",
" headings captions _distance \n",
"0 [References] None 0.332435 \n",
"1 [5 Evaluation and Conclusions] None 1.469304 \n",
"2 [References] None 1.525625 \n",
"3 [3 Syntactic-Semantic Graphs] None 1.540550 \n",
"4 [4 Algorithm] None 1.576838 "
"1 [References] None 1.525625 \n",
"2 [3 Syntactic-Semantic Graphs] None 1.569923 \n",
"3 [4 Algorithm] None 1.576838 \n",
"4 [5 Evaluation and Conclusions] None 1.580265 "
]
},
"execution_count": 6,
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -503,10 +502,10 @@
"\n",
"\n",
"db_uri = str(Path(mkdtemp()) / \"docling.db\") # or set as needed\n",
"index = make_lancedb_index(db_uri, doc.name, chunks, EMBED_MODEL)\n",
"index = make_lancedb_index(db_uri, doc.name, chunks, embed_model)\n",
"\n",
"sample_query = \"Making SME greedy and pragmatic\"\n",
"sample_embedding = EMBED_MODEL.encode(sample_query)\n",
"sample_embedding = embed_model.encode(sample_query)\n",
"results = index.search(sample_embedding).limit(5)\n",
"\n",
"results.to_pandas()"
Expand Down

0 comments on commit 6f30048

Please sign in to comment.