From 2db33e83c2c028f85d7aa4ef65f7c0ae87cf6658 Mon Sep 17 00:00:00 2001 From: Oliver Sauder Date: Tue, 9 Sep 2025 15:48:34 +0400 Subject: [PATCH] fix: Ensured transactional safety when writing to redis --- scheduler/redis_models/base.py | 66 +++++++++++++++------------------- 1 file changed, 29 insertions(+), 37 deletions(-) diff --git a/scheduler/redis_models/base.py b/scheduler/redis_models/base.py index eb088b8..2a4c4bb 100644 --- a/scheduler/redis_models/base.py +++ b/scheduler/redis_models/base.py @@ -115,20 +115,9 @@ def deserialize(cls, data: Dict[str, Any]) -> Self: class HashModel(BaseModel): created_at: Optional[datetime] = None parent: Optional[str] = None - _dirty_fields: Set[str] = dataclasses.field(default_factory=set) # fields that were changed - _save_all: bool = True # Save all fields to broker, after init, or after delete _list_key: ClassVar[str] = ":list_all:" _children_key_template: ClassVar[str] = ":children:{}:" - def __post_init__(self): - self._dirty_fields = set() - self._save_all = True - - def __setattr__(self, key, value): - if key != "_dirty_fields" and hasattr(self, "_dirty_fields"): - self._dirty_fields.add(key) - super(HashModel, self).__setattr__(key, value) - @property def _parent_key(self) -> Optional[str]: if self.parent is None: @@ -155,8 +144,10 @@ def exists(cls, name: str, connection: ConnectionType) -> bool: @classmethod def delete_many(cls, names: List[str], connection: ConnectionType) -> None: - for name in names: - connection.delete(cls._element_key_template.format(name)) + with connection.pipeline() as pipeline: + for name in names: + pipeline.delete(cls._element_key_template.format(name)) + pipeline.execute() @classmethod def get(cls, name: str, connection: ConnectionType) -> Optional[Self]: @@ -171,34 +162,35 @@ def get(cls, name: str, connection: ConnectionType) -> Optional[Self]: @classmethod def get_many(cls, names: Sequence[str], connection: ConnectionType) -> List[Optional[Self]]: - pipeline = connection.pipeline() - for name in names: - pipeline.hgetall(cls._element_key_template.format(name)) - values = pipeline.execute() - return [(cls.deserialize(decode_dict(v, set())) if v else None) for v in values] + with connection.pipeline() as pipeline: + for name in names: + pipeline.hgetall(cls._element_key_template.format(name)) + values = pipeline.execute() + return [(cls.deserialize(decode_dict(v, set())) if v else None) for v in values] def save(self, connection: ConnectionType) -> None: - connection.sadd(self._list_key, self.name) - if self._parent_key is not None: - connection.sadd(self._parent_key, self.name) - mapping = self.serialize(with_nones=True) - if not self._save_all and len(self._dirty_fields) > 0: - mapping = {k: v for k, v in mapping.items() if k in self._dirty_fields} - none_values = {k for k, v in mapping.items() if v is None} - if none_values: - connection.hdel(self._key, *none_values) - mapping = {k: v for k, v in mapping.items() if v is not None} - if mapping: - connection.hset(self._key, mapping=mapping) - self._dirty_fields = set() - self._save_all = False + with connection.pipeline() as pipeline: + pipeline.sadd(self._list_key, self.name) + if self._parent_key is not None: + pipeline.sadd(self._parent_key, self.name) + mapping = self.serialize(with_nones=True) + none_values = {k for k, v in mapping.items() if v is None} + if none_values: + pipeline.hdel(self._key, *none_values) + mapping = {k: v for k, v in mapping.items() if v is not None} + if mapping: + pipeline.hset(self._key, mapping=mapping) + + pipeline.execute() def delete(self, connection: ConnectionType) -> None: - connection.srem(self._list_key, self._key) - if self._parent_key is not None: - connection.srem(self._parent_key, 0, self._key) - connection.delete(self._key) - self._save_all = True + with connection.pipeline() as pipeline: + pipeline.srem(self._list_key, self._key) + if self._parent_key is not None: + pipeline.srem(self._parent_key, 0, self._key) + pipeline.delete(self._key) + + pipeline.execute() @classmethod def count(cls, connection: ConnectionType, parent: Optional[str] = None) -> int: