Skip to content

Commit 0d0cfe6

Browse files
authored
Redis Search/Aggregate improved type annotations (#3676)
* Fix ft.aggregate variable arguments type annotation * Fix ft.aggregate query type annotation
1 parent cc5f129 commit 0d0cfe6

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

redis/commands/search/aggregation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ class Reducer:
2626

2727
NAME = None
2828

29-
def __init__(self, *args: List[str]) -> None:
29+
def __init__(self, *args: str) -> None:
3030
self._args = args
3131
self._field = None
3232
self._alias = None
@@ -116,7 +116,7 @@ def __init__(self, query: str = "*") -> None:
116116
self._add_scores = False
117117
self._scorer = "TFIDF"
118118

119-
def load(self, *fields: List[str]) -> "AggregateRequest":
119+
def load(self, *fields: str) -> "AggregateRequest":
120120
"""
121121
Indicate the fields to be returned in the response. These fields are
122122
returned in addition to any others implicitly specified.
@@ -223,7 +223,7 @@ def limit(self, offset: int, num: int) -> "AggregateRequest":
223223
self._aggregateplan.extend(_limit.build_args())
224224
return self
225225

226-
def sort_by(self, *fields: List[str], **kwargs) -> "AggregateRequest":
226+
def sort_by(self, *fields: str, **kwargs) -> "AggregateRequest":
227227
"""
228228
Indicate how the results should be sorted. This can also be used for
229229
*top-N* style queries

redis/commands/search/commands.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -542,7 +542,7 @@ def explain_cli(self, query: Union[str, Query]): # noqa
542542

543543
def aggregate(
544544
self,
545-
query: Union[str, Query],
545+
query: Union[AggregateRequest, Cursor],
546546
query_params: Dict[str, Union[str, int, float]] = None,
547547
):
548548
"""
@@ -573,7 +573,7 @@ def aggregate(
573573
)
574574

575575
def _get_aggregate_result(
576-
self, raw: List, query: Union[str, Query, AggregateRequest], has_cursor: bool
576+
self, raw: List, query: Union[AggregateRequest, Cursor], has_cursor: bool
577577
):
578578
if has_cursor:
579579
if isinstance(query, Cursor):
@@ -967,7 +967,7 @@ async def search(
967967

968968
async def aggregate(
969969
self,
970-
query: Union[str, Query],
970+
query: Union[AggregateResult, Cursor],
971971
query_params: Dict[str, Union[str, int, float]] = None,
972972
):
973973
"""

0 commit comments

Comments
 (0)