diff --git a/alembic/env.py b/alembic/env.py index f7d9f2f..371975e 100644 --- a/alembic/env.py +++ b/alembic/env.py @@ -15,8 +15,8 @@ # Import your models and Base from paste.database import Base + # Import all your models here -from paste.models import Paste # this is the Alembic Config object config = context.config @@ -46,6 +46,7 @@ def run_migrations_offline() -> None: with context.begin_transaction(): context.run_migrations() + def run_migrations_online() -> None: connectable = engine_from_config( config.get_section(config.config_ini_section, {}), @@ -54,14 +55,13 @@ def run_migrations_online() -> None: ) with connectable.connect() as connection: - context.configure( - connection=connection, target_metadata=target_metadata - ) + context.configure(connection=connection, target_metadata=target_metadata) with context.begin_transaction(): context.run_migrations() + if context.is_offline_mode(): run_migrations_offline() else: - run_migrations_online() \ No newline at end of file + run_migrations_online() diff --git a/alembic/versions/9513acd42747_initial_migration.py b/alembic/versions/9513acd42747_initial_migration.py index d5362c0..dee74fb 100644 --- a/alembic/versions/9513acd42747_initial_migration.py +++ b/alembic/versions/9513acd42747_initial_migration.py @@ -12,7 +12,7 @@ # revision identifiers, used by Alembic. -revision: str = '9513acd42747' +revision: str = "9513acd42747" down_revision: Union[str, None] = None branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -20,19 +20,20 @@ def upgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### - op.create_table('pastes', - sa.Column('pasteID', sa.String(length=4), nullable=False), - sa.Column('content', sa.Text(), nullable=True), - sa.Column('extension', sa.String(length=50), nullable=True), - sa.Column('s3_link', sa.String(length=500), nullable=True), - sa.Column('created_at', sa.DateTime(), nullable=True), - sa.Column('expiresat', sa.DateTime(), nullable=True), - sa.PrimaryKeyConstraint('pasteID') + op.create_table( + "pastes", + sa.Column("pasteID", sa.String(length=4), nullable=False), + sa.Column("content", sa.Text(), nullable=True), + sa.Column("extension", sa.String(length=50), nullable=True), + sa.Column("s3_link", sa.String(length=500), nullable=True), + sa.Column("created_at", sa.DateTime(), nullable=True), + sa.Column("expiresat", sa.DateTime(), nullable=True), + sa.PrimaryKeyConstraint("pasteID"), ) # ### end Alembic commands ### def downgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### - op.drop_table('pastes') + op.drop_table("pastes") # ### end Alembic commands ### diff --git a/sdk/example.py b/sdk/example.py index efca28c..2c9a248 100644 --- a/sdk/example.py +++ b/sdk/example.py @@ -1,5 +1,6 @@ from sdk.module import PasteBinSDK + def test_pastebin_sdk(): sdk = PasteBinSDK() @@ -23,5 +24,6 @@ def test_pastebin_sdk(): except RuntimeError as e: print(f"An error occurred: {e}") + if __name__ == "__main__": - test_pastebin_sdk() \ No newline at end of file + test_pastebin_sdk() diff --git a/sdk/sdk/module.py b/sdk/sdk/module.py index 5b7c428..585704e 100644 --- a/sdk/sdk/module.py +++ b/sdk/sdk/module.py @@ -1,7 +1,8 @@ import requests -from typing import Optional, Union +from typing import Union from pathlib import Path + class PasteBinSDK: def __init__(self, base_url: str = "https://paste.fosscu.org"): self.base_url = base_url @@ -15,17 +16,14 @@ def create_paste(self, content: Union[str, Path], file_extension: str) -> str: """ try: if isinstance(content, Path): - with open(content, 'r', encoding='utf-8') as f: + with open(content, "r", encoding="utf-8") as f: content = f.read() - data = { - 'content': content, - 'extension': file_extension - } + data = {"content": content, "extension": file_extension} response = requests.post(f"{self.base_url}/api/paste", json=data) response.raise_for_status() result = response.json() - return result['uuid'] + return result["uuid"] except requests.RequestException as e: raise RuntimeError(f"Error creating paste: {str(e)}") @@ -65,4 +63,4 @@ def get_languages(self) -> dict: response.raise_for_status() return response.json() except requests.RequestException as e: - raise RuntimeError(f"Error fetching languages: {str(e)}") \ No newline at end of file + raise RuntimeError(f"Error fetching languages: {str(e)}") diff --git a/src/paste/main.py b/src/paste/main.py index 077516a..df04386 100644 --- a/src/paste/main.py +++ b/src/paste/main.py @@ -5,20 +5,17 @@ from datetime import datetime, timedelta, timezone from logging.config import dictConfig from pathlib import Path -from typing import Any, Awaitable, Callable, List, Optional, Union +from typing import Awaitable, List, Optional, Union -from fastapi import (Depends, FastAPI, File, Form, Header, HTTPException, - Query, Request, Response, UploadFile, status) -from fastapi.exception_handlers import http_exception_handler +from fastapi import Depends, FastAPI, File, Form, Header, HTTPException, Query, Request, Response, UploadFile, status from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import (HTMLResponse, JSONResponse, PlainTextResponse, - RedirectResponse) +from fastapi.responses import HTMLResponse, JSONResponse, PlainTextResponse, RedirectResponse from fastapi.templating import Jinja2Templates from pygments import highlight from pygments.formatters import HtmlFormatter from pygments.lexers import get_lexer_by_name, guess_lexer from pygments.util import ClassNotFound -from slowapi import Limiter, _rate_limit_exceeded_handler +from slowapi import Limiter from slowapi.errors import RateLimitExceeded from slowapi.util import get_remote_address from sqlalchemy import text @@ -34,8 +31,7 @@ from .middleware import LimitUploadSize from .minio import get_object_data, post_object_data from .models import Paste -from .schema import (HealthErrorResponse, HealthResponse, PasteCreate, - PasteDetails, PasteResponse) +from .schema import HealthErrorResponse, HealthResponse, PasteCreate, PasteDetails, PasteResponse from .utils import _filter_object_name_from_link, extract_uuid # -------------------------------------------------------------------- @@ -54,7 +50,6 @@ async def delete_expired_urls() -> None: while True: try: - db: Session = Session_Local() current_time = datetime.utcnow() @@ -95,9 +90,7 @@ async def delete_expired_urls() -> None: app.state.limiter = limiter -def rate_limit_exceeded_handler( - request: Request, exc: Exception -) -> Union[Response, Awaitable[Response]]: +def rate_limit_exceeded_handler(request: Request, exc: Exception) -> Union[Response, Awaitable[Response]]: if isinstance(exc, RateLimitExceeded): return Response(content="Rate limit exceeded", status_code=429) return Response(content="An error occurred", status_code=500) @@ -107,9 +100,7 @@ def rate_limit_exceeded_handler( @app.exception_handler(StarletteHTTPException) -async def custom_http_exception_handler( - request: Request, exc: StarletteHTTPException -) -> Response: +async def custom_http_exception_handler(request: Request, exc: StarletteHTTPException) -> Response: # Check if it's an API route if request.url.path.startswith("/api/"): return JSONResponse(status_code=exc.status_code, content={"detail": exc.detail}) @@ -121,16 +112,12 @@ async def custom_http_exception_handler( if is_browser_request: try: - return templates.TemplateResponse( - "404.html", {"request": request}, status_code=404 - ) + return templates.TemplateResponse("404.html", {"request": request}, status_code=404) except Exception as e: logger.error(f"Template error: {e}") return PlainTextResponse("404: Template Error", status_code=404) else: - return PlainTextResponse( - "404: The requested resource was not found", status_code=404 - ) + return PlainTextResponse("404: The requested resource was not found", status_code=404) return PlainTextResponse(str(exc.detail), status_code=exc.status_code) @@ -176,9 +163,7 @@ async def indexpage(request: Request) -> Response: "/health", status_code=status.HTTP_200_OK, response_model=HealthResponse, - responses={ - 503: {"model": HealthErrorResponse, "description": "Database connection failed"} - }, + responses={503: {"model": HealthErrorResponse, "description": "Database connection failed"}}, ) async def health(db: Session = Depends(get_db)) -> HealthResponse: """ @@ -193,9 +178,7 @@ async def health(db: Session = Depends(get_db)) -> HealthResponse: db.execute(text("SELECT 1")) end_time = time.time() - return HealthResponse( - db_response_time_ms=round((end_time - start_time) * 1000, 2) - ) + return HealthResponse(db_response_time_ms=round((end_time - start_time) * 1000, 2)) except Exception as e: db.rollback() @@ -287,9 +270,7 @@ async def get_paste_data( async def post_as_a_file( request: Request, file: UploadFile = File(...), - expiration: Optional[str] = Query( - None, description="Expiration time: '1h', '1d', '1w', '1m', or ISO datetime" - ), + expiration: Optional[str] = Query(None, description="Expiration time: '1h', '1d', '1w', '1m', or ISO datetime"), db: Session = Depends(get_db), ) -> PlainTextResponse: try: @@ -316,9 +297,7 @@ async def post_as_a_file( else: # Try parsing as ISO format datetime try: - expiration_time = datetime.fromisoformat( - expiration.replace("Z", "+00:00") - ) + expiration_time = datetime.fromisoformat(expiration.replace("Z", "+00:00")) if expiration_time <= current_time: raise HTTPException( detail="Expiration time must be in the future", @@ -340,24 +319,20 @@ async def post_as_a_file( db.commit() db.refresh(file_data) _uuid = file_data.pasteID - return PlainTextResponse( - f"{BASE_URL}/paste/{_uuid}", status_code=status.HTTP_201_CREATED - ) + return PlainTextResponse(f"{BASE_URL}/paste/{_uuid}", status_code=status.HTTP_201_CREATED) else: file_data = Paste(content=file_content, extension=file_extension) db.add(file_data) db.commit() db.refresh(file_data) _uuid = file_data.pasteID - return PlainTextResponse( - f"{BASE_URL}/paste/{_uuid}", status_code=status.HTTP_201_CREATED - ) + return PlainTextResponse(f"{BASE_URL}/paste/{_uuid}", status_code=status.HTTP_201_CREATED) except Exception as e: db.rollback() logger.error(f"Error uploading file: {e}") raise HTTPException( - detail=f"There was an error uploading the file", + detail="There was an error uploading the file", status_code=status.HTTP_403_FORBIDDEN, ) finally: @@ -375,14 +350,12 @@ async def delete_paste(uuid: str, db: Session = Depends(get_db)) -> PlainTextRes db.commit() return PlainTextResponse(f"File successfully deleted {uuid}") else: - raise HTTPException( - detail="File Not Found", status_code=status.HTTP_404_NOT_FOUND - ) + raise HTTPException(detail="File Not Found", status_code=status.HTTP_404_NOT_FOUND) except Exception as e: db.rollback() raise HTTPException( logger.error(f"Error deleting paste: {e}"), - detail=f"There is an error happend.", + detail="There is an error happend.", status_code=status.HTTP_409_CONFLICT, ) finally: @@ -427,9 +400,7 @@ async def web_post( elif expiration == "custom" and custom_expiry: # Parse the custom expiry datetime string try: - expiration_time = datetime.fromisoformat( - custom_expiry.replace("Z", "+00:00") - ) + expiration_time = datetime.fromisoformat(custom_expiry.replace("Z", "+00:00")) except ValueError: raise HTTPException( detail="Invalid custom expiry date format", @@ -448,20 +419,14 @@ async def web_post( db.commit() db.refresh(file) _uuid = file.pasteID - return RedirectResponse( - f"{BASE_URL}/paste/{_uuid}", status_code=status.HTTP_303_SEE_OTHER - ) + return RedirectResponse(f"{BASE_URL}/paste/{_uuid}", status_code=status.HTTP_303_SEE_OTHER) else: - file = Paste( - content=content, extension=extension, expiresat=expiration_time - ) + file = Paste(content=content, extension=extension, expiresat=expiration_time) db.add(file) db.commit() db.refresh(file) _uuid = file.pasteID - return RedirectResponse( - f"{BASE_URL}/paste/{_uuid}", status_code=status.HTTP_303_SEE_OTHER - ) + return RedirectResponse(f"{BASE_URL}/paste/{_uuid}", status_code=status.HTTP_303_SEE_OTHER) except Exception as e: db.rollback() raise HTTPException( @@ -479,9 +444,7 @@ async def web_post( @app.get("/api/paste/{uuid}", response_model=PasteDetails) @limiter.limit("100/minute") -async def get_paste_details( - request: Request, uuid: str, db: Session = Depends(get_db) -) -> JSONResponse: +async def get_paste_details(request: Request, uuid: str, db: Session = Depends(get_db)) -> JSONResponse: try: uuid = extract_uuid(uuid) data = db.query(Paste).filter(Paste.pasteID == uuid).first() @@ -499,10 +462,10 @@ async def get_paste_details( detail="Paste not found", status_code=status.HTTP_404_NOT_FOUND, ) - except Exception as e: + except Exception: db.rollback() raise HTTPException( - detail=f"Error retrieving paste", + detail="Error retrieving paste", status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, ) finally: @@ -511,11 +474,8 @@ async def get_paste_details( @app.post("/api/paste", response_model=PasteResponse) @limiter.limit("100/minute") -async def create_paste( - request: Request, paste: PasteCreate, db: Session = Depends(get_db) -) -> JSONResponse: +async def create_paste(request: Request, paste: PasteCreate, db: Session = Depends(get_db)) -> JSONResponse: try: - # Calculate expiration time if provided expiration_time = None if paste.expiration: @@ -552,9 +512,7 @@ async def create_paste( db.refresh(file) _uuid = file.pasteID return JSONResponse( - content=PasteResponse( - uuid=_uuid, url=f"{BASE_URL}/paste/{_uuid}" - ).model_dump(), + content=PasteResponse(uuid=_uuid, url=f"{BASE_URL}/paste/{_uuid}").model_dump(), status_code=status.HTTP_201_CREATED, ) else: @@ -568,18 +526,16 @@ async def create_paste( db.refresh(file) _uuid = file.pasteID return JSONResponse( - content=PasteResponse( - uuid=_uuid, url=f"{BASE_URL}/paste/{_uuid}" - ).model_dump(), + content=PasteResponse(uuid=_uuid, url=f"{BASE_URL}/paste/{_uuid}").model_dump(), status_code=status.HTTP_201_CREATED, ) except HTTPException: db.rollback() raise - except Exception as e: + except Exception: db.rollback() raise HTTPException( - detail=f"There was an error creating the paste", + detail="There was an error creating the paste", status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, ) finally: @@ -605,6 +561,6 @@ async def get_languages() -> JSONResponse: except Exception as e: logger.error(f"Error reading languages file: {e}") raise HTTPException( - detail=f"Error reading languages file", + detail="Error reading languages file", status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, ) diff --git a/src/paste/middleware.py b/src/paste/middleware.py index 65b8005..5210507 100644 --- a/src/paste/middleware.py +++ b/src/paste/middleware.py @@ -1,6 +1,5 @@ from starlette import status -from starlette.middleware.base import (BaseHTTPMiddleware, - RequestResponseEndpoint) +from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint from starlette.requests import Request from starlette.responses import Response from starlette.types import ASGIApp @@ -11,9 +10,7 @@ def __init__(self, app: ASGIApp, max_upload_size: int) -> None: super().__init__(app) self.max_upload_size: int = max_upload_size - async def dispatch( - self, request: Request, call_next: RequestResponseEndpoint - ) -> Response: + async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: if request.method == "POST": if "content-length" not in request.headers: return Response(status_code=status.HTTP_411_LENGTH_REQUIRED) diff --git a/src/paste/minio.py b/src/paste/minio.py index e639a4d..eef4343 100644 --- a/src/paste/minio.py +++ b/src/paste/minio.py @@ -15,20 +15,16 @@ ) -def get_object_data( - object_name: str, bucket_name: str = get_settings().MINIO_BUCKET_NAME -) -> str | None: +def get_object_data(object_name: str, bucket_name: str = get_settings().MINIO_BUCKET_NAME) -> str | None: response = None data = None try: response = client.get_object(bucket_name, object_name) data = response.read() except S3Error as exc: - raise Exception(f"error occured.", exc) + raise Exception("error occured.", exc) except Exception as exc: - raise FileNotFoundError( - f"Failed to retrieve file '{object_name}' from bucket '{bucket_name}': {exc}" - ) + raise FileNotFoundError(f"Failed to retrieve file '{object_name}' from bucket '{bucket_name}': {exc}") finally: if response: response.close() @@ -65,9 +61,7 @@ def post_object_data( ) return object_url except S3Error as exc: - raise Exception( - f"Failed to upload file '{object_name}' to bucket '{bucket_name}': {exc}" - ) + raise Exception(f"Failed to upload file '{object_name}' to bucket '{bucket_name}': {exc}") def post_object_data_as_file( @@ -81,17 +75,11 @@ def post_object_data_as_file( client.fput_object(bucket_name, object_name, source_file_path) except S3Error as exc: - raise Exception( - f"Failed to upload file '{object_name}' to bucket '{bucket_name}': {exc}" - ) + raise Exception(f"Failed to upload file '{object_name}' to bucket '{bucket_name}': {exc}") -def delete_object_data( - object_name: str, bucket_name: str = get_settings().MINIO_BUCKET_NAME -) -> None: +def delete_object_data(object_name: str, bucket_name: str = get_settings().MINIO_BUCKET_NAME) -> None: try: client.remove_object(bucket_name, object_name) except S3Error as exc: - raise Exception( - f"Failed to delete file '{object_name}' from bucket '{bucket_name}': {exc}" - ) + raise Exception(f"Failed to delete file '{object_name}' from bucket '{bucket_name}': {exc}") diff --git a/src/paste/utils.py b/src/paste/utils.py index e300a7d..8103d5b 100644 --- a/src/paste/utils.py +++ b/src/paste/utils.py @@ -34,11 +34,7 @@ def _find_without_extension(file_name: str) -> str: file_list: list = os.listdir("data") pattern_with_dot: Pattern[str] = re.compile(r"^(" + re.escape(file_name) + r")\.") pattern_without_dot: Pattern[str] = re.compile(r"^" + file_name + "$") - math_pattern: list = [ - x - for x in file_list - if pattern_with_dot.match(x) or pattern_without_dot.match(x) - ] + math_pattern: list = [x for x in file_list if pattern_with_dot.match(x) or pattern_without_dot.match(x)] if len(math_pattern) == 0: return str() else: diff --git a/tests/test_api.py b/tests/test_api.py index 69020c3..a2124dc 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -71,7 +71,6 @@ def test_post_file_route_failure() -> None: # Add body assertion in future. - def test_post_file_route_size_limit() -> None: large_file_name: str = "large_file.txt" file_size: int = 20 * 1024 * 1024 # 20 MB in bytes @@ -91,11 +90,9 @@ def test_post_file_route_size_limit() -> None: assert response.status_code == 413 assert "File is too large" in response.text + def test_post_api_paste_route() -> None: - paste_data = { - "content": "This is a test paste content", - "extension": "txt" - } + paste_data = {"content": "This is a test paste content", "extension": "txt"} response = client.post("/api/paste", json=paste_data) assert response.status_code == 201 response_json = response.json() @@ -109,12 +106,10 @@ def test_post_api_paste_route() -> None: delete_response = client.delete(f"/paste/{uuid}") assert delete_response.status_code == 200 + def test_get_api_paste_route() -> None: # First, create a paste - paste_data = { - "content": "This is a test paste content for GET", - "extension": "md" - } + paste_data = {"content": "This is a test paste content for GET", "extension": "md"} create_response = client.post("/api/paste", json=paste_data) assert create_response.status_code == 201 created_uuid = create_response.json()["uuid"] @@ -131,6 +126,7 @@ def test_get_api_paste_route() -> None: delete_response = client.delete(f"/paste/{created_uuid}") assert delete_response.status_code == 200 + def test_get_api_paste_route_not_found() -> None: response = client.get("/api/paste/nonexistent_uuid.txt") assert response.status_code == 404