Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 56 additions & 5 deletions redisvl/query/aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def __init__(
return_fields: Optional[List[str]] = None,
stopwords: Optional[Union[str, Set[str]]] = "english",
dialect: int = 2,
text_weights: Optional[Dict[str, float]] = None,
):
"""
Instantiates a HybridQuery object.
Expand All @@ -119,6 +120,9 @@ def __init__(
set, or tuple of strings is provided then those will be used as stopwords.
Defaults to "english". if set to "None" then no stopwords will be removed.
dialect (int, optional): The Redis dialect version. Defaults to 2.
text_weights (Optional[Dict[str, float]]): The importance weighting of individual words
within the query text. Defaults to None, as no modifications will be made to the
text_scorer score.

Raises:
ValueError: If the text string is empty, or if the text string becomes empty after
Expand All @@ -138,6 +142,7 @@ def __init__(
self._dtype = dtype
self._num_results = num_results
self._set_stopwords(stopwords)
self._text_weights = self._parse_text_weights(text_weights)

query_string = self._build_query_string()
super().__init__(query_string)
Expand Down Expand Up @@ -185,6 +190,7 @@ def _set_stopwords(self, stopwords: Optional[Union[str, Set[str]]] = "english"):
language will be used. if a list, set, or tuple of strings is provided then those
will be used as stopwords. Defaults to "english". if set to "None" then no stopwords
will be removed.

Raises:
TypeError: If the stopwords are not a set, list, or tuple of strings.
"""
Expand Down Expand Up @@ -214,6 +220,7 @@ def _tokenize_and_escape_query(self, user_query: str) -> str:

Returns:
str: The tokenized and escaped query string.

Raises:
ValueError: If the text string becomes empty after stopwords are removed.
"""
Expand All @@ -225,13 +232,57 @@ def _tokenize_and_escape_query(self, user_query: str) -> str:
)
for token in user_query.split()
]
tokenized = " | ".join(
[token for token in tokens if token and token not in self._stopwords]
)

if not tokenized:
token_list = [
token for token in tokens if token and token not in self._stopwords
]
for i, token in enumerate(token_list):
if token in self._text_weights:
token_list[i] = f"{token}=>{{$weight:{self._text_weights[token]}}}"

if not token_list:
raise ValueError("text string cannot be empty after removing stopwords")
return tokenized
return " | ".join(token_list)

def _parse_text_weights(
self, weights: Optional[Dict[str, float]]
) -> Dict[str, float]:
parsed_weights: Dict[str, float] = {}
if not weights:
return parsed_weights
for word, weight in weights.items():
word = word.strip().lower()
if not word or " " in word:
raise ValueError(
f"Only individual words may be weighted. Got {{ {word}:{weight} }}"
)
if (
not (isinstance(weight, float) or isinstance(weight, int))
or weight < 0.0
):
raise ValueError(
f"Weights must be positive number. Got {{ {word}:{weight} }}"
)
parsed_weights[word] = weight
return parsed_weights

def set_text_weights(self, weights: Dict[str, float]):
"""Set or update the text weights for the query.

Args:
text_weights: Dictionary of word:weight mappings
"""
self._text_weights = self._parse_text_weights(weights)
self._query = self._build_query_string()

@property
def text_weights(self) -> Dict[str, float]:
"""Get the text weights.

Returns:
Dictionary of word:weight mappings.
"""
return self._text_weights

def _build_query_string(self) -> str:
"""Build the full query string for text search with optional filtering."""
Expand Down
57 changes: 53 additions & 4 deletions redisvl/query/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -1028,6 +1028,7 @@ def __init__(
in_order: bool = False,
params: Optional[Dict[str, Any]] = None,
stopwords: Optional[Union[str, Set[str]]] = "english",
text_weights: Optional[Dict[str, float]] = None,
):
"""A query for running a full text search, along with an optional filter expression.

Expand Down Expand Up @@ -1064,13 +1065,16 @@ def __init__(
a default set of stopwords for that language will be used. Users may specify
their own stop words by providing a List or Set of words. if set to None,
then no words will be removed. Defaults to 'english'.

text_weights (Optional[Dict[str, float]]): The importance weighting of individual words
within the query text. Defaults to None, as no modifications will be made to the
text_scorer score.
Raises:
ValueError: if stopwords language string cannot be loaded.
TypeError: If stopwords is not a valid iterable set of strings.
"""
self._text = text
self._field_weights = self._parse_field_weights(text_field_name)
self._text_weights = self._parse_text_weights(text_weights)
self._num_results = num_results

self._set_stopwords(stopwords)
Expand Down Expand Up @@ -1151,9 +1155,14 @@ def _tokenize_and_escape_query(self, user_query: str) -> str:
)
for token in user_query.split()
]
return " | ".join(
[token for token in tokens if token and token not in self._stopwords]
)
token_list = [
token for token in tokens if token and token not in self._stopwords
]
for i, token in enumerate(token_list):
if token in self._text_weights:
token_list[i] = f"{token}=>{{$weight:{self._text_weights[token]}}}"

return " | ".join(token_list)

def _parse_field_weights(
self, field_spec: Union[str, Dict[str, float]]
Expand Down Expand Up @@ -1220,6 +1229,46 @@ def text_field_name(self) -> Union[str, Dict[str, float]]:
return field
return self._field_weights.copy()

def _parse_text_weights(
self, weights: Optional[Dict[str, float]]
) -> Dict[str, float]:
parsed_weights: Dict[str, float] = {}
if not weights:
return parsed_weights
for word, weight in weights.items():
word = word.strip().lower()
if not word or " " in word:
raise ValueError(
f"Only individual words may be weighted. Got {{ {word}:{weight} }}"
)
if (
not (isinstance(weight, float) or isinstance(weight, int))
or weight < 0.0
):
raise ValueError(
f"Weights must be positive number. Got {{ {word}:{weight} }}"
)
parsed_weights[word] = weight
return parsed_weights

def set_text_weights(self, weights: Dict[str, float]):
"""Set or update the text weights for the query.

Args:
text_weights: Dictionary of word:weight mappings
"""
self._text_weights = self._parse_text_weights(weights)
self._built_query_string = None

@property
def text_weights(self) -> Dict[str, float]:
"""Get the text weights.

Returns:
Dictionary of word:weight mappings.
"""
return self._text_weights

def _build_query_string(self) -> str:
"""Build the full query string for text search with optional filtering."""
filter_expression = self._filter_expression
Expand Down
76 changes: 76 additions & 0 deletions tests/integration/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,82 @@ def test_hybrid_query_with_text_filter(index):
assert "research" not in result[text_field].lower()


@pytest.mark.parametrize("scorer", ["BM25", "BM25STD", "TFIDF", "TFIDF.DOCNORM"])
def test_hybrid_query_word_weights(index, scorer):
skip_if_redis_version_below(index.client, "7.2.0")

text = "a medical professional with expertise in lung cancers"
text_field = "description"
vector = [0.1, 0.1, 0.5]
vector_field = "user_embedding"
return_fields = ["description"]

weights = {"medical": 3.4, "cancers": 5}

# test we can run a query with text weights
weighted_query = HybridQuery(
text=text,
text_field_name=text_field,
vector=vector,
vector_field_name=vector_field,
return_fields=return_fields,
text_scorer=scorer,
text_weights=weights,
)

weighted_results = index.query(weighted_query)
assert len(weighted_results) == 7

# test that weights do change the scores on results
unweighted_query = HybridQuery(
text=text,
text_field_name=text_field,
vector=vector,
vector_field_name=vector_field,
return_fields=return_fields,
text_scorer=scorer,
text_weights={},
)

unweighted_results = index.query(unweighted_query)

for weighted, unweighted in zip(weighted_results, unweighted_results):
for word in weights:
if word in weighted["description"] or word in unweighted["description"]:
assert float(weighted["text_score"]) > float(unweighted["text_score"])

# test that weights do change the document score and order of results
weights = {"medical": 5, "cancers": 3.4} # switch the weights
weighted_query = HybridQuery(
text=text,
text_field_name=text_field,
vector=vector,
vector_field_name=vector_field,
return_fields=return_fields,
text_scorer=scorer,
text_weights=weights,
)

weighted_results = index.query(weighted_query)
assert weighted_results != unweighted_results

# test assigning weights on construction is equivalent to setting them on the query object
new_query = HybridQuery(
text=text,
text_field_name=text_field,
vector=vector,
vector_field_name=vector_field,
return_fields=return_fields,
text_scorer=scorer,
text_weights=None,
)

new_query.set_text_weights(weights)

new_weighted_results = index.query(new_query)
assert new_weighted_results == weighted_results


def test_multivector_query(index):
skip_if_redis_version_below(index.client, "7.2.0")

Expand Down
66 changes: 66 additions & 0 deletions tests/integration/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -888,6 +888,72 @@ def test_text_query_with_text_filter(index):
assert "research" not in result[text_field]


@pytest.mark.parametrize("scorer", ["BM25", "BM25STD", "TFIDF", "TFIDF.DOCNORM"])
def test_text_query_word_weights(index, scorer):
skip_if_redis_version_below(index.client, "7.2.0")

text = "a medical professional with expertise in lung cancers"
text_field = "description"
return_fields = ["description"]

weights = {"medical": 3.4, "cancers": 5}

# test we can run a query with text weights
weighted_query = TextQuery(
text=text,
text_field_name=text_field,
return_fields=return_fields,
text_scorer=scorer,
text_weights=weights,
)

weighted_results = index.query(weighted_query)
assert len(weighted_results) == 4

# test that weights do change the scores on results
unweighted_query = TextQuery(
text=text,
text_field_name=text_field,
return_fields=return_fields,
text_scorer=scorer,
text_weights={},
)

unweighted_results = index.query(unweighted_query)

for weighted, unweighted in zip(weighted_results, unweighted_results):
for word in weights:
if word in weighted["description"] or word in unweighted["description"]:
assert weighted["score"] > unweighted["score"]

# test that weights do change the document score and order of results
weights = {"medical": 5, "cancers": 3.4} # switch the weights
weighted_query = TextQuery(
text=text,
text_field_name=text_field,
return_fields=return_fields,
text_scorer=scorer,
text_weights=weights,
)

weighted_results = index.query(weighted_query)
assert weighted_results != unweighted_results

# test assigning weights on construction is equivalent to setting them on the query object
new_query = TextQuery(
text=text,
text_field_name=text_field,
return_fields=return_fields,
text_scorer=scorer,
text_weights=None,
)

new_query.set_text_weights(weights)

new_weighted_results = index.query(new_query)
assert new_weighted_results == weighted_results


def test_vector_query_with_ef_runtime(index, vector_query, sample_data):
"""
Integration test: Verify that setting EF_RUNTIME on a VectorQuery works correctly.
Expand Down
Loading