Skip to content

Commit ae7003c

Browse files
Merge pull request #31 from Mr-Sunglasses/testss
feat: file upload limit
2 parents f4d6f69 + 0a0b3da commit ae7003c

File tree

3 files changed

+58
-18
lines changed

3 files changed

+58
-18
lines changed

src/paste/main.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from slowapi import Limiter, _rate_limit_exceeded_handler
1111
from slowapi.util import get_remote_address
1212
from .utils import generate_uuid
13+
from .middleware import LimitUploadSize
1314

1415
limiter = Limiter(key_func=get_remote_address)
1516
app = FastAPI(title="paste.py 🐍")
@@ -26,6 +27,8 @@
2627
allow_headers=["*"],
2728
)
2829

30+
app.add_middleware(LimitUploadSize, max_upload_size=20_000_000) # ~20MB
31+
2932
large_uuid_storage = []
3033

3134
BASE_DIR = Path(__file__).resolve().parent
@@ -58,7 +61,7 @@ async def post_as_a_file(request: Request, file: UploadFile = File(...)):
5861

5962

6063
@app.get("/paste/{uuid}")
61-
async def post_as_a_text(uuid):
64+
async def get_paste_data(uuid):
6265
path = f"data/{uuid}"
6366
try:
6467
with open(path, "rb") as f:

src/paste/middleware.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from starlette import status
2+
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
3+
from starlette.requests import Request
4+
from starlette.responses import Response
5+
from starlette.types import ASGIApp
6+
7+
8+
class LimitUploadSize(BaseHTTPMiddleware):
9+
def __init__(self, app: ASGIApp, max_upload_size: int) -> None:
10+
super().__init__(app)
11+
self.max_upload_size = max_upload_size
12+
13+
async def dispatch(
14+
self, request: Request, call_next: RequestResponseEndpoint
15+
) -> Response:
16+
if request.method == "POST":
17+
if "content-length" not in request.headers:
18+
return Response(status_code=status.HTTP_411_LENGTH_REQUIRED)
19+
content_length = int(request.headers["content-length"])
20+
if content_length > self.max_upload_size:
21+
return Response(
22+
"File is too large",
23+
status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
24+
)
25+
return await call_next(request)

tests/test_api.py

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from fastapi.testclient import TestClient
22
from src.paste.main import app
3+
import os
34

45
client = TestClient(app)
56

@@ -14,31 +15,29 @@ def test_get_health_route():
1415

1516

1617
def test_get_homepage_route():
17-
response_expected_headers = 'text/html; charset=utf-8'
18+
response_expected_headers = "text/html; charset=utf-8"
1819
response = client.get("/")
1920
assert response.status_code == 200
20-
assert response.headers.get(
21-
'Content-Type', '') == response_expected_headers
21+
assert response.headers.get("Content-Type", "") == response_expected_headers
2222

2323

2424
def test_get_web_route():
25-
response_expected_headers = 'text/html; charset=utf-8'
25+
response_expected_headers = "text/html; charset=utf-8"
2626
response = client.get("/web")
2727
assert response.status_code == 200
28-
assert response.headers.get(
29-
'Content-Type', '') == response_expected_headers
28+
assert response.headers.get("Content-Type", "") == response_expected_headers
3029

3130

32-
def test_get_paste_route():
33-
data = 'This is a test file.'
31+
def test_get_paste_data_route():
32+
data = "This is a test file."
3433
response = client.get("/paste/test")
3534
assert response.status_code == 200
3635
assert response.text == data
3736

3837

3938
def test_post_web_route():
40-
data = 'This is a test data'
41-
form_data = {'content': data}
39+
data = "This is a test data"
40+
form_data = {"content": data}
4241
response = client.post("/web", data=form_data)
4342
global file
4443
file = str(response.url).split("/")[-1]
@@ -54,8 +53,7 @@ def test_delete_paste_route():
5453

5554

5655
def test_post_file_route():
57-
response = client.post(
58-
"/file", files={"file": ("test.txt", b"test file content")})
56+
response = client.post("/file", files={"file": ("test.txt", b"test file content")})
5957
assert response.status_code == 201
6058
response_file_uuid = response.text
6159
response = client.get(f"/paste/{response_file_uuid}")
@@ -73,13 +71,27 @@ def test_post_file_route_failure():
7371
"detail": [
7472
{
7573
"type": "missing",
76-
"loc": [
77-
"body",
78-
"file"
79-
],
74+
"loc": ["body", "file"],
8075
"msg": "Field required",
8176
"input": None,
82-
"url": "https://errors.pydantic.dev/2.5/v/missing"
77+
"url": "https://errors.pydantic.dev/2.5/v/missing",
8378
}
8479
]
8580
}
81+
82+
83+
def test_post_file_route_size_limit():
84+
large_file_name = "large_file.txt"
85+
file_size = 20 * 1024 * 1024 # 20 MB in bytes
86+
additional_bytes = 100 # Adding some extra bytes to exceed 20 MB
87+
content = b"This is a line in the file.\n"
88+
with open(large_file_name, "wb") as file:
89+
while file.tell() < file_size:
90+
file.write(content)
91+
file.write(b"Extra bytes to exceed 20 MB\n" * additional_bytes)
92+
files = {"file": open(large_file_name, "rb")}
93+
response = client.post("/file", files=files)
94+
# cleanup
95+
os.remove(large_file_name)
96+
assert response.status_code == 413
97+
assert response.text == "File is too large"

0 commit comments

Comments
 (0)