Skip to content

Commit 861f805

Browse files
authored
Cache document parsing (#46)
* Refactor paths into `paths.py` * Create cache for all parsed documents * Allow non-ascii in cache keys * Use hash of file in cache key, not pathname * Include __version__ in cache key * Refactor serialization for cache * Switch to md5sum for file caching
1 parent 8c1ff7b commit 861f805

File tree

4 files changed

+119
-3
lines changed

4 files changed

+119
-3
lines changed

paperqa/contrib/zotero.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from pyzotero import zotero
99

10-
from ..docs import CACHE_PATH
10+
from ..paths import CACHE_PATH
1111

1212
StrPath = Union[str, Path]
1313

paperqa/docs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import asyncio
66
from pathlib import Path
77
import re
8+
from .paths import CACHE_PATH
89
from .utils import maybe_is_text, maybe_is_truncated
910
from .qaprompts import (
1011
summary_prompt,
@@ -27,7 +28,6 @@
2728
import langchain
2829
from datetime import datetime
2930

30-
CACHE_PATH = Path.home() / ".paperqa" / "llm_cache.db"
3131
os.makedirs(os.path.dirname(CACHE_PATH), exist_ok=True)
3232
langchain.llm_cache = SQLiteCache(CACHE_PATH)
3333

paperqa/paths.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from pathlib import Path
2+
3+
CACHE_PATH = Path.home() / ".paperqa" / "llm_cache.db"
4+
OCR_CACHE_PATH = CACHE_PATH.parent / "ocr_cache.db"

paperqa/readers.py

Lines changed: 113 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,27 @@
1-
from .utils import maybe_is_code
1+
import os
2+
from .paths import OCR_CACHE_PATH
3+
from .version import __version__
24
from html2text import html2text
35
from pathlib import Path
6+
import json
7+
import logging
8+
from hashlib import md5
49

510
from langchain.text_splitter import TokenTextSplitter
11+
from langchain.cache import SQLiteCache
12+
from langchain.schema import Generation
13+
14+
OCR_CACHE = None
15+
16+
17+
def _get_ocr_cache() -> SQLiteCache:
18+
"""Used to lazily create the cache directory and cache object."""
19+
global OCR_CACHE
20+
if OCR_CACHE is None:
21+
os.makedirs(os.path.dirname(OCR_CACHE_PATH), exist_ok=True)
22+
OCR_CACHE = SQLiteCache(OCR_CACHE_PATH)
23+
return OCR_CACHE
24+
625

726
TextSplitter = TokenTextSplitter
827

@@ -99,7 +118,100 @@ def parse_code_txt(path, citation, key, chunk_chars=2000, overlap=50):
99118
return splits, metadatas
100119

101120

121+
def _serialize_s(obj):
122+
"""Convert a json-like object to a string"""
123+
# We sort the keys to ensure
124+
# that the same object always gets serialized to the same string.
125+
return json.dumps(obj, sort_keys=True, ensure_ascii=False)
126+
127+
128+
def _deserialize_s(obj):
129+
"""The inverse of _serialize_s"""
130+
return json.loads(obj)
131+
132+
133+
def _serialize(obj):
134+
# llmchain wants a list of "Generation" objects, so we simply
135+
# stick this regular text into it.
136+
return [Generation(text=_serialize_s(obj))]
137+
138+
139+
def _deserialize(obj):
140+
# (The inverse of _serialize)
141+
try:
142+
return _deserialize_s(obj[0].text)
143+
except json.JSONDecodeError:
144+
return None
145+
146+
147+
def _filehash(path):
148+
"""Fast hash of a file - about 1ms per MB."""
149+
bufsize = 65536
150+
h = md5()
151+
with open(path, "rb") as f:
152+
while True:
153+
data = f.read(bufsize)
154+
if not data:
155+
break
156+
h.update(data)
157+
return h.hexdigest()
158+
159+
102160
def read_doc(path, citation, key, chunk_chars=3000, overlap=100, disable_check=False):
161+
logger = logging.getLogger(__name__)
162+
logger.debug(f"Creating cache key for {path}")
163+
cache_key = _serialize_s(
164+
dict(
165+
hash=str(_filehash(path)),
166+
citation=citation,
167+
key=key,
168+
chunk_chars=chunk_chars,
169+
overlap=overlap,
170+
disable_check=disable_check,
171+
version=__version__,
172+
)
173+
)
174+
logger.debug(f"Looking up cache key for {path}")
175+
cache_lookup = _get_ocr_cache().lookup(prompt=cache_key, llm_string="")
176+
177+
out = None
178+
successful_lookup = False
179+
cache_exists = cache_lookup is not None
180+
if cache_exists:
181+
logger.debug(f"Found cache key for {path}")
182+
out = _deserialize(cache_lookup)
183+
184+
successful_lookup = out is not None
185+
if successful_lookup:
186+
logger.debug(f"Succesfully loaded cache key for {path}")
187+
elif cache_exists:
188+
logger.debug(f"Failed to decode existing cache for {path}")
189+
190+
if out is None:
191+
logger.debug(f"Did not load cache, so parsing {path}")
192+
193+
# The actual call:
194+
out = _read_doc(
195+
path=path,
196+
citation=citation,
197+
key=key,
198+
chunk_chars=chunk_chars,
199+
overlap=overlap,
200+
disable_check=disable_check,
201+
)
202+
203+
logger.debug(f"Done parsing document {path}")
204+
if not successful_lookup:
205+
logger.debug(f"Updating cache for {path}")
206+
_get_ocr_cache().update(
207+
prompt=cache_key,
208+
llm_string="",
209+
return_val=_serialize(out),
210+
)
211+
return out
212+
213+
214+
def _read_doc(path, citation, key, chunk_chars=3000, overlap=100, disable_check=False):
103215
"""Parse a document into chunks."""
104216
if isinstance(path, Path):
105217
path = str(path)

0 commit comments

Comments
 (0)