diff --git a/redisvl/query/aggregate.py b/redisvl/query/aggregate.py index a3a31e05..7dea3907 100644 --- a/redisvl/query/aggregate.py +++ b/redisvl/query/aggregate.py @@ -94,6 +94,7 @@ def __init__( return_fields: Optional[List[str]] = None, stopwords: Optional[Union[str, Set[str]]] = "english", dialect: int = 2, + text_weights: Optional[Dict[str, float]] = None, ): """ Instantiates a HybridQuery object. @@ -119,6 +120,9 @@ def __init__( set, or tuple of strings is provided then those will be used as stopwords. Defaults to "english". if set to "None" then no stopwords will be removed. dialect (int, optional): The Redis dialect version. Defaults to 2. + text_weights (Optional[Dict[str, float]]): The importance weighting of individual words + within the query text. Defaults to None, as no modifications will be made to the + text_scorer score. Raises: ValueError: If the text string is empty, or if the text string becomes empty after @@ -138,6 +142,7 @@ def __init__( self._dtype = dtype self._num_results = num_results self._set_stopwords(stopwords) + self._text_weights = self._parse_text_weights(text_weights) query_string = self._build_query_string() super().__init__(query_string) @@ -185,6 +190,7 @@ def _set_stopwords(self, stopwords: Optional[Union[str, Set[str]]] = "english"): language will be used. if a list, set, or tuple of strings is provided then those will be used as stopwords. Defaults to "english". if set to "None" then no stopwords will be removed. + Raises: TypeError: If the stopwords are not a set, list, or tuple of strings. """ @@ -214,6 +220,7 @@ def _tokenize_and_escape_query(self, user_query: str) -> str: Returns: str: The tokenized and escaped query string. + Raises: ValueError: If the text string becomes empty after stopwords are removed. """ @@ -225,13 +232,57 @@ def _tokenize_and_escape_query(self, user_query: str) -> str: ) for token in user_query.split() ] - tokenized = " | ".join( - [token for token in tokens if token and token not in self._stopwords] - ) - if not tokenized: + token_list = [ + token for token in tokens if token and token not in self._stopwords + ] + for i, token in enumerate(token_list): + if token in self._text_weights: + token_list[i] = f"{token}=>{{$weight:{self._text_weights[token]}}}" + + if not token_list: raise ValueError("text string cannot be empty after removing stopwords") - return tokenized + return " | ".join(token_list) + + def _parse_text_weights( + self, weights: Optional[Dict[str, float]] + ) -> Dict[str, float]: + parsed_weights: Dict[str, float] = {} + if not weights: + return parsed_weights + for word, weight in weights.items(): + word = word.strip().lower() + if not word or " " in word: + raise ValueError( + f"Only individual words may be weighted. Got {{ {word}:{weight} }}" + ) + if ( + not (isinstance(weight, float) or isinstance(weight, int)) + or weight < 0.0 + ): + raise ValueError( + f"Weights must be positive number. Got {{ {word}:{weight} }}" + ) + parsed_weights[word] = weight + return parsed_weights + + def set_text_weights(self, weights: Dict[str, float]): + """Set or update the text weights for the query. + + Args: + text_weights: Dictionary of word:weight mappings + """ + self._text_weights = self._parse_text_weights(weights) + self._query = self._build_query_string() + + @property + def text_weights(self) -> Dict[str, float]: + """Get the text weights. + + Returns: + Dictionary of word:weight mappings. + """ + return self._text_weights def _build_query_string(self) -> str: """Build the full query string for text search with optional filtering.""" diff --git a/redisvl/query/query.py b/redisvl/query/query.py index 9ee66173..340564be 100644 --- a/redisvl/query/query.py +++ b/redisvl/query/query.py @@ -1028,6 +1028,7 @@ def __init__( in_order: bool = False, params: Optional[Dict[str, Any]] = None, stopwords: Optional[Union[str, Set[str]]] = "english", + text_weights: Optional[Dict[str, float]] = None, ): """A query for running a full text search, along with an optional filter expression. @@ -1064,13 +1065,16 @@ def __init__( a default set of stopwords for that language will be used. Users may specify their own stop words by providing a List or Set of words. if set to None, then no words will be removed. Defaults to 'english'. - + text_weights (Optional[Dict[str, float]]): The importance weighting of individual words + within the query text. Defaults to None, as no modifications will be made to the + text_scorer score. Raises: ValueError: if stopwords language string cannot be loaded. TypeError: If stopwords is not a valid iterable set of strings. """ self._text = text self._field_weights = self._parse_field_weights(text_field_name) + self._text_weights = self._parse_text_weights(text_weights) self._num_results = num_results self._set_stopwords(stopwords) @@ -1151,9 +1155,14 @@ def _tokenize_and_escape_query(self, user_query: str) -> str: ) for token in user_query.split() ] - return " | ".join( - [token for token in tokens if token and token not in self._stopwords] - ) + token_list = [ + token for token in tokens if token and token not in self._stopwords + ] + for i, token in enumerate(token_list): + if token in self._text_weights: + token_list[i] = f"{token}=>{{$weight:{self._text_weights[token]}}}" + + return " | ".join(token_list) def _parse_field_weights( self, field_spec: Union[str, Dict[str, float]] @@ -1220,6 +1229,46 @@ def text_field_name(self) -> Union[str, Dict[str, float]]: return field return self._field_weights.copy() + def _parse_text_weights( + self, weights: Optional[Dict[str, float]] + ) -> Dict[str, float]: + parsed_weights: Dict[str, float] = {} + if not weights: + return parsed_weights + for word, weight in weights.items(): + word = word.strip().lower() + if not word or " " in word: + raise ValueError( + f"Only individual words may be weighted. Got {{ {word}:{weight} }}" + ) + if ( + not (isinstance(weight, float) or isinstance(weight, int)) + or weight < 0.0 + ): + raise ValueError( + f"Weights must be positive number. Got {{ {word}:{weight} }}" + ) + parsed_weights[word] = weight + return parsed_weights + + def set_text_weights(self, weights: Dict[str, float]): + """Set or update the text weights for the query. + + Args: + text_weights: Dictionary of word:weight mappings + """ + self._text_weights = self._parse_text_weights(weights) + self._built_query_string = None + + @property + def text_weights(self) -> Dict[str, float]: + """Get the text weights. + + Returns: + Dictionary of word:weight mappings. + """ + return self._text_weights + def _build_query_string(self) -> str: """Build the full query string for text search with optional filtering.""" filter_expression = self._filter_expression diff --git a/tests/integration/test_aggregation.py b/tests/integration/test_aggregation.py index f08815a6..67c76ade 100644 --- a/tests/integration/test_aggregation.py +++ b/tests/integration/test_aggregation.py @@ -317,6 +317,82 @@ def test_hybrid_query_with_text_filter(index): assert "research" not in result[text_field].lower() +@pytest.mark.parametrize("scorer", ["BM25", "BM25STD", "TFIDF", "TFIDF.DOCNORM"]) +def test_hybrid_query_word_weights(index, scorer): + skip_if_redis_version_below(index.client, "7.2.0") + + text = "a medical professional with expertise in lung cancers" + text_field = "description" + vector = [0.1, 0.1, 0.5] + vector_field = "user_embedding" + return_fields = ["description"] + + weights = {"medical": 3.4, "cancers": 5} + + # test we can run a query with text weights + weighted_query = HybridQuery( + text=text, + text_field_name=text_field, + vector=vector, + vector_field_name=vector_field, + return_fields=return_fields, + text_scorer=scorer, + text_weights=weights, + ) + + weighted_results = index.query(weighted_query) + assert len(weighted_results) == 7 + + # test that weights do change the scores on results + unweighted_query = HybridQuery( + text=text, + text_field_name=text_field, + vector=vector, + vector_field_name=vector_field, + return_fields=return_fields, + text_scorer=scorer, + text_weights={}, + ) + + unweighted_results = index.query(unweighted_query) + + for weighted, unweighted in zip(weighted_results, unweighted_results): + for word in weights: + if word in weighted["description"] or word in unweighted["description"]: + assert float(weighted["text_score"]) > float(unweighted["text_score"]) + + # test that weights do change the document score and order of results + weights = {"medical": 5, "cancers": 3.4} # switch the weights + weighted_query = HybridQuery( + text=text, + text_field_name=text_field, + vector=vector, + vector_field_name=vector_field, + return_fields=return_fields, + text_scorer=scorer, + text_weights=weights, + ) + + weighted_results = index.query(weighted_query) + assert weighted_results != unweighted_results + + # test assigning weights on construction is equivalent to setting them on the query object + new_query = HybridQuery( + text=text, + text_field_name=text_field, + vector=vector, + vector_field_name=vector_field, + return_fields=return_fields, + text_scorer=scorer, + text_weights=None, + ) + + new_query.set_text_weights(weights) + + new_weighted_results = index.query(new_query) + assert new_weighted_results == weighted_results + + def test_multivector_query(index): skip_if_redis_version_below(index.client, "7.2.0") diff --git a/tests/integration/test_query.py b/tests/integration/test_query.py index bbd56339..588ee0bd 100644 --- a/tests/integration/test_query.py +++ b/tests/integration/test_query.py @@ -888,6 +888,72 @@ def test_text_query_with_text_filter(index): assert "research" not in result[text_field] +@pytest.mark.parametrize("scorer", ["BM25", "BM25STD", "TFIDF", "TFIDF.DOCNORM"]) +def test_text_query_word_weights(index, scorer): + skip_if_redis_version_below(index.client, "7.2.0") + + text = "a medical professional with expertise in lung cancers" + text_field = "description" + return_fields = ["description"] + + weights = {"medical": 3.4, "cancers": 5} + + # test we can run a query with text weights + weighted_query = TextQuery( + text=text, + text_field_name=text_field, + return_fields=return_fields, + text_scorer=scorer, + text_weights=weights, + ) + + weighted_results = index.query(weighted_query) + assert len(weighted_results) == 4 + + # test that weights do change the scores on results + unweighted_query = TextQuery( + text=text, + text_field_name=text_field, + return_fields=return_fields, + text_scorer=scorer, + text_weights={}, + ) + + unweighted_results = index.query(unweighted_query) + + for weighted, unweighted in zip(weighted_results, unweighted_results): + for word in weights: + if word in weighted["description"] or word in unweighted["description"]: + assert weighted["score"] > unweighted["score"] + + # test that weights do change the document score and order of results + weights = {"medical": 5, "cancers": 3.4} # switch the weights + weighted_query = TextQuery( + text=text, + text_field_name=text_field, + return_fields=return_fields, + text_scorer=scorer, + text_weights=weights, + ) + + weighted_results = index.query(weighted_query) + assert weighted_results != unweighted_results + + # test assigning weights on construction is equivalent to setting them on the query object + new_query = TextQuery( + text=text, + text_field_name=text_field, + return_fields=return_fields, + text_scorer=scorer, + text_weights=None, + ) + + new_query.set_text_weights(weights) + + new_weighted_results = index.query(new_query) + assert new_weighted_results == weighted_results + + def test_vector_query_with_ef_runtime(index, vector_query, sample_data): """ Integration test: Verify that setting EF_RUNTIME on a VectorQuery works correctly. diff --git a/tests/unit/test_aggregation_types.py b/tests/unit/test_aggregation_types.py index f2b6be86..3681fb36 100644 --- a/tests/unit/test_aggregation_types.py +++ b/tests/unit/test_aggregation_types.py @@ -196,6 +196,83 @@ def test_hybrid_query_with_string_filter(): assert "AND" not in query_string_wildcard +def test_hybrid_query_text_weights(): + # verify word weights get added into the raw Redis query syntax + vector = [0.1, 0.1, 0.5] + vector_field = "user_embedding" + + query = HybridQuery( + text="query string alpha bravo delta tango alpha", + text_field_name="description", + vector=vector, + vector_field_name=vector_field, + text_weights={"alpha": 2, "delta": 0.555, "gamma": 0.95}, + ) + + assert ( + str(query) + == "(~@description:(query | string | alpha=>{$weight:2} | bravo | delta=>{$weight:0.555} | tango | alpha=>{$weight:2}))=>[KNN 10 @user_embedding $vector AS vector_distance] SCORER BM25STD ADDSCORES DIALECT 2 APPLY (2 - @vector_distance)/2 AS vector_similarity APPLY @__score AS text_score APPLY 0.30000000000000004*@text_score + 0.7*@vector_similarity AS hybrid_score SORTBY 2 @hybrid_score DESC MAX 10" + ) + + # raise an error if weights are not positive floats + with pytest.raises(ValueError): + _ = HybridQuery( + text="sample text query", + text_field_name="description", + vector=vector, + vector_field_name=vector_field, + text_weights={"first": 0.2, "second": -0.1}, + ) + + with pytest.raises(ValueError): + _ = HybridQuery( + text="sample text query", + text_field_name="description", + vector=vector, + vector_field_name=vector_field, + text_weights={"first": 0.2, "second": "0.1"}, + ) + + # no error if weights dictionary is empty or None + query = HybridQuery( + text="sample text query", + text_field_name="description", + vector=vector, + vector_field_name=vector_field, + text_weights={}, + ) + assert query + + query = HybridQuery( + text="sample text query", + text_field_name="description", + vector=vector, + vector_field_name=vector_field, + text_weights=None, + ) + assert query + + # no error if the words in weights dictionary don't appear in query + query = HybridQuery( + text="sample text query", + text_field_name="description", + vector=vector, + vector_field_name=vector_field, + text_weights={"alpha": 0.2, "bravo": 0.4}, + ) + assert query + + # we can access the word weights on a query object + assert query.text_weights == {"alpha": 0.2, "bravo": 0.4} + + # we can change the text weights on a query object + query.set_text_weights(weights={"new": 0.3, "words": 0.125, "here": 99}) + assert query.text_weights == {"new": 0.3, "words": 0.125, "here": 99} + + query.set_text_weights(weights={}) + assert query.text_weights == {} + + def test_multi_vector_query(): # test we require Vector objects with pytest.raises(TypeError): diff --git a/tests/unit/test_query_types.py b/tests/unit/test_query_types.py index 7fdab1d0..de76056c 100644 --- a/tests/unit/test_query_types.py +++ b/tests/unit/test_query_types.py @@ -261,24 +261,10 @@ def test_text_query(): with pytest.raises(TypeError): text_query = TextQuery(text_string, text_field_name, stopwords=[1, 2, 3]) - text_query = TextQuery( - text_string, text_field_name, stopwords=set(["the", "a", "of"]) - ) - assert text_query.stopwords == set(["the", "a", "of"]) - - text_query = TextQuery(text_string, text_field_name, stopwords="german") - assert text_query.stopwords != set([]) - # test that filter expression is set correctly text_query.set_filter(filter_expression) assert text_query.filter == filter_expression - with pytest.raises(ValueError): - text_query = TextQuery(text_string, text_field_name, stopwords="gibberish") - - with pytest.raises(TypeError): - text_query = TextQuery(text_string, text_field_name, stopwords=[1, 2, 3]) - def test_text_query_with_string_filter(): """Test that TextQuery correctly includes string filter expressions in query string. @@ -347,6 +333,64 @@ def test_text_query_with_string_filter(): assert "AND" not in query_string_wildcard +def test_text_query_word_weights(): + # verify word weights get added into the raw Redis query syntax + query = TextQuery( + text="query string alpha bravo delta tango alpha", + text_field_name="description", + text_weights={"alpha": 2, "delta": 0.555, "gamma": 0.95}, + ) + + assert ( + str(query) + == "@description:(query | string | alpha=>{$weight:2} | bravo | delta=>{$weight:0.555} | tango | alpha=>{$weight:2}) SCORER BM25STD WITHSCORES DIALECT 2 LIMIT 0 10" + ) + + # raise an error if weights are not positive floats + with pytest.raises(ValueError): + _ = TextQuery( + text="sample text query", + text_field_name="description", + text_weights={"first": 0.2, "second": -0.1}, + ) + + with pytest.raises(ValueError): + _ = TextQuery( + text="sample text query", + text_field_name="description", + text_weights={"first": 0.2, "second": "0.1"}, + ) + + # no error if weights dictionary is empty or None + query = TextQuery( + text="sample text query", text_field_name="description", text_weights={} + ) + assert query + + query = TextQuery( + text="sample text query", text_field_name="description", text_weights=None + ) + assert query + + # no error if the words in weights dictionary don't appear in query + query = TextQuery( + text="sample text query", + text_field_name="description", + text_weights={"alpha": 0.2, "bravo": 0.4}, + ) + assert query + + # we can access the word weights on a query object + assert query.text_weights == {"alpha": 0.2, "bravo": 0.4} + + # we can change the text weights on a query object + query.set_text_weights(weights={"new": 0.3, "words": 0.125, "here": 99}) + assert query.text_weights == {"new": 0.3, "words": 0.125, "here": 99} + + query.set_text_weights(weights={}) + assert query.text_weights == {} + + @pytest.mark.parametrize( "query", [