Skip to content

Commit aabe3c3

Browse files
committed
Bump version to 0.1.7 and add sliding window rate limit
1 parent 8d179c0 commit aabe3c3

File tree

7 files changed

+120
-25
lines changed

7 files changed

+120
-25
lines changed

README.md

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ async def websocket_endpoint(websocket: WebSocket):
143143

144144
The lua script used.
145145

146+
### fixed window
146147
```lua
147148
local key = KEYS[1]
148149
local limit = tonumber(ARGV[1])
@@ -162,6 +163,27 @@ else
162163
end
163164
```
164165

166+
### sliding window
167+
```lua
168+
local key = KEYS[1]
169+
local limit = tonumber(ARGV[1])
170+
local expire_time = tonumber(ARGV[2])
171+
local current_time = redis.call('TIME')[1]
172+
local start_time = current_time - expire_time / 1000
173+
174+
redis.call('ZREMRANGEBYSCORE', key, 0, start_time)
175+
176+
local current = redis.call('ZCARD', key)
177+
178+
if current >= limit then
179+
return redis.call("PTTL",key)
180+
else
181+
redis.call("ZADD", key, current_time, current_time)
182+
redis.call('PEXPIRE', key, expire_time)
183+
return 0
184+
end
185+
```
186+
165187
## License
166188

167189
This project is licensed under the

examples/main.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from fastapi_limiter import FastAPILimiter
88
from fastapi_limiter.depends import RateLimiter, WebSocketRateLimiter
9-
9+
from fastapi_limiter.utils import RateLimitType
1010

1111
@asynccontextmanager
1212
async def lifespan(_: FastAPI):
@@ -52,6 +52,15 @@ async def websocket_endpoint(websocket: WebSocket):
5252
except HTTPException:
5353
await websocket.send_text("Hello again")
5454

55+
@app.get(
56+
"/test_sliding_window",
57+
dependencies=[
58+
Depends(RateLimiter(times=2, seconds=5, rate_limit_type=RateLimitType.SLIDING_WINDOW))
59+
],
60+
)
61+
async def test_sliding_window():
62+
return {"msg": "Hello World"}
63+
5564

5665
if __name__ == "__main__":
5766
uvicorn.run("main:app", debug=True, reload=True)

fastapi_limiter/__init__.py

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from starlette.responses import Response
77
from starlette.status import HTTP_429_TOO_MANY_REQUESTS
88
from starlette.websockets import WebSocket
9+
from fastapi_limiter.utils import LuaScript
910

1011

1112
async def default_identifier(request: Union[Request, WebSocket]):
@@ -51,22 +52,8 @@ class FastAPILimiter:
5152
identifier: Optional[Callable] = None
5253
http_callback: Optional[Callable] = None
5354
ws_callback: Optional[Callable] = None
54-
lua_script = """local key = KEYS[1]
55-
local limit = tonumber(ARGV[1])
56-
local expire_time = ARGV[2]
57-
58-
local current = tonumber(redis.call('get', key) or "0")
59-
if current > 0 then
60-
if current + 1 > limit then
61-
return redis.call("PTTL",key)
62-
else
63-
redis.call("INCR", key)
64-
return 0
65-
end
66-
else
67-
redis.call("SET", key, 1,"px",expire_time)
68-
return 0
69-
end"""
55+
lua_sha_fix_window: Optional[str] = None
56+
lua_sha_sliding_window: Optional[str] = None
7057

7158
@classmethod
7259
async def init(
@@ -82,7 +69,8 @@ async def init(
8269
cls.identifier = identifier
8370
cls.http_callback = http_callback
8471
cls.ws_callback = ws_callback
85-
cls.lua_sha = await redis.script_load(cls.lua_script)
72+
cls.lua_sha_fix_window = await redis.script_load(LuaScript.FIXED_WINDOW_LIMIT_SCRIPT.value)
73+
cls.lua_sha_sliding_window = await redis.script_load(LuaScript.SLIDING_WINDOW_LIMIT_SCRIPT.value)
8674

8775
@classmethod
8876
async def close(cls) -> None:

fastapi_limiter/depends.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from starlette.websockets import WebSocket
88

99
from fastapi_limiter import FastAPILimiter
10+
from fastapi_limiter.utils import RateLimitType
1011

1112

1213
class RateLimiter:
@@ -19,16 +20,30 @@ def __init__(
1920
hours: Annotated[int, Field(ge=-1)] = 0,
2021
identifier: Optional[Callable] = None,
2122
callback: Optional[Callable] = None,
23+
rate_limit_type: RateLimitType = RateLimitType.FIXED_WINDOW
2224
):
2325
self.times = times
2426
self.milliseconds = milliseconds + 1000 * seconds + 60000 * minutes + 3600000 * hours
2527
self.identifier = identifier
2628
self.callback = callback
29+
self.rate_limit_type = rate_limit_type
2730

28-
async def _check(self, key):
31+
def _get_lua_sha(self, specific_lua_sha=None):
32+
if specific_lua_sha:
33+
return specific_lua_sha
34+
elif self.rate_limit_type is RateLimitType.SLIDING_WINDOW:
35+
return FastAPILimiter.lua_sha_sliding_window
36+
return FastAPILimiter.lua_sha_fix_window
37+
38+
39+
async def _check(self, key, specific_lua_sha=None):
2940
redis = FastAPILimiter.redis
3041
pexpire = await redis.evalsha(
31-
FastAPILimiter.lua_sha, 1, key, str(self.times), str(self.milliseconds)
42+
self._get_lua_sha(specific_lua_sha),
43+
1,
44+
key,
45+
str(self.times),
46+
str(self.milliseconds)
3247
)
3348
return pexpire
3449

@@ -53,10 +68,7 @@ async def __call__(self, request: Request, response: Response):
5368
try:
5469
pexpire = await self._check(key)
5570
except pyredis.exceptions.NoScriptError:
56-
FastAPILimiter.lua_sha = await FastAPILimiter.redis.script_load(
57-
FastAPILimiter.lua_script
58-
)
59-
pexpire = await self._check(key)
71+
pexpire = await self._check(key, specific_lua_sha=FastAPILimiter.lua_sha_fix_window)
6072
if pexpire != 0:
6173
return await callback(request, response, pexpire)
6274

fastapi_limiter/utils.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
from enum import Enum
2+
3+
4+
class RateLimitType(Enum):
5+
FIXED_WINDOW = "fixed_window"
6+
SLIDING_WINDOW = "sliding_window"
7+
8+
9+
class LuaScript(Enum):
10+
FIXED_WINDOW_LIMIT_SCRIPT = """
11+
local key = KEYS[1]
12+
local limit = tonumber(ARGV[1])
13+
local expire_time = ARGV[2]
14+
15+
local current = tonumber(redis.call('get', key) or "0")
16+
17+
if current > 0 then
18+
if current + 1 > limit then
19+
return redis.call("PTTL",key)
20+
else
21+
redis.call("INCR", key)
22+
return 0
23+
end
24+
else
25+
redis.call("SET", key, 1, "px", expire_time)
26+
return 0
27+
end
28+
"""
29+
SLIDING_WINDOW_LIMIT_SCRIPT = """
30+
local key = KEYS[1]
31+
local limit = tonumber(ARGV[1])
32+
local expire_time = tonumber(ARGV[2])
33+
local current_time = redis.call('TIME')[1]
34+
local start_time = current_time - expire_time / 1000
35+
36+
redis.call('ZREMRANGEBYSCORE', key, 0, start_time)
37+
38+
local current = redis.call('ZCARD', key)
39+
40+
if current >= limit then
41+
return redis.call("PTTL",key)
42+
else
43+
redis.call("ZADD", key, current_time, current_time)
44+
redis.call('PEXPIRE', key, expire_time)
45+
return 0
46+
end
47+
"""

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ packages = [
1212
]
1313
readme = "README.md"
1414
repository = "https://github.com/long2ice/fastapi-limiter.git"
15-
version = "0.1.6"
15+
version = "0.1.7"
1616

1717
[tool.poetry.dependencies]
1818
redis = ">=4.2.0rc1"

tests/test_depends.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,3 +65,20 @@ def test_limiter_websockets():
6565
data = ws.receive_text()
6666
assert data == "Hello, world"
6767
ws.close()
68+
69+
70+
def test_limiter_sliding_window():
71+
with TestClient(app) as client:
72+
def req(sleep_times, assert_code):
73+
nonlocal client
74+
response = client.get("/test_sliding_window")
75+
assert response.status_code == assert_code
76+
sleep(sleep_times)
77+
78+
req(4, 200) # 0s
79+
req(1, 200) # 4s
80+
req(1, 200) # 5s
81+
req(1, 429) # 6s
82+
req(1, 429) # 7s
83+
req(1, 429) # 8s
84+
req(1, 200) # 9s

0 commit comments

Comments
 (0)