2727 search_prompt ,
2828 select_paper_prompt ,
2929 summary_prompt ,
30+ get_score ,
3031)
3132from .readers import read_doc
3233from .types import Answer , Context
33- from .utils import maybe_is_text , md5sum , gather_with_concurrency
34+ from .utils import maybe_is_text , md5sum , gather_with_concurrency , guess_is_4xx
3435
3536os .makedirs (os .path .dirname (CACHE_PATH ), exist_ok = True )
3637langchain .llm_cache = SQLiteCache (CACHE_PATH )
@@ -373,31 +374,46 @@ async def aget_evidence(
373374 docs = self ._faiss_index .similarity_search (
374375 answer .question , k = _k , fetch_k = 5 * _k
375376 )
377+ # ok now filter
378+ if key_filter is not None :
379+ docs = [doc for doc in docs if doc .metadata ["dockey" ] in key_filter ][:k ]
376380
377381 async def process (doc ):
378382 if doc .metadata ["dockey" ] in self ._deleted_keys :
379383 return None , None
380- if key_filter is not None and doc .metadata ["dockey" ] not in key_filter :
381- return None , None
382384 # check if it is already in answer (possible in agent setting)
383385 if doc .metadata ["key" ] in [c .key for c in answer .contexts ]:
384386 return None , None
385387 callbacks = [OpenAICallbackHandler ()] + get_callbacks (
386388 "evidence:" + doc .metadata ["key" ]
387389 )
388390 summary_chain = make_chain (summary_prompt , self .summary_llm )
389- c = Context (
390- key = doc .metadata ["key" ],
391- citation = doc .metadata ["citation" ],
392- context = await summary_chain .arun (
391+ # This is dangerous because it
392+ # could mask errors that are important
393+ # I also cannot know what the exception
394+ # type is because any model could be used
395+ # my best idea is see if there is a 4XX
396+ # http code in the exception
397+ try :
398+ context = await summary_chain .arun (
393399 question = answer .question ,
394400 context_str = doc .page_content ,
395401 citation = doc .metadata ["citation" ],
396402 callbacks = callbacks ,
397- ),
403+ )
404+ except Exception as e :
405+ if guess_is_4xx (e ):
406+ return None , None
407+ raise e
408+ c = Context (
409+ key = doc .metadata ["key" ],
410+ citation = doc .metadata ["citation" ],
411+ context = context ,
398412 text = doc .page_content ,
413+ score = get_score (context ),
399414 )
400415 if "not applicable" not in c .context .casefold ():
416+ print (c .score )
401417 return c , callbacks [0 ]
402418 return None , None
403419
@@ -411,7 +427,7 @@ async def process(doc):
411427 contexts = [c for c , _ in results if c is not None ]
412428 if len (contexts ) == 0 :
413429 return answer
414- contexts = sorted (contexts , key = lambda x : len ( x . context ) , reverse = True )
430+ contexts = sorted (contexts , key = lambda x : x . score , reverse = True )
415431 contexts = contexts [:max_sources ]
416432 # add to answer (if not already there)
417433 keys = [c .key for c in answer .contexts ]
@@ -499,11 +515,12 @@ async def aquery(
499515 if answer is None :
500516 answer = Answer (query )
501517 if len (answer .contexts ) == 0 :
502- if key_filter or (key_filter is None and len (self .docs ) > 5 ):
518+ if key_filter or (key_filter is None and len (self .docs ) > max_sources ):
503519 callbacks = [OpenAICallbackHandler ()] + get_callbacks ("filter" )
504520 keys = await self .adoc_match (answer .question , callbacks = callbacks )
505521 answer .tokens += callbacks [0 ].total_tokens
506522 answer .cost += callbacks [0 ].total_cost
523+ key_filter = True if len (keys ) > 0 else False
507524 answer = await self .aget_evidence (
508525 answer ,
509526 k = k ,
@@ -532,8 +549,8 @@ async def aquery(
532549 answer .tokens += cb .total_tokens
533550 answer .cost += cb .total_cost
534551 # it still happens lol
535- if "(Foo2012 )" in answer_text :
536- answer_text = answer_text .replace ("(Foo2012 )" , "" )
552+ if "(Example2012 )" in answer_text :
553+ answer_text = answer_text .replace ("(Example2012 )" , "" )
537554 for c in contexts :
538555 key = c .key
539556 text = c .context
0 commit comments