diff --git a/changelog.d/721.feature b/changelog.d/721.feature new file mode 100644 index 00000000..d2da078b --- /dev/null +++ b/changelog.d/721.feature @@ -0,0 +1 @@ +Introduction of sets and support of basic operations. \ No newline at end of file diff --git a/django_redis/cache.py b/django_redis/cache.py index d26c33fa..f9bddd09 100644 --- a/django_redis/cache.py +++ b/django_redis/cache.py @@ -204,3 +204,75 @@ def hkeys(self, *args, **kwargs): @omit_exception def hexists(self, *args, **kwargs): return self.client.hexists(*args, **kwargs) + + @omit_exception + def sadd(self, *args, **kwargs): + return self.client.sadd(*args, **kwargs) + + @omit_exception + def scard(self, *args, **kwargs): + return self.client.scard(*args, **kwargs) + + @omit_exception + def sdiff(self, *args, **kwargs): + return self.client.sdiff(*args, **kwargs) + + @omit_exception + def sdiffstore(self, *args, **kwargs): + return self.client.sdiffstore(*args, **kwargs) + + @omit_exception + def sinter(self, *args, **kwargs): + return self.client.sinter(*args, **kwargs) + + @omit_exception + def sinterstore(self, *args, **kwargs): + return self.client.sinterstore(*args, **kwargs) + + @omit_exception + def sismember(self, *args, **kwargs): + return self.client.sismember(*args, **kwargs) + + @omit_exception + def smismember(self, *args, **kwargs): + return self.client.smismember(*args, **kwargs) + + @omit_exception + def smembers(self, *args, **kwargs): + return self.client.smembers(*args, **kwargs) + + @omit_exception + def smove(self, *args, **kwargs): + return self.client.smove(*args, **kwargs) + + @omit_exception + def spop(self, *args, **kwargs): + return self.client.spop(*args, **kwargs) + + @omit_exception + def srandmember(self, *args, **kwargs): + return self.client.srandmember(*args, **kwargs) + + @omit_exception + def srem(self, *args, **kwargs): + return self.client.srem(*args, **kwargs) + + @omit_exception + def sunion(self, *args, **kwargs): + return self.client.sunion(*args, **kwargs) + + @omit_exception + def sunionstore(self, *args, **kwargs): + return self.client.sunionstore(*args, **kwargs) + + @omit_exception + def sintercard(self, *args, **kwargs): + return self.client.sintercard(*args, **kwargs) + + @omit_exception + def smismember(self, *args, **kwargs): + return self.client.smismember(*args, **kwargs) + + @omit_exception + def sscan(self, *args, **kwargs): + return self.client.sscan(*args, **kwargs) diff --git a/django_redis/client/default.py b/django_redis/client/default.py index 7850d3c7..523f81cd 100644 --- a/django_redis/client/default.py +++ b/django_redis/client/default.py @@ -3,7 +3,7 @@ import socket from collections import OrderedDict from contextlib import suppress -from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple, Union +from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple, Union from django.conf import settings from django.core.cache.backends.base import DEFAULT_TIMEOUT, BaseCache, get_key_func @@ -889,3 +889,265 @@ def hexists( client = self.get_client(write=False) nkey = self.make_key(key, version=version) return bool(client.hexists(name, nkey)) + + def sadd( + self, + key: Any, + *values: Any, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> int: + if client is None: + client = self.get_client(write=True) + + key = self.make_key(key, version=version) + values = [self.encode(value) for value in values] + return int(client.sadd(key, *values)) + + def scard( + self, + key: Any, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> int: + if client is None: + client = self.get_client(write=False) + + key = self.make_key(key, version=version) + return int(client.scard(key)) + + def sdiff( + self, + *keys, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> Set: + if client is None: + client = self.get_client(write=False) + + keys = [self.make_key(key, version=version) for key in keys] + return {self.decode(value) for value in client.sdiff(*keys)} + + def sdiffstore( + self, + dest: Any, + key: Any, + *keys, + version_dest: Optional[int] = None, + version_minuend: Optional[int] = None, + version_subtrahend: Optional[int] = None, + client: Optional[Redis] = None, + ) -> int: + if client is None: + client = self.get_client(write=True) + + dest = self.make_key(dest, version=version_dest) + minuend_key = self.make_key(key, version=version_minuend) + subtrahend_keys: Set[str] = { + self.make_key(key_, version=version_subtrahend) for key_ in keys + } + return int(client.sdiffstore(dest, minuend_key, *subtrahend_keys)) + + def sinter( + self, + *keys, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> Set: + if client is None: + client = self.get_client(write=False) + + keys = [self.make_key(key, version=version) for key in keys] + return {self.decode(value) for value in client.sinter(*keys)} + + def sinterstore( + self, + dest: Any, + *keys, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> int: + if client is None: + client = self.get_client(write=True) + + dest = self.make_key(dest, version=version) + keys = [self.make_key(key, version=version) for key in keys] + return int(client.sinterstore(dest, *keys)) + + def sismember( + self, + key: Any, + member: Any, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> bool: + if client is None: + client = self.get_client(write=False) + + key = self.make_key(key, version=version) + member = self.encode(member) + return bool(client.sismember(key, member)) + + def smembers( + self, + key: Any, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> Set: + if client is None: + client = self.get_client(write=False) + + key = self.make_key(key, version=version) + return {self.decode(value) for value in client.smembers(key)} + + def smove( + self, + source: Any, + destination: Any, + member: Any, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> bool: + if client is None: + client = self.get_client(write=True) + + source = self.make_key(source, version=version) + destination = self.make_key(destination) + member = self.encode(member) + return bool(client.smove(source, destination, member)) + + def spop( + self, + key: Any, + count: Optional[int] = None, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> Union[Set, Any]: + if client is None: + client = self.get_client(write=True) + + key = self.make_key(key, version=version) + result = client.spop(key, count) + if isinstance(result, list): + return {self.decode(value) for value in result} + return self.decode(result) + + def srandmember( + self, + key: Any, + count: Optional[int] = None, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> Union[Set, Any]: + if client is None: + client = self.get_client(write=False) + + key = self.make_key(key, version=version) + result = client.srandmember(key, count) + if isinstance(result, list): + return {self.decode(value) for value in result} + return self.decode(result) + + def srem( + self, + key: Any, + *members, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> int: + if client is None: + client = self.get_client(write=True) + + key = self.make_key(key, version=version) + members = [self.decode(member) for member in members] + return int(client.srem(key, *members)) + + def sunion( + self, + *keys, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> Set: + if client is None: + client = self.get_client(write=False) + + keys = [self.make_key(key, version=version) for key in keys] + return {self.decode(value) for value in client.sunion(*keys)} + + def sunionstore( + self, + destination: Any, + *keys, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> int: + if client is None: + client = self.get_client(write=True) + + destination = self.make_key(destination, version=version) + keys = [self.make_key(key, version=version) for key in keys] + return int(client.sunionstore(destination, *keys)) + + def sintercard( + self, + *keys, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> int: + if client is None: + client = self.get_client(write=True) + + keys = [self.make_key(key, version=version) for key in keys] + + result_key = "__temp_inter_key__" + client.sinterstore(result_key, *keys) + cardinality = client.scard(result_key) + client.delete(result_key) + + return cardinality + + def smismember( + self, + key: Any, + members: Any, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> bool: + if client is None: + client = self.get_client(write=False) + + key = self.make_key(key, version=version) + members = [self.encode(member) for member in members] + + with client.pipeline() as pipe: + for member in members: + pipe.sismember(key, member) + + results = pipe.execute() + + return all(bool(result) for result in results) + + def sscan( + self, + key: Any, + cursor: int = 0, + match: Optional[str] = None, + count: Optional[int] = None, + version: Optional[int] = None, + client: Optional[Redis] = None, + ) -> Set[Any]: + if client is None: + client = self.get_client(write=False) + + key = self.make_key(key, version=version) + elements = set() + + while True: + result = client.sscan(key, cursor, match=match, count=count) + cursor, partial_elements = result + elements.update(self.decode(value) for value in partial_elements) + + if cursor == 0: + break + + return elements diff --git a/tests/test_backend.py b/tests/test_backend.py index 550ce79c..2fb2a4ba 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -845,3 +845,96 @@ def test_hexists(self, cache: RedisCache): cache.hset("foo_hash5", "foo1", "bar1") assert cache.hexists("foo_hash5", "foo1") assert not cache.hexists("foo_hash5", "foo") + + def test_sadd(self, cache: RedisCache): + assert cache.sadd("foo", "bar") == 1 + assert cache.smembers("foo") == {"bar"} + + def test_scard(self, cache: RedisCache): + cache.sadd("foo", "bar", "bar2") + assert cache.scard("foo") == 2 + + def test_sdiff(self, cache: RedisCache): + cache.sadd("foo1", "bar1", "bar2") + cache.sadd("foo2", "bar2", "bar3") + assert cache.sdiff("foo1", "foo2") == {"bar1"} + + def test_sdiffstore(self, cache: RedisCache): + cache.sadd("foo1", "bar1", "bar2") + cache.sadd("foo2", "bar2", "bar3") + assert cache.sdiffstore("foo3", "foo1", "foo2") == 1 + assert cache.smembers("foo3") == {"bar1"} + + def test_sinter(self, cache: RedisCache): + cache.sadd("foo1", "bar1", "bar2") + cache.sadd("foo2", "bar2", "bar3") + assert cache.sinter("foo1", "foo2") == {"bar2"} + + def test_interstore(self, cache: RedisCache): + cache.sadd("foo1", "bar1", "bar2") + cache.sadd("foo2", "bar2", "bar3") + assert cache.sinterstore("foo3", "foo1", "foo2") == 1 + assert cache.smembers("foo3") == {"bar2"} + + def test_sismember(self, cache: RedisCache): + cache.sadd("foo", "bar") + assert cache.sismember("foo", "bar") is True + assert cache.sismember("foo", "bar2") is False + + def test_smove(self, cache: RedisCache): + cache.sadd("foo1", "bar1", "bar2") + cache.sadd("foo2", "bar2", "bar3") + assert cache.smove("foo1", "foo2", "bar1") is True + assert cache.smove("foo1", "foo2", "bar4") is False + assert cache.smembers("foo1") == {"bar2"} + assert cache.smembers("foo2") == {"bar1", "bar2", "bar3"} + + def test_spop_default_count(self, cache: RedisCache): + cache.sadd("foo", "bar1", "bar2") + assert cache.spop("foo") in {"bar1", "bar2"} + assert cache.smembers("foo") in {{"bar1"}, {"bar2"}} + + def test_spop(self, cache: RedisCache): + cache.sadd("foo", "bar1", "bar2") + assert cache.spop("foo", 1) in {{"bar1"}, {"bar2"}} + assert cache.smembers("foo") in {{"bar1"}, {"bar2"}} + + def test_srandmember_default_count(self, cache: RedisCache): + cache.sadd("foo", "bar1", "bar2") + assert cache.srandmember("foo") in {"bar1", "bar2"} + + def test_srandmember(self, cache: RedisCache): + cache.sadd("foo", "bar1", "bar2") + assert cache.srandmember("foo", 1) in {{"bar1"}, {"bar2"}} + + def test_srem(self, cache: RedisCache): + cache.sadd("foo", "bar1", "bar2") + assert cache.srem("foo", "bar1") == 1 + assert cache.srem("foo", "bar3") == 0 + + def test_sunion(self, cache: RedisCache): + cache.sadd("foo1", "bar1", "bar2") + cache.sadd("foo2", "bar2", "bar3") + assert cache.sunion("foo1", "foo2") == {"bar1", "bar2", "bar3"} + + def test_sunionstore(self, cache: RedisCache): + cache.sadd("foo1", "bar1", "bar2") + cache.sadd("foo2", "bar2", "bar3") + assert cache.sunionstore("foo3", "foo1", "foo2") == 3 + assert cache.smembers("foo3") == {"bar1", "bar2", "bar3"} + + def test_sintercard(self, cache: RedisCache): + cache.sadd("foo1", "bar1", "bar2") + cache.sadd("foo2", "bar2", "bar3") + assert cache.sintercard("foo1", "foo2") == 1 + + def test_smismember(self, cache: RedisCache): + cache.sadd("foo", "bar") + assert cache.smismember("foo", "bar") is True + assert cache.smismember("foo", "bar2") is False + + def test_sscan(self, cache: RedisCache): + cache.sadd("foo", "bar1", "bar2", "bar3", "bar4", "bar5") + cursor, members = cache.sscan("foo", match="bar*") + assert cursor == 0 # Assuming there is only one iteration for simplicity + assert set(members) == {"bar1", "bar2", "bar3", "bar4", "bar5"}