diff --git a/libs/knowledge-store/ragstack_knowledge_store/compare_retrieval.py b/libs/knowledge-store/ragstack_knowledge_store/compare_retrieval.py new file mode 100644 index 000000000..c69ef24fe --- /dev/null +++ b/libs/knowledge-store/ragstack_knowledge_store/compare_retrieval.py @@ -0,0 +1,20 @@ +import pickle +from langchain_core.documents import Document + +from typing import Dict, List + +def get_stuff(table_name): + with open(f"debug_retrieval_{table_name}.pkl", "rb") as file: + return pickle.load(file) + + +metadata_based: Dict[str, List[Document]] = get_stuff("metadata_based") +link_based: Dict[str, List[Document]] = get_stuff("link_column_based") + +count = 1 +for query in metadata_based.keys(): + metadata_chunks = metadata_based[query] + link_chunks = link_based[query] + + print(f"Query {count} has {len(metadata_chunks)} metadata chunks and {len(link_chunks)} link chunks. Diff: {len(metadata_chunks)-len(link_chunks)}") + count += 1 diff --git a/libs/knowledge-store/ragstack_knowledge_store/concurrency copy.py b/libs/knowledge-store/ragstack_knowledge_store/concurrency copy.py new file mode 100644 index 000000000..29b6ac224 --- /dev/null +++ b/libs/knowledge-store/ragstack_knowledge_store/concurrency copy.py @@ -0,0 +1,143 @@ +from __future__ import annotations + +import contextlib +import logging +import threading +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Literal, + NamedTuple, + Protocol, + Sequence, +) + +if TYPE_CHECKING: + from types import TracebackType + + from cassandra.cluster import ResponseFuture, Session + from cassandra.query import PreparedStatement, SimpleStatement + +logger = logging.getLogger(__name__) + + +class _Callback(Protocol): + def __call__(self, rows: Sequence[Any], /) -> None: + ... + + +class ConcurrentQueries(contextlib.AbstractContextManager["ConcurrentQueries"]): + """Context manager for concurrent queries with a max limit of 5 ongoing queries.""" + + _MAX_CONCURRENT_QUERIES = 5 + + def __init__(self, session: Session) -> None: + self._session = session + self._completion = threading.Condition() + self._pending = 0 + self._error: BaseException | None = None + self._semaphore = threading.Semaphore(self._MAX_CONCURRENT_QUERIES) + + def _handle_result( + self, + result: Sequence[NamedTuple], + future: ResponseFuture, + callback: Callable[[Sequence[NamedTuple]], Any] | None, + ) -> None: + if callback is not None: + callback(result) + + if future.has_more_pages: + future.start_fetching_next_page() + else: + with self._completion: + self._pending -= 1 + self._semaphore.release() # Release the semaphore once a query completes + if self._pending == 0: + self._completion.notify() + + def _handle_error(self, error: BaseException, future: ResponseFuture) -> None: + logger.error( + "Error executing query: %s", + future.query, + exc_info=error, + ) + with self._completion: + self._error = error + self._pending -= 1 # Decrement pending count + self._semaphore.release() # Release the semaphore on error + self._completion.notify() + + def execute( + self, + query: PreparedStatement | SimpleStatement, + parameters: tuple[Any, ...] | None = None, + callback: _Callback | None = None, + timeout: float | None = None, + ) -> None: + """Execute a query concurrently with a max of 5 concurrent queries. + + Args: + query: The query to execute. + parameters: Parameter tuple for the query. Defaults to `None`. + callback: Callback to apply to the results. Defaults to `None`. + timeout: Timeout to use (if not the session default). + """ + with self._completion: + if self._error is not None: + return + + # Acquire the semaphore before proceeding to ensure we do not exceed the max limit + self._semaphore.acquire() + + with self._completion: + if self._error is not None: + # Release semaphore before returning + self._semaphore.release() + return + self._pending += 1 + + try: + execute_kwargs = {} + if timeout is not None: + execute_kwargs["timeout"] = timeout + future: ResponseFuture = self._session.execute_async( + query, + parameters, + **execute_kwargs, + ) + future.add_callbacks( + self._handle_result, + self._handle_error, + callback_kwargs={ + "future": future, + "callback": callback, + }, + errback_kwargs={ + "future": future, + }, + ) + except Exception as e: + with self._completion: + self._error = e + self._pending -= 1 # Decrement pending count + self._semaphore.release() # Release semaphore + self._completion.notify() + raise + + def __exit__( + self, + _exc_type: type[BaseException] | None, + _exc_inst: BaseException | None, + _exc_traceback: TracebackType | None, + ) -> Literal[False]: + with self._completion: + while self._error is None and self._pending > 0: + self._completion.wait() + + if self._error is not None: + raise self._error + + # Don't swallow the exception. + return False diff --git a/libs/knowledge-store/ragstack_knowledge_store/concurrency.py b/libs/knowledge-store/ragstack_knowledge_store/concurrency.py index 7a2e57b37..29b6ac224 100644 --- a/libs/knowledge-store/ragstack_knowledge_store/concurrency.py +++ b/libs/knowledge-store/ragstack_knowledge_store/concurrency.py @@ -17,23 +17,27 @@ from types import TracebackType from cassandra.cluster import ResponseFuture, Session - from cassandra.query import PreparedStatement + from cassandra.query import PreparedStatement, SimpleStatement logger = logging.getLogger(__name__) class _Callback(Protocol): - def __call__(self, rows: Sequence[Any], /) -> None: ... + def __call__(self, rows: Sequence[Any], /) -> None: + ... class ConcurrentQueries(contextlib.AbstractContextManager["ConcurrentQueries"]): - """Context manager for concurrent queries.""" + """Context manager for concurrent queries with a max limit of 5 ongoing queries.""" + + _MAX_CONCURRENT_QUERIES = 5 def __init__(self, session: Session) -> None: self._session = session self._completion = threading.Condition() self._pending = 0 self._error: BaseException | None = None + self._semaphore = threading.Semaphore(self._MAX_CONCURRENT_QUERIES) def _handle_result( self, @@ -49,6 +53,7 @@ def _handle_result( else: with self._completion: self._pending -= 1 + self._semaphore.release() # Release the semaphore once a query completes if self._pending == 0: self._completion.notify() @@ -60,19 +65,18 @@ def _handle_error(self, error: BaseException, future: ResponseFuture) -> None: ) with self._completion: self._error = error + self._pending -= 1 # Decrement pending count + self._semaphore.release() # Release the semaphore on error self._completion.notify() def execute( self, - query: PreparedStatement, + query: PreparedStatement | SimpleStatement, parameters: tuple[Any, ...] | None = None, callback: _Callback | None = None, timeout: float | None = None, ) -> None: - """Execute a query concurrently. - - Because this is done concurrently, it expects a callback if you need - to inspect the results. + """Execute a query concurrently with a max of 5 concurrent queries. Args: query: The query to execute. @@ -80,33 +84,47 @@ def execute( callback: Callback to apply to the results. Defaults to `None`. timeout: Timeout to use (if not the session default). """ - # TODO: We could have some form of throttling, where we track the number - # of pending calls and queue things if it exceed some threshold. + with self._completion: + if self._error is not None: + return + + # Acquire the semaphore before proceeding to ensure we do not exceed the max limit + self._semaphore.acquire() with self._completion: - self._pending += 1 if self._error is not None: + # Release semaphore before returning + self._semaphore.release() return + self._pending += 1 - execute_kwargs = {} - if timeout is not None: - execute_kwargs["timeout"] = timeout - future: ResponseFuture = self._session.execute_async( - query, - parameters, - **execute_kwargs, - ) - future.add_callbacks( - self._handle_result, - self._handle_error, - callback_kwargs={ - "future": future, - "callback": callback, - }, - errback_kwargs={ - "future": future, - }, - ) + try: + execute_kwargs = {} + if timeout is not None: + execute_kwargs["timeout"] = timeout + future: ResponseFuture = self._session.execute_async( + query, + parameters, + **execute_kwargs, + ) + future.add_callbacks( + self._handle_result, + self._handle_error, + callback_kwargs={ + "future": future, + "callback": callback, + }, + errback_kwargs={ + "future": future, + }, + ) + except Exception as e: + with self._completion: + self._error = e + self._pending -= 1 # Decrement pending count + self._semaphore.release() # Release semaphore + self._completion.notify() + raise def __exit__( self, @@ -122,6 +140,4 @@ def __exit__( raise self._error # Don't swallow the exception. - # We don't need to do anything with the exception (`_exc_*` parameters) - # since returning false here will automatically re-raise it. return False diff --git a/libs/knowledge-store/ragstack_knowledge_store/graph_embed_legal.py b/libs/knowledge-store/ragstack_knowledge_store/graph_embed_legal.py new file mode 100644 index 000000000..ddb1cff75 --- /dev/null +++ b/libs/knowledge-store/ragstack_knowledge_store/graph_embed_legal.py @@ -0,0 +1,221 @@ +import cassio +import json +import time +import os +from glob import glob +from typing import Any, Dict, List, Generator, Tuple + +from langchain_core.documents import Document +from ragstack_knowledge_store.graph_store_tags import CONTENT_ID +from langchain_core.graph_vectorstores.links import add_links, get_links, Link +from ragstack_knowledge_store.keybert_link_extractor import KeybertLinkExtractor +from ragstack_knowledge_store.langchain_cassandra_tags import CassandraGraphVectorStore +from langchain_text_splitters import MarkdownHeaderTextSplitter +from langchain_openai.embeddings import OpenAIEmbeddings + +from keyphrase_vectorizers import KeyphraseCountVectorizer + +import tiktoken +from dotenv import load_dotenv +from tqdm import tqdm +import re + +from cassio.config import check_resolve_keyspace, check_resolve_session + +EMBEDDING_MODEL = "text-embedding-3-small" +LLM_MODEL = "gpt-4o-mini" +BATCH_SIZE = 250 +KEYSPACE = "legal_graph_store" +TABLE_NAME = "tag_based" +DRY_RUN = False + +load_dotenv() + +def delete_all_files_in_folder(folder_path: str) -> None: + for file_name in os.listdir(folder_path): + file_path = os.path.join(folder_path, file_name) + if os.path.isfile(file_path): + os.remove(file_path) + +token_counter = tiktoken.encoding_for_model(EMBEDDING_MODEL) +def token_count(text: str) -> int: + return len(token_counter.encode(text)) + +headers_to_split_on = [ + ("#", "Header 1"), + ("##", "Header 2"), + ("###", "Header 3"), +] + +markdown_splitter = MarkdownHeaderTextSplitter( + headers_to_split_on=headers_to_split_on, + return_each_line=False, + strip_headers=False +) + +# Define the regex pattern +outgoing_section_pattern = r"(\d+\.\d+)\s+\*\*(.*?)\*\*" +incoming_internal_section_pattern = r"\*\*Section\s(\d+\.\d+)\*\*" +incoming_external_section_pattern1 = r"\*\*(.*?)\s\((.*?)\),\sSection\s(\d+\.\d+)\*\*" +incoming_external_section_pattern2 = r"\*\*Section\s(\d+\.\d+)\sof\sthe\s(.*?)\s\((.*?)\)\*\*" + +# Others to fix in the original dataset: +# Section 2.3 +# Section 2.1 +# **Section 5** of the Software Development Agreement +# **Section 4.1** of that Agreement + +keybert_link_extractor = KeybertLinkExtractor( + extract_keywords_kwargs={ + "vectorizer": KeyphraseCountVectorizer(stop_words="english"), + "use_mmr":True, + "diversity": 0.7 + } +) + +def build_document_batch(doc_batch: List[Document]) -> List[Document]: + keybert_links_batch = keybert_link_extractor.extract_many(doc_batch) + for keybert_links, doc in zip(keybert_links_batch, doc_batch): + # drop links with one word + # pruned_links = [link for link in keybert_links if " " in link.tag] + add_links(doc, keybert_links) + return doc_batch + + +def load_chunks(markdown_file_paths: List[str]) -> Generator[List[Document], None, None]: + doc_batch: List[Document] = [] + + for markdown_file_path in tqdm(markdown_file_paths): + with open(markdown_file_path, 'r') as file: + markdown_text = file.read() + + docs = markdown_splitter.split_text(markdown_text) + + for doc in docs: + doc.metadata[CONTENT_ID] = markdown_file_path + + doc_title = doc.metadata.get("Header 1", "") + + section_links = [] + + # find outgoing links + for out_section in re.findall(outgoing_section_pattern, doc.page_content): + out_number = out_section[0] + out_title = out_section[1] + section_links.append(Link("section", direction="in", tag=f"{doc_title} {out_number}")) + + # find incoming links + for in_number in re.findall(incoming_internal_section_pattern, doc.page_content): + section_links.append(Link("section", direction="out", tag=f"{doc_title} {in_number}")) + + for in_section1 in re.findall(incoming_external_section_pattern1, doc.page_content): + in_title1 = in_section1[0] + in_abbreviation1 = in_section1[1] + in_number1 = in_section1[2] + section_links.append(Link("section", direction="out", tag=f"{in_title1} ({in_abbreviation1}) {in_number1}")) + + for in_section2 in re.findall(incoming_external_section_pattern2, doc.page_content): + in_number2 = in_section2[0] + in_title2 = in_section2[1] + in_abbreviation2 = in_section2[2] + section_links.append(Link("section", direction="out", tag=f"{in_title2} ({in_abbreviation2}) {in_number2}")) + + add_links(doc, section_links) + + + doc_batch.append(doc) + + if len(doc_batch) == BATCH_SIZE: + yield build_document_batch(doc_batch=doc_batch) + doc_batch: List[Document] = [] + + yield build_document_batch(doc_batch=doc_batch) + +def init_graph_store() -> CassandraGraphVectorStore: + cassio.init(auto=True) + + session = check_resolve_session() + keyspace = check_resolve_keyspace(KEYSPACE) + statement = session.prepare(f"DROP TABLE IF EXISTS {keyspace}.{TABLE_NAME};") + session.execute(statement) + + embedding_model = OpenAIEmbeddings(model=EMBEDDING_MODEL) + return CassandraGraphVectorStore( + node_table=TABLE_NAME, + session=session, + embedding=embedding_model, + keyspace=keyspace, + ) + +class CustomJSONEncoder(json.JSONEncoder): + def default(self, obj: Any) -> Any: + if isinstance(obj, Link): + return {"direction": obj.direction, "kind": obj.kind, "tag": obj.tag} + return super().default(obj) + +def sort_dict_by_count_and_key(data: dict[str, int]) -> List[Tuple[str, int]]: + return sorted(data.items(), key=lambda item: (-item[1], item[0])) + + +def load_and_insert_chunks(dry_run: bool = True): + in_links = set() + out_links = set() + bidir_links: Dict[str, int] = {} + + if dry_run: + delete_all_files_in_folder("chunk_debug") + else: + graph_store = init_graph_store() + + markdown_file_paths = glob(pathname="datasets2/legal_documents/**/*.md", recursive=True) + + index = 0 + + for chunk_batch in load_chunks(markdown_file_paths=markdown_file_paths): + if not dry_run: + while True: + try: + graph_store.add_documents(chunk_batch) + break + except Exception as e: + print(f"Encountered issue trying to store document batch: {e}") + time.sleep(2) + graph_store = init_graph_store() + + for chunk in chunk_batch: + if dry_run: + + id = chunk.metadata[CONTENT_ID] + id = re.sub(r'[^\w\-.]', '_', id) + with open(f"chunk_debug/{str(index).zfill(5)}_{id}", "w") as f: + f.write(chunk.page_content + "\n\n") + f.write(json.dumps(chunk.metadata, cls=CustomJSONEncoder) + "\n\n") + links = get_links(chunk) + for link in links: + f.write(f"LINK Kind: '{link.kind}', Direction: '{link.direction}', Tag: '{link.tag}'\n") + index += 1 + + links = get_links(chunk) + for link in links: + if link.direction == "in": + in_links.add(link.tag) + elif link.direction == "out": + out_links.add(link.tag) + elif link.direction == "bidir": + if link.tag in bidir_links: + bidir_links[link.tag] += 1 + else: + bidir_links[link.tag] = 0 + + with open("debug_links_legal.json", "w") as f: + json.dump(fp=f, obj={ + "in_links": sorted(list(in_links)), + "out_links": sorted(list(out_links)), + "bidir_links": sort_dict_by_count_and_key(bidir_links), + }) + + print(f"Links In: {len(in_links)}, Out: {len(out_links)}, BiDir: {len(bidir_links)}") + + +def main(): + load_and_insert_chunks(dry_run=DRY_RUN) diff --git a/libs/knowledge-store/ragstack_knowledge_store/graph_query_legal.py b/libs/knowledge-store/ragstack_knowledge_store/graph_query_legal.py new file mode 100644 index 000000000..c350b8f51 --- /dev/null +++ b/libs/knowledge-store/ragstack_knowledge_store/graph_query_legal.py @@ -0,0 +1,153 @@ +import cassio +import json +import pickle +from dotenv import load_dotenv +# from transformers import pipeline +from ragstack_knowledge_store.langchain_cassandra_tags import CassandraGraphVectorStore +# from langchain_core.graph_vectorstores import GraphVectorStoreRetriever +# from langchain_core.language_models.chat_models import BaseChatModel +# from langchain_core.output_parsers import StrOutputParser +# from langchain_core.prompts import ChatPromptTemplate +# from langchain_core.runnables import RunnablePassthrough +# from langchain_core.documents import Document +# from langchain_core.callbacks import ( +# AsyncCallbackManagerForRetrieverRun, +# CallbackManagerForRetrieverRun, +# ) +# from langchain_openai.chat_models import ChatOpenAI +from langchain_openai.embeddings import OpenAIEmbeddings + +# from typing import List, Tuple + +from tqdm import tqdm + +EMBEDDING_MODEL = "text-embedding-3-small" +CHAT_MODEL = "gpt-4o-mini" +KEYSPACE_NAME = "legal_graph_store" +TABLE_NAME = "tag_based" + +load_dotenv() + +# class GraphReRankRetriever(GraphVectorStoreRetriever): +# cross_encoder = pipeline("text-classification", model="cross-encoder/ms-marco-MiniLM-L-6-v2", top_k=None) +# re_rank:bool = False + +# def __init__(self, *args, re_rank: str = "true", **kwargs): +# super().__init__(*args, **kwargs) +# self.re_rank = re_rank.lower() == "true" + + +# def rerank_with_cross_encoder(self, query: str, documents: List[Document], k: int = 5) -> List[Document]: +# if not self.re_rank: +# return documents + +# # Re-rank documents using the cross-encoder +# scored_documents = [ +# (doc, self.cross_encoder(f"{query} [SEP] {doc.page_content}")[0][0]['score']) for doc in documents +# ] + +# # Sort documents by score in descending order and return top 5 +# ranked_documents = sorted(scored_documents, key=lambda x: x[1], reverse=True) + +# return [doc for doc, _ in ranked_documents[:k]] # Return the top 5 documents + +# def _get_relevant_documents( +# self, query: str, *, run_manager: CallbackManagerForRetrieverRun +# ) -> List[Document]: +# result = super()._get_relevant_documents(query, run_manager=run_manager) +# return self.rerank_with_cross_encoder(query=query, documents=result) + +# async def _aget_relevant_documents( +# self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun +# ) -> List[Document]: +# # Call the original async method +# result = await super()._aget_relevant_documents(query, run_manager=run_manager) +# return self.rerank_with_cross_encoder(query=query, documents=result) + + + +# def get_llm(chat_model_name: str) -> BaseChatModel: +# return ChatOpenAI(model=chat_model_name, temperature=0.0) + + +def get_graph_store() -> CassandraGraphVectorStore: + cassio.init(auto=True) + embedding_model = OpenAIEmbeddings(model=EMBEDDING_MODEL) + return CassandraGraphVectorStore(embedding=embedding_model, keyspace=KEYSPACE_NAME, node_table=TABLE_NAME) + +def get_retriever(depth: int, search_type: str): + graph_store = get_graph_store() + return graph_store.as_retriever( + search_type=search_type, + search_kwargs={ + "k": 5, + "fetch_k": 20, # Fetch 20 docs, but we'll return top 5 after re-ranking + "depth": depth, + }, + ) + +def test_retrieval(): + retriever = get_retriever(depth=2, search_type="traversal") + with open("datasets2/crag/legal/questions.jsonl") as f: + lines = f.readlines() + + retrieved_chunks = {} + for line in tqdm(lines): + data = json.loads(line) + chunks = retriever.invoke(data["query"]) + retrieved_chunks[data["query"]] = chunks + + with open(f"debug_retrieval_{TABLE_NAME}.pkl", "wb") as file: # Open the file in write-binary mode + pickle.dump(retrieved_chunks, file) + + +# def query_pipeline(depth: int, search_type: str, re_rank: str = "true", **kwargs): +# llm = get_llm(chat_model_name=CHAT_MODEL) +# graph_store = get_graph_store() + +# retriever = GraphReRankRetriever( +# re_rank=re_rank, +# vectorstore=graph_store, +# search_type=search_type, +# search_kwargs={ +# "k": 5, +# "fetch_k": 20, # Fetch 20 docs, but we'll return top 5 after re-ranking +# "depth": depth, +# }, +# ) + +# # # Prepare the retriever +# # retriever = graph_store.as_retriever( +# # search_type=search_type, +# # search_kwargs={ +# # "k": 5, +# # "fetch_k": 20, # Fetch 20 docs, but we'll return top 5 after re-ranking +# # "depth": depth, +# # }, +# # ) + + +# # Define the prompt template +# prompt_template = """ +# Retrieved Information: +# {retrieved_docs} + +# User Query: +# {query} + +# Response Instruction: +# Please generate a response without using markdown that uses the retrieved information to directly and clearly answer the user's query. Ensure that the response is relevant, accurate, and well-organized. +# """ # noqa: E501 + +# prompt = ChatPromptTemplate.from_template(prompt_template) + +# # Return the pipeline with retriever and re-ranking step +# return ( +# { +# "retrieved_docs": retriever, +# "query": RunnablePassthrough(), +# } +# | prompt # Generate prompt with re-ranked docs +# | llm # Pass through the LLM +# | StrOutputParser() # Final output parser +# ) diff --git a/libs/knowledge-store/ragstack_knowledge_store/graph_query_legal_tags_async.py b/libs/knowledge-store/ragstack_knowledge_store/graph_query_legal_tags_async.py new file mode 100644 index 000000000..374569670 --- /dev/null +++ b/libs/knowledge-store/ragstack_knowledge_store/graph_query_legal_tags_async.py @@ -0,0 +1,153 @@ +import cassio +import json +import pickle +from dotenv import load_dotenv +# from transformers import pipeline +from ragstack_knowledge_store.langchain_cassandra_tags_async import CassandraGraphVectorStore +# from langchain_core.graph_vectorstores import GraphVectorStoreRetriever +# from langchain_core.language_models.chat_models import BaseChatModel +# from langchain_core.output_parsers import StrOutputParser +# from langchain_core.prompts import ChatPromptTemplate +# from langchain_core.runnables import RunnablePassthrough +# from langchain_core.documents import Document +# from langchain_core.callbacks import ( +# AsyncCallbackManagerForRetrieverRun, +# CallbackManagerForRetrieverRun, +# ) +# from langchain_openai.chat_models import ChatOpenAI +from langchain_openai.embeddings import OpenAIEmbeddings + +# from typing import List, Tuple + +from tqdm import tqdm + +EMBEDDING_MODEL = "text-embedding-3-small" +CHAT_MODEL = "gpt-4o-mini" +KEYSPACE_NAME = "legal_graph_store" +TABLE_NAME = "tag_based" + +load_dotenv() + +# class GraphReRankRetriever(GraphVectorStoreRetriever): +# cross_encoder = pipeline("text-classification", model="cross-encoder/ms-marco-MiniLM-L-6-v2", top_k=None) +# re_rank:bool = False + +# def __init__(self, *args, re_rank: str = "true", **kwargs): +# super().__init__(*args, **kwargs) +# self.re_rank = re_rank.lower() == "true" + + +# def rerank_with_cross_encoder(self, query: str, documents: List[Document], k: int = 5) -> List[Document]: +# if not self.re_rank: +# return documents + +# # Re-rank documents using the cross-encoder +# scored_documents = [ +# (doc, self.cross_encoder(f"{query} [SEP] {doc.page_content}")[0][0]['score']) for doc in documents +# ] + +# # Sort documents by score in descending order and return top 5 +# ranked_documents = sorted(scored_documents, key=lambda x: x[1], reverse=True) + +# return [doc for doc, _ in ranked_documents[:k]] # Return the top 5 documents + +# def _get_relevant_documents( +# self, query: str, *, run_manager: CallbackManagerForRetrieverRun +# ) -> List[Document]: +# result = super()._get_relevant_documents(query, run_manager=run_manager) +# return self.rerank_with_cross_encoder(query=query, documents=result) + +# async def _aget_relevant_documents( +# self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun +# ) -> List[Document]: +# # Call the original async method +# result = await super()._aget_relevant_documents(query, run_manager=run_manager) +# return self.rerank_with_cross_encoder(query=query, documents=result) + + + +# def get_llm(chat_model_name: str) -> BaseChatModel: +# return ChatOpenAI(model=chat_model_name, temperature=0.0) + + +def get_graph_store() -> CassandraGraphVectorStore: + cassio.init(auto=True) + embedding_model = OpenAIEmbeddings(model=EMBEDDING_MODEL) + return CassandraGraphVectorStore(embedding=embedding_model, keyspace=KEYSPACE_NAME, node_table=TABLE_NAME) + +def get_retriever(depth: int, search_type: str): + graph_store = get_graph_store() + return graph_store.as_retriever( + search_type=search_type, + search_kwargs={ + "k": 5, + "fetch_k": 20, # Fetch 20 docs, but we'll return top 5 after re-ranking + "depth": depth, + }, + ) + +def test_retrieval(): + retriever = get_retriever(depth=2, search_type="traversal") + with open("datasets2/crag/legal/questions.jsonl") as f: + lines = f.readlines() + + retrieved_chunks = {} + for line in tqdm(lines): + data = json.loads(line) + chunks = retriever.invoke(data["query"]) + retrieved_chunks[data["query"]] = chunks + + with open(f"debug_retrieval_{TABLE_NAME}.pkl", "wb") as file: # Open the file in write-binary mode + pickle.dump(retrieved_chunks, file) + +test_retrieval() +# def query_pipeline(depth: int, search_type: str, re_rank: str = "true", **kwargs): +# llm = get_llm(chat_model_name=CHAT_MODEL) +# graph_store = get_graph_store() + +# retriever = GraphReRankRetriever( +# re_rank=re_rank, +# vectorstore=graph_store, +# search_type=search_type, +# search_kwargs={ +# "k": 5, +# "fetch_k": 20, # Fetch 20 docs, but we'll return top 5 after re-ranking +# "depth": depth, +# }, +# ) + +# # # Prepare the retriever +# # retriever = graph_store.as_retriever( +# # search_type=search_type, +# # search_kwargs={ +# # "k": 5, +# # "fetch_k": 20, # Fetch 20 docs, but we'll return top 5 after re-ranking +# # "depth": depth, +# # }, +# # ) + + +# # Define the prompt template +# prompt_template = """ +# Retrieved Information: +# {retrieved_docs} + +# User Query: +# {query} + +# Response Instruction: +# Please generate a response without using markdown that uses the retrieved information to directly and clearly answer the user's query. Ensure that the response is relevant, accurate, and well-organized. +# """ # noqa: E501 + +# prompt = ChatPromptTemplate.from_template(prompt_template) + +# # Return the pipeline with retriever and re-ranking step +# return ( +# { +# "retrieved_docs": retriever, +# "query": RunnablePassthrough(), +# } +# | prompt # Generate prompt with re-ranked docs +# | llm # Pass through the LLM +# | StrOutputParser() # Final output parser +# ) diff --git a/libs/knowledge-store/ragstack_knowledge_store/graph_store.py b/libs/knowledge-store/ragstack_knowledge_store/graph_store.py index b392f65b1..74945355b 100644 --- a/libs/knowledge-store/ragstack_knowledge_store/graph_store.py +++ b/libs/knowledge-store/ragstack_knowledge_store/graph_store.py @@ -4,6 +4,8 @@ import logging import re import secrets +import sys +from collections import deque from collections.abc import Iterable from dataclasses import asdict, dataclass, field, is_dataclass from enum import Enum @@ -11,19 +13,25 @@ TYPE_CHECKING, Any, Sequence, + Set, Union, cast, ) -from cassandra.cluster import ConsistencyLevel, PreparedStatement, Session +from tqdm import tqdm + +from cassandra.cluster import ConsistencyLevel, PreparedStatement, Session, SimpleStatement from cassio.config import check_resolve_keyspace, check_resolve_session from typing_extensions import assert_never from ._mmr_helper import MmrHelper from .concurrency import ConcurrentQueries -from .content import Kind from .links import Link +from concurrent.futures import ThreadPoolExecutor +from queue import Queue, Empty +import threading + if TYPE_CHECKING: from .embedding_model import EmbeddingModel @@ -31,10 +39,8 @@ CONTENT_ID = "content_id" -CONTENT_COLUMNS = "content_id, kind, text_content, links_blob, metadata_blob" - SELECT_CQL_TEMPLATE = ( - "SELECT {columns} FROM {table_name}{where_clause}{order_clause}{limit_clause};" + "SELECT {columns} FROM {table_name}{where_clause}{order_clause}{limit_clause}" ) @@ -46,10 +52,18 @@ class Node: """Text contained by the node.""" id: str | None = None """Unique ID for the node. Will be generated by the GraphStore if not set.""" + embedding: list[float] = field(default_factory=list) + """Vector embedding of the text""" metadata: dict[str, Any] = field(default_factory=dict) """Metadata for the node.""" links: set[Link] = field(default_factory=set) - """Links for the node.""" + """All the links for the node.""" + + def incoming_links(self) -> set[Link]: + return set([l for l in self.links if (l.direction in ["in", "bidir"])]) + + def outgoing_links(self) -> set[Link]: + return set([l for l in self.links if (l.direction in ["out", "bidir"])]) class SetupMode(Enum): @@ -114,13 +128,25 @@ def _deserialize_links(json_blob: str | None) -> set[Link]: for link in cast(list[dict[str, Any]], json.loads(json_blob or "")) } +def _metadata_s_link_key(link: Link) -> str: + return "link_from_" + json.dumps({"kind": link.kind, "tag": link.tag}) + +def _metadata_s_link_value() -> str: + return "link_from" def _row_to_node(row: Any) -> Node: - metadata = _deserialize_metadata(row.metadata_blob) - links = _deserialize_links(row.links_blob) + if hasattr(row, "metadata_blob"): + metadata_blob = getattr(row, "metadata_blob") + metadata = _deserialize_metadata(metadata_blob) + links: set[Link] = _deserialize_links(metadata.get("links")) + metadata["links"] = links + else: + metadata = {} + links = set() return Node( - id=row.content_id, - text=row.text_content, + id=getattr(row, CONTENT_ID, ""), + embedding=getattr(row, "text_embedding", []), + text=getattr(row, "text_content", ""), metadata=metadata, links=links, ) @@ -129,13 +155,6 @@ def _row_to_node(row: Any) -> Node: _CQL_IDENTIFIER_PATTERN = re.compile(r"[a-zA-Z][a-zA-Z0-9_]*") -@dataclass -class _Edge: - target_content_id: str - target_text_embedding: list[float] - target_link_to_tags: set[tuple[str, str]] - - class GraphStore: """A hybrid vector-and-graph store backed by Cassandra. @@ -201,25 +220,24 @@ def __init__( self._insert_passage = session.prepare( f""" INSERT INTO {keyspace}.{node_table} ( - content_id, kind, text_content, text_embedding, link_to_tags, - link_from_tags, links_blob, metadata_blob, metadata_s - ) VALUES (?, '{Kind.passage}', ?, ?, ?, ?, ?, ?, ?) + {CONTENT_ID}, text_content, text_embedding, metadata_blob, metadata_s + ) VALUES (?, ?, ?, ?, ?) """ # noqa: S608 ) self._query_by_id = session.prepare( f""" - SELECT {CONTENT_COLUMNS} + SELECT {CONTENT_ID}, text_content, metadata_blob FROM {keyspace}.{node_table} - WHERE content_id = ? + WHERE {CONTENT_ID} = ? """ # noqa: S608 ) - self._query_ids_and_link_to_tags_by_id = session.prepare( + self._query_id_and_metadata_by_id = session.prepare( f""" - SELECT content_id, link_to_tags + SELECT {CONTENT_ID}, metadata_blob FROM {keyspace}.{node_table} - WHERE content_id = ? + WHERE {CONTENT_ID} = ? """ # noqa: S608 ) @@ -232,18 +250,12 @@ def _apply_schema(self) -> None: embedding_dim = len(self._embedding.embed_query("Test Query")) self._session.execute(f""" CREATE TABLE IF NOT EXISTS {self.table_name()} ( - content_id TEXT, - kind TEXT, + {CONTENT_ID} TEXT, text_content TEXT, text_embedding VECTOR, - - link_to_tags SET>, - link_from_tags SET>, - links_blob TEXT, metadata_blob TEXT, metadata_s MAP, - - PRIMARY KEY (content_id) + PRIMARY KEY ({CONTENT_ID}) ) """) @@ -254,12 +266,6 @@ def _apply_schema(self) -> None: USING 'StorageAttachedIndex'; """) - self._session.execute(f""" - CREATE CUSTOM INDEX IF NOT EXISTS {self._node_table}_link_from_tags - ON {self.table_name()}(link_from_tags) - USING 'StorageAttachedIndex'; - """) - self._session.execute(f""" CREATE CUSTOM INDEX IF NOT EXISTS {self._node_table}_metadata_s_index ON {self.table_name()}(ENTRIES(metadata_s)) @@ -277,32 +283,24 @@ def add_nodes( """Add nodes to the graph store.""" node_ids: list[str] = [] texts: list[str] = [] - metadatas: list[dict[str, Any]] = [] - nodes_links: list[set[Link]] = [] + metadata_list: list[dict[str, Any]] = [] + incoming_links_list: list[set[Link]] = [] for node in nodes: if not node.id: node_ids.append(secrets.token_hex(8)) else: node_ids.append(node.id) texts.append(node.text) - metadatas.append(node.metadata) - nodes_links.append(node.links) + combined_metadata = node.metadata.copy() + combined_metadata["links"] = _serialize_links(node.links) + metadata_list.append(combined_metadata) + incoming_links_list.append(node.incoming_links()) text_embeddings = self._embedding.embed_texts(texts) with self._concurrent_queries() as cq: - tuples = zip(node_ids, texts, text_embeddings, metadatas, nodes_links) - for node_id, text, text_embedding, metadata, links in tuples: - link_to_tags = set() # link to these tags - link_from_tags = set() # link from these tags - - for tag in links: - if tag.direction in {"in", "bidir"}: - # An incoming link should be linked *from* nodes with the given - # tag. - link_from_tags.add((tag.kind, tag.tag)) - if tag.direction in {"out", "bidir"}: - link_to_tags.add((tag.kind, tag.tag)) + tuples = zip(node_ids, texts, text_embeddings, metadata_list, incoming_links_list) + for node_id, text, text_embedding, metadata, incoming_links in tuples: metadata_s = { k: self._coerce_string(v) @@ -310,17 +308,17 @@ def add_nodes( if _is_metadata_field_indexed(k, self._metadata_indexing_policy) } + for incoming_link in incoming_links: + metadata_s[_metadata_s_link_key(link=incoming_link)] =_metadata_s_link_value() + metadata_blob = _serialize_metadata(metadata) - links_blob = _serialize_links(links) + cq.execute( self._insert_passage, parameters=( node_id, text, text_embedding, - link_to_tags, - link_from_tags, - links_blob, metadata_blob, metadata_s, ), @@ -412,74 +410,58 @@ def mmr_traversal_search( score_threshold=score_threshold, ) - # For each unselected node, stores the outgoing tags. - outgoing_tags: dict[str, set[tuple[str, str]]] = {} - - # Fetch the initial candidates and add them to the helper and - # outgoing_tags. - columns = "content_id, text_embedding, link_to_tags" - adjacent_query = self._get_search_cql( - has_limit=True, - columns=columns, - metadata_keys=list(metadata_filter.keys()), - has_embedding=True, - has_link_from_tags=True, - ) - - visited_tags: set[tuple[str, str]] = set() + # For each unselected node, stores the outgoing links. + outgoing_links_map: dict[str, set[Link]] = {} + visited_links: set[Link] = set() def fetch_neighborhood(neighborhood: Sequence[str]) -> None: - # Put the neighborhood into the outgoing tags, to avoid adding it + nonlocal outgoing_links_map + nonlocal visited_links + + # Put the neighborhood into the outgoing links, to avoid adding it # to the candidate set in the future. - outgoing_tags.update({content_id: set() for content_id in neighborhood}) + outgoing_links_map.update({content_id: set() for content_id in neighborhood}) - # Initialize the visited_tags with the set of outgoing from the + # Initialize the visited_links with the set of outgoing links from the # neighborhood. This prevents re-visiting them. - visited_tags = self._get_outgoing_tags(neighborhood) + visited_links = self._get_outgoing_links(neighborhood) # Call `self._get_adjacent` to fetch the candidates. - adjacents = self._get_adjacent( - visited_tags, - adjacent_query=adjacent_query, + adjacent_nodes: Iterable[Node] = self._get_adjacent_nodes( + outgoing_links=visited_links, query_embedding=query_embedding, k_per_tag=adjacent_k, metadata_filter=metadata_filter, ) - new_candidates = {} - for adjacent in adjacents: - if adjacent.target_content_id not in outgoing_tags: - outgoing_tags[adjacent.target_content_id] = ( - adjacent.target_link_to_tags - ) + candidates: dict[str, list[float]] = {} + for node in adjacent_nodes: + if node.id not in outgoing_links_map: + outgoing_links_map[node.id] = node.outgoing_links() + candidates[node.id] = node.embedding - new_candidates[adjacent.target_content_id] = ( - adjacent.target_text_embedding - ) - helper.add_candidates(new_candidates) + helper.add_candidates(candidates) def fetch_initial_candidates() -> None: - initial_candidates_query = self._get_search_cql( - has_limit=True, - columns=columns, - metadata_keys=list(metadata_filter.keys()), - has_embedding=True, - ) + nonlocal outgoing_links_map + nonlocal visited_links - params = self._get_search_params( + initial_query, initial_params = self._get_search_cql_and_params( + columns=f"{CONTENT_ID}, text_embedding, metadata_blob", limit=fetch_k, metadata=metadata_filter, embedding=query_embedding, ) - fetched = self._session.execute( - query=initial_candidates_query, parameters=params + rows = self._session.execute( + query=initial_query, parameters=initial_params ) - candidates = {} - for row in fetched: - if row.content_id not in outgoing_tags: - candidates[row.content_id] = row.text_embedding - outgoing_tags[row.content_id] = set(row.link_to_tags or []) + candidates: dict[str, list[float]] = {} + for row in rows: + node = _row_to_node(row) + if node.id not in outgoing_links_map: + outgoing_links_map[node.id] = node.outgoing_links() + candidates[node.id] = node.embedding helper.add_candidates(candidates) if initial_roots: @@ -502,40 +484,34 @@ def fetch_initial_candidates() -> None: # If the next nodes would not exceed the depth limit, find the # adjacent nodes. # - # TODO: For a big performance win, we should track which tags we've + # TODO: For a big performance win, we should track which links we've # already incorporated. We don't need to issue adjacent queries for # those. - # Find the tags linked to from the selected ID. - link_to_tags = outgoing_tags.pop(selected_id) + # Find the outgoing links linked to from the selected ID. + selected_outgoing_links = outgoing_links_map.pop(selected_id) - # Don't re-visit already visited tags. - link_to_tags.difference_update(visited_tags) + # Don't re-visit already visited links. + selected_outgoing_links.difference_update(visited_links) # Find the nodes with incoming links from those tags. - adjacents = self._get_adjacent( - link_to_tags, - adjacent_query=adjacent_query, + adjacent_nodes: Iterable[Node] = self._get_adjacent_nodes( + outgoing_links=selected_outgoing_links, query_embedding=query_embedding, k_per_tag=adjacent_k, metadata_filter=metadata_filter, ) - # Record the link_to_tags as visited. - visited_tags.update(link_to_tags) + # Record the selected_outgoing_links as visited. + visited_links.update(selected_outgoing_links) - new_candidates = {} - for adjacent in adjacents: - if adjacent.target_content_id not in outgoing_tags: - outgoing_tags[adjacent.target_content_id] = ( - adjacent.target_link_to_tags - ) - new_candidates[adjacent.target_content_id] = ( - adjacent.target_text_embedding - ) - if next_depth < depths.get( - adjacent.target_content_id, depth + 1 - ): + candidates: dict[str, list[float]] = {} + for node in adjacent_nodes: + if node.id not in outgoing_links_map: + outgoing_links_map[node.id] = node.outgoing_links() + candidates[node.id] = node.embedding + + if next_depth < depths.get(node.id, depth + 1): # If this is a new shortest depth, or there was no # previous depth, update the depths. This ensures that # when we discover a node we will have the shortest @@ -546,19 +522,21 @@ def fetch_initial_candidates() -> None: # a shorter path via nodes selected later. This is # currently "intended", but may be worth experimenting # with. - depths[adjacent.target_content_id] = next_depth - helper.add_candidates(new_candidates) + depths[node.id] = next_depth + helper.add_candidates(candidates) return self._nodes_with_ids(helper.selected_ids) - def traversal_search( - self, - query: str, - *, - k: int = 4, - depth: int = 1, - metadata_filter: dict[str, Any] = {}, # noqa: B006 - ) -> Iterable[Node]: + + + def traversal_search_sync( + self, + query: str, + *, + k: int = 4, + depth: int = 1, + metadata_filter: dict[str, Any] = {}, # noqa: B006 + ) -> Iterable[Node]: """Retrieve documents from this knowledge store. First, `k` nodes are retrieved using a vector search for the `query` string. @@ -575,114 +553,169 @@ def traversal_search( Returns: Collection of retrieved documents. """ - # Depth 0: - # Query for `k` nodes similar to the question. - # Retrieve `content_id` and `link_to_tags`. - # - # Depth 1: - # Query for nodes that have an incoming tag in the `link_to_tags` set. - # Combine node IDs. - # Query for `link_to_tags` of those "new" node IDs. - # - # ... - - traversal_query = self._get_search_cql( - columns="content_id, link_to_tags", - has_limit=True, - metadata_keys=list(metadata_filter.keys()), - has_embedding=True, - ) + visited_ids: dict[str, int] = {} + visited_link_keys: dict[str, int] = {} - visit_nodes_query = self._get_search_cql( - columns="content_id AS target_content_id", - has_link_from_tags=True, - metadata_keys=list(metadata_filter.keys()), + work_queue = deque() + + # Initial traversal query + traversal_query, traversal_params = self._get_search_cql_and_params( + columns=f"{CONTENT_ID}, metadata_blob", + metadata=metadata_filter, + embedding=self._embedding.embed_query(query), + limit=k ) - with self._concurrent_queries() as cq: - # Map from visited ID to depth - visited_ids: dict[str, int] = {} - - # Map from visited tag `(kind, tag)` to depth. Allows skipping queries - # for tags that we've already traversed. - visited_tags: dict[tuple[str, str], int] = {} - - def visit_nodes(d: int, nodes: Sequence[Any]) -> None: - nonlocal visited_ids - nonlocal visited_tags - - # Visit nodes at the given depth. - # Each node has `content_id` and `link_to_tags`. - - # Iterate over nodes, tracking the *new* outgoing kind tags for this - # depth. This is tags that are either new, or newly discovered at a - # lower depth. - outgoing_tags = set() - for node in nodes: - content_id = node.content_id - - # Add visited ID. If it is closer it is a new node at this depth: - if d <= visited_ids.get(content_id, depth): - visited_ids[content_id] = d - - # If we can continue traversing from this node, - if d < depth and node.link_to_tags: - # Record any new (or newly discovered at a lower depth) - # tags to the set to traverse. - for kind, value in node.link_to_tags: - if d <= visited_tags.get((kind, value), depth): - # Record that we'll query this tag at the - # given depth, so we don't fetch it again - # (unless we find it an earlier depth) - visited_tags[(kind, value)] = d - outgoing_tags.add((kind, value)) - - if outgoing_tags: - # If there are new tags to visit at the next depth, query for the - # node IDs. - for kind, value in outgoing_tags: - params = self._get_search_params( - link_from_tags=(kind, value), metadata=metadata_filter - ) - cq.execute( - query=visit_nodes_query, - parameters=params, - callback=lambda rows, d=d: visit_targets(d, rows), - ) + # Execute the initial query synchronously + initial_rows = self._session.execute(traversal_query, traversal_params) + + for row in initial_rows: + node = _row_to_node(row=row) + work_queue.append((node, 0)) + + while work_queue: + node, d = work_queue.popleft() + # Check if node has been visited at a lower depth + if d <= visited_ids.get(node.id, depth): + visited_ids[node.id] = d + if d < depth: + # Get outgoing link keys + for outgoing_link in tqdm(node.outgoing_links()): + link_key = _metadata_s_link_key(link=outgoing_link) + if d <= visited_link_keys.get(link_key, depth): + visited_link_keys[link_key] = d + # Query nodes with this link key + query, params = self._get_search_cql_and_params( + columns=CONTENT_ID, + metadata=metadata_filter, + link_keys=[link_key], + ) + target_rows = self._session.execute(query, params) + for row in target_rows: + target_node_id = getattr(row, CONTENT_ID) + if d < visited_ids.get(target_node_id, depth): + # Fetch node by ID + node_query = self._query_id_and_metadata_by_id + node_params = (target_node_id,) + node_rows = self._session.execute(node_query, node_params) + for node_row in node_rows: + target_node = _row_to_node(node_row) + work_queue.append((target_node, d + 1)) - def visit_targets(d: int, targets: Sequence[Any]) -> None: - nonlocal visited_ids - - # target_content_id, tag=(kind,value) - new_nodes_at_next_depth = set() - for target in targets: - content_id = target.target_content_id - if d < visited_ids.get(content_id, depth): - new_nodes_at_next_depth.add(content_id) - - if new_nodes_at_next_depth: - for node_id in new_nodes_at_next_depth: - cq.execute( - self._query_ids_and_link_to_tags_by_id, - parameters=(node_id,), - callback=lambda rows, d=d: visit_nodes(d + 1, rows), - ) + return self._nodes_with_ids(visited_ids.keys()) - query_embedding = self._embedding.embed_query(query) - params = self._get_search_params( - limit=k, - metadata=metadata_filter, - embedding=query_embedding, - ) + def traversal_search( + self, + query: str, + *, + k: int = 4, + depth: int = 1, + metadata_filter: dict[str, Any] = {}, # noqa: B006 + ) -> Iterable[Node]: + """Retrieve documents from this knowledge store. - cq.execute( - traversal_query, - parameters=params, - callback=lambda nodes: visit_nodes(0, nodes), - ) + First, `k` nodes are retrieved using a vector search for the `query` string. + Then, additional nodes are discovered up to the given `depth` from those + starting nodes. + + Args: + query: The query string. + k: The number of Documents to return from the initial vector search. + Defaults to 4. + depth: The maximum depth of edges to traverse. Defaults to 1. + metadata_filter: Optional metadata to filter the results. + + Returns: + Collection of retrieved documents. + """ + visited_ids: dict[str, int] = {} + visited_link_keys: dict[str, int] = {} + + # Locks for thread safety + visited_ids_lock = threading.Lock() + visited_link_keys_lock = threading.Lock() + + work_queue = Queue() + + # Initial traversal query + traversal_query, traversal_params = self._get_search_cql_and_params( + columns=f"{CONTENT_ID}, metadata_blob", + metadata=metadata_filter, + embedding=self._embedding.embed_query(query), + limit=k + ) + + # Execute the initial query synchronously + initial_rows = self._session.execute(traversal_query, traversal_params) + + for row in initial_rows: + node = _row_to_node(row=row) + work_queue.put((node, 0)) + + def worker(): + while True: + try: + node, d = work_queue.get(timeout=1) + except Empty: + # If no work is available after timeout, exit the worker + return + + with visited_ids_lock: + if d <= visited_ids.get(node.id, depth): + visited_ids[node.id] = d + else: + # Node already visited at a lower depth + work_queue.task_done() + continue + + if d < depth: + # Get outgoing link keys + outgoing_links = node.outgoing_links() + for outgoing_link in outgoing_links: + link_key = _metadata_s_link_key(link=outgoing_link) + with visited_link_keys_lock: + if d <= visited_link_keys.get(link_key, depth): + visited_link_keys[link_key] = d + else: + continue # Already visited at lower depth + + # Query nodes with this link key + query, params = self._get_search_cql_and_params( + columns=CONTENT_ID, + metadata=metadata_filter, + link_keys=[link_key], + ) + target_rows = self._session.execute(query, params) + for row in target_rows: + target_node_id = getattr(row, CONTENT_ID) + with visited_ids_lock: + if d < visited_ids.get(target_node_id, depth): + # Fetch node by ID + node_query = self._query_id_and_metadata_by_id + node_params = (target_node_id,) + node_rows = self._session.execute(node_query, node_params) + for node_row in node_rows: + target_node = _row_to_node(node_row) + work_queue.put((target_node, d + 1)) + work_queue.task_done() + + num_workers = 10 # Adjust the number of worker threads as needed + threads = [] + for _ in range(num_workers): + t = threading.Thread(target=worker) + t.start() + threads.append(t) + + # Wait for all items to be processed + work_queue.join() + + # Wait for all worker threads to finish + for t in threads: + t.join() return self._nodes_with_ids(visited_ids.keys()) + def similarity_search( self, embedding: list[float], @@ -691,7 +724,10 @@ def similarity_search( ) -> Iterable[Node]: """Retrieve nodes similar to the given embedding, optionally filtered by metadata.""" # noqa: E501 query, params = self._get_search_cql_and_params( - embedding=embedding, limit=k, metadata=metadata_filter + columns=f"{CONTENT_ID}, text_content, metadata_blob", + embedding=embedding, + limit=k, + metadata=metadata_filter, ) for row in self._session.execute(query, params): @@ -703,7 +739,11 @@ def metadata_search( n: int = 5, ) -> Iterable[Node]: """Retrieve nodes based on their metadata.""" - query, params = self._get_search_cql_and_params(metadata=metadata, limit=n) + query, params = self._get_search_cql_and_params( + columns=f"{CONTENT_ID}, text_content, metadata_blob", + metadata=metadata, + limit=n, + ) for row in self._session.execute(query, params): yield _row_to_node(row) @@ -712,78 +752,75 @@ def get_node(self, content_id: str) -> Node: """Get a node by its id.""" return self._nodes_with_ids(ids=[content_id])[0] - def _get_outgoing_tags( + def _get_outgoing_links( self, source_ids: Iterable[str], - ) -> set[tuple[str, str]]: - """Return the set of outgoing tags for the given source ID(s). + ) -> set[Link]: + """Return the set of outgoing links for the given source ID(s). Args: - source_ids: The IDs of the source nodes to retrieve outgoing tags for. + source_ids: The IDs of the source nodes to retrieve outgoing links for. """ - tags = set() + outgoing_links: Set[Link] = set() def add_sources(rows: Iterable[Any]) -> None: for row in rows: - if row.link_to_tags: - tags.update(row.link_to_tags) + node = _row_to_node(row=row) + outgoing_links.update(node.outgoing_links()) with self._concurrent_queries() as cq: for source_id in source_ids: cq.execute( - self._query_ids_and_link_to_tags_by_id, + self._query_id_and_metadata_by_id, (source_id,), callback=add_sources, ) - return tags + return outgoing_links - def _get_adjacent( + def _get_adjacent_nodes( self, - tags: set[tuple[str, str]], - adjacent_query: PreparedStatement, + outgoing_links: set[Link], query_embedding: list[float], - k_per_tag: int | None = None, - metadata_filter: dict[str, Any] | None = None, - ) -> Iterable[_Edge]: - """Return the target nodes with incoming links from any of the given tags. + k_per_link: int = 10, + metadata_filter: dict[str, Any] = {}, + ) -> Iterable[Node]: + """Return the target nodes with incoming links from any of the given outgoing_links. Args: - tags: The tags to look for links *from*. - adjacent_query: Prepared query for adjacent nodes. + outgoing_links: The links to search for query_embedding: The query embedding. Used to rank target nodes. - k_per_tag: The number of target nodes to fetch for each outgoing tag. + k_per_link: The number of target nodes to fetch for each outgoing link. metadata_filter: Optional metadata to filter the results. Returns: List of adjacent edges. """ - targets: dict[str, _Edge] = {} + targets: dict[str, Node] = {} + + columns = f"{CONTENT_ID}, text_embedding, metadata_blob" def add_targets(rows: Iterable[Any]) -> None: - # TODO: Figure out how to use the "kind" on the edge. - # This is tricky, since we currently issue one query for anything - # adjacent via any kind, and we don't have enough information to - # determine which kind(s) a given target was reached from. + nonlocal targets + for row in rows: - if row.content_id not in targets: - targets[row.content_id] = _Edge( - target_content_id=row.content_id, - target_text_embedding=row.text_embedding, - target_link_to_tags=set(row.link_to_tags or []), - ) + target_node = _row_to_node(row) + if target_node.id not in targets: + targets[target_node.id] = target_node with self._concurrent_queries() as cq: - for kind, value in tags: - params = self._get_search_params( - limit=k_per_tag or 10, + for outgoing_link in outgoing_links: + link_key = _metadata_s_link_key(link=outgoing_link) + query, params = self._get_search_cql_and_params( + columns=columns, + limit=k_per_link, metadata=metadata_filter, embedding=query_embedding, - link_from_tags=(kind, value), + link_keys=[link_key] ) cq.execute( - query=adjacent_query, + query=query, parameters=params, callback=add_targets, ) @@ -848,21 +885,13 @@ def _coerce_string(value: Any) -> str: def _extract_where_clause_cql( self, - has_id: bool = False, metadata_keys: Sequence[str] = (), - has_link_from_tags: bool = False, ) -> str: wc_blocks: list[str] = [] - if has_id: - wc_blocks.append("content_id == ?") - - if has_link_from_tags: - wc_blocks.append("link_from_tags CONTAINS (?, ?)") - for key in sorted(metadata_keys): if _is_metadata_field_indexed(key, self._metadata_indexing_policy): - wc_blocks.append(f"metadata_s['{key}'] = ?") + wc_blocks.append(f"metadata_s['{key}'] = %s") else: msg = "Non-indexed metadata fields cannot be used in queries." raise ValueError(msg) @@ -875,14 +904,9 @@ def _extract_where_clause_cql( def _extract_where_clause_params( self, metadata: dict[str, Any], - link_from_tags: tuple[str, str] | None = None, ) -> list[Any]: params: list[Any] = [] - if link_from_tags is not None: - params.append(link_from_tags[0]) - params.append(link_from_tags[1]) - for key, value in sorted(metadata.items()): if _is_metadata_field_indexed(key, self._metadata_indexing_policy): params.append(self._coerce_string(value=value)) @@ -892,22 +916,28 @@ def _extract_where_clause_params( return params - def _get_search_cql( + def _get_search_cql_and_params( self, - has_limit: bool = False, - columns: str | None = CONTENT_COLUMNS, - metadata_keys: Sequence[str] = (), - has_id: bool = False, - has_embedding: bool = False, - has_link_from_tags: bool = False, - ) -> PreparedStatement: - where_clause = self._extract_where_clause_cql( - has_id=has_id, - metadata_keys=metadata_keys, - has_link_from_tags=has_link_from_tags, - ) - limit_clause = " LIMIT ?" if has_limit else "" - order_clause = " ORDER BY text_embedding ANN OF ?" if has_embedding else "" + columns: str, + limit: int | None = None, + metadata: dict[str, Any] | None = None, + embedding: list[float] | None = None, + link_keys: list[str] | None = None, + ) -> tuple[PreparedStatement|SimpleStatement, tuple[Any, ...]]: + if link_keys is not None: + if metadata is None: + metadata = {} + else: + # don't add link search to original metadata dict + metadata = metadata.copy() + for link_key in link_keys: + metadata[link_key] = _metadata_s_link_value() + + metadata_keys = list(metadata.keys()) if metadata else [] + + where_clause = self._extract_where_clause_cql(metadata_keys=metadata_keys) + limit_clause = " LIMIT ?" if limit is not None else "" + order_clause = " ORDER BY text_embedding ANN OF ?" if embedding is not None else "" select_cql = SELECT_CQL_TEMPLATE.format( columns=columns, @@ -917,50 +947,18 @@ def _get_search_cql( limit_clause=limit_clause, ) - if select_cql in self._prepared_query_cache: - return self._prepared_query_cache[select_cql] - - prepared_query = self._session.prepare(select_cql) - prepared_query.consistency_level = ConsistencyLevel.ONE - self._prepared_query_cache[select_cql] = prepared_query - - return prepared_query - - def _get_search_params( - self, - limit: int | None = None, - metadata: dict[str, Any] | None = None, - embedding: list[float] | None = None, - link_from_tags: tuple[str, str] | None = None, - ) -> tuple[PreparedStatement, tuple[Any, ...]]: - where_params = self._extract_where_clause_params( - metadata=metadata or {}, link_from_tags=link_from_tags - ) - + where_params = self._extract_where_clause_params(metadata=metadata or {}) limit_params = [limit] if limit is not None else [] order_params = [embedding] if embedding is not None else [] - return tuple(list(where_params) + order_params + limit_params) + params = tuple(list(where_params) + order_params + limit_params) - def _get_search_cql_and_params( - self, - limit: int | None = None, - columns: str | None = CONTENT_COLUMNS, - metadata: dict[str, Any] | None = None, - embedding: list[float] | None = None, - link_from_tags: tuple[str, str] | None = None, - ) -> tuple[PreparedStatement, tuple[Any, ...]]: - query = self._get_search_cql( - has_limit=limit is not None, - columns=columns, - metadata_keys=list(metadata.keys()) if metadata else (), - has_embedding=embedding is not None, - has_link_from_tags=link_from_tags is not None, - ) - params = self._get_search_params( - limit=limit, - metadata=metadata, - embedding=embedding, - link_from_tags=link_from_tags, - ) - return query, params + if len(metadata_keys) > 0: + return SimpleStatement(query_string=select_cql, fetch_size=100), params + elif select_cql in self._prepared_query_cache: + return self._prepared_query_cache[select_cql], params + else: + prepared_query = self._session.prepare(select_cql) + prepared_query.consistency_level = ConsistencyLevel.ONE + self._prepared_query_cache[select_cql] = prepared_query + return prepared_query, params diff --git a/libs/knowledge-store/ragstack_knowledge_store/graph_store_tags.py b/libs/knowledge-store/ragstack_knowledge_store/graph_store_tags.py new file mode 100644 index 000000000..90f7dfb5f --- /dev/null +++ b/libs/knowledge-store/ragstack_knowledge_store/graph_store_tags.py @@ -0,0 +1,972 @@ +from __future__ import annotations + +import json +import logging +import re +import secrets +import sys +from collections import deque +from collections.abc import Iterable +from dataclasses import asdict, dataclass, field, is_dataclass +from enum import Enum +from typing import ( + TYPE_CHECKING, + Any, + Sequence, + Set, + Union, + cast, +) + +from tqdm import tqdm + +from cassandra.cluster import ConsistencyLevel, PreparedStatement, Session, SimpleStatement +from cassio.config import check_resolve_keyspace, check_resolve_session +from typing_extensions import assert_never + +from ._mmr_helper import MmrHelper +from .concurrency import ConcurrentQueries +from .links import Link + +from concurrent.futures import ThreadPoolExecutor +from queue import Queue, Empty +import threading + +if TYPE_CHECKING: + from .embedding_model import EmbeddingModel + +logger = logging.getLogger(__name__) + +CONTENT_ID = "content_id" + +SELECT_CQL_TEMPLATE = ( + "SELECT {columns} FROM {table_name}{where_clause}{order_clause}{limit_clause}" +) + + +@dataclass +class Node: + """Node in the GraphStore.""" + + text: str + """Text contained by the node.""" + id: str | None = None + """Unique ID for the node. Will be generated by the GraphStore if not set.""" + embedding: list[float] = field(default_factory=list) + """Vector embedding of the text""" + metadata: dict[str, Any] = field(default_factory=dict) + """Metadata for the node.""" + links: set[Link] = field(default_factory=set) + """All the links for the node.""" + + def incoming_links(self) -> set[Link]: + return set([l for l in self.links if (l.direction in ["in", "bidir"])]) + + def outgoing_links(self) -> set[Link]: + return set([l for l in self.links if (l.direction in ["out", "bidir"])]) + + +class SetupMode(Enum): + """Mode used to create the Cassandra table.""" + + SYNC = 1 + ASYNC = 2 + OFF = 3 + + +class MetadataIndexingMode(Enum): + """Mode used to index metadata.""" + + DEFAULT_TO_UNSEARCHABLE = 1 + DEFAULT_TO_SEARCHABLE = 2 + + +MetadataIndexingType = Union[tuple[str, Iterable[str]], str] +MetadataIndexingPolicy = tuple[MetadataIndexingMode, set[str]] + + +def _is_metadata_field_indexed(field_name: str, policy: MetadataIndexingPolicy) -> bool: + p_mode, p_fields = policy + if p_mode == MetadataIndexingMode.DEFAULT_TO_UNSEARCHABLE: + return field_name in p_fields + if p_mode == MetadataIndexingMode.DEFAULT_TO_SEARCHABLE: + return field_name not in p_fields + assert_never(p_mode) + + +def _serialize_metadata(md: dict[str, Any]) -> str: + if isinstance(md.get("links"), set): + md = md.copy() + md["links"] = list(md["links"]) + return json.dumps(md) + + +def _serialize_links(links: set[Link]) -> str: + class SetAndLinkEncoder(json.JSONEncoder): + def default(self, obj: Any) -> Any: + if not isinstance(obj, type) and is_dataclass(obj): + return asdict(obj) + + if isinstance(obj, Iterable): + return list(obj) + + # Let the base class default method raise the TypeError + return super().default(obj) + + return json.dumps(list(links), cls=SetAndLinkEncoder) + + +def _deserialize_metadata(json_blob: str | None) -> dict[str, Any]: + # We don't need to convert the links list back to a set -- it will be + # converted when accessed, if needed. + return cast(dict[str, Any], json.loads(json_blob or "")) + + +def _deserialize_links(json_blob: str | None) -> set[Link]: + return { + Link(kind=link["kind"], direction=link["direction"], tag=link["tag"]) + for link in cast(list[dict[str, Any]], json.loads(json_blob or "")) + } + +def _tag_s_link_key(link: Link) -> str: + return "link_from_" + json.dumps({"kind": link.kind, "tag": link.tag}) + +def _row_to_node(row: Any) -> Node: + if hasattr(row, "metadata_blob"): + metadata_blob = getattr(row, "metadata_blob") + metadata = _deserialize_metadata(metadata_blob) + links: set[Link] = _deserialize_links(metadata.get("links")) + metadata["links"] = links + else: + metadata = {} + links = set() + return Node( + id=getattr(row, CONTENT_ID, ""), + embedding=getattr(row, "text_embedding", []), + text=getattr(row, "text_content", ""), + metadata=metadata, + links=links, + ) + + +_CQL_IDENTIFIER_PATTERN = re.compile(r"[a-zA-Z][a-zA-Z0-9_]*") + + +class GraphStore: + """A hybrid vector-and-graph store backed by Cassandra. + + Document chunks support vector-similarity search as well as edges linking + documents based on structural and semantic properties. + + Args: + embedding: The embeddings to use for the document content. + setup_mode: Mode used to create the Cassandra table (SYNC, + ASYNC or OFF). + """ + + def __init__( + self, + embedding: EmbeddingModel, + *, + node_table: str = "graph_nodes", + targets_table: str = "", + session: Session | None = None, + keyspace: str | None = None, + setup_mode: SetupMode = SetupMode.SYNC, + metadata_indexing: MetadataIndexingType = "all", + insert_timeout: float = 30.0, + ): + self._insert_timeout = insert_timeout + if targets_table: + logger.warning( + "The 'targets_table' parameter is deprecated " + "and will be removed in future versions." + ) + + session = check_resolve_session(session) + keyspace = check_resolve_keyspace(keyspace) + + if not _CQL_IDENTIFIER_PATTERN.fullmatch(keyspace): + msg = f"Invalid keyspace: {keyspace}" + raise ValueError(msg) + + if not _CQL_IDENTIFIER_PATTERN.fullmatch(node_table): + msg = f"Invalid node table name: {node_table}" + raise ValueError(msg) + + self._embedding = embedding + self._node_table = node_table + self._session = session + self._keyspace = keyspace + self._prepared_query_cache: dict[str, PreparedStatement] = {} + + self._metadata_indexing_policy = self._normalize_metadata_indexing_policy( + metadata_indexing=metadata_indexing, + ) + + if setup_mode == SetupMode.SYNC: + self._apply_schema() + elif setup_mode != SetupMode.OFF: + msg = ( + f"Invalid setup mode {setup_mode.name}. " + "Only SYNC and OFF are supported at the moment" + ) + raise ValueError(msg) + + # TODO: Parent ID / source ID / etc. + self._insert_passage = session.prepare( + f""" + INSERT INTO {keyspace}.{node_table} ( + {CONTENT_ID}, text_content, text_embedding, metadata_blob, metadata_s, tag_s + ) VALUES (?, ?, ?, ?, ?, ?) + """ # noqa: S608 + ) + + self._query_by_id = session.prepare( + f""" + SELECT {CONTENT_ID}, text_content, metadata_blob + FROM {keyspace}.{node_table} + WHERE {CONTENT_ID} = ? + """ # noqa: S608 + ) + + self._query_id_and_metadata_by_id = session.prepare( + f""" + SELECT {CONTENT_ID}, metadata_blob + FROM {keyspace}.{node_table} + WHERE {CONTENT_ID} = ? + """ # noqa: S608 + ) + + def table_name(self) -> str: + """Returns the fully qualified table name.""" + return f"{self._keyspace}.{self._node_table}" + + def _apply_schema(self) -> None: + """Apply the schema to the database.""" + embedding_dim = len(self._embedding.embed_query("Test Query")) + self._session.execute(f""" + CREATE TABLE IF NOT EXISTS {self.table_name()} ( + {CONTENT_ID} TEXT, + text_content TEXT, + text_embedding VECTOR, + metadata_blob TEXT, + metadata_s MAP, + tag_s SET, + PRIMARY KEY ({CONTENT_ID}) + ) + """) + + # Index on text_embedding (for similarity search) + self._session.execute(f""" + CREATE CUSTOM INDEX IF NOT EXISTS {self._node_table}_text_embedding_index + ON {self.table_name()}(text_embedding) + USING 'StorageAttachedIndex'; + """) + + self._session.execute(f""" + CREATE CUSTOM INDEX IF NOT EXISTS {self._node_table}_tag_s_index + ON {self.table_name()}(tag_s) + USING 'StorageAttachedIndex'; + """) + + self._session.execute(f""" + CREATE CUSTOM INDEX IF NOT EXISTS {self._node_table}_metadata_s_index + ON {self.table_name()}(ENTRIES(metadata_s)) + USING 'StorageAttachedIndex'; + """) + + def _concurrent_queries(self) -> ConcurrentQueries: + return ConcurrentQueries(self._session) + + # TODO: Async (aadd_nodes) + def add_nodes( + self, + nodes: Iterable[Node], + ) -> Iterable[str]: + """Add nodes to the graph store.""" + node_ids: list[str] = [] + texts: list[str] = [] + metadata_list: list[dict[str, Any]] = [] + incoming_links_list: list[set[Link]] = [] + for node in nodes: + if not node.id: + node_ids.append(secrets.token_hex(8)) + else: + node_ids.append(node.id) + texts.append(node.text) + combined_metadata = node.metadata.copy() + combined_metadata["links"] = _serialize_links(node.links) + metadata_list.append(combined_metadata) + incoming_links_list.append(node.incoming_links()) + + text_embeddings = self._embedding.embed_texts(texts) + + with self._concurrent_queries() as cq: + tuples = zip(node_ids, texts, text_embeddings, metadata_list, incoming_links_list) + for node_id, text, text_embedding, metadata, incoming_links in tuples: + + metadata_s = { + k: self._coerce_string(v) + for k, v in metadata.items() + if _is_metadata_field_indexed(k, self._metadata_indexing_policy) + } + + tag_s = [_tag_s_link_key(l) for l in incoming_links] + + + metadata_blob = _serialize_metadata(metadata) + + cq.execute( + self._insert_passage, + parameters=( + node_id, + text, + text_embedding, + metadata_blob, + metadata_s, + tag_s, + ), + timeout=self._insert_timeout, + ) + + return node_ids + + def _nodes_with_ids( + self, + ids: Iterable[str], + ) -> list[Node]: + results: dict[str, Node | None] = {} + with self._concurrent_queries() as cq: + + def node_callback(rows: Iterable[Any]) -> None: + # Should always be exactly one row here. We don't need to check + # 1. The query is for a `ID == ?` query on the primary key. + # 2. If it doesn't exist, the `get_result` method below will + # raise an exception indicating the ID doesn't exist. + for row in rows: + results[row.content_id] = _row_to_node(row) + + for node_id in ids: + if node_id not in results: + # Mark this node ID as being fetched. + results[node_id] = None + cq.execute( + self._query_by_id, parameters=(node_id,), callback=node_callback + ) + + def get_result(node_id: str) -> Node: + if (result := results[node_id]) is None: + msg = f"No node with ID '{node_id}'" + raise ValueError(msg) + return result + + return [get_result(node_id) for node_id in ids] + + def mmr_traversal_search( + self, + query: str, + *, + initial_roots: Sequence[str] = (), + k: int = 4, + depth: int = 2, + fetch_k: int = 100, + adjacent_k: int = 10, + lambda_mult: float = 0.5, + score_threshold: float = float("-inf"), + metadata_filter: dict[str, Any] = {}, # noqa: B006 + ) -> Iterable[Node]: + """Retrieve documents from this graph store using MMR-traversal. + + This strategy first retrieves the top `fetch_k` results by similarity to + the question. It then selects the top `k` results based on + maximum-marginal relevance using the given `lambda_mult`. + + At each step, it considers the (remaining) documents from `fetch_k` as + well as any documents connected by edges to a selected document + retrieved based on similarity (a "root"). + + Args: + query: The query string to search for. + initial_roots: Optional list of document IDs to use for initializing search. + The top `adjacent_k` nodes adjacent to each initial root will be + included in the set of initial candidates. To fetch only in the + neighborhood of these nodes, set `ftech_k = 0`. + k: Number of Documents to return. Defaults to 4. + fetch_k: Number of initial Documents to fetch via similarity. + Will be added to the nodes adjacent to `initial_roots`. + Defaults to 100. + adjacent_k: Number of adjacent Documents to fetch. + Defaults to 10. + depth: Maximum depth of a node (number of edges) from a node + retrieved via similarity. Defaults to 2. + lambda_mult: Number between 0 and 1 that determines the degree + of diversity among the results with 0 corresponding to maximum + diversity and 1 to minimum diversity. Defaults to 0.5. + score_threshold: Only documents with a score greater than or equal + this threshold will be chosen. Defaults to -infinity. + metadata_filter: Optional metadata to filter the results. + """ + query_embedding = self._embedding.embed_query(query) + helper = MmrHelper( + k=k, + query_embedding=query_embedding, + lambda_mult=lambda_mult, + score_threshold=score_threshold, + ) + + # For each unselected node, stores the outgoing links. + outgoing_link_keys_map: dict[str, set[str]] = {} + visited_link_keys: set[str] = set() + + def fetch_neighborhood(neighborhood: Sequence[str]) -> None: + nonlocal outgoing_link_keys_map + nonlocal visited_link_keys + + # Put the neighborhood into the outgoing links, to avoid adding it + # to the candidate set in the future. + outgoing_link_keys_map.update({content_id: set() for content_id in neighborhood}) + + # Initialize the visited_links with the set of outgoing links from the + # neighborhood. This prevents re-visiting them. + visited_link_keys = self._get_outgoing_link_keys(neighborhood) + + # Call `self._get_adjacent` to fetch the candidates. + adjacent_nodes: Iterable[Node] = self._get_adjacent_nodes( + outgoing_link_keys=visited_link_keys, + query_embedding=query_embedding, + k_per_tag=adjacent_k, + metadata_filter=metadata_filter, + ) + + candidates: dict[str, list[float]] = {} + for node in adjacent_nodes: + if node.id not in outgoing_link_keys_map: + outgoing_link_keys_map[node.id] = node.outgoing_links() + candidates[node.id] = node.embedding + + helper.add_candidates(candidates) + + def fetch_initial_candidates() -> None: + nonlocal outgoing_link_keys_map + nonlocal visited_link_keys + + initial_query, initial_params = self._get_search_cql_and_params( + columns=f"{CONTENT_ID}, text_embedding, metadata_blob", + limit=fetch_k, + metadata=metadata_filter, + embedding=query_embedding, + ) + + rows = self._session.execute( + query=initial_query, parameters=initial_params + ) + candidates: dict[str, list[float]] = {} + for row in rows: + node = _row_to_node(row) + if node.id not in outgoing_link_keys_map: + outgoing_link_keys = [_tag_s_link_key(l) for l in node.outgoing_links()] + outgoing_link_keys_map[node.id] = set(outgoing_link_keys) + candidates[node.id] = node.embedding + helper.add_candidates(candidates) + + if initial_roots: + fetch_neighborhood(initial_roots) + if fetch_k > 0: + fetch_initial_candidates() + + # Tracks the depth of each candidate. + depths = {candidate_id: 0 for candidate_id in helper.candidate_ids()} + + # Select the best item, K times. + for _ in range(k): + selected_id = helper.pop_best() + + if selected_id is None: + break + + next_depth = depths[selected_id] + 1 + if next_depth < depth: + # If the next nodes would not exceed the depth limit, find the + # adjacent nodes. + # + # TODO: For a big performance win, we should track which links we've + # already incorporated. We don't need to issue adjacent queries for + # those. + + # Find the outgoing links linked to from the selected ID. + selected_outgoing_link_keys = outgoing_link_keys_map.pop(selected_id) + + # Don't re-visit already visited links. + outgoing_link_keys_map.difference_update(visited_link_keys) + + # Find the nodes with incoming links from those tags. + adjacent_nodes: Iterable[Node] = self._get_adjacent_nodes( + outgoing_link_keys=selected_outgoing_link_keys, + query_embedding=query_embedding, + k_per_tag=adjacent_k, + metadata_filter=metadata_filter, + ) + + # Record the selected_outgoing_links as visited. + visited_link_keys.update(selected_outgoing_link_keys) + + candidates: dict[str, list[float]] = {} + for node in adjacent_nodes: + if node.id not in outgoing_link_keys_map: + outgoing_link_keys = [_tag_s_link_key(l) for l in node.outgoing_links()] + outgoing_link_keys_map[node.id] = set(outgoing_link_keys) + candidates[node.id] = node.embedding + + if next_depth < depths.get(node.id, depth + 1): + # If this is a new shortest depth, or there was no + # previous depth, update the depths. This ensures that + # when we discover a node we will have the shortest + # depth available. + # + # NOTE: No effort is made to traverse from nodes that + # were previously selected if they become reachable via + # a shorter path via nodes selected later. This is + # currently "intended", but may be worth experimenting + # with. + depths[node.id] = next_depth + helper.add_candidates(candidates) + + return self._nodes_with_ids(helper.selected_ids) + + + + def traversal_search( + self, + query: str, + *, + k: int = 4, + depth: int = 1, + metadata_filter: dict[str, Any] = {}, # noqa: B006 + ) -> Iterable[Node]: + """Retrieve documents from this knowledge store. + + First, `k` nodes are retrieved using a vector search for the `query` string. + Then, additional nodes are discovered up to the given `depth` from those + starting nodes. + + Args: + query: The query string. + k: The number of Documents to return from the initial vector search. + Defaults to 4. + depth: The maximum depth of edges to traverse. Defaults to 1. + metadata_filter: Optional metadata to filter the results. + + Returns: + Collection of retrieved documents. + """ + visited_ids: dict[str, int] = {} + visited_link_keys: dict[str, int] = {} + + work_queue = deque() + + # Initial traversal query + traversal_query, traversal_params = self._get_search_cql_and_params( + columns=f"{CONTENT_ID}, metadata_blob", + metadata=metadata_filter, + embedding=self._embedding.embed_query(query), + limit=k + ) + + # Execute the initial query synchronously + initial_rows = self._session.execute(traversal_query, traversal_params) + + for row in initial_rows: + node = _row_to_node(row=row) + work_queue.append((node, 0)) + + while work_queue: + node, d = work_queue.popleft() + # Check if node has been visited at a lower depth + if d <= visited_ids.get(node.id, depth): + visited_ids[node.id] = d + if d < depth: + # Get outgoing link keys + for outgoing_link in node.outgoing_links(): + link_key = _tag_s_link_key(link=outgoing_link) + if d <= visited_link_keys.get(link_key, depth): + visited_link_keys[link_key] = d + # Query nodes with this link key + query, params = self._get_search_cql_and_params( + columns=CONTENT_ID, + metadata=metadata_filter, + tag=link_key, + ) + target_rows = self._session.execute(query, params) + for row in target_rows: + target_node_id = getattr(row, CONTENT_ID) + if d < visited_ids.get(target_node_id, depth): + # Fetch node by ID + node_query = self._query_id_and_metadata_by_id + node_params = (target_node_id,) + node_rows = self._session.execute(node_query, node_params) + for node_row in node_rows: + target_node = _row_to_node(node_row) + work_queue.append((target_node, d + 1)) + + return self._nodes_with_ids(visited_ids.keys()) + + def traversal_search_async( + self, + query: str, + *, + k: int = 4, + depth: int = 1, + metadata_filter: dict[str, Any] = {}, # noqa: B006 + ) -> Iterable[Node]: + """Retrieve documents from this knowledge store. + + First, `k` nodes are retrieved using a vector search for the `query` string. + Then, additional nodes are discovered up to the given `depth` from those + starting nodes. + + Args: + query: The query string. + k: The number of Documents to return from the initial vector search. + Defaults to 4. + depth: The maximum depth of edges to traverse. Defaults to 1. + metadata_filter: Optional metadata to filter the results. + + Returns: + Collection of retrieved documents. + """ + visited_ids: dict[str, int] = {} + visited_link_keys: dict[str, int] = {} + + # Locks for thread safety + visited_ids_lock = threading.Lock() + visited_link_keys_lock = threading.Lock() + + work_queue = Queue() + + # Initial traversal query + traversal_query, traversal_params = self._get_search_cql_and_params( + columns=f"{CONTENT_ID}, metadata_blob", + metadata=metadata_filter, + embedding=self._embedding.embed_query(query), + limit=k + ) + + # Execute the initial query synchronously + initial_rows = self._session.execute(traversal_query, traversal_params) + + for row in initial_rows: + node = _row_to_node(row=row) + work_queue.put((node, 0)) + + def worker(): + while True: + try: + node, d = work_queue.get(timeout=1) + except Empty: + # If no work is available after timeout, exit the worker + return + + with visited_ids_lock: + if d <= visited_ids.get(node.id, depth): + visited_ids[node.id] = d + else: + # Node already visited at a lower depth + work_queue.task_done() + continue + + if d < depth: + # Get outgoing link keys + outgoing_links = node.outgoing_links() + for outgoing_link in outgoing_links: + link_key = _tag_s_link_key(link=outgoing_link) + with visited_link_keys_lock: + if d <= visited_link_keys.get(link_key, depth): + visited_link_keys[link_key] = d + else: + continue # Already visited at lower depth + + # Query nodes with this link key + query, params = self._get_search_cql_and_params( + columns=CONTENT_ID, + metadata=metadata_filter, + tag=link_key, + ) + target_rows = self._session.execute(query, params) + for row in target_rows: + target_node_id = getattr(row, CONTENT_ID) + with visited_ids_lock: + if d < visited_ids.get(target_node_id, depth): + # Fetch node by ID + node_query = self._query_id_and_metadata_by_id + node_params = (target_node_id,) + node_rows = self._session.execute(node_query, node_params) + for node_row in node_rows: + target_node = _row_to_node(node_row) + work_queue.put((target_node, d + 1)) + work_queue.task_done() + + num_workers = 10 # Adjust the number of worker threads as needed + threads = [] + for _ in range(num_workers): + t = threading.Thread(target=worker) + t.start() + threads.append(t) + + # Wait for all items to be processed + work_queue.join() + + # Wait for all worker threads to finish + for t in threads: + t.join() + + return self._nodes_with_ids(visited_ids.keys()) + + + def similarity_search( + self, + embedding: list[float], + k: int = 4, + metadata_filter: dict[str, Any] = {}, # noqa: B006 + ) -> Iterable[Node]: + """Retrieve nodes similar to the given embedding, optionally filtered by metadata.""" # noqa: E501 + query, params = self._get_search_cql_and_params( + columns=f"{CONTENT_ID}, text_content, metadata_blob", + embedding=embedding, + limit=k, + metadata=metadata_filter, + ) + + for row in self._session.execute(query, params): + yield _row_to_node(row) + + def metadata_search( + self, + metadata: dict[str, Any] = {}, # noqa: B006 + n: int = 5, + ) -> Iterable[Node]: + """Retrieve nodes based on their metadata.""" + query, params = self._get_search_cql_and_params( + columns=f"{CONTENT_ID}, text_content, metadata_blob", + metadata=metadata, + limit=n, + ) + + for row in self._session.execute(query, params): + yield _row_to_node(row) + + def get_node(self, content_id: str) -> Node: + """Get a node by its id.""" + return self._nodes_with_ids(ids=[content_id])[0] + + def _get_outgoing_link_keys( + self, + source_ids: Iterable[str], + ) -> set[str]: + """Return the set of outgoing links for the given source ID(s). + + Args: + source_ids: The IDs of the source nodes to retrieve outgoing links for. + """ + outgoing_links: Set[str] = set() + + def add_sources(rows: Iterable[Any]) -> None: + for row in rows: + node = _row_to_node(row=row) + outgoing_link_keys = [_tag_s_link_key(l) for l in node.outgoing_links()] + outgoing_links.update(outgoing_link_keys) + + with self._concurrent_queries() as cq: + for source_id in source_ids: + cq.execute( + self._query_id_and_metadata_by_id, + (source_id,), + callback=add_sources, + ) + + return outgoing_links + + def _get_adjacent_nodes( + self, + outgoing_link_keys: set[str], + query_embedding: list[float], + k_per_link: int = 10, + metadata_filter: dict[str, Any] = {}, + ) -> Iterable[Node]: + """Return the target nodes with incoming links from any of the given outgoing_links. + + Args: + outgoing_links: The links to search for + query_embedding: The query embedding. Used to rank target nodes. + k_per_link: The number of target nodes to fetch for each outgoing link. + metadata_filter: Optional metadata to filter the results. + + Returns: + List of adjacent edges. + """ + targets: dict[str, Node] = {} + + columns = f"{CONTENT_ID}, text_embedding, metadata_blob" + + def add_targets(rows: Iterable[Any]) -> None: + nonlocal targets + + for row in rows: + target_node = _row_to_node(row) + if target_node.id not in targets: + targets[target_node.id] = target_node + + with self._concurrent_queries() as cq: + for outgoing_link_key in outgoing_link_keys: + query, params = self._get_search_cql_and_params( + columns=columns, + limit=k_per_link, + metadata=metadata_filter, + embedding=query_embedding, + link_key=outgoing_link_key, + ) + + cq.execute( + query=query, + parameters=params, + callback=add_targets, + ) + + # TODO: Consider a combined limit based on the similarity and/or + # predicated MMR score? + return targets.values() + + @staticmethod + def _normalize_metadata_indexing_policy( + metadata_indexing: tuple[str, Iterable[str]] | str, + ) -> MetadataIndexingPolicy: + mode: MetadataIndexingMode + fields: set[str] + # metadata indexing policy normalization: + if isinstance(metadata_indexing, str): + if metadata_indexing.lower() == "all": + mode, fields = (MetadataIndexingMode.DEFAULT_TO_SEARCHABLE, set()) + elif metadata_indexing.lower() == "none": + mode, fields = (MetadataIndexingMode.DEFAULT_TO_UNSEARCHABLE, set()) + else: + msg = f"Unsupported metadata_indexing value '{metadata_indexing}'" + raise ValueError(msg) + else: + # it's a 2-tuple (mode, fields) still to normalize + _mode, _field_spec = metadata_indexing + fields = {_field_spec} if isinstance(_field_spec, str) else set(_field_spec) + if _mode.lower() in { + "default_to_unsearchable", + "allowlist", + "allow", + "allow_list", + }: + mode = MetadataIndexingMode.DEFAULT_TO_UNSEARCHABLE + elif _mode.lower() in { + "default_to_searchable", + "denylist", + "deny", + "deny_list", + }: + mode = MetadataIndexingMode.DEFAULT_TO_SEARCHABLE + else: + msg = f"Unsupported metadata indexing mode specification '{_mode}'" + raise ValueError(msg) + return mode, fields + + @staticmethod + def _coerce_string(value: Any) -> str: + if isinstance(value, str): + return value + if isinstance(value, bool): + # bool MUST come before int in this chain of ifs! + return json.dumps(value) + if isinstance(value, int): + # we don't want to store '1' and '1.0' differently + # for the sake of metadata-filtered retrieval: + return json.dumps(float(value)) + if isinstance(value, float) or value is None: + return json.dumps(value) + # when all else fails ... + return str(value) + + def _extract_where_clause_cql( + self, + metadata: dict[str, Any] = {}, + tag: str | None = None + ) -> str: + wc_blocks: list[str] = [] + + # Use SimpleStatements if querying for tags + item_placeholder = "%s" if tag is not None else "?" + + if tag is not None: + wc_blocks.append(f"tag_s CONTAINS {item_placeholder}") + + for key in sorted(metadata.keys()): + if _is_metadata_field_indexed(key, self._metadata_indexing_policy): + wc_blocks.append(f"metadata_s['{key}'] = {item_placeholder}") + else: + msg = "Non-indexed metadata fields cannot be used in queries." + raise ValueError(msg) + + if len(wc_blocks) == 0: + return "" + + return " WHERE " + " AND ".join(wc_blocks) + + def _extract_where_clause_params( + self, + metadata: dict[str, Any], + tag: str | None = None + ) -> list[Any]: + params: list[Any] = [] + + if tag is not None: + params.append(tag) + + for key, value in sorted(metadata.items()): + if _is_metadata_field_indexed(key, self._metadata_indexing_policy): + params.append(self._coerce_string(value=value)) + else: + msg = "Non-indexed metadata fields cannot be used in queries." + raise ValueError(msg) + + return params + + def _get_search_cql_and_params( + self, + columns: str, + limit: int | None = None, + metadata: dict[str, Any] | None = None, + embedding: list[float] | None = None, + tag: str | None = None, + ) -> tuple[PreparedStatement|SimpleStatement, tuple[Any, ...]]: + + where_clause = self._extract_where_clause_cql(metadata=metadata, tag=tag) + limit_clause = " LIMIT ?" if limit is not None else "" + order_clause = " ORDER BY text_embedding ANN OF ?" if embedding is not None else "" + + select_cql = SELECT_CQL_TEMPLATE.format( + columns=columns, + table_name=self.table_name(), + where_clause=where_clause, + order_clause=order_clause, + limit_clause=limit_clause, + ) + + where_params = self._extract_where_clause_params(metadata=metadata, tag=tag) + limit_params = [limit] if limit is not None else [] + order_params = [embedding] if embedding is not None else [] + + params = tuple(list(where_params) + order_params + limit_params) + + if tag is not None: + return SimpleStatement(query_string=select_cql), params + elif select_cql in self._prepared_query_cache: + return self._prepared_query_cache[select_cql], params + else: + prepared_query = self._session.prepare(select_cql) + prepared_query.consistency_level = ConsistencyLevel.ONE + self._prepared_query_cache[select_cql] = prepared_query + return prepared_query, params diff --git a/libs/knowledge-store/ragstack_knowledge_store/graph_store_tags_async.py b/libs/knowledge-store/ragstack_knowledge_store/graph_store_tags_async.py new file mode 100644 index 000000000..c3657c053 --- /dev/null +++ b/libs/knowledge-store/ragstack_knowledge_store/graph_store_tags_async.py @@ -0,0 +1,939 @@ +from __future__ import annotations + +import json +import logging +import re +import secrets +import sys +from collections import deque +from collections.abc import Iterable +from dataclasses import asdict, dataclass, field, is_dataclass +from enum import Enum +from typing import ( + TYPE_CHECKING, + Any, + Sequence, + Set, + Union, + cast, +) + +from tqdm import tqdm + +from cassandra.cluster import ConsistencyLevel, PreparedStatement, Session, SimpleStatement +from cassio.config import check_resolve_keyspace, check_resolve_session +from typing_extensions import assert_never + +from ._mmr_helper import MmrHelper +from .concurrency import ConcurrentQueries +from .links import Link + +from concurrent.futures import ThreadPoolExecutor +from queue import Queue, Empty +import threading + +if TYPE_CHECKING: + from .embedding_model import EmbeddingModel + +logger = logging.getLogger(__name__) + +CONTENT_ID = "content_id" + +SELECT_CQL_TEMPLATE = ( + "SELECT {columns} FROM {table_name}{where_clause}{order_clause}{limit_clause}" +) + + +@dataclass +class Node: + """Node in the GraphStore.""" + + text: str + """Text contained by the node.""" + id: str | None = None + """Unique ID for the node. Will be generated by the GraphStore if not set.""" + embedding: list[float] = field(default_factory=list) + """Vector embedding of the text""" + metadata: dict[str, Any] = field(default_factory=dict) + """Metadata for the node.""" + links: set[Link] = field(default_factory=set) + """All the links for the node.""" + + def incoming_links(self) -> set[Link]: + return set([l for l in self.links if (l.direction in ["in", "bidir"])]) + + def outgoing_links(self) -> set[Link]: + return set([l for l in self.links if (l.direction in ["out", "bidir"])]) + + +class SetupMode(Enum): + """Mode used to create the Cassandra table.""" + + SYNC = 1 + ASYNC = 2 + OFF = 3 + + +class MetadataIndexingMode(Enum): + """Mode used to index metadata.""" + + DEFAULT_TO_UNSEARCHABLE = 1 + DEFAULT_TO_SEARCHABLE = 2 + + +MetadataIndexingType = Union[tuple[str, Iterable[str]], str] +MetadataIndexingPolicy = tuple[MetadataIndexingMode, set[str]] + + +def _is_metadata_field_indexed(field_name: str, policy: MetadataIndexingPolicy) -> bool: + p_mode, p_fields = policy + if p_mode == MetadataIndexingMode.DEFAULT_TO_UNSEARCHABLE: + return field_name in p_fields + if p_mode == MetadataIndexingMode.DEFAULT_TO_SEARCHABLE: + return field_name not in p_fields + assert_never(p_mode) + + +def _serialize_metadata(md: dict[str, Any]) -> str: + if isinstance(md.get("links"), set): + md = md.copy() + md["links"] = list(md["links"]) + return json.dumps(md) + + +def _serialize_links(links: set[Link]) -> str: + class SetAndLinkEncoder(json.JSONEncoder): + def default(self, obj: Any) -> Any: + if not isinstance(obj, type) and is_dataclass(obj): + return asdict(obj) + + if isinstance(obj, Iterable): + return list(obj) + + # Let the base class default method raise the TypeError + return super().default(obj) + + return json.dumps(list(links), cls=SetAndLinkEncoder) + + +def _deserialize_metadata(json_blob: str | None) -> dict[str, Any]: + # We don't need to convert the links list back to a set -- it will be + # converted when accessed, if needed. + return cast(dict[str, Any], json.loads(json_blob or "")) + + +def _deserialize_links(json_blob: str | None) -> set[Link]: + return { + Link(kind=link["kind"], direction=link["direction"], tag=link["tag"]) + for link in cast(list[dict[str, Any]], json.loads(json_blob or "")) + } + +def _tag_s_link_key(link: Link) -> str: + return "link_from_" + json.dumps({"kind": link.kind, "tag": link.tag}) + +def _row_to_node(row: Any) -> Node: + if hasattr(row, "metadata_blob"): + metadata_blob = getattr(row, "metadata_blob") + metadata = _deserialize_metadata(metadata_blob) + links: set[Link] = _deserialize_links(metadata.get("links")) + metadata["links"] = links + else: + metadata = {} + links = set() + return Node( + id=getattr(row, CONTENT_ID, ""), + embedding=getattr(row, "text_embedding", []), + text=getattr(row, "text_content", ""), + metadata=metadata, + links=links, + ) + + +_CQL_IDENTIFIER_PATTERN = re.compile(r"[a-zA-Z][a-zA-Z0-9_]*") + + +class GraphStore: + """A hybrid vector-and-graph store backed by Cassandra. + + Document chunks support vector-similarity search as well as edges linking + documents based on structural and semantic properties. + + Args: + embedding: The embeddings to use for the document content. + setup_mode: Mode used to create the Cassandra table (SYNC, + ASYNC or OFF). + """ + + def __init__( + self, + embedding: EmbeddingModel, + *, + node_table: str = "graph_nodes", + targets_table: str = "", + session: Session | None = None, + keyspace: str | None = None, + setup_mode: SetupMode = SetupMode.SYNC, + metadata_indexing: MetadataIndexingType = "all", + insert_timeout: float = 30.0, + ): + self._insert_timeout = insert_timeout + if targets_table: + logger.warning( + "The 'targets_table' parameter is deprecated " + "and will be removed in future versions." + ) + + session = check_resolve_session(session) + keyspace = check_resolve_keyspace(keyspace) + + if not _CQL_IDENTIFIER_PATTERN.fullmatch(keyspace): + msg = f"Invalid keyspace: {keyspace}" + raise ValueError(msg) + + if not _CQL_IDENTIFIER_PATTERN.fullmatch(node_table): + msg = f"Invalid node table name: {node_table}" + raise ValueError(msg) + + self._embedding = embedding + self._node_table = node_table + self._session = session + self._keyspace = keyspace + self._prepared_query_cache: dict[str, PreparedStatement] = {} + + self._metadata_indexing_policy = self._normalize_metadata_indexing_policy( + metadata_indexing=metadata_indexing, + ) + + if setup_mode == SetupMode.SYNC: + self._apply_schema() + elif setup_mode != SetupMode.OFF: + msg = ( + f"Invalid setup mode {setup_mode.name}. " + "Only SYNC and OFF are supported at the moment" + ) + raise ValueError(msg) + + # TODO: Parent ID / source ID / etc. + self._insert_passage = session.prepare( + f""" + INSERT INTO {keyspace}.{node_table} ( + {CONTENT_ID}, text_content, text_embedding, metadata_blob, metadata_s, tag_s + ) VALUES (?, ?, ?, ?, ?, ?) + """ # noqa: S608 + ) + + self._query_by_id = session.prepare( + f""" + SELECT {CONTENT_ID}, text_content, metadata_blob + FROM {keyspace}.{node_table} + WHERE {CONTENT_ID} = ? + """ # noqa: S608 + ) + + self._query_id_and_metadata_by_id = session.prepare( + f""" + SELECT {CONTENT_ID}, metadata_blob + FROM {keyspace}.{node_table} + WHERE {CONTENT_ID} = ? + """ # noqa: S608 + ) + + def table_name(self) -> str: + """Returns the fully qualified table name.""" + return f"{self._keyspace}.{self._node_table}" + + def _apply_schema(self) -> None: + """Apply the schema to the database.""" + embedding_dim = len(self._embedding.embed_query("Test Query")) + self._session.execute(f""" + CREATE TABLE IF NOT EXISTS {self.table_name()} ( + {CONTENT_ID} TEXT, + text_content TEXT, + text_embedding VECTOR, + metadata_blob TEXT, + metadata_s MAP, + tag_s SET, + PRIMARY KEY ({CONTENT_ID}) + ) + """) + + # Index on text_embedding (for similarity search) + self._session.execute(f""" + CREATE CUSTOM INDEX IF NOT EXISTS {self._node_table}_text_embedding_index + ON {self.table_name()}(text_embedding) + USING 'StorageAttachedIndex'; + """) + + self._session.execute(f""" + CREATE CUSTOM INDEX IF NOT EXISTS {self._node_table}_tag_s_index + ON {self.table_name()}(tag_s) + USING 'StorageAttachedIndex'; + """) + + self._session.execute(f""" + CREATE CUSTOM INDEX IF NOT EXISTS {self._node_table}_metadata_s_index + ON {self.table_name()}(ENTRIES(metadata_s)) + USING 'StorageAttachedIndex'; + """) + + def _concurrent_queries(self) -> ConcurrentQueries: + return ConcurrentQueries(self._session) + + # TODO: Async (aadd_nodes) + def add_nodes( + self, + nodes: Iterable[Node], + ) -> Iterable[str]: + """Add nodes to the graph store.""" + node_ids: list[str] = [] + texts: list[str] = [] + metadata_list: list[dict[str, Any]] = [] + incoming_links_list: list[set[Link]] = [] + for node in nodes: + if not node.id: + node_ids.append(secrets.token_hex(8)) + else: + node_ids.append(node.id) + texts.append(node.text) + combined_metadata = node.metadata.copy() + combined_metadata["links"] = _serialize_links(node.links) + metadata_list.append(combined_metadata) + incoming_links_list.append(node.incoming_links()) + + text_embeddings = self._embedding.embed_texts(texts) + + with self._concurrent_queries() as cq: + tuples = zip(node_ids, texts, text_embeddings, metadata_list, incoming_links_list) + for node_id, text, text_embedding, metadata, incoming_links in tuples: + + metadata_s = { + k: self._coerce_string(v) + for k, v in metadata.items() + if _is_metadata_field_indexed(k, self._metadata_indexing_policy) + } + + tag_s = [_tag_s_link_key(l) for l in incoming_links] + + + metadata_blob = _serialize_metadata(metadata) + + cq.execute( + self._insert_passage, + parameters=( + node_id, + text, + text_embedding, + metadata_blob, + metadata_s, + tag_s, + ), + timeout=self._insert_timeout, + ) + + return node_ids + + def _nodes_with_ids( + self, + ids: Iterable[str], + ) -> list[Node]: + results: dict[str, Node | None] = {} + with self._concurrent_queries() as cq: + + def node_callback(rows: Iterable[Any]) -> None: + # Should always be exactly one row here. We don't need to check + # 1. The query is for a `ID == ?` query on the primary key. + # 2. If it doesn't exist, the `get_result` method below will + # raise an exception indicating the ID doesn't exist. + for row in rows: + results[row.content_id] = _row_to_node(row) + + for node_id in ids: + if node_id not in results: + # Mark this node ID as being fetched. + results[node_id] = None + cq.execute( + self._query_by_id, parameters=(node_id,), callback=node_callback + ) + + def get_result(node_id: str) -> Node: + if (result := results[node_id]) is None: + msg = f"No node with ID '{node_id}'" + raise ValueError(msg) + return result + + return [get_result(node_id) for node_id in ids] + + def mmr_traversal_search( + self, + query: str, + *, + initial_roots: Sequence[str] = (), + k: int = 4, + depth: int = 2, + fetch_k: int = 100, + adjacent_k: int = 10, + lambda_mult: float = 0.5, + score_threshold: float = float("-inf"), + metadata_filter: dict[str, Any] = {}, # noqa: B006 + ) -> Iterable[Node]: + """Retrieve documents from this graph store using MMR-traversal. + + This strategy first retrieves the top `fetch_k` results by similarity to + the question. It then selects the top `k` results based on + maximum-marginal relevance using the given `lambda_mult`. + + At each step, it considers the (remaining) documents from `fetch_k` as + well as any documents connected by edges to a selected document + retrieved based on similarity (a "root"). + + Args: + query: The query string to search for. + initial_roots: Optional list of document IDs to use for initializing search. + The top `adjacent_k` nodes adjacent to each initial root will be + included in the set of initial candidates. To fetch only in the + neighborhood of these nodes, set `ftech_k = 0`. + k: Number of Documents to return. Defaults to 4. + fetch_k: Number of initial Documents to fetch via similarity. + Will be added to the nodes adjacent to `initial_roots`. + Defaults to 100. + adjacent_k: Number of adjacent Documents to fetch. + Defaults to 10. + depth: Maximum depth of a node (number of edges) from a node + retrieved via similarity. Defaults to 2. + lambda_mult: Number between 0 and 1 that determines the degree + of diversity among the results with 0 corresponding to maximum + diversity and 1 to minimum diversity. Defaults to 0.5. + score_threshold: Only documents with a score greater than or equal + this threshold will be chosen. Defaults to -infinity. + metadata_filter: Optional metadata to filter the results. + """ + query_embedding = self._embedding.embed_query(query) + helper = MmrHelper( + k=k, + query_embedding=query_embedding, + lambda_mult=lambda_mult, + score_threshold=score_threshold, + ) + + # For each unselected node, stores the outgoing links. + outgoing_link_keys_map: dict[str, set[str]] = {} + visited_link_keys: set[str] = set() + + def fetch_neighborhood(neighborhood: Sequence[str]) -> None: + nonlocal outgoing_link_keys_map + nonlocal visited_link_keys + + # Put the neighborhood into the outgoing links, to avoid adding it + # to the candidate set in the future. + outgoing_link_keys_map.update({content_id: set() for content_id in neighborhood}) + + # Initialize the visited_links with the set of outgoing links from the + # neighborhood. This prevents re-visiting them. + visited_link_keys = self._get_outgoing_link_keys(neighborhood) + + # Call `self._get_adjacent` to fetch the candidates. + adjacent_nodes: Iterable[Node] = self._get_adjacent_nodes( + outgoing_link_keys=visited_link_keys, + query_embedding=query_embedding, + k_per_tag=adjacent_k, + metadata_filter=metadata_filter, + ) + + candidates: dict[str, list[float]] = {} + for node in adjacent_nodes: + if node.id not in outgoing_link_keys_map: + outgoing_link_keys_map[node.id] = node.outgoing_links() + candidates[node.id] = node.embedding + + helper.add_candidates(candidates) + + def fetch_initial_candidates() -> None: + nonlocal outgoing_link_keys_map + nonlocal visited_link_keys + + initial_query, initial_params = self._get_search_cql_and_params( + columns=f"{CONTENT_ID}, text_embedding, metadata_blob", + limit=fetch_k, + metadata=metadata_filter, + embedding=query_embedding, + ) + + rows = self._session.execute( + query=initial_query, parameters=initial_params + ) + candidates: dict[str, list[float]] = {} + for row in rows: + node = _row_to_node(row) + if node.id not in outgoing_link_keys_map: + outgoing_link_keys = [_tag_s_link_key(l) for l in node.outgoing_links()] + outgoing_link_keys_map[node.id] = set(outgoing_link_keys) + candidates[node.id] = node.embedding + helper.add_candidates(candidates) + + if initial_roots: + fetch_neighborhood(initial_roots) + if fetch_k > 0: + fetch_initial_candidates() + + # Tracks the depth of each candidate. + depths = {candidate_id: 0 for candidate_id in helper.candidate_ids()} + + # Select the best item, K times. + for _ in range(k): + selected_id = helper.pop_best() + + if selected_id is None: + break + + next_depth = depths[selected_id] + 1 + if next_depth < depth: + # If the next nodes would not exceed the depth limit, find the + # adjacent nodes. + # + # TODO: For a big performance win, we should track which links we've + # already incorporated. We don't need to issue adjacent queries for + # those. + + # Find the outgoing links linked to from the selected ID. + selected_outgoing_link_keys = outgoing_link_keys_map.pop(selected_id) + + # Don't re-visit already visited links. + outgoing_link_keys_map.difference_update(visited_link_keys) + + # Find the nodes with incoming links from those tags. + adjacent_nodes: Iterable[Node] = self._get_adjacent_nodes( + outgoing_link_keys=selected_outgoing_link_keys, + query_embedding=query_embedding, + k_per_tag=adjacent_k, + metadata_filter=metadata_filter, + ) + + # Record the selected_outgoing_links as visited. + visited_link_keys.update(selected_outgoing_link_keys) + + candidates: dict[str, list[float]] = {} + for node in adjacent_nodes: + if node.id not in outgoing_link_keys_map: + outgoing_link_keys = [_tag_s_link_key(l) for l in node.outgoing_links()] + outgoing_link_keys_map[node.id] = set(outgoing_link_keys) + candidates[node.id] = node.embedding + + if next_depth < depths.get(node.id, depth + 1): + # If this is a new shortest depth, or there was no + # previous depth, update the depths. This ensures that + # when we discover a node we will have the shortest + # depth available. + # + # NOTE: No effort is made to traverse from nodes that + # were previously selected if they become reachable via + # a shorter path via nodes selected later. This is + # currently "intended", but may be worth experimenting + # with. + depths[node.id] = next_depth + helper.add_candidates(candidates) + + return self._nodes_with_ids(helper.selected_ids) + + + + def traversal_search( + self, + query: str, + *, + k: int = 4, + depth: int = 1, + metadata_filter: dict[str, Any] = {}, # noqa: B006 + ) -> Iterable[Node]: + """Retrieve documents from this knowledge store. + + First, `k` nodes are retrieved using a vector search for the `query` string. + Then, additional nodes are discovered up to the given `depth` from those + starting nodes. + + Args: + query: The query string. + k: The number of Documents to return from the initial vector search. + Defaults to 4. + depth: The maximum depth of edges to traverse. Defaults to 1. + metadata_filter: Optional metadata to filter the results. + + Returns: + Collection of retrieved documents. + """ + # Depth 0: + # Query for `k` nodes similar to the question. + # Retrieve `content_id` and `link_to_tags` (via `metadata-blob`). + # + # Depth 1: + # Query for nodes that have an incoming tag in the `link_to_tags` set. + # Combine node IDs. + # Query for `link_to_tags` of those "new" node IDs. + # + # ... + + + with self._concurrent_queries() as cq: + # Map from visited ID to depth + visited_ids: dict[str, int] = {} + + # Map from visited link keys to depth. Allows skipping queries + # for link keys that we've already traversed. + visited_link_keys: dict[str, int] = {} + + def visit_nodes(d: int, rows: Sequence[Any]) -> None: + nonlocal visited_ids + nonlocal visited_link_keys + + # Visit nodes at the given depth. + # Each node has `content_id` and `link_to_tags` (via `metadata_blob`). + + # Iterate over nodes, tracking the *new* outgoing kind links for this + # depth. This is links that are either new, or newly discovered at a + # lower depth. + outgoing_link_keys: Set[str] = set() + for row in rows: + node = _row_to_node(row=row) + + # Add visited ID. If it is closer it is a new node at this depth: + if d <= visited_ids.get(node.id, depth): + visited_ids[node.id] = d + # If we can continue traversing from this node, + if d < depth: + # Record any new (or newly discovered at a lower depth) + # links to the set to traverse. + for outgoing_link in node.outgoing_links(): + link_key = _tag_s_link_key(link=outgoing_link) + if d <= visited_link_keys.get(link_key, depth): + # Record that we'll query this link at the + # given depth, so we don't fetch it again + # (unless we find it an earlier depth) + visited_link_keys[link_key] = d + outgoing_link_keys.add(link_key) + + # If there are new link keys to visit at the next depth, query for the + # node IDs. + if d == 0: + for link_key in tqdm(outgoing_link_keys, desc="outgoing link key queries"): + query, params = self._get_search_cql_and_params( + columns=CONTENT_ID, + metadata=metadata_filter, + tag=link_key, + ) + + # print(query) + # print(params) + + # target_rows = self._session.execute(query, params) + # visit_targets(d=d, rows=target_rows) + + cq.execute( + query=query, + parameters=params, + callback=lambda target_rows, d=d: visit_targets(d, rows=target_rows), + timeout=4, + ) + else: + for link_key in outgoing_link_keys: + query, params = self._get_search_cql_and_params( + columns=CONTENT_ID, + metadata=metadata_filter, + tag=link_key, + ) + + # target_rows = self._session.execute(query, params) + # visit_targets(d=d, rows=target_rows) + + cq.execute( + query=query, + parameters=params, + callback=lambda target_rows, d=d: visit_targets(d, rows=target_rows), + ) + + def visit_targets(d: int, rows: Sequence[Any]) -> None: + nonlocal visited_ids + + new_node_ids_at_next_depth: Set[int] = set() + for row in rows: + target_node = _row_to_node(row=row) + if d < visited_ids.get(target_node.id, depth): + new_node_ids_at_next_depth.add(target_node.id) + + for node_id in new_node_ids_at_next_depth: + # node_rows = self._session.execute( + # query=self._query_id_and_metadata_by_id, + # parameters=(node_id,), + # ) + # visit_nodes(d=d+1, rows=node_rows) + + cq.execute( + self._query_id_and_metadata_by_id, + parameters=(node_id,), + callback=lambda node_rows, d=d: visit_nodes(d + 1, node_rows), + ) + + traversal_query, traversal_params = self._get_search_cql_and_params( + columns=f"{CONTENT_ID}, metadata_blob", + metadata=metadata_filter, + embedding=self._embedding.embed_query(query), + limit=k + ) + + # initial_rows = self._session.execute(query=traversal_query, parameters=traversal_params) + # visit_nodes(d=0, rows=initial_rows) + + cq.execute( + traversal_query, + parameters=traversal_params, + callback=lambda initial_rows: visit_nodes(0, rows=initial_rows), + ) + + return self._nodes_with_ids(visited_ids.keys()) + + + def similarity_search( + self, + embedding: list[float], + k: int = 4, + metadata_filter: dict[str, Any] = {}, # noqa: B006 + ) -> Iterable[Node]: + """Retrieve nodes similar to the given embedding, optionally filtered by metadata.""" # noqa: E501 + query, params = self._get_search_cql_and_params( + columns=f"{CONTENT_ID}, text_content, metadata_blob", + embedding=embedding, + limit=k, + metadata=metadata_filter, + ) + + for row in self._session.execute(query, params): + yield _row_to_node(row) + + def metadata_search( + self, + metadata: dict[str, Any] = {}, # noqa: B006 + n: int = 5, + ) -> Iterable[Node]: + """Retrieve nodes based on their metadata.""" + query, params = self._get_search_cql_and_params( + columns=f"{CONTENT_ID}, text_content, metadata_blob", + metadata=metadata, + limit=n, + ) + + for row in self._session.execute(query, params): + yield _row_to_node(row) + + def get_node(self, content_id: str) -> Node: + """Get a node by its id.""" + return self._nodes_with_ids(ids=[content_id])[0] + + def _get_outgoing_link_keys( + self, + source_ids: Iterable[str], + ) -> set[str]: + """Return the set of outgoing links for the given source ID(s). + + Args: + source_ids: The IDs of the source nodes to retrieve outgoing links for. + """ + outgoing_links: Set[str] = set() + + def add_sources(rows: Iterable[Any]) -> None: + for row in rows: + node = _row_to_node(row=row) + outgoing_link_keys = [_tag_s_link_key(l) for l in node.outgoing_links()] + outgoing_links.update(outgoing_link_keys) + + with self._concurrent_queries() as cq: + for source_id in source_ids: + cq.execute( + self._query_id_and_metadata_by_id, + (source_id,), + callback=add_sources, + ) + + return outgoing_links + + def _get_adjacent_nodes( + self, + outgoing_link_keys: set[str], + query_embedding: list[float], + k_per_link: int = 10, + metadata_filter: dict[str, Any] = {}, + ) -> Iterable[Node]: + """Return the target nodes with incoming links from any of the given outgoing_links. + + Args: + outgoing_links: The links to search for + query_embedding: The query embedding. Used to rank target nodes. + k_per_link: The number of target nodes to fetch for each outgoing link. + metadata_filter: Optional metadata to filter the results. + + Returns: + List of adjacent edges. + """ + targets: dict[str, Node] = {} + + columns = f"{CONTENT_ID}, text_embedding, metadata_blob" + + def add_targets(rows: Iterable[Any]) -> None: + nonlocal targets + + for row in rows: + target_node = _row_to_node(row) + if target_node.id not in targets: + targets[target_node.id] = target_node + + with self._concurrent_queries() as cq: + for outgoing_link_key in outgoing_link_keys: + query, params = self._get_search_cql_and_params( + columns=columns, + limit=k_per_link, + metadata=metadata_filter, + embedding=query_embedding, + link_key=outgoing_link_key, + ) + + cq.execute( + query=query, + parameters=params, + callback=add_targets, + ) + + # TODO: Consider a combined limit based on the similarity and/or + # predicated MMR score? + return targets.values() + + @staticmethod + def _normalize_metadata_indexing_policy( + metadata_indexing: tuple[str, Iterable[str]] | str, + ) -> MetadataIndexingPolicy: + mode: MetadataIndexingMode + fields: set[str] + # metadata indexing policy normalization: + if isinstance(metadata_indexing, str): + if metadata_indexing.lower() == "all": + mode, fields = (MetadataIndexingMode.DEFAULT_TO_SEARCHABLE, set()) + elif metadata_indexing.lower() == "none": + mode, fields = (MetadataIndexingMode.DEFAULT_TO_UNSEARCHABLE, set()) + else: + msg = f"Unsupported metadata_indexing value '{metadata_indexing}'" + raise ValueError(msg) + else: + # it's a 2-tuple (mode, fields) still to normalize + _mode, _field_spec = metadata_indexing + fields = {_field_spec} if isinstance(_field_spec, str) else set(_field_spec) + if _mode.lower() in { + "default_to_unsearchable", + "allowlist", + "allow", + "allow_list", + }: + mode = MetadataIndexingMode.DEFAULT_TO_UNSEARCHABLE + elif _mode.lower() in { + "default_to_searchable", + "denylist", + "deny", + "deny_list", + }: + mode = MetadataIndexingMode.DEFAULT_TO_SEARCHABLE + else: + msg = f"Unsupported metadata indexing mode specification '{_mode}'" + raise ValueError(msg) + return mode, fields + + @staticmethod + def _coerce_string(value: Any) -> str: + if isinstance(value, str): + return value + if isinstance(value, bool): + # bool MUST come before int in this chain of ifs! + return json.dumps(value) + if isinstance(value, int): + # we don't want to store '1' and '1.0' differently + # for the sake of metadata-filtered retrieval: + return json.dumps(float(value)) + if isinstance(value, float) or value is None: + return json.dumps(value) + # when all else fails ... + return str(value) + + def _extract_where_clause_cql( + self, + metadata: dict[str, Any] = {}, + tag: str | None = None + ) -> str: + wc_blocks: list[str] = [] + + # Use SimpleStatements if querying for tags + item_placeholder = "%s" if tag is not None else "?" + + if tag is not None: + wc_blocks.append(f"tag_s CONTAINS {item_placeholder}") + + for key in sorted(metadata.keys()): + if _is_metadata_field_indexed(key, self._metadata_indexing_policy): + wc_blocks.append(f"metadata_s['{key}'] = {item_placeholder}") + else: + msg = "Non-indexed metadata fields cannot be used in queries." + raise ValueError(msg) + + if len(wc_blocks) == 0: + return "" + + return " WHERE " + " AND ".join(wc_blocks) + + def _extract_where_clause_params( + self, + metadata: dict[str, Any], + tag: str | None = None + ) -> list[Any]: + params: list[Any] = [] + + if tag is not None: + params.append(tag) + + for key, value in sorted(metadata.items()): + if _is_metadata_field_indexed(key, self._metadata_indexing_policy): + params.append(self._coerce_string(value=value)) + else: + msg = "Non-indexed metadata fields cannot be used in queries." + raise ValueError(msg) + + return params + + def _get_search_cql_and_params( + self, + columns: str, + limit: int | None = None, + metadata: dict[str, Any] | None = None, + embedding: list[float] | None = None, + tag: str | None = None, + ) -> tuple[PreparedStatement|SimpleStatement, tuple[Any, ...]]: + + where_clause = self._extract_where_clause_cql(metadata=metadata, tag=tag) + limit_clause = " LIMIT ?" if limit is not None else "" + order_clause = " ORDER BY text_embedding ANN OF ?" if embedding is not None else "" + + select_cql = SELECT_CQL_TEMPLATE.format( + columns=columns, + table_name=self.table_name(), + where_clause=where_clause, + order_clause=order_clause, + limit_clause=limit_clause, + ) + + where_params = self._extract_where_clause_params(metadata=metadata, tag=tag) + limit_params = [limit] if limit is not None else [] + order_params = [embedding] if embedding is not None else [] + + params = tuple(list(where_params) + order_params + limit_params) + + if tag is not None: + return SimpleStatement(query_string=select_cql), params + elif select_cql in self._prepared_query_cache: + return self._prepared_query_cache[select_cql], params + else: + prepared_query = self._session.prepare(select_cql) + prepared_query.consistency_level = ConsistencyLevel.ONE + self._prepared_query_cache[select_cql] = prepared_query + return prepared_query, params diff --git a/libs/knowledge-store/ragstack_knowledge_store/keybert_link_extractor.py b/libs/knowledge-store/ragstack_knowledge_store/keybert_link_extractor.py new file mode 100644 index 000000000..581cd7cdc --- /dev/null +++ b/libs/knowledge-store/ragstack_knowledge_store/keybert_link_extractor.py @@ -0,0 +1,72 @@ +from typing import Any, Dict, Iterable, Optional, Set, Union + +from langchain_core._api import beta +from langchain_core.documents import Document +from langchain_core.graph_vectorstores.links import Link + +from ragstack_knowledge_store.link_extractor import LinkExtractor +import keybert + +KeybertInput = Union[str, Document] + + +@beta() +class KeybertLinkExtractor(LinkExtractor[KeybertInput]): + def __init__( + self, + *, + kind: str = "kw", + embedding_model: str = "all-MiniLM-L6-v2", + extract_keywords_kwargs: Optional[Dict[str, Any]] = None, + ): + """Extract keywords using KeyBERT . + + Example: + + .. code-block:: python + + extractor = KeybertLinkExtractor() + + results = extractor.extract_one(PAGE_1) + + Args: + kind: Kind of links to produce with this extractor. + embedding_model: Name of the embedding model to use with KeyBERT. + extract_keywords_kwargs: Keyword arguments to pass to KeyBERT's + `extract_keywords` method. + """ + try: + self._kw_model = keybert.KeyBERT(model=embedding_model) + except ImportError: + raise ImportError( + "keybert is required for KeybertLinkExtractor. " + "Please install it with `pip install keybert`." + ) from None + + self._kind = kind + self._extract_keywords_kwargs = extract_keywords_kwargs or {} + + def extract_one(self, input: KeybertInput) -> Set[Link]: # noqa: A002 + keywords = self._kw_model.extract_keywords( + input if isinstance(input, str) else input.page_content, + **self._extract_keywords_kwargs, + ) + return {Link.bidir(kind=self._kind, tag=kw[0]) for kw in keywords} + + def extract_many( + self, + inputs: Iterable[KeybertInput], + ) -> Iterable[Set[Link]]: + inputs = list(inputs) + if len(inputs) == 1: + # Even though we pass a list, if it contains one item, keybert will + # flatten it. This means it's easier to just call the special case + # for one item. + yield self.extract_one(inputs[0]) + elif len(inputs) > 1: + strs = [i if isinstance(i, str) else i.page_content for i in inputs] + extracted = self._kw_model.extract_keywords( + strs, **self._extract_keywords_kwargs + ) + for keywords in extracted: + yield {Link.bidir(kind=self._kind, tag=kw[0]) for kw in keywords} diff --git a/libs/knowledge-store/ragstack_knowledge_store/langchain_cassandra.py b/libs/knowledge-store/ragstack_knowledge_store/langchain_cassandra.py new file mode 100644 index 000000000..3b2569deb --- /dev/null +++ b/libs/knowledge-store/ragstack_knowledge_store/langchain_cassandra.py @@ -0,0 +1,174 @@ +from __future__ import annotations + +from typing import ( + TYPE_CHECKING, + Any, + Iterable, + List, + Optional, + Type, +) + +from langchain_core._api import beta +from langchain_core.documents import Document +from langchain_core.embeddings import Embeddings +from langchain_core.graph_vectorstores.base import ( + GraphVectorStore, + nodes_to_documents, +) + +from langchain_core.graph_vectorstores.base import Node as BaseNode +from langchain_community.utilities.cassandra import SetupMode +from langchain_core.graph_vectorstores import GraphVectorStoreRetriever + + +if TYPE_CHECKING: + from cassandra.cluster import Session + +from ragstack_knowledge_store import graph_store, Node +from ragstack_knowledge_store.embedding_model import EmbeddingModel + +@beta() +class CassandraGraphVectorStore(GraphVectorStore): + def __init__( + self, + embedding: Embeddings, + *, + node_table: str = "graph_nodes", + session: Optional[Session] = None, + keyspace: Optional[str] = None, + setup_mode: SetupMode = SetupMode.SYNC, + **kwargs: Any, + ): + """ + Create the hybrid graph store. + + Args: + embedding: The embeddings to use for the document content. + setup_mode: Mode used to create the Cassandra table (SYNC, + ASYNC or OFF). + """ + # try: + # from ragstack_knowledge_store import EmbeddingModel, graph_store + # except (ImportError, ModuleNotFoundError): + # raise ImportError( + # "Could not import ragstack_knowledge_store python package. " + # "Please install it with `pip install ragstack-ai-knowledge-store`." + # ) + + self._embedding = embedding + _setup_mode = getattr(graph_store.SetupMode, setup_mode.name) + + class _EmbeddingModelAdapter(EmbeddingModel): + def __init__(self, embeddings: Embeddings): + self.embeddings = embeddings + + def embed_texts(self, texts: List[str]) -> List[List[float]]: + return self.embeddings.embed_documents(texts) + + def embed_query(self, text: str) -> List[float]: + return self.embeddings.embed_query(text) + + async def aembed_texts(self, texts: List[str]) -> List[List[float]]: + return await self.embeddings.aembed_documents(texts) + + async def aembed_query(self, text: str) -> List[float]: + return await self.embeddings.aembed_query(text) + + self.store = graph_store.GraphStore( + embedding=_EmbeddingModelAdapter(embedding), + node_table=node_table, + session=session, + keyspace=keyspace, + setup_mode=_setup_mode, + **kwargs, + ) + + @property + def embeddings(self) -> Optional[Embeddings]: + return self._embedding + + def add_nodes( + self, + nodes: Iterable[BaseNode], + **kwargs: Any, + ) -> Iterable[str]: + converted_nodes = [Node(text=n.text, id=n.id, metadata=n.metadata, links=n.links) for n in nodes] + return self.store.add_nodes(converted_nodes) + + @classmethod + def from_texts( + cls: Type["CassandraGraphVectorStore"], + texts: Iterable[str], + embedding: Embeddings, + metadatas: Optional[List[dict]] = None, + ids: Optional[Iterable[str]] = None, + **kwargs: Any, + ) -> "CassandraGraphVectorStore": + """Return CassandraGraphVectorStore initialized from texts and embeddings.""" + store = cls(embedding, **kwargs) + store.add_texts(texts, metadatas, ids=ids) + return store + + @classmethod + def from_documents( + cls: Type["CassandraGraphVectorStore"], + documents: Iterable[Document], + embedding: Embeddings, + ids: Optional[Iterable[str]] = None, + **kwargs: Any, + ) -> "CassandraGraphVectorStore": + """Return CassandraGraphVectorStore initialized from documents and + embeddings.""" + store = cls(embedding, **kwargs) + store.add_documents(documents, ids=ids) + return store + + def similarity_search( + self, query: str, k: int = 4, **kwargs: Any + ) -> List[Document]: + embedding_vector = self._embedding.embed_query(query) + return self.similarity_search_by_vector( + embedding_vector, + k=k, + ) + + def similarity_search_by_vector( + self, embedding: List[float], k: int = 4, **kwargs: Any + ) -> List[Document]: + nodes = self.store.similarity_search(embedding, k=k) + return list(nodes_to_documents(nodes)) + + def traversal_search( + self, + query: str, + *, + k: int = 4, + depth: int = 1, + **kwargs: Any, + ) -> Iterable[Document]: + nodes = self.store.traversal_search(query, k=k, depth=depth) + return nodes_to_documents(nodes) + + def mmr_traversal_search( + self, + query: str, + *, + k: int = 4, + depth: int = 2, + fetch_k: int = 100, + adjacent_k: int = 10, + lambda_mult: float = 0.5, + score_threshold: float = float("-inf"), + **kwargs: Any, + ) -> Iterable[Document]: + nodes = self.store.mmr_traversal_search( + query, + k=k, + depth=depth, + fetch_k=fetch_k, + adjacent_k=adjacent_k, + lambda_mult=lambda_mult, + score_threshold=score_threshold, + ) + return nodes_to_documents(nodes) diff --git a/libs/knowledge-store/ragstack_knowledge_store/langchain_cassandra_tags.py b/libs/knowledge-store/ragstack_knowledge_store/langchain_cassandra_tags.py new file mode 100644 index 000000000..72aa16ff8 --- /dev/null +++ b/libs/knowledge-store/ragstack_knowledge_store/langchain_cassandra_tags.py @@ -0,0 +1,174 @@ +from __future__ import annotations + +from typing import ( + TYPE_CHECKING, + Any, + Iterable, + List, + Optional, + Type, +) + +from langchain_core._api import beta +from langchain_core.documents import Document +from langchain_core.embeddings import Embeddings +from langchain_core.graph_vectorstores.base import ( + GraphVectorStore, + nodes_to_documents, +) + +from langchain_core.graph_vectorstores.base import Node as BaseNode +from langchain_community.utilities.cassandra import SetupMode +from langchain_core.graph_vectorstores import GraphVectorStoreRetriever + + +if TYPE_CHECKING: + from cassandra.cluster import Session + +from ragstack_knowledge_store import graph_store_tags, Node +from ragstack_knowledge_store.embedding_model import EmbeddingModel + +@beta() +class CassandraGraphVectorStore(GraphVectorStore): + def __init__( + self, + embedding: Embeddings, + *, + node_table: str = "graph_nodes", + session: Optional[Session] = None, + keyspace: Optional[str] = None, + setup_mode: SetupMode = SetupMode.SYNC, + **kwargs: Any, + ): + """ + Create the hybrid graph store. + + Args: + embedding: The embeddings to use for the document content. + setup_mode: Mode used to create the Cassandra table (SYNC, + ASYNC or OFF). + """ + # try: + # from ragstack_knowledge_store import EmbeddingModel, graph_store + # except (ImportError, ModuleNotFoundError): + # raise ImportError( + # "Could not import ragstack_knowledge_store python package. " + # "Please install it with `pip install ragstack-ai-knowledge-store`." + # ) + + self._embedding = embedding + _setup_mode = getattr(graph_store_tags.SetupMode, setup_mode.name) + + class _EmbeddingModelAdapter(EmbeddingModel): + def __init__(self, embeddings: Embeddings): + self.embeddings = embeddings + + def embed_texts(self, texts: List[str]) -> List[List[float]]: + return self.embeddings.embed_documents(texts) + + def embed_query(self, text: str) -> List[float]: + return self.embeddings.embed_query(text) + + async def aembed_texts(self, texts: List[str]) -> List[List[float]]: + return await self.embeddings.aembed_documents(texts) + + async def aembed_query(self, text: str) -> List[float]: + return await self.embeddings.aembed_query(text) + + self.store = graph_store_tags.GraphStore( + embedding=_EmbeddingModelAdapter(embedding), + node_table=node_table, + session=session, + keyspace=keyspace, + setup_mode=_setup_mode, + **kwargs, + ) + + @property + def embeddings(self) -> Optional[Embeddings]: + return self._embedding + + def add_nodes( + self, + nodes: Iterable[BaseNode], + **kwargs: Any, + ) -> Iterable[str]: + converted_nodes = [Node(text=n.text, id=n.id, metadata=n.metadata, links=n.links) for n in nodes] + return self.store.add_nodes(converted_nodes) + + @classmethod + def from_texts( + cls: Type["CassandraGraphVectorStore"], + texts: Iterable[str], + embedding: Embeddings, + metadatas: Optional[List[dict]] = None, + ids: Optional[Iterable[str]] = None, + **kwargs: Any, + ) -> "CassandraGraphVectorStore": + """Return CassandraGraphVectorStore initialized from texts and embeddings.""" + store = cls(embedding, **kwargs) + store.add_texts(texts, metadatas, ids=ids) + return store + + @classmethod + def from_documents( + cls: Type["CassandraGraphVectorStore"], + documents: Iterable[Document], + embedding: Embeddings, + ids: Optional[Iterable[str]] = None, + **kwargs: Any, + ) -> "CassandraGraphVectorStore": + """Return CassandraGraphVectorStore initialized from documents and + embeddings.""" + store = cls(embedding, **kwargs) + store.add_documents(documents, ids=ids) + return store + + def similarity_search( + self, query: str, k: int = 4, **kwargs: Any + ) -> List[Document]: + embedding_vector = self._embedding.embed_query(query) + return self.similarity_search_by_vector( + embedding_vector, + k=k, + ) + + def similarity_search_by_vector( + self, embedding: List[float], k: int = 4, **kwargs: Any + ) -> List[Document]: + nodes = self.store.similarity_search(embedding, k=k) + return list(nodes_to_documents(nodes)) + + def traversal_search( + self, + query: str, + *, + k: int = 4, + depth: int = 1, + **kwargs: Any, + ) -> Iterable[Document]: + nodes = self.store.traversal_search(query, k=k, depth=depth) + return nodes_to_documents(nodes) + + def mmr_traversal_search( + self, + query: str, + *, + k: int = 4, + depth: int = 2, + fetch_k: int = 100, + adjacent_k: int = 10, + lambda_mult: float = 0.5, + score_threshold: float = float("-inf"), + **kwargs: Any, + ) -> Iterable[Document]: + nodes = self.store.mmr_traversal_search( + query, + k=k, + depth=depth, + fetch_k=fetch_k, + adjacent_k=adjacent_k, + lambda_mult=lambda_mult, + score_threshold=score_threshold, + ) + return nodes_to_documents(nodes) diff --git a/libs/knowledge-store/ragstack_knowledge_store/langchain_cassandra_tags_async.py b/libs/knowledge-store/ragstack_knowledge_store/langchain_cassandra_tags_async.py new file mode 100644 index 000000000..defa4ae4f --- /dev/null +++ b/libs/knowledge-store/ragstack_knowledge_store/langchain_cassandra_tags_async.py @@ -0,0 +1,174 @@ +from __future__ import annotations + +from typing import ( + TYPE_CHECKING, + Any, + Iterable, + List, + Optional, + Type, +) + +from langchain_core._api import beta +from langchain_core.documents import Document +from langchain_core.embeddings import Embeddings +from langchain_core.graph_vectorstores.base import ( + GraphVectorStore, + nodes_to_documents, +) + +from langchain_core.graph_vectorstores.base import Node as BaseNode +from langchain_community.utilities.cassandra import SetupMode +from langchain_core.graph_vectorstores import GraphVectorStoreRetriever + + +if TYPE_CHECKING: + from cassandra.cluster import Session + +from ragstack_knowledge_store import graph_store_tags_async, Node +from ragstack_knowledge_store.embedding_model import EmbeddingModel + +@beta() +class CassandraGraphVectorStore(GraphVectorStore): + def __init__( + self, + embedding: Embeddings, + *, + node_table: str = "graph_nodes", + session: Optional[Session] = None, + keyspace: Optional[str] = None, + setup_mode: SetupMode = SetupMode.SYNC, + **kwargs: Any, + ): + """ + Create the hybrid graph store. + + Args: + embedding: The embeddings to use for the document content. + setup_mode: Mode used to create the Cassandra table (SYNC, + ASYNC or OFF). + """ + # try: + # from ragstack_knowledge_store import EmbeddingModel, graph_store + # except (ImportError, ModuleNotFoundError): + # raise ImportError( + # "Could not import ragstack_knowledge_store python package. " + # "Please install it with `pip install ragstack-ai-knowledge-store`." + # ) + + self._embedding = embedding + _setup_mode = getattr(graph_store_tags_async.SetupMode, setup_mode.name) + + class _EmbeddingModelAdapter(EmbeddingModel): + def __init__(self, embeddings: Embeddings): + self.embeddings = embeddings + + def embed_texts(self, texts: List[str]) -> List[List[float]]: + return self.embeddings.embed_documents(texts) + + def embed_query(self, text: str) -> List[float]: + return self.embeddings.embed_query(text) + + async def aembed_texts(self, texts: List[str]) -> List[List[float]]: + return await self.embeddings.aembed_documents(texts) + + async def aembed_query(self, text: str) -> List[float]: + return await self.embeddings.aembed_query(text) + + self.store = graph_store_tags_async.GraphStore( + embedding=_EmbeddingModelAdapter(embedding), + node_table=node_table, + session=session, + keyspace=keyspace, + setup_mode=_setup_mode, + **kwargs, + ) + + @property + def embeddings(self) -> Optional[Embeddings]: + return self._embedding + + def add_nodes( + self, + nodes: Iterable[BaseNode], + **kwargs: Any, + ) -> Iterable[str]: + converted_nodes = [Node(text=n.text, id=n.id, metadata=n.metadata, links=n.links) for n in nodes] + return self.store.add_nodes(converted_nodes) + + @classmethod + def from_texts( + cls: Type["CassandraGraphVectorStore"], + texts: Iterable[str], + embedding: Embeddings, + metadatas: Optional[List[dict]] = None, + ids: Optional[Iterable[str]] = None, + **kwargs: Any, + ) -> "CassandraGraphVectorStore": + """Return CassandraGraphVectorStore initialized from texts and embeddings.""" + store = cls(embedding, **kwargs) + store.add_texts(texts, metadatas, ids=ids) + return store + + @classmethod + def from_documents( + cls: Type["CassandraGraphVectorStore"], + documents: Iterable[Document], + embedding: Embeddings, + ids: Optional[Iterable[str]] = None, + **kwargs: Any, + ) -> "CassandraGraphVectorStore": + """Return CassandraGraphVectorStore initialized from documents and + embeddings.""" + store = cls(embedding, **kwargs) + store.add_documents(documents, ids=ids) + return store + + def similarity_search( + self, query: str, k: int = 4, **kwargs: Any + ) -> List[Document]: + embedding_vector = self._embedding.embed_query(query) + return self.similarity_search_by_vector( + embedding_vector, + k=k, + ) + + def similarity_search_by_vector( + self, embedding: List[float], k: int = 4, **kwargs: Any + ) -> List[Document]: + nodes = self.store.similarity_search(embedding, k=k) + return list(nodes_to_documents(nodes)) + + def traversal_search( + self, + query: str, + *, + k: int = 4, + depth: int = 1, + **kwargs: Any, + ) -> Iterable[Document]: + nodes = self.store.traversal_search(query, k=k, depth=depth) + return nodes_to_documents(nodes) + + def mmr_traversal_search( + self, + query: str, + *, + k: int = 4, + depth: int = 2, + fetch_k: int = 100, + adjacent_k: int = 10, + lambda_mult: float = 0.5, + score_threshold: float = float("-inf"), + **kwargs: Any, + ) -> Iterable[Document]: + nodes = self.store.mmr_traversal_search( + query, + k=k, + depth=depth, + fetch_k=fetch_k, + adjacent_k=adjacent_k, + lambda_mult=lambda_mult, + score_threshold=score_threshold, + ) + return nodes_to_documents(nodes) diff --git a/libs/knowledge-store/ragstack_knowledge_store/link_extractor.py b/libs/knowledge-store/ragstack_knowledge_store/link_extractor.py new file mode 100644 index 000000000..45b8a526f --- /dev/null +++ b/libs/knowledge-store/ragstack_knowledge_store/link_extractor.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Generic, Iterable, Set, TypeVar + +from langchain_core._api import beta +from langchain_core.graph_vectorstores import Link + +InputT = TypeVar("InputT") + +METADATA_LINKS_KEY = "links" + + +@beta() +class LinkExtractor(ABC, Generic[InputT]): + """Interface for extracting links (incoming, outgoing, bidirectional).""" + + @abstractmethod + def extract_one(self, input: InputT) -> Set[Link]: + """Add edges from each `input` to the corresponding documents. + + Args: + input: The input content to extract edges from. + + Returns: + Set of links extracted from the input. + """ + + def extract_many(self, inputs: Iterable[InputT]) -> Iterable[Set[Link]]: + """Add edges from each `input` to the corresponding documents. + + Args: + inputs: The input content to extract edges from. + + Returns: + Iterable over the set of links extracted from the input. + """ + return map(self.extract_one, inputs) diff --git a/libs/knowledge-store/ragstack_knowledge_store/metadata_size.py b/libs/knowledge-store/ragstack_knowledge_store/metadata_size.py new file mode 100644 index 000000000..13f543468 --- /dev/null +++ b/libs/knowledge-store/ragstack_knowledge_store/metadata_size.py @@ -0,0 +1,25 @@ +import cassio +from cassio.config import check_resolve_keyspace, check_resolve_session +from dotenv import load_dotenv + +load_dotenv() + +KEYSPACE = "legal_graph_store" +TABLE_NAME = "metadata_based" + +cassio.init(auto=True) +session = check_resolve_session() +keyspace = check_resolve_keyspace(KEYSPACE) + + +# Query the data from the table +rows = session.execute(f"SELECT content_id, metadata_blob FROM {keyspace}.{TABLE_NAME}") + +# Loop through the rows and calculate the size of the text column +for row in rows: + text_value = row.metadata_blob + text_size = len(text_value.encode('utf-8')) # Get size in bytes + print(f"ID: {row.content_id}, Size of text_column: {text_size} bytes") + +# Close the connection +session.shutdown()