diff --git a/django_redis/client/default.py b/django_redis/client/default.py index 1df90a27..f1a074bb 100644 --- a/django_redis/client/default.py +++ b/django_redis/client/default.py @@ -132,6 +132,7 @@ def set( client: Optional[Redis] = None, nx: bool = False, xx: bool = False, + enforce_encoding: bool = False, ) -> bool: """ Persist a value to the cache, and set an optional expiration time. @@ -140,7 +141,7 @@ def set( setnx instead of set. """ nkey = self.make_key(key, version=version) - nvalue = self.encode(value) + nvalue = self.encode(value, enforce_encoding=enforce_encoding) if timeout is DEFAULT_TIMEOUT: timeout = self._backend.default_timeout @@ -448,12 +449,16 @@ def decode(self, value: Union[bytes, int]) -> Any: value = self._serializer.loads(value) return value - def encode(self, value: Any) -> Union[bytes, Any]: + def encode(self, value: Any, enforce_encoding: bool = False) -> Union[bytes, Any]: """ Encode the given value. """ - if isinstance(value, bool) or not isinstance(value, int): + if ( + isinstance(value, bool) + or not isinstance(value, int) + or enforce_encoding is True + ): value = self._serializer.dumps(value) value = self._compressor.compress(value) return value diff --git a/tests/test_backend.py b/tests/test_backend.py index 0e8e1fdf..28c1c09c 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -2,6 +2,7 @@ import threading import time from datetime import timedelta +from enum import IntEnum from typing import List, Union, cast from unittest.mock import patch @@ -19,6 +20,11 @@ herd.CACHE_HERD_TIMEOUT = 2 +class Values2(IntEnum): + SOMETHING_1 = 1 + SOMETHING_2 = 2 + + class TestDjangoRedisCache: def test_setnx(self, cache: RedisCache): # we should ensure there is no test_key_nx in redis @@ -650,6 +656,16 @@ def test_expire_at(self, cache: RedisCache): expiration_time = datetime.datetime.now() + timedelta(hours=2) assert cache.expire_at("not-existent-key", expiration_time) is False + def test_intenum(self, cache: RedisCache): + + cache.set("hello", Values2.SOMETHING_1, enforce_encoding=True) + + value = cache.get("hello") + + assert value == Values2.SOMETHING_1 + + assert isinstance(value, Values2) + def test_lock(self, cache: RedisCache): lock = cache.lock("foobar") lock.acquire(blocking=True)