77from starlette .websockets import WebSocket
88
99from fastapi_limiter import FastAPILimiter
10+ from fastapi_limiter .utils import RateLimitType
1011
1112
1213class RateLimiter :
@@ -19,16 +20,30 @@ def __init__(
1920 hours : Annotated [int , Field (ge = - 1 )] = 0 ,
2021 identifier : Optional [Callable ] = None ,
2122 callback : Optional [Callable ] = None ,
23+ rate_limit_type : RateLimitType = RateLimitType .FIXED_WINDOW
2224 ):
2325 self .times = times
2426 self .milliseconds = milliseconds + 1000 * seconds + 60000 * minutes + 3600000 * hours
2527 self .identifier = identifier
2628 self .callback = callback
29+ self .rate_limit_type = rate_limit_type
2730
28- async def _check (self , key ):
31+ def _get_lua_sha (self , specific_lua_sha = None ):
32+ if specific_lua_sha :
33+ return specific_lua_sha
34+ elif self .rate_limit_type is RateLimitType .SLIDING_WINDOW :
35+ return FastAPILimiter .lua_sha_sliding_window
36+ return FastAPILimiter .lua_sha_fix_window
37+
38+
39+ async def _check (self , key , specific_lua_sha = None ):
2940 redis = FastAPILimiter .redis
3041 pexpire = await redis .evalsha (
31- FastAPILimiter .lua_sha , 1 , key , str (self .times ), str (self .milliseconds )
42+ self ._get_lua_sha (specific_lua_sha ),
43+ 1 ,
44+ key ,
45+ str (self .times ),
46+ str (self .milliseconds )
3247 )
3348 return pexpire
3449
@@ -53,10 +68,7 @@ async def __call__(self, request: Request, response: Response):
5368 try :
5469 pexpire = await self ._check (key )
5570 except pyredis .exceptions .NoScriptError :
56- FastAPILimiter .lua_sha = await FastAPILimiter .redis .script_load (
57- FastAPILimiter .lua_script
58- )
59- pexpire = await self ._check (key )
71+ pexpire = await self ._check (key , specific_lua_sha = FastAPILimiter .lua_sha_fix_window )
6072 if pexpire != 0 :
6173 return await callback (request , response , pexpire )
6274
0 commit comments