diff --git a/libs/knowledge-store/ragstack_knowledge_store/graph_store.py b/libs/knowledge-store/ragstack_knowledge_store/graph_store.py index b392f65b1..24cf8c326 100644 --- a/libs/knowledge-store/ragstack_knowledge_store/graph_store.py +++ b/libs/knowledge-store/ragstack_knowledge_store/graph_store.py @@ -11,17 +11,17 @@ TYPE_CHECKING, Any, Sequence, + Set, Union, cast, ) -from cassandra.cluster import ConsistencyLevel, PreparedStatement, Session +from cassandra.cluster import ConsistencyLevel, PreparedStatement, Session, SimpleStatement from cassio.config import check_resolve_keyspace, check_resolve_session from typing_extensions import assert_never from ._mmr_helper import MmrHelper from .concurrency import ConcurrentQueries -from .content import Kind from .links import Link if TYPE_CHECKING: @@ -31,10 +31,8 @@ CONTENT_ID = "content_id" -CONTENT_COLUMNS = "content_id, kind, text_content, links_blob, metadata_blob" - SELECT_CQL_TEMPLATE = ( - "SELECT {columns} FROM {table_name}{where_clause}{order_clause}{limit_clause};" + "SELECT {columns} FROM {table_name}{where_clause}{order_clause}{limit_clause}" ) @@ -46,11 +44,18 @@ class Node: """Text contained by the node.""" id: str | None = None """Unique ID for the node. Will be generated by the GraphStore if not set.""" + embedding: list[float] = field(default_factory=list) + """Vector embedding of the text""" metadata: dict[str, Any] = field(default_factory=dict) """Metadata for the node.""" links: set[Link] = field(default_factory=set) """Links for the node.""" + def incoming_links(self) -> set[Link]: + return set([l for l in self.links if (l.direction in ["in", "bidir"])]) + + def outgoing_links(self) -> set[Link]: + return set([l for l in self.links if (l.direction in ["out", "bidir"])]) class SetupMode(Enum): """Mode used to create the Cassandra table.""" @@ -114,26 +119,41 @@ def _deserialize_links(json_blob: str | None) -> set[Link]: for link in cast(list[dict[str, Any]], json.loads(json_blob or "")) } +def _metadata_s_link_key(link: Link) -> str: + return "link_from_" + json.dumps({"kind": link.kind, "tag": link.tag}) + +def _metadata_s_link_value() -> str: + return "link_from" def _row_to_node(row: Any) -> Node: - metadata = _deserialize_metadata(row.metadata_blob) - links = _deserialize_links(row.links_blob) + if hasattr(row, "metadata_blob"): + metadata_blob = getattr(row, "metadata_blob") + metadata = _deserialize_metadata(metadata_blob) + links: set[Link] = _deserialize_links(metadata.get("links")) + metadata["links"] = links + else: + metadata = {} + links = set() return Node( - id=row.content_id, - text=row.text_content, + id=getattr(row, CONTENT_ID, ""), + embedding=getattr(row, "text_embedding", []), + text=getattr(row, "text_content", ""), metadata=metadata, links=links, ) +def _get_metadata_filter( + metadata: dict[str, Any] | None = None, + outgoing_link: Link | None = None, +) -> dict[str, Any]: + if outgoing_link is None: + return metadata -_CQL_IDENTIFIER_PATTERN = re.compile(r"[a-zA-Z][a-zA-Z0-9_]*") - + metadata_filter = {} if metadata is None else metadata.copy() + metadata_filter[_metadata_s_link_key(link=outgoing_link)] = _metadata_s_link_value() + return metadata_filter -@dataclass -class _Edge: - target_content_id: str - target_text_embedding: list[float] - target_link_to_tags: set[tuple[str, str]] +_CQL_IDENTIFIER_PATTERN = re.compile(r"[a-zA-Z][a-zA-Z0-9_]*") class GraphStore: @@ -201,23 +221,22 @@ def __init__( self._insert_passage = session.prepare( f""" INSERT INTO {keyspace}.{node_table} ( - content_id, kind, text_content, text_embedding, link_to_tags, - link_from_tags, links_blob, metadata_blob, metadata_s - ) VALUES (?, '{Kind.passage}', ?, ?, ?, ?, ?, ?, ?) + content_id, text_content, text_embedding, metadata_blob, metadata_s + ) VALUES (?, ?, ?, ?, ?) """ # noqa: S608 ) self._query_by_id = session.prepare( f""" - SELECT {CONTENT_COLUMNS} + SELECT content_id, text_content, metadata_blob FROM {keyspace}.{node_table} WHERE content_id = ? """ # noqa: S608 ) - self._query_ids_and_link_to_tags_by_id = session.prepare( + self._query_id_and_metadata_by_id = session.prepare( f""" - SELECT content_id, link_to_tags + SELECT content_id, metadata_blob FROM {keyspace}.{node_table} WHERE content_id = ? """ # noqa: S608 @@ -233,16 +252,10 @@ def _apply_schema(self) -> None: self._session.execute(f""" CREATE TABLE IF NOT EXISTS {self.table_name()} ( content_id TEXT, - kind TEXT, text_content TEXT, text_embedding VECTOR, - - link_to_tags SET>, - link_from_tags SET>, - links_blob TEXT, metadata_blob TEXT, metadata_s MAP, - PRIMARY KEY (content_id) ) """) @@ -254,12 +267,6 @@ def _apply_schema(self) -> None: USING 'StorageAttachedIndex'; """) - self._session.execute(f""" - CREATE CUSTOM INDEX IF NOT EXISTS {self._node_table}_link_from_tags - ON {self.table_name()}(link_from_tags) - USING 'StorageAttachedIndex'; - """) - self._session.execute(f""" CREATE CUSTOM INDEX IF NOT EXISTS {self._node_table}_metadata_s_index ON {self.table_name()}(ENTRIES(metadata_s)) @@ -277,32 +284,24 @@ def add_nodes( """Add nodes to the graph store.""" node_ids: list[str] = [] texts: list[str] = [] - metadatas: list[dict[str, Any]] = [] - nodes_links: list[set[Link]] = [] + metadata_list: list[dict[str, Any]] = [] + incoming_links_list: list[set[Link]] = [] for node in nodes: if not node.id: node_ids.append(secrets.token_hex(8)) else: node_ids.append(node.id) texts.append(node.text) - metadatas.append(node.metadata) - nodes_links.append(node.links) + combined_metadata = node.metadata.copy() + combined_metadata["links"] = _serialize_links(node.links) + metadata_list.append(combined_metadata) + incoming_links_list.append(node.incoming_links()) text_embeddings = self._embedding.embed_texts(texts) with self._concurrent_queries() as cq: - tuples = zip(node_ids, texts, text_embeddings, metadatas, nodes_links) - for node_id, text, text_embedding, metadata, links in tuples: - link_to_tags = set() # link to these tags - link_from_tags = set() # link from these tags - - for tag in links: - if tag.direction in {"in", "bidir"}: - # An incoming link should be linked *from* nodes with the given - # tag. - link_from_tags.add((tag.kind, tag.tag)) - if tag.direction in {"out", "bidir"}: - link_to_tags.add((tag.kind, tag.tag)) + tuples = zip(node_ids, texts, text_embeddings, metadata_list, incoming_links_list) + for node_id, text, text_embedding, metadata, incoming_links in tuples: metadata_s = { k: self._coerce_string(v) @@ -310,17 +309,17 @@ def add_nodes( if _is_metadata_field_indexed(k, self._metadata_indexing_policy) } + for incoming_link in incoming_links: + metadata_s[_metadata_s_link_key(link=incoming_link)] =_metadata_s_link_value() + metadata_blob = _serialize_metadata(metadata) - links_blob = _serialize_links(links) + cq.execute( self._insert_passage, parameters=( node_id, text, text_embedding, - link_to_tags, - link_from_tags, - links_blob, metadata_blob, metadata_s, ), @@ -412,74 +411,58 @@ def mmr_traversal_search( score_threshold=score_threshold, ) - # For each unselected node, stores the outgoing tags. - outgoing_tags: dict[str, set[tuple[str, str]]] = {} + # For each unselected node, stores the outgoing links. + outgoing_links_map: dict[str, set[Link]] = {} + visited_links: set[Link] = set() - # Fetch the initial candidates and add them to the helper and - # outgoing_tags. - columns = "content_id, text_embedding, link_to_tags" - adjacent_query = self._get_search_cql( - has_limit=True, - columns=columns, - metadata_keys=list(metadata_filter.keys()), - has_embedding=True, - has_link_from_tags=True, - ) - - visited_tags: set[tuple[str, str]] = set() def fetch_neighborhood(neighborhood: Sequence[str]) -> None: - # Put the neighborhood into the outgoing tags, to avoid adding it + nonlocal outgoing_links_map + nonlocal visited_links + + # Put the neighborhood into the outgoing links, to avoid adding it # to the candidate set in the future. - outgoing_tags.update({content_id: set() for content_id in neighborhood}) + outgoing_links_map.update({content_id: set() for content_id in neighborhood}) - # Initialize the visited_tags with the set of outgoing from the + # Initialize the visited_links with the set of outgoing links from the # neighborhood. This prevents re-visiting them. - visited_tags = self._get_outgoing_tags(neighborhood) + visited_links = self._get_outgoing_links(neighborhood) # Call `self._get_adjacent` to fetch the candidates. - adjacents = self._get_adjacent( - visited_tags, - adjacent_query=adjacent_query, + adjacent_nodes = self._get_adjacent( + links=visited_links, query_embedding=query_embedding, - k_per_tag=adjacent_k, + k_per_link=adjacent_k, metadata_filter=metadata_filter, ) - new_candidates = {} - for adjacent in adjacents: - if adjacent.target_content_id not in outgoing_tags: - outgoing_tags[adjacent.target_content_id] = ( - adjacent.target_link_to_tags - ) - - new_candidates[adjacent.target_content_id] = ( - adjacent.target_text_embedding - ) + new_candidates: dict[str, list[float]] = {} + for adjacent_node in adjacent_nodes: + if adjacent_node.id not in outgoing_links_map: + outgoing_links_map[adjacent_node.id] = adjacent_node.outgoing_links() + new_candidates[adjacent_node.id] = adjacent_node.embedding helper.add_candidates(new_candidates) def fetch_initial_candidates() -> None: - initial_candidates_query = self._get_search_cql( - has_limit=True, - columns=columns, - metadata_keys=list(metadata_filter.keys()), - has_embedding=True, - ) + nonlocal outgoing_links_map + nonlocal visited_links - params = self._get_search_params( + initial_candidates_query, params = self._get_search_cql_and_params( + columns = "content_id, text_embedding, metadata_blob", limit=fetch_k, metadata=metadata_filter, embedding=query_embedding, ) - fetched = self._session.execute( + rows = self._session.execute( query=initial_candidates_query, parameters=params ) - candidates = {} - for row in fetched: - if row.content_id not in outgoing_tags: - candidates[row.content_id] = row.text_embedding - outgoing_tags[row.content_id] = set(row.link_to_tags or []) + candidates: dict[str, list[float]] = {} + for row in rows: + if row.content_id not in outgoing_links_map: + node = _row_to_node(row=row) + outgoing_links_map[node.id] = node.outgoing_links() + candidates[node.id] = node.embedding helper.add_candidates(candidates) if initial_roots: @@ -502,40 +485,33 @@ def fetch_initial_candidates() -> None: # If the next nodes would not exceed the depth limit, find the # adjacent nodes. # - # TODO: For a big performance win, we should track which tags we've + # TODO: For a big performance win, we should track which links we've # already incorporated. We don't need to issue adjacent queries for # those. - # Find the tags linked to from the selected ID. - link_to_tags = outgoing_tags.pop(selected_id) + # Find the links linked to from the selected ID. + selected_outgoing_links = outgoing_links_map.pop(selected_id) - # Don't re-visit already visited tags. - link_to_tags.difference_update(visited_tags) + # Don't re-visit already visited links. + selected_outgoing_links.difference_update(visited_links) - # Find the nodes with incoming links from those tags. - adjacents = self._get_adjacent( - link_to_tags, - adjacent_query=adjacent_query, + # Find the nodes with incoming links from those links. + adjacent_nodes = self._get_adjacent( + links=selected_outgoing_links, query_embedding=query_embedding, - k_per_tag=adjacent_k, + k_per_link=adjacent_k, metadata_filter=metadata_filter, ) - # Record the link_to_tags as visited. - visited_tags.update(link_to_tags) + # Record the selected_outgoing_links as visited. + visited_links.update(selected_outgoing_links) new_candidates = {} - for adjacent in adjacents: - if adjacent.target_content_id not in outgoing_tags: - outgoing_tags[adjacent.target_content_id] = ( - adjacent.target_link_to_tags - ) - new_candidates[adjacent.target_content_id] = ( - adjacent.target_text_embedding - ) - if next_depth < depths.get( - adjacent.target_content_id, depth + 1 - ): + for adjacent_node in adjacent_nodes: + if adjacent_node.id not in outgoing_links_map: + outgoing_links_map[adjacent_node.id] = adjacent_node.outgoing_links() + new_candidates[adjacent_node.id] = adjacent_node.embedding + if next_depth < depths.get(adjacent_node.id, depth + 1): # If this is a new shortest depth, or there was no # previous depth, update the depths. This ensures that # when we discover a node we will have the shortest @@ -546,7 +522,7 @@ def fetch_initial_candidates() -> None: # a shorter path via nodes selected later. This is # currently "intended", but may be worth experimenting # with. - depths[adjacent.target_content_id] = next_depth + depths[adjacent_node.id] = next_depth helper.add_candidates(new_candidates) return self._nodes_with_ids(helper.selected_ids) @@ -577,72 +553,64 @@ def traversal_search( """ # Depth 0: # Query for `k` nodes similar to the question. - # Retrieve `content_id` and `link_to_tags`. + # Retrieve `content_id` and `outgoing_links()`. # # Depth 1: - # Query for nodes that have an incoming tag in the `link_to_tags` set. + # Query for nodes that have an incoming link in the `outgoing_links()` set. # Combine node IDs. - # Query for `link_to_tags` of those "new" node IDs. + # Query for `outgoing_links()` of those "new" node IDs. # # ... - traversal_query = self._get_search_cql( - columns="content_id, link_to_tags", - has_limit=True, - metadata_keys=list(metadata_filter.keys()), - has_embedding=True, - ) - visit_nodes_query = self._get_search_cql( - columns="content_id AS target_content_id", - has_link_from_tags=True, - metadata_keys=list(metadata_filter.keys()), - ) + with self._concurrent_queries() as cq: # Map from visited ID to depth visited_ids: dict[str, int] = {} - # Map from visited tag `(kind, tag)` to depth. Allows skipping queries - # for tags that we've already traversed. - visited_tags: dict[tuple[str, str], int] = {} + # Map from visited link to depth. Allows skipping queries + # for links that we've already traversed. + visited_links: dict[Link, int] = {} - def visit_nodes(d: int, nodes: Sequence[Any]) -> None: + def visit_nodes(d: int, rows: Sequence[Any]) -> None: nonlocal visited_ids - nonlocal visited_tags + nonlocal visited_links # Visit nodes at the given depth. - # Each node has `content_id` and `link_to_tags`. - # Iterate over nodes, tracking the *new* outgoing kind tags for this - # depth. This is tags that are either new, or newly discovered at a + # Iterate over nodes, tracking the *new* outgoing links for this + # depth. These are links that are either new, or newly discovered at a # lower depth. - outgoing_tags = set() - for node in nodes: - content_id = node.content_id + outgoing_links: Set[Link] = set() + for row in rows: + content_id = row.content_id # Add visited ID. If it is closer it is a new node at this depth: if d <= visited_ids.get(content_id, depth): visited_ids[content_id] = d # If we can continue traversing from this node, - if d < depth and node.link_to_tags: + if d < depth: + node = _row_to_node(row=row) # Record any new (or newly discovered at a lower depth) - # tags to the set to traverse. - for kind, value in node.link_to_tags: - if d <= visited_tags.get((kind, value), depth): - # Record that we'll query this tag at the + # links to the set to traverse. + for link in node.outgoing_links(): + if d <= visited_links.get(link, depth): + # Record that we'll query this link at the # given depth, so we don't fetch it again # (unless we find it an earlier depth) - visited_tags[(kind, value)] = d - outgoing_tags.add((kind, value)) + visited_links[link] = d + outgoing_links.add(link) - if outgoing_tags: - # If there are new tags to visit at the next depth, query for the + if outgoing_links: + # If there are new links to visit at the next depth, query for the # node IDs. - for kind, value in outgoing_tags: - params = self._get_search_params( - link_from_tags=(kind, value), metadata=metadata_filter + for outgoing_link in outgoing_links: + visit_nodes_query, params = self._get_search_cql_and_params( + columns="content_id AS target_content_id", + metadata=metadata_filter, + outgoing_link=outgoing_link, ) cq.execute( query=visit_nodes_query, @@ -650,35 +618,34 @@ def visit_nodes(d: int, nodes: Sequence[Any]) -> None: callback=lambda rows, d=d: visit_targets(d, rows), ) - def visit_targets(d: int, targets: Sequence[Any]) -> None: + def visit_targets(d: int, rows: Sequence[Any]) -> None: nonlocal visited_ids - # target_content_id, tag=(kind,value) - new_nodes_at_next_depth = set() - for target in targets: - content_id = target.target_content_id + new_node_ids_at_next_depth = set() + for row in rows: + content_id = row.target_content_id if d < visited_ids.get(content_id, depth): - new_nodes_at_next_depth.add(content_id) + new_node_ids_at_next_depth.add(content_id) - if new_nodes_at_next_depth: - for node_id in new_nodes_at_next_depth: + if new_node_ids_at_next_depth: + for node_id in new_node_ids_at_next_depth: cq.execute( - self._query_ids_and_link_to_tags_by_id, + self._query_id_and_metadata_by_id, parameters=(node_id,), callback=lambda rows, d=d: visit_nodes(d + 1, rows), ) - query_embedding = self._embedding.embed_query(query) - params = self._get_search_params( + initial_query, params = self._get_search_cql_and_params( + columns="content_id, metadata_blob", limit=k, metadata=metadata_filter, - embedding=query_embedding, + embedding=self._embedding.embed_query(query), ) cq.execute( - traversal_query, + initial_query, parameters=params, - callback=lambda nodes: visit_nodes(0, nodes), + callback=lambda initial_rows: visit_nodes(0, initial_rows), ) return self._nodes_with_ids(visited_ids.keys()) @@ -691,7 +658,10 @@ def similarity_search( ) -> Iterable[Node]: """Retrieve nodes similar to the given embedding, optionally filtered by metadata.""" # noqa: E501 query, params = self._get_search_cql_and_params( - embedding=embedding, limit=k, metadata=metadata_filter + columns=f"{CONTENT_ID}, text_content, metadata_blob", + embedding=embedding, + limit=k, + metadata=metadata_filter, ) for row in self._session.execute(query, params): @@ -703,7 +673,11 @@ def metadata_search( n: int = 5, ) -> Iterable[Node]: """Retrieve nodes based on their metadata.""" - query, params = self._get_search_cql_and_params(metadata=metadata, limit=n) + query, params = self._get_search_cql_and_params( + columns=f"{CONTENT_ID}, text_content, metadata_blob", + metadata=metadata, + limit=n, + ) for row in self._session.execute(query, params): yield _row_to_node(row) @@ -712,74 +686,71 @@ def get_node(self, content_id: str) -> Node: """Get a node by its id.""" return self._nodes_with_ids(ids=[content_id])[0] - def _get_outgoing_tags( + def _get_outgoing_links( self, source_ids: Iterable[str], - ) -> set[tuple[str, str]]: - """Return the set of outgoing tags for the given source ID(s). + ) -> set[Link]: + """Return the set of outgoing links for the given source ID(s). Args: - source_ids: The IDs of the source nodes to retrieve outgoing tags for. + source_ids: The IDs of the source nodes to retrieve outgoing links for. """ - tags = set() + links = set() def add_sources(rows: Iterable[Any]) -> None: for row in rows: - if row.link_to_tags: - tags.update(row.link_to_tags) + node = _row_to_node(row=row) + links.update(node.outgoing_links()) with self._concurrent_queries() as cq: for source_id in source_ids: cq.execute( - self._query_ids_and_link_to_tags_by_id, + self._query_id_and_metadata_by_id, (source_id,), callback=add_sources, ) - return tags + return links def _get_adjacent( self, - tags: set[tuple[str, str]], - adjacent_query: PreparedStatement, + links: set[Link], query_embedding: list[float], - k_per_tag: int | None = None, + k_per_link: int | None = None, metadata_filter: dict[str, Any] | None = None, - ) -> Iterable[_Edge]: - """Return the target nodes with incoming links from any of the given tags. + ) -> Iterable[Node]: + """Return the target nodes with incoming links from any of the given links. Args: - tags: The tags to look for links *from*. - adjacent_query: Prepared query for adjacent nodes. + links: The links to look for. query_embedding: The query embedding. Used to rank target nodes. - k_per_tag: The number of target nodes to fetch for each outgoing tag. + k_per_link: The number of target nodes to fetch for each link. metadata_filter: Optional metadata to filter the results. Returns: List of adjacent edges. """ - targets: dict[str, _Edge] = {} + targets: dict[str, Node] = {} def add_targets(rows: Iterable[Any]) -> None: + nonlocal targets + # TODO: Figure out how to use the "kind" on the edge. # This is tricky, since we currently issue one query for anything # adjacent via any kind, and we don't have enough information to # determine which kind(s) a given target was reached from. for row in rows: if row.content_id not in targets: - targets[row.content_id] = _Edge( - target_content_id=row.content_id, - target_text_embedding=row.text_embedding, - target_link_to_tags=set(row.link_to_tags or []), - ) + targets[row.content_id] = _row_to_node(row=row) with self._concurrent_queries() as cq: - for kind, value in tags: - params = self._get_search_params( - limit=k_per_tag or 10, + for link in links: + adjacent_query, params = self._get_search_cql_and_params( + columns = "content_id, text_embedding, metadata_blob", + limit=k_per_link or 10, metadata=metadata_filter, embedding=query_embedding, - link_from_tags=(kind, value), + outgoing_link=link, ) cq.execute( @@ -848,21 +819,13 @@ def _coerce_string(value: Any) -> str: def _extract_where_clause_cql( self, - has_id: bool = False, metadata_keys: Sequence[str] = (), - has_link_from_tags: bool = False, ) -> str: wc_blocks: list[str] = [] - if has_id: - wc_blocks.append("content_id == ?") - - if has_link_from_tags: - wc_blocks.append("link_from_tags CONTAINS (?, ?)") - for key in sorted(metadata_keys): if _is_metadata_field_indexed(key, self._metadata_indexing_policy): - wc_blocks.append(f"metadata_s['{key}'] = ?") + wc_blocks.append(f"metadata_s['{key}'] = %s") else: msg = "Non-indexed metadata fields cannot be used in queries." raise ValueError(msg) @@ -875,14 +838,9 @@ def _extract_where_clause_cql( def _extract_where_clause_params( self, metadata: dict[str, Any], - link_from_tags: tuple[str, str] | None = None, ) -> list[Any]: params: list[Any] = [] - if link_from_tags is not None: - params.append(link_from_tags[0]) - params.append(link_from_tags[1]) - for key, value in sorted(metadata.items()): if _is_metadata_field_indexed(key, self._metadata_indexing_policy): params.append(self._coerce_string(value=value)) @@ -892,22 +850,24 @@ def _extract_where_clause_params( return params - def _get_search_cql( + + + + def _get_search_cql_and_params( self, - has_limit: bool = False, - columns: str | None = CONTENT_COLUMNS, - metadata_keys: Sequence[str] = (), - has_id: bool = False, - has_embedding: bool = False, - has_link_from_tags: bool = False, - ) -> PreparedStatement: - where_clause = self._extract_where_clause_cql( - has_id=has_id, - metadata_keys=metadata_keys, - has_link_from_tags=has_link_from_tags, - ) - limit_clause = " LIMIT ?" if has_limit else "" - order_clause = " ORDER BY text_embedding ANN OF ?" if has_embedding else "" + columns: str, + limit: int | None = None, + metadata: dict[str, Any] | None = None, + embedding: list[float] | None = None, + outgoing_link: Link | None = None, + ) -> tuple[PreparedStatement|SimpleStatement, tuple[Any, ...]]: + metadata_filter = _get_metadata_filter(metadata=metadata, outgoing_link=outgoing_link) + + metadata_keys = list(metadata_filter.keys()) if metadata else [] + + where_clause = self._extract_where_clause_cql(metadata_keys=metadata_keys) + limit_clause = " LIMIT ?" if limit is not None else "" + order_clause = " ORDER BY text_embedding ANN OF ?" if embedding is not None else "" select_cql = SELECT_CQL_TEMPLATE.format( columns=columns, @@ -917,50 +877,19 @@ def _get_search_cql( limit_clause=limit_clause, ) - if select_cql in self._prepared_query_cache: - return self._prepared_query_cache[select_cql] - - prepared_query = self._session.prepare(select_cql) - prepared_query.consistency_level = ConsistencyLevel.ONE - self._prepared_query_cache[select_cql] = prepared_query - - return prepared_query - - def _get_search_params( - self, - limit: int | None = None, - metadata: dict[str, Any] | None = None, - embedding: list[float] | None = None, - link_from_tags: tuple[str, str] | None = None, - ) -> tuple[PreparedStatement, tuple[Any, ...]]: - where_params = self._extract_where_clause_params( - metadata=metadata or {}, link_from_tags=link_from_tags - ) - + where_params = self._extract_where_clause_params(metadata=metadata_filter or {}) limit_params = [limit] if limit is not None else [] order_params = [embedding] if embedding is not None else [] - return tuple(list(where_params) + order_params + limit_params) + params = tuple(list(where_params) + order_params + limit_params) + + if len(metadata_keys) > 0: + return SimpleStatement(query_string=select_cql, fetch_size=100), params + elif select_cql in self._prepared_query_cache: + return self._prepared_query_cache[select_cql], params + else: + prepared_query = self._session.prepare(select_cql) + prepared_query.consistency_level = ConsistencyLevel.ONE + self._prepared_query_cache[select_cql] = prepared_query + return prepared_query, params - def _get_search_cql_and_params( - self, - limit: int | None = None, - columns: str | None = CONTENT_COLUMNS, - metadata: dict[str, Any] | None = None, - embedding: list[float] | None = None, - link_from_tags: tuple[str, str] | None = None, - ) -> tuple[PreparedStatement, tuple[Any, ...]]: - query = self._get_search_cql( - has_limit=limit is not None, - columns=columns, - metadata_keys=list(metadata.keys()) if metadata else (), - has_embedding=embedding is not None, - has_link_from_tags=link_from_tags is not None, - ) - params = self._get_search_params( - limit=limit, - metadata=metadata, - embedding=embedding, - link_from_tags=link_from_tags, - ) - return query, params