diff --git a/django_mongodb_backend/cache.py b/django_mongodb_backend/cache.py index 00b903afe..7c1ed8e2e 100644 --- a/django_mongodb_backend/cache.py +++ b/django_mongodb_backend/cache.py @@ -1,33 +1,61 @@ import pickle from datetime import datetime, timezone +from hashlib import blake2b +from typing import Any, Optional, Tuple from django.core.cache.backends.base import DEFAULT_TIMEOUT, BaseCache from django.core.cache.backends.db import Options +from django.core.exceptions import SuspiciousOperation from django.db import connections, router from django.utils.functional import cached_property from pymongo import ASCENDING, DESCENDING, IndexModel, ReturnDocument from pymongo.errors import DuplicateKeyError, OperationFailure +from django.conf import settings class MongoSerializer: - def __init__(self, protocol=None): + def __init__(self, protocol=None, signer=None): self.protocol = pickle.HIGHEST_PROTOCOL if protocol is None else protocol + self.signer = signer - def dumps(self, obj): - # For better incr() and decr() atomicity, don't pickle integers. - # Using type() rather than isinstance() matches only integers and not - # subclasses like bool. - if type(obj) is int: # noqa: E721 - return obj - return pickle.dumps(obj, self.protocol) + def _get_signature(self, data) -> Optional[bytes]: + if self.signer is None: + return None + s = self.signer.copy() + s.update(data) + return s.digest() - def loads(self, data): - try: - return int(data) - except (ValueError, TypeError): - return pickle.loads(data) # noqa: S301 + def _get_pickled(self, obj: Any) -> bytes: + return pickle.dumps(obj, protocol=self.protocol) # noqa: S301 + def dumps(self, obj) -> Tuple[Any, bool, Optional[str]]: + # Serialize the object to a format suitable for MongoDB storage. + # The return value is a tuple of (data, pickled, signature). + match obj: + case int() | str() | bytes(): + return (obj, False, None) + case _: + pickled_data = self._get_pickled(obj) + return (pickled_data, True, self._get_signature(pickled_data) if self.signer else None) + def loads(self, data:Any, pickled:bool, signature=None) -> Any: + if pickled: + try: + if self.signer is not None: + # constant time compare is not required due to how data is retrieved + if signature and (signature == self._get_signature(data)): + return pickle.loads(data) # noqa: S301 + else: + raise SuspiciousOperation(f"Pickeled cache data is missing signature") + else: + return pickle.loads(data) + except (ValueError, TypeError): + # ValueError: Invalid signature + # TypeError: Data wasn't a byte string + raise SuspiciousOperation(f'Invalid pickle signature: {{"signature": {signature}, "data":{data}}}') + else: + return data + class MongoDBCache(BaseCache): pickle_protocol = pickle.HIGHEST_PROTOCOL @@ -39,6 +67,17 @@ class CacheEntry: _meta = Options(collection_name) self.cache_model_class = CacheEntry + self._sign_cache = params.get("ENABLE_SIGNING", True) + + self._key = params.get("KEY", settings.SECRET_KEY[:64]) + if len(self._key) == 0: + self._key = settings.SECRET_KEY[:64] + if isinstance(self._key, str): + self._key = self._key.encode() + + self._salt = params.get("SALT", "") + if isinstance(self._salt, str): + self._salt = self._salt.encode() def create_indexes(self): expires_index = IndexModel("expires_at", expireAfterSeconds=0) @@ -47,7 +86,10 @@ def create_indexes(self): @cached_property def serializer(self): - return MongoSerializer(self.pickle_protocol) + signer = None + if self._sign_cache: + signer = blake2b(key=self._key[:64], salt=self._salt[:16], person=self._collection_name[:16].encode()) + return MongoSerializer(self.pickle_protocol, signer) @property def collection_for_read(self): @@ -84,19 +126,30 @@ def get_many(self, keys, version=None): with self.collection_for_read.find( {"key": {"$in": tuple(keys_map)}, **self._filter_expired(expired=False)} ) as cursor: - return {keys_map[row["key"]]: self.serializer.loads(row["value"]) for row in cursor} + results = {} + for row in cursor: + try: + results[keys_map[row["key"]]] = self.serializer.loads(row["value"], row["pickled"], row["signature"]) + except SuspiciousOperation as e: + self.delete(row["key"]) + e.add_note(f"Cache entry with key '{row['key']}' was deleted due to suspicious data") + raise e + return results def set(self, key, value, timeout=DEFAULT_TIMEOUT, version=None): key = self.make_and_validate_key(key, version=version) num = self.collection_for_write.count_documents({}, hint="_id_") if num >= self._max_entries: self._cull(num) + value, pickled, signature = self.serializer.dumps(value) self.collection_for_write.update_one( {"key": key}, { "$set": { "key": key, - "value": self.serializer.dumps(value), + "value": value, + "pickled": pickled, + "signature": signature, "expires_at": self.get_backend_timeout(timeout), } }, @@ -109,12 +162,15 @@ def add(self, key, value, timeout=DEFAULT_TIMEOUT, version=None): if num >= self._max_entries: self._cull(num) try: + value, pickled, signature = self.serializer.dumps(value) self.collection_for_write.update_one( {"key": key, **self._filter_expired(expired=True)}, { "$set": { "key": key, - "value": self.serializer.dumps(value), + "value": value, + "pickled": pickled, + "signature": signature, "expires_at": self.get_backend_timeout(timeout), } }, diff --git a/docs/source/topics/cache.rst b/docs/source/topics/cache.rst index 881e1b78b..cdf54daff 100644 --- a/docs/source/topics/cache.rst +++ b/docs/source/topics/cache.rst @@ -32,6 +32,25 @@ In addition, the cache is culled based on ``CULL_FREQUENCY`` when ``add()`` or ``set()`` is called, if ``MAX_ENTRIES`` is exceeded. See :ref:`django:cache_arguments` for an explanation of these two options. +Cache entries include a HMAC signature to ensure data integrity by default. +You can disable this by setting ``ENABLE_SIGNING`` to ``False``. +Signatures can also include an optional key and salt parameter by setting +``KEY`` and ``SALT`` repectively. Signatures are provided by the Blake2 hash +function, making Key sizes limited to 64 bytes, and salt sizes limited to 16 +bytes. If a key is not provided, cache entries will be signed using the +``SECRET_KEY``. + +In this example, the cache collection is configured with a key and salt:: + + CACHES = { + "default": { + "BACKEND": "django_mongodb_backend.cache.MongoDBCache", + "LOCATION": "my_cache_collection", + "KEY": "my_secret_key", + "SALT": "my_salt", + }, + } + Creating the cache collection ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/tests/cache_/tests.py b/tests/cache_/tests.py index c28b549e5..ad382bb9f 100644 --- a/tests/cache_/tests.py +++ b/tests/cache_/tests.py @@ -28,6 +28,8 @@ def f(): class C: def m(n): return 24 + def __eq__(self, other): + return isinstance(other, type(self)) class Unpicklable: @@ -97,6 +99,7 @@ def caches_setting_for_tests(base=None, exclude=None, **params): BACKEND="django_mongodb_backend.cache.MongoDBCache", # Spaces are used in the name to ensure quoting/escaping works. LOCATION="test cache collection", + ENABLE_SIGNING=False, ), ) @modify_settings( @@ -950,10 +953,59 @@ def test_collection_has_indexes(self): ) def test_serializer_dumps(self): - self.assertEqual(cache.serializer.dumps(123), 123) - self.assertIsInstance(cache.serializer.dumps(True), bytes) - self.assertIsInstance(cache.serializer.dumps("abc"), bytes) + self.assertTupleEqual(cache.serializer.dumps(123), (123, False, None)) + self.assertTupleEqual(cache.serializer.dumps(True), (True, False, None)) + self.assertTupleEqual(cache.serializer.dumps("abc"), ("abc", False, None)) + self.assertTupleEqual(cache.serializer.dumps(b"abc"), (b"abc", False, None)) + c = C() + pickled_c = pickle.dumps(c, protocol=pickle.HIGHEST_PROTOCOL) + self.assertTupleEqual(cache.serializer.dumps(c), (pickled_c, True, None)) + + def test_serializer_loads(self): + self.assertEqual(cache.serializer.loads(123, False, None), 123) + self.assertEqual(cache.serializer.loads(True, False, None), True) + self.assertEqual(cache.serializer.loads("abc", False, None), "abc") + self.assertEqual(cache.serializer.loads(b"abc", False, None), b"abc") + + c = C() + pickled_c = pickle.dumps(c, protocol=pickle.HIGHEST_PROTOCOL) + self.assertEqual(cache.serializer.loads(pickled_c, True, None), c) + + + +@override_settings( + CACHES=caches_setting_for_tests( + BACKEND="django_mongodb_backend.cache.MongoDBCache", + # Spaces are used in the name to ensure quoting/escaping works. + LOCATION="test cache collection", + ENABLE_SIGNING=True, + SALT="test-salt", + ), +) +class SignedCacheTests(CacheTests): + def test_serializer_dumps(self): + self.assertTupleEqual(cache.serializer.dumps(123), (123, False, None)) + self.assertTupleEqual(cache.serializer.dumps(True), (True, False, None)) + self.assertTupleEqual(cache.serializer.dumps("abc"), ("abc", False, None)) + self.assertTupleEqual(cache.serializer.dumps(b"abc"), (b"abc", False, None)) + + c = C() + pickled_c = pickle.dumps(c, protocol=pickle.HIGHEST_PROTOCOL) + self.assertTupleEqual(cache.serializer.dumps(c), (pickled_c, True, cache.serializer._get_signature(pickled_c))) + + def test_serializer_loads(self): + self.assertEqual(cache.serializer.loads(123, False, None), 123) + self.assertEqual(cache.serializer.loads(True, False, None), True) + self.assertEqual(cache.serializer.loads("abc", False, None), "abc") + self.assertEqual(cache.serializer.loads(b"abc", False, None), b"abc") + + c = C() + pickled_c = pickle.dumps(c, protocol=pickle.HIGHEST_PROTOCOL) + self.assertEqual(cache.serializer.loads(pickled_c, True, cache.serializer._get_signature(pickled_c)), c) + + with self.assertRaises(Exception): + cache.serializer.loads(pickled_c, True, "invalid-signature") class DBCacheRouter: """A router that puts the cache table on the 'other' database."""