Skip to content

Commit bda7bef

Browse files
committed
Added max concurrency
1 parent 1dc6ad0 commit bda7bef

File tree

3 files changed

+22
-3
lines changed

3 files changed

+22
-3
lines changed

paperqa/docs.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
)
3131
from .readers import read_doc
3232
from .types import Answer, Context
33-
from .utils import maybe_is_text, md5sum
33+
from .utils import maybe_is_text, md5sum, gather_with_concurrency
3434

3535
os.makedirs(os.path.dirname(CACHE_PATH), exist_ok=True)
3636
langchain.llm_cache = SQLiteCache(CACHE_PATH)
@@ -47,6 +47,7 @@ def __init__(
4747
name: str = "default",
4848
index_path: Optional[Path] = None,
4949
embeddings: Optional[Embeddings] = None,
50+
max_concurrent: int = 5,
5051
) -> None:
5152
"""Initialize the collection of documents.
5253
@@ -57,6 +58,7 @@ def __init__(
5758
name: The name of the collection.
5859
index_path: The path to the index file IF pickled. If None, defaults to using name in $HOME/.paperqa/name
5960
embeddings: The embeddings to use for indexing documents. Default - OpenAI embeddings
61+
max_concurrent: Number of concurrent LLM model calls to make
6062
"""
6163
self.docs = []
6264
self.chunk_size_limit = chunk_size_limit
@@ -71,6 +73,7 @@ def __init__(
7173
if embeddings is None:
7274
embeddings = OpenAIEmbeddings()
7375
self.embeddings = embeddings
76+
self.max_concurrent = max_concurrent
7477
self._deleted_keys = set()
7578

7679
def update_llm(
@@ -295,6 +298,8 @@ def __setstate__(self, state):
295298
# must be a better way to have backwards compatibility
296299
if not hasattr(self, "_deleted_keys"):
297300
self._deleted_keys = set()
301+
if not hasattr(self, "max_concurrent"):
302+
self.max_concurrent = 5
298303
self.update_llm(None, None)
299304

300305
def _build_faiss_index(self):
@@ -396,7 +401,9 @@ async def process(doc):
396401
return c, callbacks[0]
397402
return None, None
398403

399-
results = await asyncio.gather(*[process(doc) for doc in docs])
404+
results = await gather_with_concurrency(
405+
self.max_concurrent, *[process(doc) for doc in docs]
406+
)
400407
# filter out failures
401408
results = [r for r in results if r[0] is not None]
402409
answer.tokens += sum([cb.total_tokens for _, cb in results])

paperqa/utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import math
22
import string
3+
import asyncio
34

45
import pypdf
56

@@ -68,3 +69,14 @@ def md5sum(file_path: StrPath) -> str:
6869

6970
with open(file_path, "rb") as f:
7071
return hashlib.md5(f.read()).hexdigest()
72+
73+
74+
async def gather_with_concurrency(n, *coros):
75+
# https://stackoverflow.com/a/61478547/2392535
76+
semaphore = asyncio.Semaphore(n)
77+
78+
async def sem_coro(coro):
79+
async with semaphore:
80+
return await coro
81+
82+
return await asyncio.gather(*(sem_coro(c) for c in coros))

paperqa/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "1.11.0"
1+
__version__ = "1.12.0"

0 commit comments

Comments
 (0)