From 4a8d37c7e5772eac67860abc9f8db2e219704285 Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Mon, 2 Jun 2025 16:26:18 +0300 Subject: [PATCH 1/2] Fixing mypy errors in redis/commands/search/query.py --- redis/commands/search/query.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/redis/commands/search/query.py b/redis/commands/search/query.py index a8312a2ad2..615e6d10fa 100644 --- a/redis/commands/search/query.py +++ b/redis/commands/search/query.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Union +from typing import List, Optional, Tuple, Union from redis.commands.search.dialect import DEFAULT_DIALECT @@ -31,7 +31,7 @@ def __init__(self, query_string: str) -> None: self._with_scores: bool = False self._scorer: Optional[str] = None self._filters: List = list() - self._ids: Optional[List[str]] = None + self._ids: Optional[Tuple[str]] = None self._slop: int = -1 self._timeout: Optional[float] = None self._in_order: bool = False @@ -81,7 +81,7 @@ def return_field( self._return_fields += ("AS", as_field) return self - def _mk_field_list(self, fields: List[str]) -> List: + def _mk_field_list(self, fields: Optional[Union[List[str], str]]) -> List: if not fields: return [] return [fields] if isinstance(fields, str) else list(fields) @@ -126,7 +126,7 @@ def summarize( def highlight( self, fields: Optional[List[str]] = None, tags: Optional[List[str]] = None - ) -> None: + ) -> "Query": """ Apply specified markup to matched term(s) within the returned field(s). @@ -187,16 +187,16 @@ def scorer(self, scorer: str) -> "Query": self._scorer = scorer return self - def get_args(self) -> List[str]: + def get_args(self) -> List[Union[str, int, float]]: """Format the redis arguments for this query and return them.""" - args = [self._query_string] + args: List[Union[str, int, float]] = [self._query_string] args += self._get_args_tags() args += self._summarize_fields + self._highlight_fields args += ["LIMIT", self._offset, self._num] return args - def _get_args_tags(self) -> List[str]: - args = [] + def _get_args_tags(self) -> List[Union[str, int, float]]: + args: List[Union[str, int, float]] = [] if self._no_content: args.append("NOCONTENT") if self._fields: @@ -288,14 +288,14 @@ def with_scores(self) -> "Query": self._with_scores = True return self - def limit_fields(self, *fields: List[str]) -> "Query": + def limit_fields(self, *fields: str) -> "Query": """ Limit the search to specific TEXT fields only. - - **fields**: A list of strings, case sensitive field names + - **fields**: Each element should be a string, case sensitive field name from the defined schema. """ - self._fields = fields + self._fields = list(fields) return self def add_filter(self, flt: "Filter") -> "Query": @@ -340,7 +340,7 @@ def dialect(self, dialect: int) -> "Query": class Filter: - def __init__(self, keyword: str, field: str, *args: List[str]) -> None: + def __init__(self, keyword: str, field: str, *args: Union[str, float]) -> None: self.args = [keyword, field] + list(args) From 4a284e1e123593284c5dcfca80c897980d6be494 Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Mon, 16 Jun 2025 15:46:56 +0300 Subject: [PATCH 2/2] Fixing mypy errors in redis/commands/search/aggregation.py --- redis/commands/search/aggregation.py | 58 ++++++++++++++-------------- redis/commands/search/commands.py | 4 +- 2 files changed, 30 insertions(+), 32 deletions(-) diff --git a/redis/commands/search/aggregation.py b/redis/commands/search/aggregation.py index 13edefa081..45e4325022 100644 --- a/redis/commands/search/aggregation.py +++ b/redis/commands/search/aggregation.py @@ -1,4 +1,4 @@ -from typing import List, Union +from typing import List, Optional, Union from redis.commands.search.dialect import DEFAULT_DIALECT @@ -26,10 +26,10 @@ class Reducer: NAME = None - def __init__(self, *args: List[str]) -> None: - self._args = args - self._field = None - self._alias = None + def __init__(self, *args: str) -> None: + self._args: tuple[str, ...] = args + self._field: Optional[str] = None + self._alias: Optional[str] = None def alias(self, alias: str) -> "Reducer": """ @@ -49,13 +49,14 @@ def alias(self, alias: str) -> "Reducer": if alias is FIELDNAME: if not self._field: raise ValueError("Cannot use FIELDNAME alias with no field") - # Chop off initial '@' - alias = self._field[1:] + else: + # Chop off initial '@' + alias = self._field[1:] self._alias = alias return self @property - def args(self) -> List[str]: + def args(self) -> tuple[str, ...]: return self._args @@ -64,7 +65,7 @@ class SortDirection: This special class is used to indicate sort direction. """ - DIRSTRING = None + DIRSTRING: Optional[str] = None def __init__(self, field: str) -> None: self.field = field @@ -104,19 +105,19 @@ def __init__(self, query: str = "*") -> None: All member methods (except `build_args()`) return the object itself, making them useful for chaining. """ - self._query = query - self._aggregateplan = [] - self._loadfields = [] - self._loadall = False - self._max = 0 - self._with_schema = False - self._verbatim = False - self._cursor = [] - self._dialect = DEFAULT_DIALECT - self._add_scores = False - self._scorer = "TFIDF" - - def load(self, *fields: List[str]) -> "AggregateRequest": + self._query: str = query + self._aggregateplan: List[str] = [] + self._loadfields: List[str] = [] + self._loadall: bool = False + self._max: int = 0 + self._with_schema: bool = False + self._verbatim: bool = False + self._cursor: List[str] = [] + self._dialect: int = DEFAULT_DIALECT + self._add_scores: bool = False + self._scorer: str = "TFIDF" + + def load(self, *fields: str) -> "AggregateRequest": """ Indicate the fields to be returned in the response. These fields are returned in addition to any others implicitly specified. @@ -133,7 +134,7 @@ def load(self, *fields: List[str]) -> "AggregateRequest": return self def group_by( - self, fields: List[str], *reducers: Union[Reducer, List[Reducer]] + self, fields: Union[str, List[str]], *reducers: Reducer ) -> "AggregateRequest": """ Specify by which fields to group the aggregation. @@ -147,7 +148,6 @@ def group_by( `aggregation` module. """ fields = [fields] if isinstance(fields, str) else fields - reducers = [reducers] if isinstance(reducers, Reducer) else reducers ret = ["GROUPBY", str(len(fields)), *fields] for reducer in reducers: @@ -223,7 +223,7 @@ def limit(self, offset: int, num: int) -> "AggregateRequest": self._aggregateplan.extend(_limit.build_args()) return self - def sort_by(self, *fields: List[str], **kwargs) -> "AggregateRequest": + def sort_by(self, *fields: str, **kwargs) -> "AggregateRequest": """ Indicate how the results should be sorted. This can also be used for *top-N* style queries @@ -251,12 +251,10 @@ def sort_by(self, *fields: List[str], **kwargs) -> "AggregateRequest": .sort_by(Desc("@paid"), max=10) ``` """ - if isinstance(fields, (str, SortDirection)): - fields = [fields] fields_args = [] for f in fields: - if isinstance(f, SortDirection): + if isinstance(f, (Asc, Desc)): fields_args += [f.field, f.DIRSTRING] else: fields_args += [f] @@ -356,7 +354,7 @@ def build_args(self) -> List[str]: ret.extend(self._loadfields) if self._dialect: - ret.extend(["DIALECT", self._dialect]) + ret.extend(["DIALECT", str(self._dialect)]) ret.extend(self._aggregateplan) @@ -393,7 +391,7 @@ def __init__(self, rows, cursor: Cursor, schema) -> None: self.cursor = cursor self.schema = schema - def __repr__(self) -> (str, str): + def __repr__(self) -> str: cid = self.cursor.cid if self.cursor else -1 return ( f"<{self.__class__.__name__} at 0x{id(self):x} " diff --git a/redis/commands/search/commands.py b/redis/commands/search/commands.py index bc48fa9aa8..cdf71723f6 100644 --- a/redis/commands/search/commands.py +++ b/redis/commands/search/commands.py @@ -464,7 +464,7 @@ def info(self): return self._parse_results(INFO_CMD, res) def get_params_args( - self, query_params: Union[Dict[str, Union[str, int, float, bytes]], None] + self, query_params: Optional[Dict[str, Union[str, int, float, bytes]]] ): if query_params is None: return [] @@ -543,7 +543,7 @@ def explain_cli(self, query: Union[str, Query]): # noqa def aggregate( self, query: Union[str, Query], - query_params: Dict[str, Union[str, int, float]] = None, + query_params: Optional[Dict[str, Union[str, int, float]]] = None, ): """ Issue an aggregation query.