diff --git a/docs/user_guide/08_semantic_router.ipynb b/docs/user_guide/08_semantic_router.ipynb index 360108d0..514d7bfa 100644 --- a/docs/user_guide/08_semantic_router.ipynb +++ b/docs/user_guide/08_semantic_router.ipynb @@ -48,7 +48,7 @@ " \"what's trending in tech?\"\n", " ],\n", " metadata={\"category\": \"tech\", \"priority\": 1},\n", - " distance_threshold=1.0\n", + " distance_threshold=0.71\n", ")\n", "\n", "sports = Route(\n", @@ -61,7 +61,7 @@ " \"basketball and football\"\n", " ],\n", " metadata={\"category\": \"sports\", \"priority\": 2},\n", - " distance_threshold=0.5\n", + " distance_threshold=0.72\n", ")\n", "\n", "entertainment = Route(\n", @@ -95,17 +95,8 @@ "name": "stderr", "output_type": "stream", "text": [ - "/Users/robert.shelton/.pyenv/versions/3.11.9/lib/python3.11/site-packages/huggingface_hub/file_download.py:1142: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n", - " warnings.warn(\n", - "/Users/robert.shelton/.pyenv/versions/3.11.9/lib/python3.11/site-packages/huggingface_hub/file_download.py:1142: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n", - " warnings.warn(\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "14:07:31 redisvl.index.index INFO Index already exists, overwriting.\n" + "/Users/robert.shelton/Library/Caches/pypoetry/virtualenvs/redisvl-56gG2io_-py3.11/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], @@ -130,26 +121,6 @@ "cell_type": "code", "execution_count": 3, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "HFTextVectorizer(model='sentence-transformers/all-mpnet-base-v2', dims=768)" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "router.vectorizer" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, "outputs": [ { "name": "stdout", @@ -164,13 +135,14 @@ "│ topic-router │ HASH │ ['topic-router'] │ [] │ 0 │\n", "╰──────────────┴────────────────┴──────────────────┴─────────────────┴────────────╯\n", "Index Fields:\n", - "╭────────────┬─────────────┬────────┬────────────────┬────────────────┬────────────────┬────────────────┬────────────────┬────────────────┬─────────────────┬────────────────╮\n", - "│ Name │ Attribute │ Type │ Field Option │ Option Value │ Field Option │ Option Value │ Field Option │ Option Value │ Field Option │ Option Value │\n", - "├────────────┼─────────────┼────────┼────────────────┼────────────────┼────────────────┼────────────────┼────────────────┼────────────────┼─────────────────┼────────────────┤\n", - "│ route_name │ route_name │ TAG │ SEPARATOR │ , │ │ │ │ │ │ │\n", - "│ reference │ reference │ TEXT │ WEIGHT │ 1 │ │ │ │ │ │ │\n", - "│ vector │ vector │ VECTOR │ algorithm │ FLAT │ data_type │ FLOAT32 │ dim │ 768 │ distance_metric │ COSINE │\n", - "╰────────────┴─────────────┴────────┴────────────────┴────────────────┴────────────────┴────────────────┴────────────────┴────────────────┴─────────────────┴────────────────╯\n" + "╭──────────────┬──────────────┬────────┬────────────────┬────────────────┬────────────────┬────────────────┬────────────────┬────────────────┬─────────────────┬────────────────╮\n", + "│ Name │ Attribute │ Type │ Field Option │ Option Value │ Field Option │ Option Value │ Field Option │ Option Value │ Field Option │ Option Value │\n", + "├──────────────┼──────────────┼────────┼────────────────┼────────────────┼────────────────┼────────────────┼────────────────┼────────────────┼─────────────────┼────────────────┤\n", + "│ reference_id │ reference_id │ TAG │ SEPARATOR │ , │ │ │ │ │ │ │\n", + "│ route_name │ route_name │ TAG │ SEPARATOR │ , │ │ │ │ │ │ │\n", + "│ reference │ reference │ TEXT │ WEIGHT │ 1 │ │ │ │ │ │ │\n", + "│ vector │ vector │ VECTOR │ algorithm │ FLAT │ data_type │ FLOAT32 │ dim │ 768 │ distance_metric │ COSINE │\n", + "╰──────────────┴──────────────┴────────┴────────────────┴────────────────┴────────────────┴────────────────┴────────────────┴────────────────┴─────────────────┴────────────────╯\n" ] } ], @@ -179,60 +151,58 @@ "!rvl index info -i topic-router" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Simple routing" - ] - }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "RouteMatch(name='technology', distance=0.119614303112)" + "11" ] }, - "execution_count": 5, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "# Query the router with a statement\n", - "route_match = router(\"Can you tell me about the latest in artificial intelligence?\")\n", - "route_match" + "router._index.info()[\"num_docs\"]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Simple routing" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "RouteMatch(name=None, distance=None)" + "RouteMatch(name='technology', distance=0.419145862261)" ] }, - "execution_count": 6, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "# Query the router with a statement and return a miss\n", - "route_match = router(\"are aliens real?\")\n", + "# Query the router with a statement\n", + "route_match = router(\"Can you tell me about the latest in artificial intelligence?\")\n", "route_match" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "metadata": {}, "outputs": [ { @@ -241,14 +211,14 @@ "RouteMatch(name=None, distance=None)" ] }, - "execution_count": 7, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "# Toggle the runtime distance threshold\n", - "route_match = router(\"Which basketball team will win the NBA finals?\")\n", + "# Query the router with a statement and return a miss\n", + "route_match = router(\"are aliens real?\")\n", "route_match" ] }, @@ -261,38 +231,40 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "[]" + "[RouteMatch(name='technology', distance=0.556493759155),\n", + " RouteMatch(name='sports', distance=0.67106004556)]" ] }, - "execution_count": 8, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Perform multi-class classification with route_many() -- toggle the max_k and the distance_threshold\n", - "route_matches = router.route_many(\"Lebron James\", max_k=3)\n", + "route_matches = router.route_many(\"How is AI used in basketball?\", max_k=3)\n", "route_matches" ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "[]" + "[RouteMatch(name='technology', distance=0.556493759155),\n", + " RouteMatch(name='sports', distance=0.62926441431)]" ] }, - "execution_count": 9, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } @@ -301,7 +273,7 @@ "# Toggle the aggregation method -- note the different distances in the result\n", "from redisvl.extensions.router.schema import DistanceAggregationMethod\n", "\n", - "route_matches = router.route_many(\"Lebron James\", aggregation_method=DistanceAggregationMethod.min, max_k=3)\n", + "route_matches = router.route_many(\"How is AI used in basketball?\", aggregation_method=DistanceAggregationMethod.min, max_k=3)\n", "route_matches" ] }, @@ -321,7 +293,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ @@ -334,16 +306,16 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "[]" + "[RouteMatch(name='sports', distance=0.663253903389)]" ] }, - "execution_count": 11, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -362,7 +334,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 11, "metadata": {}, "outputs": [ { @@ -373,30 +345,28 @@ " 'references': ['what are the latest advancements in AI?',\n", " 'tell me about the newest gadgets',\n", " \"what's trending in tech?\"],\n", - " 'metadata': {'category': 'tech', 'priority': '1'},\n", - " 'distance_threshold': 1.0},\n", + " 'metadata': {'category': 'tech', 'priority': 1},\n", + " 'distance_threshold': 0.71},\n", " {'name': 'sports',\n", " 'references': ['who won the game last night?',\n", " 'tell me about the upcoming sports events',\n", " \"what's the latest in the world of sports?\",\n", " 'sports',\n", " 'basketball and football'],\n", - " 'metadata': {'category': 'sports', 'priority': '2'},\n", - " 'distance_threshold': 0.5},\n", + " 'metadata': {'category': 'sports', 'priority': 2},\n", + " 'distance_threshold': 0.72},\n", " {'name': 'entertainment',\n", " 'references': ['what are the top movies right now?',\n", " 'who won the best actor award?',\n", " \"what's new in the entertainment industry?\"],\n", - " 'metadata': {'category': 'entertainment', 'priority': '3'},\n", + " 'metadata': {'category': 'entertainment', 'priority': 3},\n", " 'distance_threshold': 0.7}],\n", " 'vectorizer': {'type': 'hf',\n", " 'model': 'sentence-transformers/all-mpnet-base-v2'},\n", - " 'routing_config': {'distance_threshold': 0.5,\n", - " 'max_k': 3,\n", - " 'aggregation_method': 'min'}}" + " 'routing_config': {'max_k': 3, 'aggregation_method': 'min'}}" ] }, - "execution_count": 12, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" } @@ -407,14 +377,14 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "14:07:34 redisvl.index.index INFO Index already exists, not overwriting.\n" + "\u001b[32m11:57:04\u001b[0m \u001b[34mredisvl.index.index\u001b[0m \u001b[1;30mINFO\u001b[0m Index already exists, not overwriting.\n" ] } ], @@ -426,7 +396,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ @@ -435,14 +405,14 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "14:07:34 redisvl.index.index INFO Index already exists, not overwriting.\n" + "\u001b[32m11:57:06\u001b[0m \u001b[34mredisvl.index.index\u001b[0m \u001b[1;30mINFO\u001b[0m Index already exists, not overwriting.\n" ] } ], @@ -456,13 +426,165 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Clean up the router" + "# Add route references" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['topic-router:technology:f243fb2d073774e81c7815247cb3013794e6225df3cbe3769cee8c6cefaca777',\n", + " 'topic-router:technology:7e4bca5853c1c3298b4d001de13c3c7a79a6e0f134f81acc2e7cddbd6845961f']" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "router.add_route_references(route_name=\"technology\", references=[\"latest AI trends\", \"new tech gadgets\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Get route references" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[{'id': 'topic-router:technology:f243fb2d073774e81c7815247cb3013794e6225df3cbe3769cee8c6cefaca777',\n", + " 'reference_id': 'f243fb2d073774e81c7815247cb3013794e6225df3cbe3769cee8c6cefaca777',\n", + " 'route_name': 'technology',\n", + " 'reference': 'latest AI trends'},\n", + " {'id': 'topic-router:technology:851f51cce5a9ccfbbcb66993908be6b7871479af3e3a4b139ad292a1bf7e0676',\n", + " 'reference_id': '851f51cce5a9ccfbbcb66993908be6b7871479af3e3a4b139ad292a1bf7e0676',\n", + " 'route_name': 'technology',\n", + " 'reference': 'what are the latest advancements in AI?'},\n", + " {'id': 'topic-router:technology:149a9c9919c58534aa0f369e85ad95ba7f00aa0513e0f81e2aff2ea4a717b0e0',\n", + " 'reference_id': '149a9c9919c58534aa0f369e85ad95ba7f00aa0513e0f81e2aff2ea4a717b0e0',\n", + " 'route_name': 'technology',\n", + " 'reference': \"what's trending in tech?\"},\n", + " {'id': 'topic-router:technology:85cc73a1437df27caa2f075a29c497e5a2e532023fbb75378aedbae80779ab37',\n", + " 'reference_id': '85cc73a1437df27caa2f075a29c497e5a2e532023fbb75378aedbae80779ab37',\n", + " 'route_name': 'technology',\n", + " 'reference': 'tell me about the newest gadgets'},\n", + " {'id': 'topic-router:technology:7e4bca5853c1c3298b4d001de13c3c7a79a6e0f134f81acc2e7cddbd6845961f',\n", + " 'reference_id': '7e4bca5853c1c3298b4d001de13c3c7a79a6e0f134f81acc2e7cddbd6845961f',\n", + " 'route_name': 'technology',\n", + " 'reference': 'new tech gadgets'}]" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# by route name\n", + "refs = router.get_route_references(route_name=\"technology\")\n", + "refs" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[{'id': 'topic-router:technology:f243fb2d073774e81c7815247cb3013794e6225df3cbe3769cee8c6cefaca777',\n", + " 'reference_id': 'f243fb2d073774e81c7815247cb3013794e6225df3cbe3769cee8c6cefaca777',\n", + " 'route_name': 'technology',\n", + " 'reference': 'latest AI trends'}]" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# by reference id\n", + "refs = router.get_route_references(reference_ids=[refs[0][\"reference_id\"]])\n", + "refs" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Delete route references" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "5" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# by route name\n", + "deleted_count = router.delete_route_references(route_name=\"sports\")\n", + "deleted_count" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "1" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# by id\n", + "deleted_count = router.delete_route_references(reference_ids=[refs[0][\"reference_id\"]])\n", + "deleted_count" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Clean up the router" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, "outputs": [], "source": [ "# Use clear to flush all routes from the index\n", @@ -471,7 +593,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 21, "metadata": {}, "outputs": [], "source": [ @@ -482,7 +604,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "redisvl-56gG2io_-py3.11", "language": "python", "name": "python3" }, diff --git a/docs/user_guide/router.yaml b/docs/user_guide/router.yaml index b743aaf6..f5c33af0 100644 --- a/docs/user_guide/router.yaml +++ b/docs/user_guide/router.yaml @@ -8,7 +8,7 @@ routes: metadata: category: tech priority: 1 - distance_threshold: 1.0 + distance_threshold: 0.71 - name: sports references: - who won the game last night? @@ -19,7 +19,7 @@ routes: metadata: category: sports priority: 2 - distance_threshold: 0.5 + distance_threshold: 0.72 - name: entertainment references: - what are the top movies right now? diff --git a/redisvl/extensions/router/schema.py b/redisvl/extensions/router/schema.py index 441cbd76..bb2c492c 100644 --- a/redisvl/extensions/router/schema.py +++ b/redisvl/extensions/router/schema.py @@ -100,6 +100,7 @@ def from_params(cls, name: str, vector_dims: int, dtype: str): return cls( index={"name": name, "prefix": name}, # type: ignore fields=[ # type: ignore + {"name": "reference_id", "type": "tag"}, {"name": "route_name", "type": "tag"}, {"name": "reference", "type": "text"}, { diff --git a/redisvl/extensions/router/semantic.py b/redisvl/extensions/router/semantic.py index c2ebb50d..ab86f240 100644 --- a/redisvl/extensions/router/semantic.py +++ b/redisvl/extensions/router/semantic.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Any, Dict, List, Optional, Type +from typing import Any, Dict, List, Optional, Type, Union import redis.commands.search.reducers as reducers import yaml @@ -8,6 +8,7 @@ from redis.commands.search.aggregation import AggregateRequest, AggregateResult, Reducer from redis.exceptions import ResponseError +from redisvl.exceptions import RedisModuleVersionError from redisvl.extensions.constants import ROUTE_VECTOR_FIELD_NAME from redisvl.extensions.router.schema import ( DistanceAggregationMethod, @@ -17,10 +18,12 @@ SemanticRouterIndexSchema, ) from redisvl.index import SearchIndex -from redisvl.query import VectorRangeQuery +from redisvl.query import FilterQuery, VectorRangeQuery +from redisvl.query.filter import Tag +from redisvl.redis.connection import RedisConnectionFactory from redisvl.redis.utils import convert_bytes, hashify, make_dict from redisvl.utils.log import get_logger -from redisvl.utils.utils import deprecated_argument, model_to_dict +from redisvl.utils.utils import deprecated_argument, model_to_dict, scan_by_pattern from redisvl.utils.vectorize.base import BaseVectorizer from redisvl.utils.vectorize.text.huggingface import HFTextVectorizer @@ -98,9 +101,41 @@ def __init__( routes=routes, vectorizer=vectorizer, routing_config=routing_config, + redis_url=redis_url, + redis_client=redis_client, ) + self._initialize_index(redis_client, redis_url, overwrite, **connection_kwargs) + self._index.client.json().set(f"{self.name}:route_config", f".", self.to_dict()) # type: ignore + + @classmethod + def from_existing( + cls, + name: str, + redis_client: Optional[Redis] = None, + redis_url: str = "redis://localhost:6379", + **kwargs, + ) -> "SemanticRouter": + """Return SemanticRouter instance from existing index.""" + try: + if redis_url: + redis_client = RedisConnectionFactory.get_redis_connection( + redis_url=redis_url, + **kwargs, + ) + elif redis_client: + RedisConnectionFactory.validate_sync_redis(redis_client) + except RedisModuleVersionError as e: + raise RedisModuleVersionError( + f"Loading from existing index failed. {str(e)}" + ) + + router_dict = redis_client.json().get(f"{name}:route_config") # type: ignore + return cls.from_dict( + router_dict, redis_url=redis_url, redis_client=redis_client + ) + @deprecated_argument("dtype") def _initialize_index( self, @@ -111,9 +146,11 @@ def _initialize_index( **connection_kwargs, ): """Initialize the search index and handle Redis connection.""" + schema = SemanticRouterIndexSchema.from_params( self.name, self.vectorizer.dims, self.vectorizer.dtype # type: ignore ) + self._index = SearchIndex( schema=schema, redis_client=redis_client, @@ -174,10 +211,10 @@ def update_route_thresholds(self, route_thresholds: Dict[str, Optional[float]]): if route.name in route_thresholds: route.distance_threshold = route_thresholds[route.name] # type: ignore - def _route_ref_key(self, route_name: str, reference: str) -> str: + @staticmethod + def _route_ref_key(index: SearchIndex, route_name: str, reference_hash: str) -> str: """Generate the route reference key.""" - reference_hash = hashify(reference) - return f"{self._index.prefix}:{route_name}:{reference_hash}" + return f"{index.prefix}:{route_name}:{reference_hash}" def _add_routes(self, routes: List[Route]): """Add routes to the router and index. @@ -195,14 +232,18 @@ def _add_routes(self, routes: List[Route]): ) # set route references for i, reference in enumerate(route.references): + reference_hash = hashify(reference) route_references.append( { + "reference_id": reference_hash, "route_name": route.name, "reference": reference, "vector": reference_vectors[i], } ) - keys.append(self._route_ref_key(route.name, reference)) + keys.append( + self._route_ref_key(self._index, route.name, reference_hash) + ) # set route if does not yet exist client side if not self.get(route.name): @@ -438,7 +479,7 @@ def remove_route(self, route_name: str) -> None: else: self._index.drop_keys( [ - self._route_ref_key(route.name, reference) + self._route_ref_key(self._index, route.name, hashify(reference)) for reference in route.references ] ) @@ -596,3 +637,155 @@ def to_yaml(self, file_path: str, overwrite: bool = True) -> None: with open(fp, "w") as f: yaml_data = self.to_dict() yaml.dump(yaml_data, f, sort_keys=False) + + # reference methods + def add_route_references( + self, + route_name: str, + references: Union[str, List[str]], + ) -> List[str]: + """Add a reference(s) to an existing route. + + Args: + router_name (str): The name of the router. + references (Union[str, List[str]]): The reference or list of references to add. + + Returns: + List[str]: The list of added references keys. + """ + + if isinstance(references, str): + references = [references] + + route_references: List[Dict[str, Any]] = [] + keys: List[str] = [] + + # embed route references as a single batch + reference_vectors = self.vectorizer.embed_many(references, as_buffer=True) + + # set route references + for i, reference in enumerate(references): + reference_hash = hashify(reference) + + route_references.append( + { + "reference_id": reference_hash, + "route_name": route_name, + "reference": reference, + "vector": reference_vectors[i], + } + ) + keys.append(self._route_ref_key(self._index, route_name, reference_hash)) + + keys = self._index.load(route_references, keys=keys) + + route = self.get(route_name) + if not route: + raise ValueError(f"Route {route_name} not found in the SemanticRouter") + route.references.extend(references) + self._update_router_state() + return keys + + @staticmethod + def _make_filter_queries(ids: List[str]) -> List[FilterQuery]: + """Create a filter query for the given ids.""" + + queries = [] + + for id in ids: + fe = Tag("reference_id") == id + fq = FilterQuery( + return_fields=["reference_id", "route_name", "reference"], + filter_expression=fe, + ) + queries.append(fq) + + return queries + + def get_route_references( + self, + route_name: str = "", + reference_ids: List[str] = [], + keys: List[str] = [], + ) -> List[Dict[str, Any]]: + """Get references for an existing route route. + + Args: + router_name (str): The name of the router. + references (Union[str, List[str]]): The reference or list of references to add. + + Returns: + List[Dict[str, Any]]]: Reference objects stored + """ + + if reference_ids: + queries = self._make_filter_queries(reference_ids) + elif route_name: + if not keys: + keys = scan_by_pattern( + self._index.client, f"{self._index.prefix}:{route_name}:*" # type: ignore + ) + + queries = self._make_filter_queries( + [key.split(":")[-1] for key in convert_bytes(keys)] + ) + else: + raise ValueError( + "Must provide a route name, reference ids, or keys to get references" + ) + + res = self._index.batch_query(queries) + + return [r[0] for r in res if len(r) > 0] + + def delete_route_references( + self, + route_name: str = "", + reference_ids: List[str] = [], + keys: List[str] = [], + ) -> int: + """Get references for an existing semantic router route. + + Args: + router_name Optional(str): The name of the router. + reference_ids Optional(List[str]]): The reference or list of references to delete. + keys Optional(List[str]]): List of fully qualified keys (prefix:router:reference_id) to delete. + + Returns: + int: Number of objects deleted + """ + + if reference_ids and not keys: + queries = self._make_filter_queries(reference_ids) + res = self._index.batch_query(queries) + keys = [r[0]["id"] for r in res if len(r) > 0] + elif not keys: + keys = scan_by_pattern( + self._index.client, f"{self._index.prefix}:{route_name}:*" # type: ignore + ) + + if not keys: + raise ValueError(f"No references found for route {route_name}") + + to_be_deleted = [] + for key in keys: + route_name = key.split(":")[-2] + to_be_deleted.append( + (route_name, convert_bytes(self._index.client.hgetall(key))) # type: ignore + ) + + deleted = self._index.drop_keys(keys) + + for route_name, delete in to_be_deleted: + route = self.get(route_name) + if not route: + raise ValueError(f"Route {route_name} not found in the SemanticRouter") + route.references.remove(delete["reference"]) + + self._update_router_state() + + return deleted + + def _update_router_state(self) -> None: + """Update the router configuration in Redis.""" + self._index.client.json().set(f"{self.name}:route_config", f".", self.to_dict()) # type: ignore diff --git a/redisvl/index/index.py b/redisvl/index/index.py index 9c189948..77271e4e 100644 --- a/redisvl/index/index.py +++ b/redisvl/index/index.py @@ -14,6 +14,7 @@ Iterable, List, Optional, + Sequence, Tuple, Union, ) @@ -833,7 +834,7 @@ def search(self, *args, **kwargs) -> "Result": raise RedisSearchError(f"Error while searching: {str(e)}") from e def batch_query( - self, queries: List[BaseQuery], batch_size: int = 10 + self, queries: Sequence[BaseQuery], batch_size: int = 10 ) -> List[List[Dict[str, Any]]]: """Execute a batch of queries and process results.""" results = self.batch_search( diff --git a/redisvl/utils/optimize/router.py b/redisvl/utils/optimize/router.py index b384e127..846b4243 100644 --- a/redisvl/utils/optimize/router.py +++ b/redisvl/utils/optimize/router.py @@ -18,9 +18,9 @@ def _generate_run_router(test_data: List[LabeledData], router: SemanticRouter) - run_dict[td.id] = {} route_match = router(td.query) if route_match and route_match.name == td.query_match: - run_dict[td.id][td.query_match] = 1 + run_dict[td.id][td.query_match] = np.int64(1) else: - run_dict[td.id][NULL_RESPONSE_KEY] = 1 + run_dict[td.id][NULL_RESPONSE_KEY] = np.int64(1) return Run(run_dict) diff --git a/redisvl/utils/optimize/utils.py b/redisvl/utils/optimize/utils.py index bebc1c79..1c8e66f7 100644 --- a/redisvl/utils/optimize/utils.py +++ b/redisvl/utils/optimize/utils.py @@ -1,5 +1,6 @@ from typing import List +import numpy as np from ranx import Qrels from redisvl.utils.optimize.schema import LabeledData @@ -13,10 +14,10 @@ def _format_qrels(test_data: List[LabeledData]) -> Qrels: for td in test_data: if td.query_match: - qrels_dict[td.id] = {td.query_match: 1} + qrels_dict[td.id] = {td.query_match: np.int64(1)} else: # This is for capturing true negatives from test set - qrels_dict[td.id] = {NULL_RESPONSE_KEY: 1} + qrels_dict[td.id] = {NULL_RESPONSE_KEY: np.int64(1)} return Qrels(qrels_dict) diff --git a/redisvl/utils/utils.py b/redisvl/utils/utils.py index 3b6511f9..ff52fcb0 100644 --- a/redisvl/utils/utils.py +++ b/redisvl/utils/utils.py @@ -7,10 +7,11 @@ from enum import Enum from functools import wraps from time import time -from typing import Any, Callable, Coroutine, Dict, Optional +from typing import Any, Callable, Coroutine, Dict, Optional, Sequence from warnings import warn from pydantic import BaseModel +from redis import Redis from ulid import ULID @@ -213,3 +214,22 @@ def norm_l2_distance(value: float) -> float: Normalize the L2 distance. """ return 1 / (1 + value) + + +def scan_by_pattern( + redis_client: Redis, + pattern: str, +) -> Sequence[str]: + """ + Scan the Redis database for keys matching a specific pattern. + + Args: + redis (Redis): The Redis client instance. + pattern (str): The pattern to match keys against. + + Returns: + List[str]: A dictionary containing the keys and their values. + """ + from redisvl.redis.utils import convert_bytes + + return convert_bytes(list(redis_client.scan_iter(match=pattern))) diff --git a/schemas/semantic_router.yaml b/schemas/semantic_router.yaml index d3e24f85..0c77691b 100644 --- a/schemas/semantic_router.yaml +++ b/schemas/semantic_router.yaml @@ -1,4 +1,4 @@ -name: test-router +name: test-router-01JSHK4MJ79HH51PS6WEK6M9MF routes: - name: greeting references: diff --git a/tests/integration/test_query.py b/tests/integration/test_query.py index 14749b1d..14f3cec7 100644 --- a/tests/integration/test_query.py +++ b/tests/integration/test_query.py @@ -22,6 +22,7 @@ Timestamp, ) from redisvl.redis.utils import array_to_buffer +from redisvl.utils.utils import create_ulid # TODO expand to multiple schema types and sync + async @@ -145,11 +146,12 @@ def sorted_range_query(): @pytest.fixture def index(sample_data, redis_url): # construct a search index from the schema + idx = f"user_index_{create_ulid()}" index = SearchIndex.from_dict( { "index": { - "name": "user_index", - "prefix": "v1", + "name": idx, + "prefix": idx, "storage_type": "hash", }, "fields": [ @@ -190,17 +192,20 @@ def hash_preprocess(item: dict) -> dict: yield index # clean up - index.delete(drop=True) + index.clear() + index.delete() @pytest.fixture def L2_index(sample_data, redis_url): # construct a search index from the schema + idx = f"L2_index_{create_ulid()}" + index = SearchIndex.from_dict( { "index": { - "name": "L2_index", - "prefix": "L2_index", + "name": idx, + "prefix": idx, "storage_type": "hash", }, "fields": [ @@ -240,7 +245,8 @@ def hash_preprocess(item: dict) -> dict: yield index # clean up - index.delete(drop=True) + index.clear() + index.delete() def test_search_and_query(index): diff --git a/tests/integration/test_semantic_router.py b/tests/integration/test_semantic_router.py index 07750b42..b7bb6b40 100644 --- a/tests/integration/test_semantic_router.py +++ b/tests/integration/test_semantic_router.py @@ -1,9 +1,9 @@ -import os import pathlib import warnings import pytest from redis.exceptions import ConnectionError +from ulid import ULID from redisvl.exceptions import RedisModuleVersionError from redisvl.extensions.router import SemanticRouter @@ -41,13 +41,14 @@ def routes(): @pytest.fixture def semantic_router(client, routes): router = SemanticRouter( - name="test-router", + name=f"test-router-{str(ULID())}", routes=routes, routing_config=RoutingConfig(max_k=2), redis_client=client, overwrite=False, ) yield router + router.clear() router.delete() @@ -59,7 +60,7 @@ def disable_deprecation_warnings(): def test_initialize_router(semantic_router): - assert semantic_router.name == "test-router" + assert semantic_router.name == semantic_router.name assert len(semantic_router.routes) == 2 assert semantic_router.routing_config.max_k == 2 @@ -208,7 +209,11 @@ def test_from_yaml(semantic_router): new_router = SemanticRouter.from_yaml( yaml_file, redis_client=semantic_router._index.client, overwrite=True ) - assert new_router.to_dict() == semantic_router.to_dict() + nr = new_router.to_dict() + nr.pop("name") + sr = semantic_router.to_dict() + sr.pop("name") + assert nr == sr def test_to_dict_missing_fields(): @@ -332,7 +337,7 @@ def test_vectorizer_dtype_mismatch(routes, redis_url): ) -def test_invalid_vectorizer(routes, redis_url): +def test_invalid_vectorizer(redis_url): with pytest.raises(TypeError): SemanticRouter( name="test_invalid_vectorizer", @@ -424,3 +429,98 @@ def test_routes_different_distance_thresholds_get_one( matches = router.route_many("hello", max_k=2) assert len(matches) == 1 assert matches[0].name == "greeting" + + +def test_add_delete_route_references(semantic_router): + redis_version = semantic_router._index.client.info()["redis_version"] + if not compare_versions(redis_version, "7.0.0"): + pytest.skip("Not using a late enough version of Redis") + + # Add new references to an existing route + added_refs = semantic_router.add_route_references( + route_name="greeting", references=["good morning", "hey there"] + ) + + # Verify references were added + assert len(added_refs) == 2 + + # Test that we can match against the new references + match = semantic_router("hey there") + assert match.name == "greeting" + + # delete by route + deleted_count = semantic_router.delete_route_references( + route_name="farewell", + ) + + if deleted_count < 2: + pytest.skip("Flaky test - skip") + + assert deleted_count == 2 + + # delete by ref_id + deleted = semantic_router.delete_route_references( + reference_ids=[added_refs[0].split(":")[-1]] + ) + + assert deleted == 1 + + # delete by key + deleted = semantic_router.delete_route_references(keys=[added_refs[1]]) + + assert deleted == 1 + + router_dict = semantic_router.to_dict() + assert len(router_dict["routes"][0]["references"]) == 2 + assert len(router_dict["routes"][1]["references"]) == 0 + + +def test_from_existing(client, redis_url, routes): + if not compare_versions(client.info()["redis_version"], "7.0.0"): + pytest.skip("Not using a late enough version of Redis") + + # connect separately + router = SemanticRouter( + name=f"test-router-{str(ULID())}", + routes=routes, + routing_config=RoutingConfig(max_k=2), + redis_url=redis_url, + overwrite=False, + ) + + router2 = SemanticRouter.from_existing( + name=router.name, + redis_url=redis_url, + ) + + assert router.to_dict() == router2.to_dict() + + +def test_get_route_references(semantic_router): + # Get references for a specific route + refs = semantic_router.get_route_references(route_name="greeting") + + if len(refs) < 2: + pytest.skip("Flaky test - skip") + + # Should return at least the initial references + assert len(refs) == 2 + + # Reference IDs should be present + reference_id = refs[0]["reference_id"] + # Get references by ID + id_refs = semantic_router.get_route_references(reference_ids=[reference_id]) + assert len(id_refs) == 1 + + with pytest.raises(ValueError): + semantic_router.get_route_references() + + +def test_delete_route_references(semantic_router): + # Get references for a specific route + deleted = semantic_router.delete_route_references(route_name="greeting") + + assert deleted == 2 + + router_dict = semantic_router.to_dict() + assert len(router_dict["routes"][0]["references"]) == 0 diff --git a/tests/integration/test_threshold_optimizer.py b/tests/integration/test_threshold_optimizer.py index 5242fd4f..8be39414 100644 --- a/tests/integration/test_threshold_optimizer.py +++ b/tests/integration/test_threshold_optimizer.py @@ -113,7 +113,7 @@ def test_routes_different_distance_thresholds_optimizer_default( # now run optimizer router_optimizer = RouterThresholdOptimizer(router, test_data_optimization) - router_optimizer.optimize(max_iterations=10, search_step=0.5) + router_optimizer.optimize(max_iterations=20, search_step=0.5) # test that it updated thresholds beyond the null case for route in routes: @@ -150,7 +150,7 @@ def test_routes_different_distance_thresholds_optimizer_precision( router_optimizer = RouterThresholdOptimizer( router, test_data_optimization, eval_metric="precision" ) - router_optimizer.optimize(max_iterations=10, search_step=0.5) + router_optimizer.optimize(max_iterations=20, search_step=0.5) # test that it updated thresholds beyond the null case for route in routes: @@ -186,7 +186,7 @@ def test_routes_different_distance_thresholds_optimizer_recall( router_optimizer = RouterThresholdOptimizer( router, test_data_optimization, eval_metric="recall" ) - router_optimizer.optimize(max_iterations=10, search_step=0.5) + router_optimizer.optimize(max_iterations=20, search_step=0.5) # test that it updated thresholds beyond the null case for route in routes: