Skip to content

Commit e8276df

Browse files
committed
feat: add sentinel support
1 parent 3ad7661 commit e8276df

File tree

6 files changed

+272
-74
lines changed

6 files changed

+272
-74
lines changed

flask_redis/__init__.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,16 @@
11
from .client import FlaskRedis
22

3-
4-
__version__ = "0.5.0.dev0"
3+
__version__ = "0.6.0"
54

65
__title__ = "flask-redis"
7-
__description__ = "A nice way to use Redis in your Flask app"
8-
__url__ = "https://github.com/underyx/flask-redis/"
6+
__description__ = "A nice way to use Redis in your Flask app with sentinel support"
7+
__url__ = "https://github.com/cyrinux/flask-redis/"
98
__uri__ = __url__
109

11-
__author__ = "Bence Nagy"
12-
__email__ = "[email protected]"
10+
__author__ = "Cyrinux"
11+
__email__ = "[email protected]"
1312

14-
__license__ = "Blue Oak License"
15-
__copyright__ = "Copyright (c) 2019 Bence Nagy"
13+
__license__ = "Blue Oak Model License"
14+
__copyright__ = "Copyright (c) 2024"
1615

1716
__all__ = [FlaskRedis]

flask_redis/client.py

Lines changed: 170 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,206 @@
1+
import ssl
2+
from urllib.parse import parse_qs, unquote, urlparse
3+
14
try:
25
import redis
6+
from redis.sentinel import Sentinel
37
except ImportError:
4-
# We can still allow custom provider-only usage without redis-py being installed
8+
# Allow usage without redis-py being installed
59
redis = None
10+
Sentinel = None
611

712

813
class FlaskRedis(object):
9-
def __init__(self, app=None, strict=True, config_prefix="REDIS", **kwargs):
14+
def __init__(
15+
self,
16+
app=None,
17+
strict=True,
18+
config_prefix="REDIS",
19+
decode_responses=True,
20+
**kwargs,
21+
):
1022
self._redis_client = None
1123
self.provider_class = redis.StrictRedis if strict else redis.Redis
12-
self.provider_kwargs = kwargs
1324
self.config_prefix = config_prefix
25+
self.decode_responses = decode_responses
26+
self.provider_kwargs = kwargs
1427

1528
if app is not None:
1629
self.init_app(app)
1730

1831
@classmethod
1932
def from_custom_provider(cls, provider, app=None, **kwargs):
20-
assert provider is not None, "your custom provider is None, come on"
33+
assert provider is not None, "Your custom provider is None."
2134

22-
# We never pass the app parameter here, so we can call init_app
23-
# ourselves later, after the provider class has been set
2435
instance = cls(**kwargs)
25-
2636
instance.provider_class = provider
2737
if app is not None:
2838
instance.init_app(app)
2939
return instance
3040

3141
def init_app(self, app, **kwargs):
3242
redis_url = app.config.get(
33-
"{0}_URL".format(self.config_prefix), "redis://localhost:6379/0"
43+
f"{self.config_prefix}_URL", "redis://localhost:6379/0"
3444
)
3545

3646
self.provider_kwargs.update(kwargs)
37-
self._redis_client = self.provider_class.from_url(
38-
redis_url, **self.provider_kwargs
39-
)
47+
48+
parsed_url = urlparse(redis_url)
49+
scheme = parsed_url.scheme
50+
51+
if scheme in ["redis+sentinel", "rediss+sentinel"]:
52+
if Sentinel is None:
53+
raise ImportError("redis-py must be installed to use Redis Sentinel.")
54+
self._init_sentinel_client(parsed_url)
55+
else:
56+
self._init_standard_client(redis_url)
4057

4158
if not hasattr(app, "extensions"):
4259
app.extensions = {}
4360
app.extensions[self.config_prefix.lower()] = self
4461

62+
def _init_standard_client(self, redis_url):
63+
self._redis_client = self.provider_class.from_url(
64+
redis_url, decode_responses=self.decode_responses, **self.provider_kwargs
65+
)
66+
67+
def _init_sentinel_client(self, parsed_url):
68+
sentinel_kwargs, client_kwargs = self._parse_sentinel_parameters(parsed_url)
69+
70+
sentinel = Sentinel(
71+
sentinel_kwargs["hosts"],
72+
socket_timeout=sentinel_kwargs["socket_timeout"],
73+
**sentinel_kwargs["ssl_params"],
74+
**sentinel_kwargs["auth_params"],
75+
**self.provider_kwargs,
76+
)
77+
78+
self._redis_client = sentinel.master_for(
79+
sentinel_kwargs["master_name"],
80+
db=client_kwargs["db"],
81+
socket_timeout=client_kwargs["socket_timeout"],
82+
decode_responses=self.decode_responses,
83+
**client_kwargs["ssl_params"],
84+
**client_kwargs["auth_params"],
85+
**self.provider_kwargs,
86+
)
87+
88+
def _parse_sentinel_parameters(self, parsed_url):
89+
username, password = self._extract_credentials(parsed_url)
90+
hosts = self._parse_hosts(parsed_url)
91+
master_name, db = self._parse_master_and_db(parsed_url)
92+
query_params = parse_qs(parsed_url.query)
93+
94+
socket_timeout = self._parse_socket_timeout(query_params)
95+
ssl_enabled = self._parse_ssl_enabled(parsed_url.scheme, query_params)
96+
ssl_params = self._parse_ssl_params(query_params, ssl_enabled)
97+
auth_params = self._parse_auth_params(username, password)
98+
99+
sentinel_kwargs = {
100+
"hosts": hosts,
101+
"socket_timeout": socket_timeout,
102+
"ssl_params": ssl_params,
103+
"auth_params": auth_params,
104+
"master_name": master_name,
105+
}
106+
107+
client_kwargs = {
108+
"db": db,
109+
"socket_timeout": socket_timeout,
110+
"ssl_params": ssl_params,
111+
"auth_params": auth_params,
112+
}
113+
114+
return sentinel_kwargs, client_kwargs
115+
116+
def _extract_credentials(self, parsed_url):
117+
username = parsed_url.username
118+
password = parsed_url.password
119+
return username, password
120+
121+
def _parse_hosts(self, parsed_url):
122+
netloc = parsed_url.netloc
123+
if "@" in netloc:
124+
hosts_part = netloc.split("@", 1)[1]
125+
else:
126+
hosts_part = netloc
127+
128+
hosts = []
129+
for host_port in hosts_part.split(","):
130+
if ":" in host_port:
131+
host, port = host_port.split(":", 1)
132+
port = int(port)
133+
else:
134+
host = host_port
135+
port = 26379 # Default Sentinel port
136+
hosts.append((host, port))
137+
return hosts
138+
139+
def _parse_master_and_db(self, parsed_url):
140+
path = parsed_url.path.lstrip("/")
141+
if "/" in path:
142+
master_name, db_part = path.split("/", 1)
143+
db = int(db_part)
144+
else:
145+
master_name = path
146+
db = 0 # Default DB
147+
return master_name, db
148+
149+
def _parse_socket_timeout(self, query_params):
150+
socket_timeout = query_params.get("socket_timeout", [None])[0]
151+
if socket_timeout is not None:
152+
return float(socket_timeout)
153+
return None
154+
155+
def _parse_ssl_enabled(self, scheme, query_params):
156+
if scheme == "rediss+sentinel":
157+
return True
158+
ssl_param = query_params.get("ssl", ["False"])[0].lower()
159+
return ssl_param == "true"
160+
161+
def _parse_ssl_params(self, query_params, ssl_enabled):
162+
ssl_params = {}
163+
if ssl_enabled:
164+
ssl_cert_reqs = self._parse_ssl_cert_reqs(query_params)
165+
ssl_keyfile = query_params.get("ssl_keyfile", [None])[0]
166+
ssl_certfile = query_params.get("ssl_certfile", [None])[0]
167+
ssl_ca_certs = query_params.get("ssl_ca_certs", [None])[0]
168+
169+
ssl_params = {"ssl": True}
170+
if ssl_cert_reqs is not None:
171+
ssl_params["ssl_cert_reqs"] = ssl_cert_reqs
172+
if ssl_keyfile:
173+
ssl_params["ssl_keyfile"] = ssl_keyfile
174+
if ssl_certfile:
175+
ssl_params["ssl_certfile"] = ssl_certfile
176+
if ssl_ca_certs:
177+
ssl_params["ssl_ca_certs"] = ssl_ca_certs
178+
return ssl_params
179+
180+
def _parse_ssl_cert_reqs(self, query_params):
181+
ssl_cert_reqs = query_params.get("ssl_cert_reqs", [None])[0]
182+
if ssl_cert_reqs:
183+
ssl_cert_reqs = ssl_cert_reqs.lower()
184+
return {
185+
"required": ssl.CERT_REQUIRED,
186+
"optional": ssl.CERT_OPTIONAL,
187+
"none": ssl.CERT_NONE,
188+
}.get(ssl_cert_reqs)
189+
return None
190+
191+
def _parse_auth_params(self, username, password):
192+
auth_params = {}
193+
if username:
194+
auth_params["username"] = username
195+
if password:
196+
auth_params["password"] = password
197+
return auth_params
198+
199+
def hmset(self, name, mapping):
200+
# Implement hmset for compatibility
201+
# Use hset with mapping parameter
202+
return self._redis_client.hset(name, mapping=mapping)
203+
45204
def __getattr__(self, name):
46205
return getattr(self._redis_client, name)
47206

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
Flask>=0.9
2-
redis>=2.6.2
2+
redis>=5.0.0

0 commit comments

Comments
 (0)