Skip to content

INTPYTHON-676: Adding security and optimization to cache collections #343

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 73 additions & 17 deletions django_mongodb_backend/cache.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,61 @@
import pickle
from datetime import datetime, timezone
from hashlib import blake2b
from typing import Any, Optional, Tuple

from django.core.cache.backends.base import DEFAULT_TIMEOUT, BaseCache
from django.core.cache.backends.db import Options
from django.core.exceptions import SuspiciousOperation
from django.db import connections, router
from django.utils.functional import cached_property
from pymongo import ASCENDING, DESCENDING, IndexModel, ReturnDocument
from pymongo.errors import DuplicateKeyError, OperationFailure
from django.conf import settings


class MongoSerializer:
def __init__(self, protocol=None):
def __init__(self, protocol=None, signer=None):
self.protocol = pickle.HIGHEST_PROTOCOL if protocol is None else protocol
self.signer = signer

def dumps(self, obj):
# For better incr() and decr() atomicity, don't pickle integers.
# Using type() rather than isinstance() matches only integers and not
# subclasses like bool.
if type(obj) is int: # noqa: E721
return obj
return pickle.dumps(obj, self.protocol)
def _get_signature(self, data) -> Optional[bytes]:
if self.signer is None:
return None
s = self.signer.copy()
s.update(data)
return s.digest()

def loads(self, data):
try:
return int(data)
except (ValueError, TypeError):
return pickle.loads(data) # noqa: S301
def _get_pickled(self, obj: Any) -> bytes:
return pickle.dumps(obj, protocol=self.protocol) # noqa: S301

def dumps(self, obj) -> Tuple[Any, bool, Optional[str]]:
# Serialize the object to a format suitable for MongoDB storage.
# The return value is a tuple of (data, pickled, signature).
match obj:
case int() | str() | bytes():
return (obj, False, None)
case _:
pickled_data = self._get_pickled(obj)
return (pickled_data, True, self._get_signature(pickled_data) if self.signer else None)

def loads(self, data:Any, pickled:bool, signature=None) -> Any:
if pickled:
try:
if self.signer is not None:
# constant time compare is not required due to how data is retrieved
if signature and (signature == self._get_signature(data)):
return pickle.loads(data) # noqa: S301
else:
raise SuspiciousOperation(f"Pickeled cache data is missing signature")
else:
return pickle.loads(data)
except (ValueError, TypeError):
# ValueError: Invalid signature
# TypeError: Data wasn't a byte string
raise SuspiciousOperation(f'Invalid pickle signature: {{"signature": {signature}, "data":{data}}}')
else:
return data

class MongoDBCache(BaseCache):
pickle_protocol = pickle.HIGHEST_PROTOCOL

Expand All @@ -39,6 +67,17 @@ class CacheEntry:
_meta = Options(collection_name)

self.cache_model_class = CacheEntry
self._sign_cache = params.get("ENABLE_SIGNING", True)

self._key = params.get("KEY", settings.SECRET_KEY[:64])
if len(self._key) == 0:
self._key = settings.SECRET_KEY[:64]
if isinstance(self._key, str):
self._key = self._key.encode()

self._salt = params.get("SALT", "")
if isinstance(self._salt, str):
self._salt = self._salt.encode()

def create_indexes(self):
expires_index = IndexModel("expires_at", expireAfterSeconds=0)
Expand All @@ -47,7 +86,10 @@ def create_indexes(self):

@cached_property
def serializer(self):
return MongoSerializer(self.pickle_protocol)
signer = None
if self._sign_cache:
signer = blake2b(key=self._key[:64], salt=self._salt[:16], person=self._collection_name[:16].encode())
return MongoSerializer(self.pickle_protocol, signer)

@property
def collection_for_read(self):
Expand Down Expand Up @@ -84,19 +126,30 @@ def get_many(self, keys, version=None):
with self.collection_for_read.find(
{"key": {"$in": tuple(keys_map)}, **self._filter_expired(expired=False)}
) as cursor:
return {keys_map[row["key"]]: self.serializer.loads(row["value"]) for row in cursor}
results = {}
for row in cursor:
try:
results[keys_map[row["key"]]] = self.serializer.loads(row["value"], row["pickled"], row["signature"])
except SuspiciousOperation as e:
self.delete(row["key"])
e.add_note(f"Cache entry with key '{row['key']}' was deleted due to suspicious data")
raise e
return results

def set(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
key = self.make_and_validate_key(key, version=version)
num = self.collection_for_write.count_documents({}, hint="_id_")
if num >= self._max_entries:
self._cull(num)
value, pickled, signature = self.serializer.dumps(value)
self.collection_for_write.update_one(
{"key": key},
{
"$set": {
"key": key,
"value": self.serializer.dumps(value),
"value": value,
"pickled": pickled,
"signature": signature,
"expires_at": self.get_backend_timeout(timeout),
}
},
Expand All @@ -109,12 +162,15 @@ def add(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
if num >= self._max_entries:
self._cull(num)
try:
value, pickled, signature = self.serializer.dumps(value)
self.collection_for_write.update_one(
{"key": key, **self._filter_expired(expired=True)},
{
"$set": {
"key": key,
"value": self.serializer.dumps(value),
"value": value,
"pickled": pickled,
"signature": signature,
"expires_at": self.get_backend_timeout(timeout),
}
},
Expand Down
19 changes: 19 additions & 0 deletions docs/source/topics/cache.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,25 @@ In addition, the cache is culled based on ``CULL_FREQUENCY`` when ``add()``
or ``set()`` is called, if ``MAX_ENTRIES`` is exceeded. See
:ref:`django:cache_arguments` for an explanation of these two options.

Cache entries include a HMAC signature to ensure data integrity by default.
You can disable this by setting ``ENABLE_SIGNING`` to ``False``.
Signatures can also include an optional key and salt parameter by setting
``KEY`` and ``SALT`` repectively. Signatures are provided by the Blake2 hash
function, making Key sizes limited to 64 bytes, and salt sizes limited to 16
bytes. If a key is not provided, cache entries will be signed using the
``SECRET_KEY``.

In this example, the cache collection is configured with a key and salt::

CACHES = {
"default": {
"BACKEND": "django_mongodb_backend.cache.MongoDBCache",
"LOCATION": "my_cache_collection",
"KEY": "my_secret_key",
"SALT": "my_salt",
},
}

Creating the cache collection
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
58 changes: 55 additions & 3 deletions tests/cache_/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ def f():
class C:
def m(n):
return 24
def __eq__(self, other):
return isinstance(other, type(self))


class Unpicklable:
Expand Down Expand Up @@ -97,6 +99,7 @@ def caches_setting_for_tests(base=None, exclude=None, **params):
BACKEND="django_mongodb_backend.cache.MongoDBCache",
# Spaces are used in the name to ensure quoting/escaping works.
LOCATION="test cache collection",
ENABLE_SIGNING=False,
),
)
@modify_settings(
Expand Down Expand Up @@ -950,10 +953,59 @@ def test_collection_has_indexes(self):
)

def test_serializer_dumps(self):
self.assertEqual(cache.serializer.dumps(123), 123)
self.assertIsInstance(cache.serializer.dumps(True), bytes)
self.assertIsInstance(cache.serializer.dumps("abc"), bytes)
self.assertTupleEqual(cache.serializer.dumps(123), (123, False, None))
self.assertTupleEqual(cache.serializer.dumps(True), (True, False, None))
self.assertTupleEqual(cache.serializer.dumps("abc"), ("abc", False, None))
self.assertTupleEqual(cache.serializer.dumps(b"abc"), (b"abc", False, None))

c = C()
pickled_c = pickle.dumps(c, protocol=pickle.HIGHEST_PROTOCOL)
self.assertTupleEqual(cache.serializer.dumps(c), (pickled_c, True, None))

def test_serializer_loads(self):
self.assertEqual(cache.serializer.loads(123, False, None), 123)
self.assertEqual(cache.serializer.loads(True, False, None), True)
self.assertEqual(cache.serializer.loads("abc", False, None), "abc")
self.assertEqual(cache.serializer.loads(b"abc", False, None), b"abc")

c = C()
pickled_c = pickle.dumps(c, protocol=pickle.HIGHEST_PROTOCOL)
self.assertEqual(cache.serializer.loads(pickled_c, True, None), c)



@override_settings(
CACHES=caches_setting_for_tests(
BACKEND="django_mongodb_backend.cache.MongoDBCache",
# Spaces are used in the name to ensure quoting/escaping works.
LOCATION="test cache collection",
ENABLE_SIGNING=True,
SALT="test-salt",
),
)
class SignedCacheTests(CacheTests):
def test_serializer_dumps(self):
self.assertTupleEqual(cache.serializer.dumps(123), (123, False, None))
self.assertTupleEqual(cache.serializer.dumps(True), (True, False, None))
self.assertTupleEqual(cache.serializer.dumps("abc"), ("abc", False, None))
self.assertTupleEqual(cache.serializer.dumps(b"abc"), (b"abc", False, None))

c = C()
pickled_c = pickle.dumps(c, protocol=pickle.HIGHEST_PROTOCOL)
self.assertTupleEqual(cache.serializer.dumps(c), (pickled_c, True, cache.serializer._get_signature(pickled_c)))

def test_serializer_loads(self):
self.assertEqual(cache.serializer.loads(123, False, None), 123)
self.assertEqual(cache.serializer.loads(True, False, None), True)
self.assertEqual(cache.serializer.loads("abc", False, None), "abc")
self.assertEqual(cache.serializer.loads(b"abc", False, None), b"abc")

c = C()
pickled_c = pickle.dumps(c, protocol=pickle.HIGHEST_PROTOCOL)
self.assertEqual(cache.serializer.loads(pickled_c, True, cache.serializer._get_signature(pickled_c)), c)

with self.assertRaises(Exception):
cache.serializer.loads(pickled_c, True, "invalid-signature")

class DBCacheRouter:
"""A router that puts the cache table on the 'other' database."""
Expand Down