Skip to content

Commit 26d888c

Browse files
committed
db fixed
1 parent d95bf4a commit 26d888c

File tree

4 files changed

+33
-5
lines changed

4 files changed

+33
-5
lines changed

learner/indexer.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,9 @@ def create_indexer(app: str, docs: str, format: str, incremental: bool, save_pat
4848
if incremental:
4949
if app in records:
5050
print_with_color("Merging with previous indexer...", "yellow")
51-
prev_db = FAISS.load_local(records[app], embeddings)
51+
prev_db = FAISS.load_local(
52+
records[app], embeddings, allow_dangerous_deserialization=True
53+
)
5254
db.merge_from(prev_db)
5355

5456
db_file_path = os.path.join(save_path, app)

record_processor/summarizer/summarizer.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,11 @@ def create_or_update_vector_db(summaries: list, db_path: str):
187187

188188
# Check if the db exists, if not, create a new one.
189189
if os.path.exists(db_path):
190-
prev_db = FAISS.load_local(db_path, get_hugginface_embedding())
190+
prev_db = FAISS.load_local(
191+
db_path,
192+
get_hugginface_embedding(),
193+
allow_dangerous_deserialization=True,
194+
)
191195
db.merge_from(prev_db)
192196

193197
db.save_local(db_path)

ufo/module/context.py

+12
Original file line numberDiff line numberDiff line change
@@ -316,3 +316,15 @@ def to_dict(self) -> Dict[str, Any]:
316316
:return: The dictionary of the context.
317317
"""
318318
return self._context
319+
320+
def from_dict(self, context_dict: Dict[str, Any]) -> None:
321+
"""
322+
Load the context from a dictionary.
323+
:param context_dict: The dictionary of the context.
324+
"""
325+
for key in ContextNames:
326+
if key.name in context_dict:
327+
self._context[key.name] = context_dict.get(key.name)
328+
329+
# Sync the current round step and cost
330+
self._sync_round_values()

ufo/rag/retriever.py

+13-3
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,9 @@ def get_indexer(self, path: str):
111111
return None
112112

113113
try:
114-
db = FAISS.load_local(path, get_hugginface_embedding())
114+
db = FAISS.load_local(
115+
path, get_hugginface_embedding(), allow_dangerous_deserialization=True
116+
)
115117
return db
116118
except:
117119
# print_with_color(
@@ -142,7 +144,11 @@ def get_indexer(self, db_path: str):
142144
"""
143145

144146
try:
145-
db = FAISS.load_local(db_path, get_hugginface_embedding())
147+
db = FAISS.load_local(
148+
db_path,
149+
get_hugginface_embedding(),
150+
allow_dangerous_deserialization=True,
151+
)
146152
return db
147153
except:
148154
# print_with_color(
@@ -209,7 +215,11 @@ def get_indexer(self, db_path: str):
209215
"""
210216

211217
try:
212-
db = FAISS.load_local(db_path, get_hugginface_embedding())
218+
db = FAISS.load_local(
219+
db_path,
220+
get_hugginface_embedding(),
221+
allow_dangerous_deserialization=True,
222+
)
213223
return db
214224
except:
215225
# print_with_color(

0 commit comments

Comments
 (0)