diff --git a/examples/main_disabled.py b/examples/main_disabled.py new file mode 100644 index 0000000..215e300 --- /dev/null +++ b/examples/main_disabled.py @@ -0,0 +1,36 @@ +import aioredis +import uvicorn +from fastapi import Depends, FastAPI + +from fastapi_limiter import FastAPILimiter +from fastapi_limiter.depends import RateLimiter + +app = FastAPI() + + +@app.on_event("startup") +async def startup(): + await FastAPILimiter.init(None, enabled=False) + +@app.on_event("shutdown") +async def shutdown(): + await FastAPILimiter.close() + + +@app.get("/", dependencies=[Depends(RateLimiter(times=2, seconds=5))]) +async def index(): + return {"msg": "Hello World"} + +@app.get( + "/multiple", + dependencies=[ + Depends(RateLimiter(times=1, seconds=5)), + Depends(RateLimiter(times=2, seconds=15)), + ], +) +async def multiple(): + return {"msg": "Hello World"} + + +if __name__ == "__main__": + uvicorn.run("main:app", debug=True, reload=True) diff --git a/fastapi_limiter/__init__.py b/fastapi_limiter/__init__.py index 89373a0..ca0d515 100644 --- a/fastapi_limiter/__init__.py +++ b/fastapi_limiter/__init__.py @@ -37,6 +37,7 @@ class FastAPILimiter: lua_sha: str = None identifier: Callable = None callback: Callable = None + enabled: bool = True lua_script = """local key = KEYS[1] local limit = tonumber(ARGV[1]) local expire_time = ARGV[2] @@ -61,14 +62,22 @@ async def init( prefix: str = "fastapi-limiter", identifier: Callable = default_identifier, callback: Callable = default_callback, + enabled: bool = True ): cls.redis = redis cls.prefix = prefix cls.identifier = identifier cls.callback = callback - cls.lua_sha = await redis.script_load(cls.lua_script) + cls.enabled = enabled + + if enabled: + cls.lua_sha = await redis.script_load(cls.lua_script) + else: + cls.lua_sha = None @classmethod async def close(cls): - cls.redis.close() - await cls.redis.wait_closed() + if cls.enabled: + cls.redis.close() + await cls.redis.wait_closed() + diff --git a/fastapi_limiter/depends.py b/fastapi_limiter/depends.py index ec20c57..65000c7 100644 --- a/fastapi_limiter/depends.py +++ b/fastapi_limiter/depends.py @@ -24,23 +24,24 @@ def __init__( self.callback = callback async def __call__(self, request: Request, response: Response): - if not FastAPILimiter.redis: - raise Exception("You must call FastAPILimiter.init in startup event of fastapi!") - index = 0 - for route in request.app.routes: - if route.path == request.scope["path"]: - for idx, dependency in enumerate(route.dependencies): - if self is dependency.dependency: - index = idx - break - # moved here because constructor run before app startup - identifier = self.identifier or FastAPILimiter.identifier - callback = self.callback or FastAPILimiter.callback - redis = FastAPILimiter.redis - rate_key = await identifier(request) - key = f"{FastAPILimiter.prefix}:{rate_key}:{index}" - pexpire = await redis.evalsha( - FastAPILimiter.lua_sha, keys=[key], args=[self.times, self.milliseconds] - ) - if pexpire != 0: - return await callback(request, response, pexpire) + if FastAPILimiter.enabled: + if not FastAPILimiter.redis: + raise Exception("You must call FastAPILimiter.init in startup event of fastapi!") + index = 0 + for route in request.app.routes: + if route.path == request.scope["path"]: + for idx, dependency in enumerate(route.dependencies): + if self is dependency.dependency: + index = idx + break + # moved here because constructor run before app startup + identifier = self.identifier or FastAPILimiter.identifier + callback = self.callback or FastAPILimiter.callback + redis = FastAPILimiter.redis + rate_key = await identifier(request) + key = f"{FastAPILimiter.prefix}:{rate_key}:{index}" + pexpire = await redis.evalsha( + FastAPILimiter.lua_sha, keys=[key], args=[self.times, self.milliseconds] + ) + if pexpire != 0: + return await callback(request, response, pexpire) diff --git a/tests/test_depends.py b/tests/test_depends.py index b37b02c..56e0c7e 100644 --- a/tests/test_depends.py +++ b/tests/test_depends.py @@ -3,6 +3,7 @@ from starlette.testclient import TestClient from examples.main import app +from examples.main_disabled import app as app_disabled def test_limiter(): @@ -19,6 +20,21 @@ def test_limiter(): response = client.get("/") assert response.status_code == 200 +def test_limiter_disabled(): + # Runs the same requests as test_limiter, but with RateLimiter disabled + with TestClient(app_disabled) as client: + response = client.get("/") + assert response.status_code == 200 + + client.get("/") + + response = client.get("/") + assert response.status_code == 200 + sleep(5) + + response = client.get("/") + assert response.status_code == 200 + def test_limiter_multiple(): with TestClient(app) as client: @@ -38,3 +54,23 @@ def test_limiter_multiple(): response = client.get("/multiple") assert response.status_code == 200 + +def test_limiter_multiple_disabled(): + # Runs the same requests as test_limiter_multiple, but with RateLimiter disabled + with TestClient(app_disabled) as client: + response = client.get("/multiple") + assert response.status_code == 200 + + response = client.get("/multiple") + assert response.status_code == 200 + sleep(5) + + response = client.get("/multiple") + assert response.status_code == 200 + + response = client.get("/multiple") + assert response.status_code == 200 + sleep(10) + + response = client.get("/multiple") + assert response.status_code == 200