diff --git a/dev_requirements.txt b/dev_requirements.txt index 449b938cb9..ddc8df5660 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -28,3 +28,5 @@ numpy>=1.24.0,<2.0 ; platform_python_implementation == "PyPy" redis-entraid==1.0.0 pybreaker>=1.4.0 + +xxhash==3.6.0 diff --git a/pyproject.toml b/pyproject.toml index f1cedefead..5d3fb843cd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,12 +28,19 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = ['async-timeout>=4.0.3; python_full_version<"3.11.3"'] +dependencies = [ + 'async-timeout>=4.0.3; python_full_version<"3.11.3"', +] [project.optional-dependencies] hiredis = [ "hiredis>=3.2.0", ] + +xxhash = [ + 'xxhash~=3.6.0', +] + ocsp = [ "cryptography>=36.0.1", "pyopenssl>=20.0.1", diff --git a/redis/commands/core.py b/redis/commands/core.py index 525b31c99d..0e3c783004 100644 --- a/redis/commands/core.py +++ b/redis/commands/core.py @@ -2,6 +2,15 @@ import datetime import hashlib + +# Try to import the xxhash library as an optional dependency +try: + import xxhash + + HAS_XXHASH = True +except ImportError: + HAS_XXHASH = False + import warnings from enum import Enum from typing import ( @@ -1889,7 +1898,42 @@ def expiretime(self, key: str) -> int: return self.execute_command("EXPIRETIME", key) @experimental_method() - def digest(self, name: KeyT) -> Optional[str]: + def digest_local(self, value: Union[bytes, str]) -> Union[bytes, str]: + """ + Compute the hexadecimal digest of the value locally, without sending it to the server. + + This is useful for conditional operations like IFDEQ/IFDNE where you need to + compute the digest client-side before sending a command. + + Warning: + **Experimental** - This API may change or be removed without notice. + + Arguments: + - value: Union[bytes, str] - the value to compute the digest of. + + Returns: + - (str | bytes) the XXH3 digest of the value as a hex string (16 hex characters) + + For more information, see https://redis.io/commands/digest + """ + if not HAS_XXHASH: + raise NotImplementedError( + "XXHASH support requires the optional 'xxhash' library. " + "Install it with 'pip install xxhash' or use this package's extra with " + "'pip install redis[xxhash]' to enable this feature." + ) + + local_digest = xxhash.xxh3_64(value).hexdigest() + + # To align with digest, we want to return bytes if decode_responses is False. + # The following should work because Python's mixin approach. + if not self.get_encoder().decode_responses: + local_digest = local_digest.encode() + + return local_digest + + @experimental_method() + def digest(self, name: KeyT) -> Union[str, bytes, None]: """ Return the digest of the value stored at the specified key. diff --git a/tests/test_asyncio/test_commands.py b/tests/test_asyncio/test_commands.py index 34a6017d22..fedc518f75 100644 --- a/tests/test_asyncio/test_commands.py +++ b/tests/test_asyncio/test_commands.py @@ -1264,6 +1264,30 @@ async def test_digest_response_when_available(self, r, value): assert len(res) == 16 + @skip_if_server_version_lt("8.3.224") + @pytest.mark.parametrize( + "value", [b"", b"abc", b"The quick brown fox jumps over the lazy dog"] + ) + async def test_local_digest_matches_server(self, r, value): + key = "k:digest" + await r.delete(key) + await r.set(key, value) + + res_server = await r.digest(key) + + # Caution! This one is not executing execute_command and it is not async + res_local = r.digest_local(value) + + # got is str if decode_responses=True; ensure bytes->str for comparison + if isinstance(res_server, bytes): + assert isinstance(res_local, bytes) + + assert res_server is not None + assert len(res_server) == 16 + assert res_local is not None + assert len(res_local) == 16 + assert res_server == res_local + @skip_if_server_version_lt("8.3.224") async def test_pipeline_digest(self, r): k1, k2 = "k:d1{42}", "k:d2{42}" diff --git a/tests/test_commands.py b/tests/test_commands.py index 4efc26f5c9..988a18d397 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -1807,6 +1807,7 @@ def test_digest_response_when_available(self, r, value): r.set(key, value) res = r.digest(key) + # got is str if decode_responses=True; ensure bytes->str for comparison if isinstance(res, bytes): res = res.decode() @@ -1815,6 +1816,28 @@ def test_digest_response_when_available(self, r, value): assert len(res) == 16 + @skip_if_server_version_lt("8.3.224") + @pytest.mark.parametrize( + "value", [b"", b"abc", b"The quick brown fox jumps over the lazy dog"] + ) + def test_local_digest_matches_server(self, r, value): + key = "k:digest" + r.delete(key) + r.set(key, value) + + res_server = r.digest(key) + res_local = r.digest_local(value) + + # got is str if decode_responses=True; ensure bytes->str for comparison + if isinstance(res_server, bytes): + assert isinstance(res_local, bytes) + + assert res_server is not None + assert len(res_server) == 16 + assert res_local is not None + assert len(res_local) == 16 + assert res_server == res_local + @skip_if_server_version_lt("8.3.224") def test_pipeline_digest(self, r): k1, k2 = "k:d1{42}", "k:d2{42}" @@ -2588,6 +2611,9 @@ def test_set_ifdeq_and_ifdne(self, r, val): d = self._server_xxh3_digest(r, "k") assert d is not None + # sanity check: local digest matches server's + assert d == self._ensure_str(r.digest_local(val)) + # IFDEQ must match to set; if key missing => won't create assert r.set("k", b"X", ifdeq=d) is True assert r.get("k") == b"X"