1- from typing import List , Optional , Tuple , Union , Callable
2- from functools import reduce
1+ import asyncio
32import os
3+ import re
44import sys
5- import asyncio
5+ from datetime import datetime
6+ from functools import reduce
67from pathlib import Path
7- import re
8- from .paths import CACHE_PATH
9- from .utils import maybe_is_text , md5sum
10- from .qaprompts import (
11- summary_prompt ,
12- qa_prompt ,
13- search_prompt ,
14- citation_prompt ,
15- select_paper_prompt ,
16- make_chain ,
17- )
18- from .types import Answer , Context
19- from .readers import read_doc
20- from langchain .vectorstores import FAISS
8+ from typing import Callable , List , Optional , Tuple , Union
9+
10+ import langchain
11+ from langchain .cache import SQLiteCache
12+ from langchain .callbacks import OpenAICallbackHandler , get_openai_callback
13+ from langchain .callbacks .base import AsyncCallbackHandler
14+ from langchain .callbacks .manager import AsyncCallbackManager
15+ from langchain .chat_models import ChatOpenAI
2116from langchain .docstore .document import Document
22- from langchain .embeddings .openai import OpenAIEmbeddings
2317from langchain .embeddings .base import Embeddings
24- from langchain .chat_models import ChatOpenAI
18+ from langchain .embeddings . openai import OpenAIEmbeddings
2519from langchain .llms .base import LLM
26- from langchain .callbacks import get_openai_callback , OpenAICallbackHandler
27- from langchain .callbacks .base import AsyncCallbackHandler , AsyncCallbackManager
28- from langchain .cache import SQLiteCache
29- import langchain
30- from datetime import datetime
20+ from langchain .vectorstores import FAISS
21+
22+ from .paths import CACHE_PATH
23+ from .qaprompts import (citation_prompt , make_chain , qa_prompt , search_prompt ,
24+ select_paper_prompt , summary_prompt )
25+ from .readers import read_doc
26+ from .types import Answer , Context
27+ from .utils import maybe_is_text , md5sum
3128
3229os .makedirs (os .path .dirname (CACHE_PATH ), exist_ok = True )
3330langchain .llm_cache = SQLiteCache (CACHE_PATH )
@@ -44,7 +41,6 @@ def __init__(
4441 name : str = "default" ,
4542 index_path : Optional [Path ] = None ,
4643 embeddings : Optional [Embeddings ] = None ,
47- get_callbacks : Callable [[str ], AsyncCallbackHandler ] = lambda x : []
4844 ) -> None :
4945 """Initialize the collection of documents.
5046
@@ -55,7 +51,6 @@ def __init__(
5551 name: The name of the collection.
5652 index_path: The path to the index file IF pickled. If None, defaults to using name in $HOME/.paperqa/name
5753 embeddings: The embeddings to use for indexing documents. Default - OpenAI embeddings
58- get_callbacks: A function that allows callbacks to built per stage of the pipeline.
5954 """
6055 self .docs = dict ()
6156 self .chunk_size_limit = chunk_size_limit
@@ -70,7 +65,6 @@ def __init__(
7065 if embeddings is None :
7166 embeddings = OpenAIEmbeddings ()
7267 self .embeddings = embeddings
73- self .get_callbacks = get_callbacks
7468
7569 def update_llm (
7670 self ,
@@ -96,7 +90,6 @@ def add(
9690 key : Optional [str ] = None ,
9791 disable_check : bool = False ,
9892 chunk_chars : Optional [int ] = 3000 ,
99- overwrite : bool = False ,
10093 ) -> None :
10194 """Add a document to the collection."""
10295
@@ -110,8 +103,9 @@ def add(
110103 cite_chain = make_chain (prompt = citation_prompt , llm = self .summary_llm )
111104 # peak first chunk
112105 texts , _ = read_doc (path , "" , "" , chunk_chars = chunk_chars )
113- with get_openai_callback ():
114- citation = cite_chain .run (texts [0 ])
106+ if len (texts ) == 0 :
107+ raise ValueError (f"Could not read document { path } . Is it empty?" )
108+ citation = cite_chain .run (texts [0 ])
115109 if len (citation ) < 3 or "Unknown" in citation or "insufficient" in citation :
116110 citation = f"Unknown, { os .path .basename (path )} , { datetime .now ().year } "
117111
@@ -207,12 +201,10 @@ def __getstate__(self):
207201 state ["_faiss_index" ].save_local (self .index_path )
208202 del state ["_faiss_index" ]
209203 del state ["_doc_index" ]
210- del state ["get_callbacks" ]
211204 return state
212205
213206 def __setstate__ (self , state ):
214207 self .__dict__ .update (state )
215- self .get_callbacks = lambda x : []
216208 try :
217209 self ._faiss_index = FAISS .load_local (self .index_path , self .embeddings )
218210 except :
@@ -240,8 +232,9 @@ def get_evidence(
240232 k : int = 3 ,
241233 max_sources : int = 5 ,
242234 marginal_relevance : bool = True ,
243- key_filter : Optional [List [str ]] = None
244- ) -> Answer :
235+ key_filter : Optional [List [str ]] = None ,
236+ get_callbacks : Callable [[str ], AsyncCallbackHandler ] = lambda x : [],
237+ ) -> Answer :
245238 # special case for jupyter notebooks
246239 if "get_ipython" in globals () or "google.colab" in sys .modules :
247240 import nest_asyncio
@@ -258,7 +251,8 @@ def get_evidence(
258251 k = k ,
259252 max_sources = max_sources ,
260253 marginal_relevance = marginal_relevance ,
261- key_filter = key_filter
254+ key_filter = key_filter ,
255+ get_callbacks = get_callbacks ,
262256 )
263257 )
264258
@@ -269,6 +263,7 @@ async def aget_evidence(
269263 max_sources : int = 5 ,
270264 marginal_relevance : bool = True ,
271265 key_filter : Optional [List [str ]] = None ,
266+ get_callbacks : Callable [[str ], AsyncCallbackHandler ] = lambda x : [],
272267 ) -> Answer :
273268 if len (self .docs ) == 0 :
274269 return answer
@@ -293,29 +288,32 @@ async def process(doc):
293288 # check if it is already in answer (possible in agent setting)
294289 if doc .metadata ["key" ] in [c .key for c in answer .contexts ]:
295290 return None , None
296- cb = OpenAICallbackHandler ()
297- manager = AsyncCallbackManager ([cb ] + self .get_callbacks ('evidence:' + doc .metadata ['key' ]))
298- summary_chain = make_chain (summary_prompt , self .summary_llm , manager )
291+ callbacks = [OpenAICallbackHandler ()] + get_callbacks (
292+ "evidence:" + doc .metadata ["key" ]
293+ )
294+ summary_chain = make_chain (summary_prompt , self .summary_llm )
299295 c = Context (
300296 key = doc .metadata ["key" ],
301297 citation = doc .metadata ["citation" ],
302298 context = await summary_chain .arun (
303299 question = answer .question ,
304300 context_str = doc .page_content ,
305301 citation = doc .metadata ["citation" ],
302+ callbacks = callbacks ,
306303 ),
307304 text = doc .page_content ,
308305 )
309306 if "Not applicable" not in c .context :
310307 return c , cb
311308 return None , None
312309
313- results = await asyncio .gather (* [process (doc ) for doc in docs ])
310+ with get_openai_callback () as cb :
311+ results = await asyncio .gather (* [process (doc ) for doc in docs ])
314312 # filter out failures
315313 results = [r for r in results if r [0 ] is not None ]
316314 answer .tokens += sum ([cb .total_tokens for _ , cb in results ])
317315 answer .cost += sum ([cb .total_cost for _ , cb in results ])
318- contexts = [c for c ,_ in results if c is not None ]
316+ contexts = [c for c , _ in results if c is not None ]
319317 if len (contexts ) == 0 :
320318 return answer
321319 contexts = sorted (contexts , key = lambda x : len (x .context ), reverse = True )
@@ -365,6 +363,7 @@ def query(
365363 marginal_relevance : bool = True ,
366364 answer : Optional [Answer ] = None ,
367365 key_filter : Optional [bool ] = None ,
366+ get_callbacks : Callable [[str ], AsyncCallbackHandler ] = lambda x : [],
368367 ) -> Answer :
369368 # special case for jupyter notebooks
370369 if "get_ipython" in globals () or "google.colab" in sys .modules :
@@ -385,6 +384,7 @@ def query(
385384 marginal_relevance = marginal_relevance ,
386385 answer = answer ,
387386 key_filter = key_filter ,
387+ get_callbacks = get_callbacks ,
388388 )
389389 )
390390
@@ -397,6 +397,7 @@ async def aquery(
397397 marginal_relevance : bool = True ,
398398 answer : Optional [Answer ] = None ,
399399 key_filter : Optional [bool ] = None ,
400+ get_callbacks : Callable [[str ], AsyncCallbackHandler ] = lambda x : [],
400401 ) -> Answer :
401402 if k < max_sources :
402403 raise ValueError ("k should be greater than max_sources" )
@@ -414,6 +415,7 @@ async def aquery(
414415 max_sources = max_sources ,
415416 marginal_relevance = marginal_relevance ,
416417 key_filter = keys if key_filter else None ,
418+ get_callbacks = get_callbacks ,
417419 )
418420 context_str , contexts = answer .context , answer .contexts
419421 bib = dict ()
@@ -424,11 +426,14 @@ async def aquery(
424426 )
425427 else :
426428 cb = OpenAICallbackHandler ()
427- manager = AsyncCallbackManager ([ cb ] + self . get_callbacks (' answer' ) )
428- qa_chain = make_chain (qa_prompt , self .llm , manager )
429+ callbacks = [ OpenAICallbackHandler () ] + get_callbacks (" answer" )
430+ qa_chain = make_chain (qa_prompt , self .llm )
429431 answer_text = await qa_chain .arun (
430- question = query , context_str = context_str , length = length_prompt
431- )
432+ question = query ,
433+ context_str = context_str ,
434+ length = length_prompt ,
435+ callbacks = callbacks ,
436+ )
432437 answer .tokens += cb .total_tokens
433438 answer .cost += cb .total_cost
434439 # it still happens lol
0 commit comments