1212from langchain .llms .base import LLM
1313from langchain .chains import LLMChain
1414from langchain .callbacks import get_openai_callback
15+ from langchain .cache import InMemoryCache
16+ import langchain
17+
18+ langchain .llm_cache = InMemoryCache ()
1519
1620
1721@dataclass
@@ -200,14 +204,20 @@ def get_evidence(
200204 answer : Answer ,
201205 k : int = 3 ,
202206 max_sources : int = 5 ,
207+ marginal_relevance : bool = True ,
203208 ) -> str :
204209 if self ._faiss_index is None :
205210 self ._build_faiss_index ()
206211
207212 # want to work through indices but less k
208- docs = self ._faiss_index .max_marginal_relevance_search (
209- answer .question , k = k , fetch_k = 5 * k
210- )
213+ if marginal_relevance :
214+ docs = self ._faiss_index .max_marginal_relevance_search (
215+ answer .question , k = k , fetch_k = 5 * k
216+ )
217+ else :
218+ docs = self ._faiss_index .similarity_search (
219+ answer .question , k = k , fetch_k = 5 * k
220+ )
211221 for doc in docs :
212222 c = (
213223 doc .metadata ["key" ],
@@ -251,9 +261,14 @@ def query_gen(
251261 k : int = 10 ,
252262 max_sources : int = 5 ,
253263 length_prompt : str = "about 100 words" ,
264+ marginal_relevance : bool = True ,
254265 ):
255266 yield from self ._query (
256- query , k = k , max_sources = max_sources , length_prompt = length_prompt
267+ query ,
268+ k = k ,
269+ max_sources = max_sources ,
270+ length_prompt = length_prompt ,
271+ marginal_relevance = marginal_relevance ,
257272 )
258273
259274 def query (
@@ -262,20 +277,37 @@ def query(
262277 k : int = 10 ,
263278 max_sources : int = 5 ,
264279 length_prompt : str = "about 100 words" ,
280+ marginal_relevance : bool = True ,
265281 ):
266282 for answer in self ._query (
267- query , k = k , max_sources = max_sources , length_prompt = length_prompt
283+ query ,
284+ k = k ,
285+ max_sources = max_sources ,
286+ length_prompt = length_prompt ,
287+ marginal_relevance = marginal_relevance ,
268288 ):
269289 pass
270290 return answer
271291
272- def _query (self , query : str , k : int , max_sources : int , length_prompt : str ):
292+ def _query (
293+ self ,
294+ query : str ,
295+ k : int ,
296+ max_sources : int ,
297+ length_prompt : str ,
298+ marginal_relevance : bool ,
299+ ):
273300 if k < max_sources :
274301 raise ValueError ("k should be greater than max_sources" )
275302 tokens = 0
276303 answer = Answer (query )
277304 with get_openai_callback () as cb :
278- for answer in self .get_evidence (answer , k = k , max_sources = max_sources ):
305+ for answer in self .get_evidence (
306+ answer ,
307+ k = k ,
308+ max_sources = max_sources ,
309+ marginal_relevance = marginal_relevance ,
310+ ):
279311 yield answer
280312 tokens += cb .total_tokens
281313 context_str , citations = answer .context , answer .contexts
@@ -290,11 +322,14 @@ def _query(self, query: str, k: int, max_sources: int, length_prompt: str):
290322 answer_text = self .qa_chain .run (
291323 question = query , context_str = context_str , length = length_prompt
292324 )[1 :]
293- if maybe_is_truncated (answer_text ):
294- answer_text = self .edit_chain .run (
295- question = query , answer = answer_text
296- )
325+ # if maybe_is_truncated(answer_text):
326+ # answer_text = self.edit_chain.run(
327+ # question=query, answer=answer_text
328+ # )
297329 tokens += cb .total_tokens
330+ # it still happens lol
331+ if "(Foo2012)" in answer_text :
332+ answer_text = answer_text .replace ("(Foo2012)" , "" )
298333 for key , citation , summary , text in citations :
299334 # do check for whole key (so we don't catch Callahan2019a with Callahan2019)
300335 skey = key .split (" " )[0 ]
0 commit comments