diff --git a/examples/main.py b/examples/main.py index 3e0a2b0..6b0f532 100644 --- a/examples/main.py +++ b/examples/main.py @@ -54,4 +54,4 @@ async def websocket_endpoint(websocket: WebSocket): if __name__ == "__main__": - uvicorn.run("main:app", debug=True, reload=True) + uvicorn.run("main:app", reload=True) diff --git a/fastapi_limiter/depends.py b/fastapi_limiter/depends.py index 295df94..d5b855a 100644 --- a/fastapi_limiter/depends.py +++ b/fastapi_limiter/depends.py @@ -60,6 +60,14 @@ async def __call__(self, request: Request, response: Response): if pexpire != 0: return await callback(request, response, pexpire) + def __hash__(self) -> int: + return hash(f"limiter-{(self.times, self.milliseconds)}") + + def __eq__(self, other: object) -> bool: + if isinstance(other, RateLimiter): + return (self.times, self.milliseconds) == (other.times, other.milliseconds) + return False + class WebSocketRateLimiter(RateLimiter): async def __call__(self, ws: WebSocket, context_key=""): diff --git a/tests/test_depends.py b/tests/test_depends.py index 3979784..0497eeb 100644 --- a/tests/test_depends.py +++ b/tests/test_depends.py @@ -1,8 +1,10 @@ from time import sleep +import pytest from starlette.testclient import TestClient from examples.main import app +from fastapi_limiter.depends import RateLimiter def test_limiter(): @@ -65,3 +67,49 @@ def test_limiter_websockets(): data = ws.receive_text() assert data == "Hello, world" ws.close() + + +@pytest.mark.parametrize( + ("eq_left", "eq_right", "neq_left1", "neq_right1", "neq_left2", "neq_right2"), + [ + ( + RateLimiter(times=1, milliseconds=5), + RateLimiter(times=1, milliseconds=5), + RateLimiter(times=1, milliseconds=5), + RateLimiter(times=2, milliseconds=5), + RateLimiter(times=1, milliseconds=5), + RateLimiter(times=1, milliseconds=10), + ), + ( + RateLimiter(times=1, seconds=5), + RateLimiter(times=1, seconds=5), + RateLimiter(times=1, seconds=5), + RateLimiter(times=2, seconds=5), + RateLimiter(times=1, seconds=5), + RateLimiter(times=1, seconds=10), + ), + ( + RateLimiter(times=1, minutes=5), + RateLimiter(times=1, minutes=5), + RateLimiter(times=1, minutes=5), + RateLimiter(times=2, minutes=5), + RateLimiter(times=1, minutes=5), + RateLimiter(times=1, minutes=10), + ), + ( + RateLimiter(times=1, hours=5), + RateLimiter(times=1, hours=5), + RateLimiter(times=1, hours=5), + RateLimiter(times=2, hours=5), + RateLimiter(times=1, hours=5), + RateLimiter(times=1, hours=10), + ), + ], +) +def test_limiter_equality(eq_left, eq_right, neq_left1, neq_right1, neq_left2, neq_right2): + assert hash(eq_left) == hash(eq_right) + assert eq_left == eq_right + assert hash(neq_left1) != hash(neq_right1) + assert neq_left1 != neq_right1 + assert hash(neq_left2) != hash(neq_right2) + assert neq_left2 != neq_right2