1- from typing import List , Optional , Tuple , Dict , Callable , Any
1+ from typing import List , Optional , Tuple , Dict , Callable , Any , Union
22from functools import reduce
33import os
44import os
1010 qa_prompt ,
1111 search_prompt ,
1212 citation_prompt ,
13- chat_pref ,
13+ make_chain ,
1414)
1515from dataclasses import dataclass
1616from .readers import read_doc
1717from langchain .vectorstores import FAISS
1818from langchain .embeddings .openai import OpenAIEmbeddings
19- from langchain .llms import OpenAI , OpenAIChat
19+ from langchain .chat_models import ChatOpenAI
2020from langchain .llms .base import LLM
2121from langchain .chains import LLMChain
2222from langchain .callbacks import get_openai_callback
@@ -64,7 +64,7 @@ def __init__(
6464 summary_llm : Optional [LLM ] = None ,
6565 name : str = "default" ,
6666 index_path : Optional [Path ] = None ,
67- model_name : str = ' gpt-3.5-turbo'
67+ model_name : str = " gpt-3.5-turbo" ,
6868 ) -> None :
6969 """Initialize the collection of documents.
7070
@@ -82,26 +82,32 @@ def __init__(
8282 self .chunk_size_limit = chunk_size_limit
8383 self .keys = set ()
8484 self ._faiss_index = None
85- if llm is None :
86- llm = OpenAIChat (temperature = 0.1 , max_tokens = 512 , prefix_messages = chat_pref , model_name = model_name )
87- if summary_llm is None :
88- summary_llm = llm
8985 self .update_llm (llm , summary_llm )
9086 if index_path is None :
9187 index_path = Path .home () / ".paperqa" / name
9288 self .index_path = index_path
9389 self .name = name
9490
95- def update_llm (self , llm : LLM , summary_llm : Optional [LLM ] = None ) -> None :
91+ def update_llm (
92+ self ,
93+ llm : Optional [Union [LLM , str ]] = None ,
94+ summary_llm : Optional [Union [LLM , str ]] = None ,
95+ ) -> None :
9696 """Update the LLM for answering questions."""
97+ if llm is None :
98+ llm = "gpt-3.5-turbo"
99+ if type (llm ) is str :
100+ llm = ChatOpenAI (temperature = 0.1 , model = llm )
101+ if type (summary_llm ) is str :
102+ summary_llm = ChatOpenAI (temperature = 0.1 , model = summary_llm )
97103 self .llm = llm
98104 if summary_llm is None :
99105 summary_llm = llm
100106 self .summary_llm = summary_llm
101- self .summary_chain = LLMChain (prompt = summary_prompt , llm = summary_llm )
102- self .qa_chain = LLMChain (prompt = qa_prompt , llm = llm )
103- self .search_chain = LLMChain (prompt = search_prompt , llm = llm )
104- self .cite_chain = LLMChain (prompt = citation_prompt , llm = llm )
107+ self .summary_chain = make_chain (prompt = summary_prompt , llm = summary_llm )
108+ self .qa_chain = make_chain (prompt = qa_prompt , llm = llm )
109+ self .search_chain = make_chain (prompt = search_prompt , llm = summary_llm )
110+ self .cite_chain = make_chain (prompt = citation_prompt , llm = summary_llm )
105111
106112 def add (
107113 self ,
@@ -112,12 +118,12 @@ def add(
112118 chunk_chars : Optional [int ] = 3000 ,
113119 ) -> None :
114120 """Add a document to the collection."""
115-
116- # first check to see if we already have this document
121+
122+ # first check to see if we already have this document
117123 # this way we don't make api call to create citation on file we already have
118124 if path in self .docs :
119125 raise ValueError (f"Document { path } already in collection." )
120-
126+
121127 if citation is None :
122128 # peak first chunk
123129 texts , _ = read_doc (path , "" , "" , chunk_chars = chunk_chars )
@@ -126,7 +132,6 @@ def add(
126132 if len (citation ) < 3 or "Unknown" in citation or "insufficient" in citation :
127133 citation = f"Unknown, { os .path .basename (path )} , { datetime .now ().year } "
128134
129-
130135 if key is None :
131136 # get first name and year from citation
132137 try :
@@ -212,9 +217,7 @@ def __setstate__(self, state):
212217 except :
213218 # they use some special exception type, but I don't want to import it
214219 self ._faiss_index = None
215- self .update_llm (
216- OpenAIChat (temperature = 0.1 , max_tokens = 512 , prefix_messages = chat_pref )
217- )
220+ self .update_llm ("gpt-3.5-turbo" )
218221
219222 def _build_faiss_index (self ):
220223 if self ._faiss_index is None :
@@ -252,7 +255,9 @@ def get_evidence(
252255 doc .metadata ["key" ],
253256 doc .metadata ["citation" ],
254257 self .summary_chain .run (
255- question = answer .question , context_str = doc .page_content
258+ question = answer .question ,
259+ context_str = doc .page_content ,
260+ citation = doc .metadata ["citation" ],
256261 ),
257262 doc .page_content ,
258263 )
0 commit comments