diff --git a/backend/app/alembic/versions/050_add_userproject_table.py b/backend/app/alembic/versions/050_add_userproject_table.py new file mode 100644 index 000000000..7b71ff942 --- /dev/null +++ b/backend/app/alembic/versions/050_add_userproject_table.py @@ -0,0 +1,60 @@ +"""Add userproject table + +Revision ID: 050 +Revises: 049 +Create Date: 2026-04-01 12:17:42.165482 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = "050" +down_revision = "049" +branch_labels = None +depends_on = None + + +def upgrade(): + op.create_table( + "user_project", + sa.Column( + "user_id", sa.Integer(), nullable=False, comment="Reference to the user" + ), + sa.Column( + "organization_id", + sa.Integer(), + nullable=False, + comment="Reference to the organization", + ), + sa.Column( + "project_id", + sa.Integer(), + nullable=False, + comment="Reference to the project", + ), + sa.Column( + "id", + sa.Integer(), + nullable=False, + comment="Unique identifier for the user-project mapping", + ), + sa.Column( + "inserted_at", + sa.DateTime(), + nullable=False, + comment="Timestamp when the mapping was created", + ), + sa.ForeignKeyConstraint( + ["organization_id"], ["organization.id"], ondelete="CASCADE" + ), + sa.ForeignKeyConstraint(["project_id"], ["project.id"], ondelete="CASCADE"), + sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("user_id", "project_id", name="uq_user_project"), + ) + + +def downgrade(): + op.drop_table("user_project") diff --git a/backend/app/api/deps.py b/backend/app/api/deps.py index 9f2c81a62..a99e0d824 100644 --- a/backend/app/api/deps.py +++ b/backend/app/api/deps.py @@ -4,19 +4,26 @@ import jwt from fastapi import Depends, HTTPException, Request, status from fastapi.security import APIKeyHeader, OAuth2PasswordBearer -from jwt.exceptions import InvalidTokenError +from jwt.exceptions import ExpiredSignatureError, InvalidTokenError from pydantic import ValidationError -from sqlmodel import Session, select +from sqlmodel import Session + +from sqlmodel import and_, select from app.core import security from app.core.config import settings from app.core.db import engine from app.core.security import api_key_manager from app.crud.organization import validate_organization +from app.crud.project import validate_project from app.models import ( + APIKey, AuthContext, + Organization, + Project, TokenPayload, User, + UserProject, ) @@ -35,57 +42,123 @@ def get_db() -> Generator[Session, None, None]: TokenDep = Annotated[str, Depends(reusable_oauth2)] +def _authenticate_with_jwt(session: Session, token: str) -> AuthContext: + """Validate a JWT token and return the authenticated user context.""" + try: + payload = jwt.decode( + token, settings.SECRET_KEY, algorithms=[security.ALGORITHM] + ) + token_data = TokenPayload(**payload) + except ExpiredSignatureError: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Token has expired", + ) + except (InvalidTokenError, ValidationError): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Could not validate credentials", + ) + + # Reject refresh tokens — they should only be used at /auth/refresh + if token_data.type == "refresh": + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Refresh tokens cannot be used for API access", + ) + + user = session.get(User, token_data.sub) + if not user or not user.is_active: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="User access has been revoked", + ) + + organization: Organization | None = None + project: Project | None = None + + if token_data.org_id: + organization = validate_organization(session=session, org_id=token_data.org_id) + if token_data.project_id: + project = validate_project(session=session, project_id=token_data.project_id) + + # Verify user still has access to this project + if project: + has_access = session.exec( + select(UserProject.id) + .where( + and_( + UserProject.user_id == user.id, + UserProject.project_id == project.id, + ) + ) + .limit(1) + ).first() + + if not has_access: + # Fallback: check APIKey table for backward compatibility + has_api_key = session.exec( + select(APIKey.id) + .where( + and_( + APIKey.user_id == user.id, + APIKey.project_id == project.id, + APIKey.is_deleted.is_(False), + ) + ) + .limit(1) + ).first() + + if not has_api_key: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="User access to this project has been revoked", + ) + + return AuthContext(user=user, organization=organization, project=project) + + def get_auth_context( + request: Request, session: SessionDep, token: TokenDep, api_key: Annotated[str, Depends(api_key_header)], ) -> AuthContext: """ - Verify valid authentication (API Key or JWT token) and return authenticated user context. + Verify valid authentication (API Key, JWT token, or cookie) and return authenticated user context. Returns AuthContext with user info, project_id, and organization_id. Authorization logic should be handled in routes. + + Authentication priority: + 1. X-API-KEY header + 2. Authorization: Bearer header + 3. access_token cookie """ + # 1. Try X-API-KEY header if api_key: auth_context = api_key_manager.verify(session, api_key) - if not auth_context: - raise HTTPException(status_code=401, detail="Invalid API Key") + if auth_context: + if not auth_context.user.is_active: + raise HTTPException(status_code=403, detail="Inactive user") - if not auth_context.user.is_active: - raise HTTPException(status_code=403, detail="Inactive user") + if not auth_context.organization.is_active: + raise HTTPException(status_code=403, detail="Inactive Organization") - if not auth_context.organization.is_active: - raise HTTPException(status_code=403, detail="Inactive Organization") + if not auth_context.project.is_active: + raise HTTPException(status_code=403, detail="Inactive Project") - if not auth_context.project.is_active: - raise HTTPException(status_code=403, detail="Inactive Project") + return auth_context - return auth_context + # 2. Try Authorization: Bearer header + if token: + return _authenticate_with_jwt(session, token) - elif token: - try: - payload = jwt.decode( - token, settings.SECRET_KEY, algorithms=[security.ALGORITHM] - ) - token_data = TokenPayload(**payload) - except (InvalidTokenError, ValidationError): - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="Could not validate credentials", - ) - - user = session.get(User, token_data.sub) - if not user: - raise HTTPException(status_code=404, detail="User not found") - if not user.is_active: - raise HTTPException(status_code=403, detail="Inactive user") - - auth_context = AuthContext( - user=user, - ) - return auth_context + # 3. Try access_token cookie + cookie_token = request.cookies.get("access_token") + if cookie_token: + return _authenticate_with_jwt(session, cookie_token) - else: - raise HTTPException(status_code=401, detail="Invalid Authorization format") + raise HTTPException(status_code=401, detail="Invalid Authorization format") AuthContextDep = Annotated[AuthContext, Depends(get_auth_context)] diff --git a/backend/app/api/docs/auth/google.md b/backend/app/api/docs/auth/google.md new file mode 100644 index 000000000..52de201b9 --- /dev/null +++ b/backend/app/api/docs/auth/google.md @@ -0,0 +1,40 @@ +# Google OAuth Authentication + +Authenticate a user via Google Sign-In by verifying the Google ID token. + +## Request + +- **token** (required): The Google ID token obtained from the frontend Google Sign-In flow. + +## Behavior + +1. Verifies the Google ID token against Google's public keys and the configured `GOOGLE_CLIENT_ID`. +2. Extracts user information (email, name, picture) from the verified token. +3. Looks up the user by email in the database. +4. If the user exists and was inactive (first login), activates the account. +5. Generates a JWT access token and refresh token, set as **HTTP-only secure cookies**. +6. If the user has exactly one project, it is auto-selected and embedded in the JWT. +7. If the user has multiple projects, `requires_project_selection: true` is returned with the list. + +## Response Format + +All responses follow the standard `APIResponse` format: +```json +{ + "success": true, + "data": { + "access_token": "...", + "token_type": "bearer", + "user": { ... }, + "google_profile": { ... }, + "requires_project_selection": false, + "available_projects": [ ... ] + } +} +``` + +## Error Responses + +- **400**: Invalid or expired Google token, or email not verified by Google. +- **401**: No account found for the Google email address. +- **500**: `GOOGLE_CLIENT_ID` is not configured. diff --git a/backend/app/api/docs/user_project/add.md b/backend/app/api/docs/user_project/add.md new file mode 100644 index 000000000..2a191c347 --- /dev/null +++ b/backend/app/api/docs/user_project/add.md @@ -0,0 +1,17 @@ +Add one or more users to a project by email. **Requires superuser access.** + +**Request Body:** +- `organization_id` (required): The ID of the organization the project belongs to. +- `project_id` (required): The ID of the project to add users to. +- `users` (required): Array of user objects. + - `email` (required): User's email address. + - `full_name` (optional): User's full name. + +**Examples:** +- **Single user**: `{"organization_id": 1, "project_id": 1, "users": [{"email": "user@gmail.com", "full_name": "User Name"}]}` +- **Multiple users**: `{"organization_id": 1, "project_id": 1, "users": [{"email": "a@gmail.com"}, {"email": "b@gmail.com"}]}` + +**Behavior per email:** +- If the user does not exist, a new account is created with `is_active: false`. The user will be activated on their first Google login. +- If the user already exists and is already in this project, they are skipped. +- If the user exists but is not in this project, they are added. diff --git a/backend/app/api/docs/user_project/delete.md b/backend/app/api/docs/user_project/delete.md new file mode 100644 index 000000000..4b765c47f --- /dev/null +++ b/backend/app/api/docs/user_project/delete.md @@ -0,0 +1,9 @@ +Remove a user from a project. **Requires superuser access.** + +**Path Parameters:** +- `user_id` (required): The ID of the user to remove. + +**Query Parameters:** +- `project_id` (required): The ID of the project to remove the user from. + +This only removes the user-project mapping — the user account itself is not deleted. You cannot remove yourself from a project. diff --git a/backend/app/api/docs/user_project/list.md b/backend/app/api/docs/user_project/list.md new file mode 100644 index 000000000..8c1cd3693 --- /dev/null +++ b/backend/app/api/docs/user_project/list.md @@ -0,0 +1,6 @@ +List all users that belong to a project. + +**Query Parameters:** +- `project_id` (required): The ID of the project to list users for. + +Returns user details including their active status — users added via invitation will have `is_active: false` until they complete their first Google login. diff --git a/backend/app/api/main.py b/backend/app/api/main.py index 5ab1cbd9e..98ce324c5 100644 --- a/backend/app/api/main.py +++ b/backend/app/api/main.py @@ -7,6 +7,7 @@ config, doc_transformation_job, documents, + auth, login, languages, llm, @@ -17,6 +18,7 @@ responses, private, threads, + user_project, users, utils, onboarding, @@ -39,6 +41,7 @@ api_router.include_router(cron.router) api_router.include_router(documents.router) api_router.include_router(doc_transformation_job.router) +api_router.include_router(auth.router) api_router.include_router(evaluations.router) api_router.include_router(languages.router) api_router.include_router(llm.router) @@ -50,6 +53,7 @@ api_router.include_router(project.router) api_router.include_router(responses.router) api_router.include_router(threads.router) +api_router.include_router(user_project.router) api_router.include_router(users.router) api_router.include_router(utils.router) api_router.include_router(fine_tuning.router) diff --git a/backend/app/api/routes/auth.py b/backend/app/api/routes/auth.py new file mode 100644 index 000000000..a7a7465b1 --- /dev/null +++ b/backend/app/api/routes/auth.py @@ -0,0 +1,200 @@ +import logging + +from fastapi import APIRouter, HTTPException, Request, status +from fastapi.responses import JSONResponse +from google.auth.transport import requests as google_requests +from google.oauth2 import id_token + +from app.api.deps import AuthContextDep, SessionDep +from app.core.config import settings +from app.crud import get_user_by_email +from app.crud.auth import get_user_accessible_projects +from app.models import ( + GoogleAuthRequest, + GoogleAuthResponse, + Message, + SelectProjectRequest, + Token, +) +from app.services.auth import ( + build_google_auth_response, + build_token_response, + clear_auth_cookies, + validate_refresh_token, +) +from app.utils import APIResponse, load_description + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/auth", tags=["Authentication"]) + + +@router.post( + "/google", + description=load_description("auth/google.md"), + response_model=APIResponse[GoogleAuthResponse], +) +def google_auth(session: SessionDep, body: GoogleAuthRequest) -> JSONResponse: + """Authenticate a user via Google OAuth ID token.""" + + if not settings.GOOGLE_CLIENT_ID: + logger.error("[google_auth] GOOGLE_CLIENT_ID is not configured") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Google authentication is not configured", + ) + + # Verify the Google ID token + try: + idinfo = id_token.verify_oauth2_token( + body.token, + google_requests.Request(), + settings.GOOGLE_CLIENT_ID, + ) + except ValueError as e: + logger.warning(f"[google_auth] Invalid Google token: {e}") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Invalid or expired Google token", + ) + + if not idinfo.get("email_verified", False): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Google email is not verified", + ) + + email: str = idinfo["email"] + + user = get_user_by_email(session=session, email=email) + if not user: + logger.info(f"[google_auth] No account found for email: {email}") + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="No account found for this Google email. Please Contact Support to add your account.", + ) + + # Activate user on first Google login + if not user.is_active: + user.is_active = True + session.add(user) + session.commit() + session.refresh(user) + logger.info(f"[google_auth] User activated on first login | user_id: {user.id}") + + google_profile = { + "email": idinfo.get("email"), + "name": idinfo.get("name"), + "picture": idinfo.get("picture"), + "given_name": idinfo.get("given_name"), + "family_name": idinfo.get("family_name"), + } + + available_projects = get_user_accessible_projects(session=session, user_id=user.id) + + if len(available_projects) == 1: + proj = available_projects[0] + logger.info( + f"[google_auth] User authenticated via Google (auto-selected project) | user_id: {user.id}" + ) + return build_google_auth_response( + user=user, + google_profile=google_profile, + available_projects=available_projects, + organization_id=proj["organization_id"], + project_id=proj["project_id"], + ) + elif len(available_projects) > 1: + logger.info( + f"[google_auth] User authenticated via Google (requires project selection) | user_id: {user.id}" + ) + return build_google_auth_response( + user=user, + google_profile=google_profile, + available_projects=available_projects, + requires_project_selection=True, + ) + else: + logger.info( + f"[google_auth] User authenticated via Google (no projects) | user_id: {user.id}" + ) + return build_google_auth_response( + user=user, + google_profile=google_profile, + available_projects=[], + ) + + +@router.post( + "/select-project", + response_model=APIResponse[Token], +) +def select_project( + session: SessionDep, + auth_context: AuthContextDep, + body: SelectProjectRequest, +) -> JSONResponse: + """Select a project and get a new JWT token with org/project embedded.""" + + user = auth_context.user + + available_projects = get_user_accessible_projects(session=session, user_id=user.id) + matching = [p for p in available_projects if p["project_id"] == body.project_id] + + if not matching: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="You do not have access to this project", + ) + + proj = matching[0] + response = build_token_response( + user_id=user.id, + organization_id=proj["organization_id"], + project_id=proj["project_id"], + ) + + logger.info( + f"[select_project] Project selected | user_id: {user.id}, project_id: {body.project_id}" + ) + return response + + +@router.post( + "/refresh", + response_model=APIResponse[Token], +) +def refresh_access_token(request: Request, session: SessionDep) -> JSONResponse: + """Use a refresh token to get a new access token without re-authenticating.""" + + refresh_token = request.cookies.get("refresh_token") + if not refresh_token: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Refresh token not found", + ) + + user, token_data = validate_refresh_token(session, refresh_token) + + response = build_token_response( + user_id=user.id, + organization_id=token_data.org_id, + project_id=token_data.project_id, + ) + + logger.info(f"[refresh_access_token] Token refreshed | user_id: {user.id}") + return response + + +@router.post( + "/logout", + response_model=APIResponse[Message], +) +def logout() -> JSONResponse: + """Clear auth cookies to log the user out.""" + api_response = APIResponse.success_response( + data=Message(message="Logged out successfully") + ) + response = JSONResponse(content=api_response.model_dump()) + clear_auth_cookies(response) + return response diff --git a/backend/app/api/routes/project.py b/backend/app/api/routes/project.py index c8c50738b..71fcf50ee 100644 --- a/backend/app/api/routes/project.py +++ b/backend/app/api/routes/project.py @@ -85,11 +85,11 @@ def update_project(*, session: SessionDep, project_id: int, project_in: ProjectU raise HTTPException(status_code=404, detail="Project not found") project_data = project_in.model_dump(exclude_unset=True) - project = project.model_copy(update=project_data) + project.sqlmodel_update(project_data) session.add(project) session.commit() - session.flush() + session.refresh(project) logger.info( f"[update_project] Project updated successfully | project_id={project.id}" ) diff --git a/backend/app/api/routes/user_project.py b/backend/app/api/routes/user_project.py new file mode 100644 index 000000000..5052761f7 --- /dev/null +++ b/backend/app/api/routes/user_project.py @@ -0,0 +1,137 @@ +import logging +from typing import Any + +from fastapi import APIRouter, Depends, HTTPException, status + +from app.api.deps import AuthContextDep, SessionDep +from app.api.permissions import Permission, require_permission +from app.crud.user_project import ( + add_user_to_project, + get_users_by_project, + remove_user_from_project, +) +from app.models import ( + AddUsersToProjectRequest, + Message, + UserProjectPublic, +) +from app.utils import APIResponse, load_description + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/user-projects", tags=["User Projects"]) + + +@router.get( + "/", + description=load_description("user_project/list.md"), + response_model=APIResponse[list[UserProjectPublic]], +) +def list_project_users( + session: SessionDep, + auth_context: AuthContextDep, + project_id: int, +) -> Any: + """List all users in a project.""" + users = get_users_by_project(session=session, project_id=project_id) + return APIResponse.success_response(data=users) + + +@router.post( + "/", + dependencies=[Depends(require_permission(Permission.SUPERUSER))], + description=load_description("user_project/add.md"), + response_model=APIResponse[list[UserProjectPublic]], + status_code=status.HTTP_201_CREATED, +) +def add_project_users( + session: SessionDep, + body: AddUsersToProjectRequest, +) -> Any: + """Add one or more users to a project by email.""" + same_project_emails = [] + different_project_emails = [] + + for entry in body.users: + _, add_status = add_user_to_project( + session=session, + email=str(entry.email), + organization_id=body.organization_id, + project_id=body.project_id, + full_name=entry.full_name, + ) + if add_status == "same_project": + same_project_emails.append(str(entry.email)) + elif add_status == "different_project": + different_project_emails.append(str(entry.email)) + + if same_project_emails or different_project_emails: + session.rollback() + errors = [] + if same_project_emails: + errors.append( + f"Already added to this project: {', '.join(same_project_emails)}" + ) + if different_project_emails: + errors.append( + f"Already assigned to another project: {', '.join(different_project_emails)}" + ) + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail="; ".join(errors), + ) + + session.commit() + + # Re-fetch all users for this project to return the full list + results = get_users_by_project(session=session, project_id=body.project_id) + + logger.info( + f"[add_project_users] Users added to project | " + f"project_id: {body.project_id}, count: {len(body.users)}" + ) + + return APIResponse.success_response(data=results) + + +@router.delete( + "/{user_id}", + dependencies=[Depends(require_permission(Permission.SUPERUSER))], + description=load_description("user_project/delete.md"), + response_model=APIResponse[Message], +) +def delete_project_user( + session: SessionDep, + auth_context: AuthContextDep, + user_id: int, + project_id: int, +) -> Any: + """Remove a user from a project.""" + if user_id == auth_context.user.id: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="You cannot remove yourself from the project", + ) + + removed = remove_user_from_project( + session=session, + user_id=user_id, + project_id=project_id, + ) + + if not removed: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="User not found in this project", + ) + + session.commit() + + logger.info( + f"[delete_project_user] User removed from project | " + f"user_id: {user_id}, project_id: {project_id}" + ) + + return APIResponse.success_response( + data=Message(message="User removed from project successfully") + ) diff --git a/backend/app/core/config.py b/backend/app/core/config.py index 44a7d7771..49cf7e5f6 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -37,6 +37,8 @@ class Settings(BaseSettings): SECRET_KEY: str = secrets.token_urlsafe(32) # 60 minutes * 24 hours * 1 days = 1 days ACCESS_TOKEN_EXPIRE_MINUTES: int = 60 * 24 * 1 + # 60 minutes * 24 hours * 7 days = 7 days + REFRESH_TOKEN_EXPIRE_MINUTES: int = 60 * 24 * 7 ENVIRONMENT: Literal[ "development", "testing", "staging", "production" ] = "development" @@ -52,6 +54,9 @@ class Settings(BaseSettings): KAAPI_GUARDRAILS_AUTH: str = "" KAAPI_GUARDRAILS_URL: str = "" + # Google OAuth + GOOGLE_CLIENT_ID: str = "" + @computed_field # type: ignore[prop-decorator] @property def SQLALCHEMY_DATABASE_URI(self) -> PostgresDsn: diff --git a/backend/app/core/security.py b/backend/app/core/security.py index 8cee6e982..e5f6ac3f8 100644 --- a/backend/app/core/security.py +++ b/backend/app/core/security.py @@ -67,19 +67,57 @@ def get_fernet() -> Fernet: return _fernet -def create_access_token(subject: str | Any, expires_delta: timedelta) -> str: +def create_access_token( + subject: str | Any, + expires_delta: timedelta, + organization_id: int | None = None, + project_id: int | None = None, +) -> str: """ Create a JWT access token. Args: subject: The subject of the token (typically user ID) expires_delta: Token expiration time delta + organization_id: Optional organization ID to embed in the token + project_id: Optional project ID to embed in the token Returns: str: Encoded JWT token """ expire = datetime.now(timezone.utc) + expires_delta - to_encode = {"exp": expire, "sub": str(subject)} + to_encode: dict[str, Any] = {"exp": expire, "sub": str(subject), "type": "access"} + if organization_id is not None: + to_encode["org_id"] = organization_id + if project_id is not None: + to_encode["project_id"] = project_id + return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=ALGORITHM) + + +def create_refresh_token( + subject: str | Any, + expires_delta: timedelta, + organization_id: int | None = None, + project_id: int | None = None, +) -> str: + """ + Create a JWT refresh token. + + Args: + subject: The subject of the token (typically user ID) + expires_delta: Token expiration time delta + organization_id: Optional organization ID to embed in the token + project_id: Optional project ID to embed in the token + + Returns: + str: Encoded JWT refresh token + """ + expire = datetime.now(timezone.utc) + expires_delta + to_encode: dict[str, Any] = {"exp": expire, "sub": str(subject), "type": "refresh"} + if organization_id is not None: + to_encode["org_id"] = organization_id + if project_id is not None: + to_encode["project_id"] = project_id return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=ALGORITHM) diff --git a/backend/app/crud/__init__.py b/backend/app/crud/__init__.py index 9baa5defd..d15e5df89 100644 --- a/backend/app/crud/__init__.py +++ b/backend/app/crud/__init__.py @@ -81,6 +81,13 @@ from .onboarding import onboard_project +from .user_project import ( + add_user_to_project, + get_user_projects, + get_users_by_project, + remove_user_from_project, +) + from .file import ( create_file, get_file_by_id, diff --git a/backend/app/crud/auth.py b/backend/app/crud/auth.py new file mode 100644 index 000000000..39147b86e --- /dev/null +++ b/backend/app/crud/auth.py @@ -0,0 +1,63 @@ +import logging + +from sqlmodel import Session, and_, select + +from app.models import ( + APIKey, + Organization, + Project, + UserProject, +) + +logger = logging.getLogger(__name__) + + +def get_user_accessible_projects(*, session: Session, user_id: int) -> list[dict]: + """ + Query distinct org/project pairs for a user from both + the UserProject table and the APIKey table (backward compatibility). + """ + # Query from UserProject table + from_user_project = ( + select(Organization.id, Organization.name, Project.id, Project.name) + .select_from(UserProject) + .join(Organization, Organization.id == UserProject.organization_id) + .join(Project, Project.id == UserProject.project_id) + .where( + and_( + UserProject.user_id == user_id, + Organization.is_active.is_(True), + Project.is_active.is_(True), + ) + ) + ) + + # Query from APIKey table (backward compatibility) + from_api_key = ( + select(Organization.id, Organization.name, Project.id, Project.name) + .select_from(APIKey) + .join(Organization, Organization.id == APIKey.organization_id) + .join(Project, Project.id == APIKey.project_id) + .where( + and_( + APIKey.user_id == user_id, + APIKey.is_deleted.is_(False), + Organization.is_active.is_(True), + Project.is_active.is_(True), + ) + ) + ) + + # Union both queries and deduplicate + combined = from_user_project.union(from_api_key) + results = session.exec(combined).all() + + return [ + { + "organization_id": org_id, + "organization_name": org_name, + "project_id": proj_id, + "project_name": proj_name, + } + for org_id, org_name, proj_id, proj_name in results + ] diff --git a/backend/app/crud/user_project.py b/backend/app/crud/user_project.py new file mode 100644 index 000000000..fec57c01d --- /dev/null +++ b/backend/app/crud/user_project.py @@ -0,0 +1,138 @@ +import logging +import secrets +from typing import Sequence + +from sqlmodel import Session, and_, select + +from app.core.security import get_password_hash +from app.models import ( + User, + UserProject, + UserProjectPublic, +) + +logger = logging.getLogger(__name__) + + +def get_users_by_project( + *, session: Session, project_id: int +) -> list[UserProjectPublic]: + """Get all users mapped to a project.""" + statement = ( + select( + User.id, User.email, User.full_name, User.is_active, UserProject.inserted_at + ) + .join(UserProject, UserProject.user_id == User.id) + .where(UserProject.project_id == project_id) + .order_by(UserProject.inserted_at.desc()) + ) + results = session.exec(statement).all() + return [ + UserProjectPublic( + user_id=user_id, + email=email, + full_name=full_name, + is_active=is_active, + inserted_at=inserted_at, + ) + for user_id, email, full_name, is_active, inserted_at in results + ] + + +def add_user_to_project( + *, + session: Session, + email: str, + organization_id: int, + project_id: int, + full_name: str | None = None, +) -> tuple[User, str]: + """ + Add a user to a project. Creates the user if they don't exist (is_active=False). + + Returns: + Tuple of (user, status) where status is one of: + - "added": User was successfully added to the project + - "same_project": User is already in this project + - "different_project": User is already assigned to another project + """ + user = session.exec(select(User).where(User.email == email)).first() + + if not user: + user = User( + email=email, + full_name=full_name, + is_active=False, + hashed_password=get_password_hash(secrets.token_urlsafe(16)), + ) + session.add(user) + session.flush() + elif full_name and not user.full_name: + user.full_name = full_name + session.add(user) + session.flush() + + # Check if user is already assigned to any project + existing = session.exec( + select(UserProject).where(UserProject.user_id == user.id) + ).first() + + if existing: + if existing.project_id == project_id: + return user, "same_project" + else: + return user, "different_project" + + user_project = UserProject( + user_id=user.id, + organization_id=organization_id, + project_id=project_id, + ) + session.add(user_project) + session.flush() + + return user, "added" + + +def remove_user_from_project( + *, session: Session, user_id: int, project_id: int +) -> bool: + """ + Remove a user from a project. If this was their last project, + deactivate the user account. + + Returns True if removed, False if not found. + """ + user_project = session.exec( + select(UserProject).where( + and_( + UserProject.user_id == user_id, + UserProject.project_id == project_id, + ) + ) + ).first() + + if not user_project: + return False + + session.delete(user_project) + session.flush() + + # Check if user has any remaining projects + remaining = session.exec( + select(UserProject.id).where(UserProject.user_id == user_id).limit(1) + ).first() + + if not remaining: + user = session.get(User, user_id) + if user and not user.is_superuser: + session.delete(user) + session.flush() + + return True + + +def get_user_projects(*, session: Session, user_id: int) -> Sequence[UserProject]: + """Get all project mappings for a user.""" + statement = select(UserProject).where(UserProject.user_id == user_id) + return session.exec(statement).all() diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index b5cb3f0c6..9341401aa 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -1,6 +1,13 @@ from sqlmodel import SQLModel -from .auth import AuthContext, Token, TokenPayload +from .auth import ( + AuthContext, + GoogleAuthRequest, + GoogleAuthResponse, + SelectProjectRequest, + Token, + TokenPayload, +) from .api_key import ( APIKey, @@ -173,3 +180,10 @@ UsersPublic, UpdatePassword, ) + +from .user_project import ( + UserProject, + AddUsersToProjectRequest, + UserEntry, + UserProjectPublic, +) diff --git a/backend/app/models/auth.py b/backend/app/models/auth.py index 26b42ef8a..bfbda1f6a 100644 --- a/backend/app/models/auth.py +++ b/backend/app/models/auth.py @@ -1,5 +1,5 @@ from sqlmodel import Field, SQLModel -from app.models.user import User +from app.models.user import User, UserPublic from app.models.organization import Organization from app.models.project import Project from typing import TYPE_CHECKING @@ -14,6 +14,27 @@ class Token(SQLModel): # Contents of JWT token class TokenPayload(SQLModel): sub: str | None = None + org_id: int | None = None + project_id: int | None = None + type: str = "access" + + +# Google OAuth +class GoogleAuthRequest(SQLModel): + token: str + + +class GoogleAuthResponse(SQLModel): + access_token: str + token_type: str = "bearer" + user: UserPublic + google_profile: dict + requires_project_selection: bool = False + available_projects: list[dict] = [] + + +class SelectProjectRequest(SQLModel): + project_id: int class AuthContext(SQLModel): diff --git a/backend/app/models/user_project.py b/backend/app/models/user_project.py new file mode 100644 index 000000000..c231c6d0f --- /dev/null +++ b/backend/app/models/user_project.py @@ -0,0 +1,74 @@ +from datetime import datetime + +from pydantic import EmailStr +from sqlmodel import Field, SQLModel, UniqueConstraint + +from app.core.util import now + + +class UserProjectBase(SQLModel): + """Base model for user-project mapping.""" + + user_id: int = Field( + foreign_key="user.id", + nullable=False, + ondelete="CASCADE", + sa_column_kwargs={"comment": "Reference to the user"}, + ) + organization_id: int = Field( + foreign_key="organization.id", + nullable=False, + ondelete="CASCADE", + sa_column_kwargs={"comment": "Reference to the organization"}, + ) + project_id: int = Field( + foreign_key="project.id", + nullable=False, + ondelete="CASCADE", + sa_column_kwargs={"comment": "Reference to the project"}, + ) + + +class UserProject(UserProjectBase, table=True): + """Maps users to projects within organizations.""" + + __tablename__ = "user_project" + __table_args__ = ( + UniqueConstraint("user_id", "project_id", name="uq_user_project"), + ) + + id: int = Field( + default=None, + primary_key=True, + sa_column_kwargs={"comment": "Unique identifier for the user-project mapping"}, + ) + inserted_at: datetime = Field( + default_factory=now, + nullable=False, + sa_column_kwargs={"comment": "Timestamp when the mapping was created"}, + ) + + +class UserEntry(SQLModel): + """A single user entry with email and optional name.""" + + email: EmailStr + full_name: str | None = Field(default=None, max_length=255) + + +class AddUsersToProjectRequest(SQLModel): + """Request to add one or more users to a project.""" + + organization_id: int + project_id: int + users: list[UserEntry] = Field(min_length=1) + + +class UserProjectPublic(SQLModel): + """Public response model for a user in a project.""" + + user_id: int + email: str + full_name: str | None + is_active: bool + inserted_at: datetime diff --git a/backend/app/services/auth.py b/backend/app/services/auth.py new file mode 100644 index 000000000..5c266092e --- /dev/null +++ b/backend/app/services/auth.py @@ -0,0 +1,178 @@ +import logging +from datetime import timedelta + +import jwt as pyjwt +from fastapi import HTTPException, status +from fastapi.responses import JSONResponse +from jwt.exceptions import ExpiredSignatureError, InvalidTokenError +from sqlmodel import Session + +from app.core import security +from app.core.config import settings +from app.models import ( + GoogleAuthResponse, + Token, + TokenPayload, + User, + UserPublic, +) +from app.utils import APIResponse + +logger = logging.getLogger(__name__) + + +def create_token_pair( + user_id: int, + organization_id: int | None = None, + project_id: int | None = None, +) -> tuple[str, str]: + """Create an access token and refresh token pair.""" + access_token = security.create_access_token( + user_id, + expires_delta=timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES), + organization_id=organization_id, + project_id=project_id, + ) + refresh_token = security.create_refresh_token( + user_id, + expires_delta=timedelta(minutes=settings.REFRESH_TOKEN_EXPIRE_MINUTES), + organization_id=organization_id, + project_id=project_id, + ) + return access_token, refresh_token + + +def set_auth_cookies( + response: JSONResponse, + access_token: str, + refresh_token: str, +) -> None: + """Set access_token and refresh_token as HTTP-only cookies on the response.""" + is_secure = settings.ENVIRONMENT in ("staging", "production") + + response.set_cookie( + key="access_token", + value=access_token, + httponly=True, + secure=is_secure, + samesite="lax", + max_age=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60, + path="/", + ) + response.set_cookie( + key="refresh_token", + value=refresh_token, + httponly=True, + secure=is_secure, + samesite="lax", + max_age=settings.REFRESH_TOKEN_EXPIRE_MINUTES * 60, + path="/", + ) + + +def clear_auth_cookies(response: JSONResponse) -> None: + """Clear access_token and refresh_token cookies from the response.""" + is_secure = settings.ENVIRONMENT in ("staging", "production") + + response.delete_cookie( + key="access_token", + httponly=True, + secure=is_secure, + samesite="lax", + path="/", + ) + response.delete_cookie( + key="refresh_token", + httponly=True, + secure=is_secure, + samesite="lax", + path="/", + ) + + +def build_google_auth_response( + user: User, + google_profile: dict, + available_projects: list[dict], + organization_id: int | None = None, + project_id: int | None = None, + requires_project_selection: bool = False, +) -> JSONResponse: + """Create JWT token pair, build Google auth response, and set cookies.""" + access_token, refresh_token = create_token_pair( + user.id, + organization_id=organization_id, + project_id=project_id, + ) + + response_data = GoogleAuthResponse( + access_token=access_token, + user=UserPublic.model_validate(user), + google_profile=google_profile, + requires_project_selection=requires_project_selection, + available_projects=available_projects, + ) + + api_response = APIResponse.success_response(data=response_data) + response = JSONResponse(content=api_response.model_dump()) + set_auth_cookies(response, access_token, refresh_token) + return response + + +def build_token_response( + user_id: int, + organization_id: int | None = None, + project_id: int | None = None, +) -> JSONResponse: + """Create JWT token pair, build token response, and set cookies.""" + access_token, refresh_token = create_token_pair( + user_id, + organization_id=organization_id, + project_id=project_id, + ) + + api_response = APIResponse.success_response(data=Token(access_token=access_token)) + response = JSONResponse(content=api_response.model_dump()) + set_auth_cookies(response, access_token, refresh_token) + return response + + +def validate_refresh_token( + session: Session, refresh_token_value: str +) -> tuple[User, TokenPayload]: + """ + Validate a refresh token and return the user and token data. + + Raises HTTPException on invalid/expired token or inactive user. + """ + try: + payload = pyjwt.decode( + refresh_token_value, + settings.SECRET_KEY, + algorithms=[security.ALGORITHM], + ) + token_data = TokenPayload(**payload) + except ExpiredSignatureError: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Refresh token has expired. Please login again.", + ) + except InvalidTokenError: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid refresh token", + ) + + if token_data.type != "refresh": + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid token type", + ) + + user = session.get(User, token_data.sub) + if not user: + raise HTTPException(status_code=404, detail="User not found") + if not user.is_active: + raise HTTPException(status_code=403, detail="Inactive user") + + return user, token_data diff --git a/backend/app/tests/api/test_auth.py b/backend/app/tests/api/test_auth.py new file mode 100644 index 000000000..1ecafcd3c --- /dev/null +++ b/backend/app/tests/api/test_auth.py @@ -0,0 +1,300 @@ +from datetime import timedelta +from unittest.mock import patch + +from fastapi.testclient import TestClient +from sqlmodel import Session + +from app.core.config import settings +from app.core.security import create_access_token, create_refresh_token +from app.tests.utils.auth import TestAuthContext +from app.tests.utils.user import create_random_user + +GOOGLE_AUTH_URL = f"{settings.API_V1_STR}/auth/google" +SELECT_PROJECT_URL = f"{settings.API_V1_STR}/auth/select-project" +REFRESH_URL = f"{settings.API_V1_STR}/auth/refresh" +LOGOUT_URL = f"{settings.API_V1_STR}/auth/logout" + +MOCK_GOOGLE_PROFILE = { + "email": None, # set per test + "email_verified": True, + "name": "Test User", + "picture": "https://example.com/photo.jpg", + "given_name": "Test", + "family_name": "User", +} + + +def _mock_idinfo(email: str, email_verified: bool = True) -> dict: + return {**MOCK_GOOGLE_PROFILE, "email": email, "email_verified": email_verified} + + +class TestGoogleAuth: + """Test suite for POST /auth/google endpoint.""" + + @patch("app.api.routes.auth.settings") + def test_google_auth_not_configured(self, mock_settings, client: TestClient): + """Test returns 500 when GOOGLE_CLIENT_ID is not set.""" + mock_settings.GOOGLE_CLIENT_ID = "" + resp = client.post(GOOGLE_AUTH_URL, json={"token": "fake"}) + assert resp.status_code == 500 + assert "not configured" in resp.json()["error"] + + @patch("app.api.routes.auth.id_token.verify_oauth2_token") + @patch("app.api.routes.auth.settings") + def test_google_auth_invalid_token( + self, mock_settings, mock_verify, client: TestClient + ): + """Test returns 400 for invalid Google token.""" + mock_settings.GOOGLE_CLIENT_ID = "test-client-id" + mock_settings.ACCESS_TOKEN_EXPIRE_MINUTES = 1440 + mock_settings.REFRESH_TOKEN_EXPIRE_MINUTES = 10080 + mock_settings.ENVIRONMENT = "testing" + mock_settings.API_V1_STR = settings.API_V1_STR + mock_verify.side_effect = ValueError("Invalid token") + + resp = client.post(GOOGLE_AUTH_URL, json={"token": "bad-token"}) + assert resp.status_code == 400 + assert "Invalid or expired" in resp.json()["error"] + + @patch("app.api.routes.auth.id_token.verify_oauth2_token") + @patch("app.api.routes.auth.settings") + def test_google_auth_unverified_email( + self, mock_settings, mock_verify, client: TestClient + ): + """Test returns 400 when Google email is not verified.""" + mock_settings.GOOGLE_CLIENT_ID = "test-client-id" + mock_verify.return_value = _mock_idinfo( + "test@example.com", email_verified=False + ) + + resp = client.post(GOOGLE_AUTH_URL, json={"token": "fake"}) + assert resp.status_code == 400 + assert "not verified" in resp.json()["error"] + + @patch("app.api.routes.auth.id_token.verify_oauth2_token") + @patch("app.api.routes.auth.settings") + def test_google_auth_user_not_found( + self, mock_settings, mock_verify, client: TestClient + ): + """Test returns 401 when no user exists for the email.""" + mock_settings.GOOGLE_CLIENT_ID = "test-client-id" + mock_verify.return_value = _mock_idinfo("nonexistent@example.com") + + resp = client.post(GOOGLE_AUTH_URL, json={"token": "fake"}) + assert resp.status_code == 401 + assert "No account found" in resp.json()["error"] + + @patch("app.api.routes.auth.id_token.verify_oauth2_token") + @patch("app.api.routes.auth.settings") + def test_google_auth_activates_inactive_user( + self, mock_settings, mock_verify, db: Session, client: TestClient + ): + """Test that inactive user is activated on first Google login.""" + user = create_random_user(db) + user.is_active = False + db.add(user) + db.commit() + db.refresh(user) + + mock_settings.GOOGLE_CLIENT_ID = "test-client-id" + mock_settings.ACCESS_TOKEN_EXPIRE_MINUTES = 1440 + mock_settings.REFRESH_TOKEN_EXPIRE_MINUTES = 10080 + mock_settings.ENVIRONMENT = "testing" + mock_settings.API_V1_STR = settings.API_V1_STR + mock_settings.SECRET_KEY = settings.SECRET_KEY + mock_verify.return_value = _mock_idinfo(user.email) + + resp = client.post(GOOGLE_AUTH_URL, json={"token": "fake"}) + assert resp.status_code == 200 + + db.refresh(user) + assert user.is_active is True + + @patch("app.api.routes.auth.id_token.verify_oauth2_token") + @patch("app.api.routes.auth.settings") + def test_google_auth_success_no_projects( + self, mock_settings, mock_verify, db: Session, client: TestClient + ): + """Test successful login for user with no projects.""" + user = create_random_user(db) + + mock_settings.GOOGLE_CLIENT_ID = "test-client-id" + mock_settings.ACCESS_TOKEN_EXPIRE_MINUTES = 1440 + mock_settings.REFRESH_TOKEN_EXPIRE_MINUTES = 10080 + mock_settings.ENVIRONMENT = "testing" + mock_settings.API_V1_STR = settings.API_V1_STR + mock_settings.SECRET_KEY = settings.SECRET_KEY + mock_verify.return_value = _mock_idinfo(user.email) + + resp = client.post(GOOGLE_AUTH_URL, json={"token": "fake"}) + assert resp.status_code == 200 + + body = resp.json() + assert body["success"] is True + data = body["data"] + assert "access_token" in data + assert data["requires_project_selection"] is False + assert data["available_projects"] == [] + assert "access_token" in resp.cookies + + @patch("app.api.routes.auth.id_token.verify_oauth2_token") + @patch("app.api.routes.auth.settings") + def test_google_auth_success_single_project_via_api_key( + self, + mock_settings, + mock_verify, + db: Session, + client: TestClient, + user_api_key: TestAuthContext, + ): + """Test successful login auto-selects single project from API key.""" + mock_settings.GOOGLE_CLIENT_ID = "test-client-id" + mock_settings.ACCESS_TOKEN_EXPIRE_MINUTES = 1440 + mock_settings.REFRESH_TOKEN_EXPIRE_MINUTES = 10080 + mock_settings.ENVIRONMENT = "testing" + mock_settings.API_V1_STR = settings.API_V1_STR + mock_settings.SECRET_KEY = settings.SECRET_KEY + mock_verify.return_value = _mock_idinfo(user_api_key.user.email) + + resp = client.post(GOOGLE_AUTH_URL, json={"token": "fake"}) + assert resp.status_code == 200 + + data = resp.json()["data"] + assert data["requires_project_selection"] is False + assert len(data["available_projects"]) == 1 + + +class TestSelectProject: + """Test suite for POST /auth/select-project endpoint.""" + + def test_select_project_unauthenticated(self, client: TestClient): + """Test returns 401 when not authenticated.""" + resp = client.post(SELECT_PROJECT_URL, json={"project_id": 1}) + assert resp.status_code == 401 + + def test_select_project_no_access( + self, client: TestClient, normal_user_token_headers: dict[str, str] + ): + """Test returns 403 when user has no access to the project.""" + resp = client.post( + SELECT_PROJECT_URL, + json={"project_id": 99999}, + headers=normal_user_token_headers, + ) + assert resp.status_code == 403 + assert "do not have access" in resp.json()["error"] + + def test_select_project_success( + self, + db: Session, + client: TestClient, + user_api_key: TestAuthContext, + normal_user_token_headers: dict[str, str], + ): + """Test successful project selection returns new token with cookies.""" + resp = client.post( + SELECT_PROJECT_URL, + json={"project_id": user_api_key.project.id}, + headers=normal_user_token_headers, + ) + assert resp.status_code == 200 + + body = resp.json() + assert body["success"] is True + assert "access_token" in body["data"] + assert "access_token" in resp.cookies + + +class TestRefreshToken: + """Test suite for POST /auth/refresh endpoint.""" + + def test_refresh_no_cookie(self, client: TestClient): + """Test returns 401 when no refresh token cookie is present.""" + resp = client.post(REFRESH_URL) + assert resp.status_code == 401 + assert "not found" in resp.json()["error"] + + def test_refresh_with_access_token_instead(self, db: Session, client: TestClient): + """Test returns 401 when access token is used instead of refresh token.""" + user = create_random_user(db) + access_token = create_access_token( + subject=str(user.id), expires_delta=timedelta(minutes=30) + ) + client.cookies.set("refresh_token", access_token) + + resp = client.post(REFRESH_URL) + assert resp.status_code == 401 + assert "Invalid token type" in resp.json()["error"] + + def test_refresh_with_expired_token(self, db: Session, client: TestClient): + """Test returns 401 when refresh token is expired.""" + user = create_random_user(db) + expired_refresh = create_refresh_token( + subject=str(user.id), expires_delta=timedelta(minutes=-1) + ) + client.cookies.set("refresh_token", expired_refresh) + + resp = client.post(REFRESH_URL) + assert resp.status_code == 401 + assert "expired" in resp.json()["error"] + + def test_refresh_success(self, db: Session, client: TestClient): + """Test successful refresh returns new tokens.""" + user = create_random_user(db) + refresh_token = create_refresh_token( + subject=str(user.id), expires_delta=timedelta(days=7) + ) + client.cookies.set("refresh_token", refresh_token) + + resp = client.post(REFRESH_URL) + assert resp.status_code == 200 + + body = resp.json() + assert body["success"] is True + assert "access_token" in body["data"] + assert "access_token" in resp.cookies + + def test_refresh_with_org_project( + self, db: Session, client: TestClient, user_api_key: TestAuthContext + ): + """Test refresh preserves org/project claims.""" + refresh_token = create_refresh_token( + subject=str(user_api_key.user.id), + expires_delta=timedelta(days=7), + organization_id=user_api_key.organization.id, + project_id=user_api_key.project.id, + ) + client.cookies.set("refresh_token", refresh_token) + + resp = client.post(REFRESH_URL) + assert resp.status_code == 200 + assert "access_token" in resp.json()["data"] + + def test_refresh_inactive_user(self, db: Session, client: TestClient): + """Test returns 403 when user is inactive.""" + user = create_random_user(db) + refresh_token = create_refresh_token( + subject=str(user.id), expires_delta=timedelta(days=7) + ) + + user.is_active = False + db.add(user) + db.commit() + + client.cookies.set("refresh_token", refresh_token) + + resp = client.post(REFRESH_URL) + assert resp.status_code == 403 + + +class TestLogout: + """Test suite for POST /auth/logout endpoint.""" + + def test_logout_success(self, client: TestClient): + """Test logout returns success response and clears cookies.""" + resp = client.post(LOGOUT_URL) + assert resp.status_code == 200 + + body = resp.json() + assert body["success"] is True + assert body["data"]["message"] == "Logged out successfully" diff --git a/backend/app/tests/api/test_deps.py b/backend/app/tests/api/test_deps.py index 5824f898e..ddb09cb2e 100644 --- a/backend/app/tests/api/test_deps.py +++ b/backend/app/tests/api/test_deps.py @@ -1,3 +1,5 @@ +from unittest.mock import MagicMock + import pytest from sqlmodel import Session from fastapi import HTTPException @@ -11,9 +13,17 @@ from app.tests.utils.auth import TestAuthContext from app.tests.utils.user import authentication_token_from_email, create_random_user from app.core.config import settings +from app.core.security import create_access_token, create_refresh_token from app.tests.utils.test_data import create_test_api_key +def _mock_request(cookies: dict | None = None) -> MagicMock: + """Create a mock Request object with optional cookies.""" + request = MagicMock() + request.cookies = cookies or {} + return request + + class TestGetAuthContext: """Test suite for get_auth_context function""" @@ -22,6 +32,7 @@ def test_get_auth_context_with_valid_api_key( ) -> None: """Test successful authentication with valid API key""" auth_context = get_auth_context( + request=_mock_request(), session=db, token=None, api_key=user_api_key.key, @@ -33,18 +44,19 @@ def test_get_auth_context_with_valid_api_key( assert auth_context.organization == user_api_key.organization def test_get_auth_context_with_invalid_api_key(self, db: Session) -> None: - """Test authentication fails with invalid API key""" + """Test authentication fails with invalid API key when no other auth is provided""" invalid_api_key = "ApiKey InvalidKeyThatDoesNotExist123456789" with pytest.raises(HTTPException) as exc_info: get_auth_context( + request=_mock_request(), session=db, token=None, api_key=invalid_api_key, ) assert exc_info.value.status_code == 401 - assert exc_info.value.detail == "Invalid API Key" + assert exc_info.value.detail == "Invalid Authorization format" def test_get_auth_context_with_valid_token( self, db: Session, normal_user_token_headers: dict[str, str] @@ -52,6 +64,7 @@ def test_get_auth_context_with_valid_token( """Test successful authentication with valid token""" token = normal_user_token_headers["Authorization"].replace("Bearer ", "") auth_context = get_auth_context( + request=_mock_request(), session=db, token=token, api_key=None, @@ -67,6 +80,7 @@ def test_get_auth_context_with_invalid_token(self, db: Session) -> None: with pytest.raises(HTTPException) as exc_info: get_auth_context( + request=_mock_request(), session=db, token=invalid_token, api_key=None, @@ -78,6 +92,7 @@ def test_get_auth_context_with_no_credentials(self, db: Session) -> None: """Test authentication fails when neither API key nor token is provided""" with pytest.raises(HTTPException) as exc_info: get_auth_context( + request=_mock_request(), session=db, token=None, api_key=None, @@ -98,6 +113,7 @@ def test_get_auth_context_with_inactive_user_via_api_key(self, db: Session) -> N with pytest.raises(HTTPException) as exc_info: get_auth_context( + request=_mock_request(), session=db, token=None, api_key=api_key.key, @@ -122,13 +138,14 @@ def test_get_auth_context_with_inactive_user_via_token( with pytest.raises(HTTPException) as exc_info: get_auth_context( + request=_mock_request(), session=db, token=token, api_key=None, ) assert exc_info.value.status_code == 403 - assert exc_info.value.detail == "Inactive user" + assert exc_info.value.detail == "User access has been revoked" def test_get_auth_context_with_inactive_organization( self, db: Session, user_api_key: TestAuthContext @@ -142,6 +159,7 @@ def test_get_auth_context_with_inactive_organization( with pytest.raises(HTTPException) as exc_info: get_auth_context( + request=_mock_request(), session=db, token=None, api_key=user_api_key.key, @@ -162,6 +180,7 @@ def test_get_auth_context_with_inactive_project( with pytest.raises(HTTPException) as exc_info: get_auth_context( + request=_mock_request(), session=db, token=None, api_key=user_api_key.key, @@ -169,3 +188,83 @@ def test_get_auth_context_with_inactive_project( assert exc_info.value.status_code == 403 assert exc_info.value.detail == "Inactive Project" + + def test_get_auth_context_with_cookie_token( + self, db: Session, normal_user_token_headers: dict[str, str] + ) -> None: + """Test successful authentication via access_token cookie""" + token = normal_user_token_headers["Authorization"].replace("Bearer ", "") + auth_context = get_auth_context( + request=_mock_request(cookies={"access_token": token}), + session=db, + token=None, + api_key=None, + ) + + assert isinstance(auth_context, AuthContext) + assert auth_context.user.email == settings.EMAIL_TEST_USER + + def test_get_auth_context_with_expired_token(self, db: Session) -> None: + """Test authentication fails with expired token""" + from datetime import timedelta + + expired_token = create_access_token( + subject="1", expires_delta=timedelta(minutes=-1) + ) + + with pytest.raises(HTTPException) as exc_info: + get_auth_context( + request=_mock_request(), + session=db, + token=expired_token, + api_key=None, + ) + + assert exc_info.value.status_code == 401 + assert exc_info.value.detail == "Token has expired" + + def test_get_auth_context_rejects_refresh_token(self, db: Session) -> None: + """Test that refresh tokens are rejected for API access""" + from datetime import timedelta + + refresh_token = create_refresh_token( + subject="1", expires_delta=timedelta(minutes=60) + ) + + with pytest.raises(HTTPException) as exc_info: + get_auth_context( + request=_mock_request(), + session=db, + token=refresh_token, + api_key=None, + ) + + assert exc_info.value.status_code == 401 + assert exc_info.value.detail == "Refresh tokens cannot be used for API access" + + def test_get_auth_context_jwt_with_org_and_project( + self, db: Session, user_api_key: TestAuthContext + ) -> None: + """Test JWT token with org_id and project_id populates AuthContext""" + from datetime import timedelta + + token = create_access_token( + subject=str(user_api_key.user.id), + expires_delta=timedelta(minutes=60), + organization_id=user_api_key.organization.id, + project_id=user_api_key.project.id, + ) + + auth_context = get_auth_context( + request=_mock_request(), + session=db, + token=token, + api_key=None, + ) + + assert isinstance(auth_context, AuthContext) + assert auth_context.user.id == user_api_key.user.id + assert auth_context.organization is not None + assert auth_context.organization.id == user_api_key.organization.id + assert auth_context.project is not None + assert auth_context.project.id == user_api_key.project.id diff --git a/backend/app/tests/api/test_permissions.py b/backend/app/tests/api/test_permissions.py index e08e73617..b8fde3bc9 100644 --- a/backend/app/tests/api/test_permissions.py +++ b/backend/app/tests/api/test_permissions.py @@ -1,3 +1,5 @@ +from unittest.mock import MagicMock + import pytest from fastapi import HTTPException from sqlmodel import Session @@ -8,6 +10,13 @@ from app.tests.utils.test_data import create_test_api_key +def _mock_request() -> MagicMock: + """Create a mock Request object with empty cookies.""" + request = MagicMock() + request.cookies = {} + return request + + class TestHasPermission: """Test suite for has_permission function""" @@ -21,7 +30,10 @@ def test_superuser_permission_with_superuser(self, db: Session) -> None: db.refresh(user) auth_context = get_auth_context( - session=db, token=None, api_key=api_key_response.key + request=_mock_request(), + session=db, + token=None, + api_key=api_key_response.key, ) result = has_permission(auth_context, Permission.SUPERUSER, db) @@ -33,7 +45,10 @@ def test_superuser_permission_with_regular_user(self, db: Session) -> None: api_key_response = create_test_api_key(db) auth_context = get_auth_context( - session=db, token=None, api_key=api_key_response.key + request=_mock_request(), + session=db, + token=None, + api_key=api_key_response.key, ) result = has_permission(auth_context, Permission.SUPERUSER, db) @@ -47,7 +62,10 @@ def test_require_organization_permission_with_organization( api_key_response = create_test_api_key(db) auth_context = get_auth_context( - session=db, token=None, api_key=api_key_response.key + request=_mock_request(), + session=db, + token=None, + api_key=api_key_response.key, ) result = has_permission(auth_context, Permission.REQUIRE_ORGANIZATION, db) @@ -61,7 +79,10 @@ def test_require_organization_permission_without_organization( api_key_response = create_test_api_key(db) auth_context = get_auth_context( - session=db, token=None, api_key=api_key_response.key + request=_mock_request(), + session=db, + token=None, + api_key=api_key_response.key, ) auth_context.organization = None @@ -75,7 +96,10 @@ def test_require_project_permission_with_project(self, db: Session) -> None: api_key_response = create_test_api_key(db) auth_context = get_auth_context( - session=db, token=None, api_key=api_key_response.key + request=_mock_request(), + session=db, + token=None, + api_key=api_key_response.key, ) result = has_permission(auth_context, Permission.REQUIRE_PROJECT, db) @@ -87,7 +111,10 @@ def test_require_project_permission_without_project(self, db: Session) -> None: api_key_response = create_test_api_key(db) auth_context = get_auth_context( - session=db, token=None, api_key=api_key_response.key + request=_mock_request(), + session=db, + token=None, + api_key=api_key_response.key, ) auth_context.project = None @@ -115,7 +142,10 @@ def test_permission_checker_passes_with_valid_permission(self, db: Session) -> N db.commit() db.refresh(user) auth_context = get_auth_context( - session=db, token=None, api_key=api_key_response.key + request=_mock_request(), + session=db, + token=None, + api_key=api_key_response.key, ) permission_checker = require_permission(Permission.SUPERUSER) @@ -127,7 +157,10 @@ def test_permission_checker_raises_403_without_permission( """Test that permission checker raises HTTPException with 403 when user lacks permission""" api_key_response = create_test_api_key(db) auth_context = get_auth_context( - session=db, token=None, api_key=api_key_response.key + request=_mock_request(), + session=db, + token=None, + api_key=api_key_response.key, ) permission_checker = require_permission(Permission.SUPERUSER) diff --git a/backend/app/tests/api/test_user_project.py b/backend/app/tests/api/test_user_project.py new file mode 100644 index 000000000..b984fd060 --- /dev/null +++ b/backend/app/tests/api/test_user_project.py @@ -0,0 +1,237 @@ +from fastapi.testclient import TestClient +from sqlmodel import Session + +from app.core.config import settings +from app.crud.user_project import add_user_to_project +from app.models import UserProject +from app.tests.utils.auth import TestAuthContext +from app.tests.utils.test_data import create_test_project +from app.tests.utils.utils import random_email + +USER_PROJECTS_URL = f"{settings.API_V1_STR}/user-projects" + + +class TestListProjectUsers: + """Test suite for GET /user-projects/""" + + def test_list_returns_empty( + self, + client: TestClient, + superuser_token_headers: dict[str, str], + ): + """Test listing users for a project with no users.""" + resp = client.get( + f"{USER_PROJECTS_URL}/?project_id=99999", + headers=superuser_token_headers, + ) + assert resp.status_code == 200 + assert resp.json()["data"] == [] + + def test_list_returns_users( + self, + db: Session, + client: TestClient, + superuser_token_headers: dict[str, str], + ): + """Test listing users returns mapped users.""" + project = create_test_project(db) + email = random_email() + add_user_to_project( + session=db, + email=email, + organization_id=project.organization_id, + project_id=project.id, + ) + db.commit() + + resp = client.get( + f"{USER_PROJECTS_URL}/?project_id={project.id}", + headers=superuser_token_headers, + ) + assert resp.status_code == 200 + data = resp.json()["data"] + assert len(data) == 1 + assert data[0]["email"] == email + + +class TestAddProjectUsers: + """Test suite for POST /user-projects/""" + + def test_add_user_requires_superuser( + self, + db: Session, + client: TestClient, + normal_user_token_headers: dict[str, str], + ): + """Test non-superuser cannot add users.""" + project = create_test_project(db) + resp = client.post( + f"{USER_PROJECTS_URL}/", + json={ + "organization_id": project.organization_id, + "project_id": project.id, + "users": [{"email": random_email()}], + }, + headers=normal_user_token_headers, + ) + assert resp.status_code == 403 + + def test_add_single_user( + self, + db: Session, + client: TestClient, + superuser_token_headers: dict[str, str], + ): + """Test adding a single user.""" + project = create_test_project(db) + email = random_email() + + resp = client.post( + f"{USER_PROJECTS_URL}/", + json={ + "organization_id": project.organization_id, + "project_id": project.id, + "users": [{"email": email, "full_name": "Test User"}], + }, + headers=superuser_token_headers, + ) + assert resp.status_code == 201 + data = resp.json()["data"] + assert len(data) >= 1 + emails = [u["email"] for u in data] + assert email in emails + + def test_add_duplicate_user_same_project( + self, + db: Session, + client: TestClient, + superuser_token_headers: dict[str, str], + ): + """Test adding same user to same project returns 409.""" + project = create_test_project(db) + email = random_email() + + # Add first time + add_user_to_project( + session=db, + email=email, + organization_id=project.organization_id, + project_id=project.id, + ) + db.commit() + + # Try adding again + resp = client.post( + f"{USER_PROJECTS_URL}/", + json={ + "organization_id": project.organization_id, + "project_id": project.id, + "users": [{"email": email}], + }, + headers=superuser_token_headers, + ) + assert resp.status_code == 409 + assert "Already added to this project" in resp.json()["error"] + + def test_add_user_different_project_returns_409( + self, + db: Session, + client: TestClient, + superuser_token_headers: dict[str, str], + ): + """Test adding user already in another project returns 409.""" + project1 = create_test_project(db) + project2 = create_test_project(db) + email = random_email() + + add_user_to_project( + session=db, + email=email, + organization_id=project1.organization_id, + project_id=project1.id, + ) + db.commit() + + resp = client.post( + f"{USER_PROJECTS_URL}/", + json={ + "organization_id": project2.organization_id, + "project_id": project2.id, + "users": [{"email": email}], + }, + headers=superuser_token_headers, + ) + assert resp.status_code == 409 + assert "Already assigned to another project" in resp.json()["error"] + + +class TestDeleteProjectUser: + """Test suite for DELETE /user-projects/{user_id}""" + + def test_delete_requires_superuser( + self, + client: TestClient, + normal_user_token_headers: dict[str, str], + ): + """Test non-superuser cannot delete users.""" + resp = client.delete( + f"{USER_PROJECTS_URL}/99999?project_id=1", + headers=normal_user_token_headers, + ) + assert resp.status_code == 403 + + def test_delete_nonexistent_user( + self, + client: TestClient, + superuser_token_headers: dict[str, str], + ): + """Test deleting non-existent user returns 404.""" + resp = client.delete( + f"{USER_PROJECTS_URL}/99999?project_id=99999", + headers=superuser_token_headers, + ) + assert resp.status_code == 404 + assert "User not found" in resp.json()["error"] + + def test_delete_user_success( + self, + db: Session, + client: TestClient, + superuser_token_headers: dict[str, str], + ): + """Test successfully removing a user from a project.""" + project = create_test_project(db) + email = random_email() + + user, _ = add_user_to_project( + session=db, + email=email, + organization_id=project.organization_id, + project_id=project.id, + ) + db.commit() + + resp = client.delete( + f"{USER_PROJECTS_URL}/{user.id}?project_id={project.id}", + headers=superuser_token_headers, + ) + assert resp.status_code == 200 + assert "removed" in resp.json()["data"]["message"] + + def test_cannot_delete_self( + self, + db: Session, + client: TestClient, + superuser_token_headers: dict[str, str], + superuser_api_key: TestAuthContext, + ): + """Test superuser cannot remove themselves.""" + project = create_test_project(db) + user_id = superuser_api_key.user.id + + resp = client.delete( + f"{USER_PROJECTS_URL}/{user_id}?project_id={project.id}", + headers=superuser_token_headers, + ) + assert resp.status_code == 400 + assert "cannot remove yourself" in resp.json()["error"] diff --git a/backend/app/tests/core/test_security.py b/backend/app/tests/core/test_security.py index efdebcd42..438bf8b05 100644 --- a/backend/app/tests/core/test_security.py +++ b/backend/app/tests/core/test_security.py @@ -1,6 +1,13 @@ +from datetime import timedelta + +import jwt from sqlmodel import Session +from app.core.config import settings from app.core.security import ( + ALGORITHM, + create_access_token, + create_refresh_token, get_encryption_key, APIKeyManager, ) @@ -190,3 +197,72 @@ def test_generate_creates_verifiable_key(self, db: Session): assert auth_context is not None assert auth_context.user.id == api_key_response.user_id + + +class TestCreateAccessToken: + """Test suite for create_access_token function.""" + + def test_creates_valid_jwt(self): + """Test that a valid JWT is created.""" + token = create_access_token(subject="42", expires_delta=timedelta(minutes=30)) + payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[ALGORITHM]) + + assert payload["sub"] == "42" + assert payload["type"] == "access" + assert "exp" in payload + + def test_includes_org_and_project(self): + """Test that org_id and project_id are embedded in the token.""" + token = create_access_token( + subject="1", + expires_delta=timedelta(minutes=30), + organization_id=10, + project_id=20, + ) + payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[ALGORITHM]) + + assert payload["org_id"] == 10 + assert payload["project_id"] == 20 + + def test_omits_org_and_project_when_none(self): + """Test that org_id and project_id are omitted when not provided.""" + token = create_access_token(subject="1", expires_delta=timedelta(minutes=30)) + payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[ALGORITHM]) + + assert "org_id" not in payload + assert "project_id" not in payload + + +class TestCreateRefreshToken: + """Test suite for create_refresh_token function.""" + + def test_creates_valid_refresh_jwt(self): + """Test that a valid refresh JWT is created.""" + token = create_refresh_token(subject="42", expires_delta=timedelta(days=7)) + payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[ALGORITHM]) + + assert payload["sub"] == "42" + assert payload["type"] == "refresh" + assert "exp" in payload + + def test_includes_org_and_project(self): + """Test that org_id and project_id are embedded in the refresh token.""" + token = create_refresh_token( + subject="1", + expires_delta=timedelta(days=7), + organization_id=10, + project_id=20, + ) + payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[ALGORITHM]) + + assert payload["org_id"] == 10 + assert payload["project_id"] == 20 + assert payload["type"] == "refresh" + + def test_omits_org_and_project_when_none(self): + """Test that org_id and project_id are omitted when not provided.""" + token = create_refresh_token(subject="1", expires_delta=timedelta(days=7)) + payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[ALGORITHM]) + + assert "org_id" not in payload + assert "project_id" not in payload diff --git a/backend/app/tests/crud/test_user_project.py b/backend/app/tests/crud/test_user_project.py new file mode 100644 index 000000000..917686e6b --- /dev/null +++ b/backend/app/tests/crud/test_user_project.py @@ -0,0 +1,226 @@ +import pytest +from sqlmodel import Session + +from app.crud.user_project import ( + add_user_to_project, + get_user_projects, + get_users_by_project, + remove_user_from_project, +) +from app.models import User, UserProject +from app.tests.utils.test_data import create_test_project +from app.tests.utils.user import create_random_user +from app.tests.utils.utils import random_email + + +class TestAddUserToProject: + """Test suite for add_user_to_project CRUD function.""" + + def test_add_new_user_creates_user_and_mapping(self, db: Session): + """Test adding a new email creates user (inactive) and project mapping.""" + project = create_test_project(db) + email = random_email() + + user, add_status = add_user_to_project( + session=db, + email=email, + organization_id=project.organization_id, + project_id=project.id, + ) + + assert add_status == "added" + assert user.email == email + assert user.is_active is False + + def test_add_new_user_with_full_name(self, db: Session): + """Test adding a new user with full_name.""" + project = create_test_project(db) + email = random_email() + + user, add_status = add_user_to_project( + session=db, + email=email, + organization_id=project.organization_id, + project_id=project.id, + full_name="Test User", + ) + + assert add_status == "added" + assert user.full_name == "Test User" + + def test_add_existing_user_updates_full_name(self, db: Session): + """Test adding existing user without full_name updates it.""" + project = create_test_project(db) + user = create_random_user(db) + user.full_name = None + db.add(user) + db.flush() + + returned_user, add_status = add_user_to_project( + session=db, + email=user.email, + organization_id=project.organization_id, + project_id=project.id, + full_name="Updated Name", + ) + + assert add_status == "added" + assert returned_user.full_name == "Updated Name" + + def test_add_user_same_project_returns_same_project(self, db: Session): + """Test adding user already in same project returns same_project.""" + project = create_test_project(db) + email = random_email() + + add_user_to_project( + session=db, + email=email, + organization_id=project.organization_id, + project_id=project.id, + ) + + _, add_status = add_user_to_project( + session=db, + email=email, + organization_id=project.organization_id, + project_id=project.id, + ) + + assert add_status == "same_project" + + def test_add_user_different_project_returns_different_project(self, db: Session): + """Test adding user already in another project returns different_project.""" + project1 = create_test_project(db) + project2 = create_test_project(db) + email = random_email() + + add_user_to_project( + session=db, + email=email, + organization_id=project1.organization_id, + project_id=project1.id, + ) + + _, add_status = add_user_to_project( + session=db, + email=email, + organization_id=project2.organization_id, + project_id=project2.id, + ) + + assert add_status == "different_project" + + +class TestGetUsersByProject: + """Test suite for get_users_by_project CRUD function.""" + + def test_returns_empty_for_project_with_no_users(self, db: Session): + """Test returns empty list when no users are mapped.""" + project = create_test_project(db) + result = get_users_by_project(session=db, project_id=project.id) + assert result == [] + + def test_returns_users_for_project(self, db: Session): + """Test returns mapped users for a project.""" + project = create_test_project(db) + email = random_email() + + add_user_to_project( + session=db, + email=email, + organization_id=project.organization_id, + project_id=project.id, + ) + + result = get_users_by_project(session=db, project_id=project.id) + assert len(result) == 1 + assert result[0].email == email + + +class TestRemoveUserFromProject: + """Test suite for remove_user_from_project CRUD function.""" + + def test_remove_existing_mapping(self, db: Session): + """Test removing a user from a project.""" + project = create_test_project(db) + email = random_email() + + user, _ = add_user_to_project( + session=db, + email=email, + organization_id=project.organization_id, + project_id=project.id, + ) + + removed = remove_user_from_project( + session=db, user_id=user.id, project_id=project.id + ) + assert removed is True + + def test_remove_nonexistent_mapping_returns_false(self, db: Session): + """Test removing a non-existent mapping returns False.""" + removed = remove_user_from_project(session=db, user_id=99999, project_id=99999) + assert removed is False + + def test_remove_last_project_deletes_user(self, db: Session): + """Test removing user from their last project deletes the user.""" + project = create_test_project(db) + email = random_email() + + user, _ = add_user_to_project( + session=db, + email=email, + organization_id=project.organization_id, + project_id=project.id, + ) + user_id = user.id + + remove_user_from_project(session=db, user_id=user_id, project_id=project.id) + + assert db.get(User, user_id) is None + + def test_remove_last_project_preserves_superuser(self, db: Session): + """Test superuser is not deleted when removed from last project.""" + project = create_test_project(db) + user = create_random_user(db) + user.is_superuser = True + db.add(user) + db.flush() + + mapping = UserProject( + user_id=user.id, + organization_id=project.organization_id, + project_id=project.id, + ) + db.add(mapping) + db.flush() + + remove_user_from_project(session=db, user_id=user.id, project_id=project.id) + + assert db.get(User, user.id) is not None + + +class TestGetUserProjects: + """Test suite for get_user_projects CRUD function.""" + + def test_returns_empty_for_user_with_no_projects(self, db: Session): + """Test returns empty when user has no project mappings.""" + user = create_random_user(db) + result = get_user_projects(session=db, user_id=user.id) + assert len(result) == 0 + + def test_returns_projects_for_user(self, db: Session): + """Test returns project mappings for a user.""" + project = create_test_project(db) + email = random_email() + + user, _ = add_user_to_project( + session=db, + email=email, + organization_id=project.organization_id, + project_id=project.id, + ) + + result = get_user_projects(session=db, user_id=user.id) + assert len(result) == 1 + assert result[0].project_id == project.id diff --git a/backend/app/tests/services/collections/providers/test_openai_provider.py b/backend/app/tests/services/collections/providers/test_openai_provider.py index 8ae028a0a..044c94a14 100644 --- a/backend/app/tests/services/collections/providers/test_openai_provider.py +++ b/backend/app/tests/services/collections/providers/test_openai_provider.py @@ -25,15 +25,10 @@ def test_create_openai_vector_store_only() -> None: ) storage = MagicMock() - document_crud = MagicMock() - - fake_batches = [["doc1"], ["doc2"]] + docs_batches = [["doc1"], ["doc2"]] vector_store_id = generate_openai_id("vs_") with patch( - "app.services.collections.providers.openai.batch_documents", - return_value=fake_batches, - ), patch( "app.services.collections.providers.openai.OpenAIVectorStoreCrud" ) as vector_store_crud_cls: vector_store_crud = vector_store_crud_cls.return_value @@ -43,7 +38,7 @@ def test_create_openai_vector_store_only() -> None: collection = provider.create( collection_request, storage, - document_crud, + docs_batches, ) assert isinstance(collection, Collection) @@ -64,16 +59,11 @@ def test_create_openai_with_assistant() -> None: ) storage = MagicMock() - document_crud = MagicMock() - - fake_batches = [["doc1"]] + docs_batches = [["doc1"]] vector_store_id = generate_openai_id("vs_") assistant_id = generate_openai_id("asst_") with patch( - "app.services.collections.providers.openai.batch_documents", - return_value=fake_batches, - ), patch( "app.services.collections.providers.openai.OpenAIVectorStoreCrud" ) as vector_store_crud_cls, patch( "app.services.collections.providers.openai.OpenAIAssistantCrud" @@ -88,7 +78,7 @@ def test_create_openai_with_assistant() -> None: collection = provider.create( collection_request, storage, - document_crud, + docs_batches, ) assert collection.llm_service_id == assistant_id @@ -145,12 +135,13 @@ def test_create_propagates_exception() -> None: ) with patch( - "app.services.collections.providers.openai.batch_documents", - side_effect=RuntimeError("boom"), - ): + "app.services.collections.providers.openai.OpenAIVectorStoreCrud" + ) as vector_store_crud_cls: + vector_store_crud_cls.return_value.create.side_effect = RuntimeError("boom") + with pytest.raises(RuntimeError): provider.create( collection_request, MagicMock(), - MagicMock(), + [["doc1"]], ) diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 395f0120b..769ac259f 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -43,6 +43,7 @@ dependencies = [ "indic-nlp-library>=0.92", "whisper-normalizer>=0.1.12", "elevenlabs>=2.38.1", + "google-auth>=2.49.1", "gevent>=25.9.1", ] diff --git a/backend/uv.lock b/backend/uv.lock index e8fd3ca4c..b0ca01d84 100644 --- a/backend/uv.lock +++ b/backend/uv.lock @@ -228,6 +228,7 @@ dependencies = [ { name = "fastapi", extra = ["standard"] }, { name = "flower" }, { name = "gevent" }, + { name = "google-auth" }, { name = "google-genai" }, { name = "httpx" }, { name = "indic-nlp-library" }, @@ -282,6 +283,7 @@ requires-dist = [ { name = "fastapi", extras = ["standard"], specifier = ">=0.116.0" }, { name = "flower", specifier = ">=2.0.1" }, { name = "gevent", specifier = ">=25.9.1" }, + { name = "google-auth", specifier = ">=2.49.1" }, { name = "google-genai", specifier = ">=1.59.0" }, { name = "httpx", specifier = ">=0.25.1,<1.0.0" }, { name = "indic-nlp-library", specifier = ">=0.92" }, @@ -1288,7 +1290,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ea/ab/1608e5a7578e62113506740b88066bf09888322a311cff602105e619bd87/greenlet-3.3.2-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:ac8d61d4343b799d1e526db579833d72f23759c71e07181c2d2944e429eb09cd", size = 280358, upload-time = "2026-02-20T20:17:43.971Z" }, { url = "https://files.pythonhosted.org/packages/a5/23/0eae412a4ade4e6623ff7626e38998cb9b11e9ff1ebacaa021e4e108ec15/greenlet-3.3.2-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3ceec72030dae6ac0c8ed7591b96b70410a8be370b6a477b1dbc072856ad02bd", size = 601217, upload-time = "2026-02-20T20:47:31.462Z" }, { url = "https://files.pythonhosted.org/packages/f8/16/5b1678a9c07098ecb9ab2dd159fafaf12e963293e61ee8d10ecb55273e5e/greenlet-3.3.2-cp312-cp312-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:a2a5be83a45ce6188c045bcc44b0ee037d6a518978de9a5d97438548b953a1ac", size = 611792, upload-time = "2026-02-20T20:55:58.423Z" }, - { url = "https://files.pythonhosted.org/packages/5c/c5/cc09412a29e43406eba18d61c70baa936e299bc27e074e2be3806ed29098/greenlet-3.3.2-cp312-cp312-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:ae9e21c84035c490506c17002f5c8ab25f980205c3e61ddb3a2a2a2e6c411fcb", size = 626250, upload-time = "2026-02-20T21:02:46.596Z" }, { url = "https://files.pythonhosted.org/packages/50/1f/5155f55bd71cabd03765a4aac9ac446be129895271f73872c36ebd4b04b6/greenlet-3.3.2-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:43e99d1749147ac21dde49b99c9abffcbc1e2d55c67501465ef0930d6e78e070", size = 613875, upload-time = "2026-02-20T20:21:01.102Z" }, { url = "https://files.pythonhosted.org/packages/fc/dd/845f249c3fcd69e32df80cdab059b4be8b766ef5830a3d0aa9d6cad55beb/greenlet-3.3.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:4c956a19350e2c37f2c48b336a3afb4bff120b36076d9d7fb68cb44e05d95b79", size = 1571467, upload-time = "2026-02-20T20:49:33.495Z" }, { url = "https://files.pythonhosted.org/packages/2a/50/2649fe21fcc2b56659a452868e695634722a6655ba245d9f77f5656010bf/greenlet-3.3.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:6c6f8ba97d17a1e7d664151284cb3315fc5f8353e75221ed4324f84eb162b395", size = 1640001, upload-time = "2026-02-20T20:21:09.154Z" }, @@ -1297,7 +1298,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ac/48/f8b875fa7dea7dd9b33245e37f065af59df6a25af2f9561efa8d822fde51/greenlet-3.3.2-cp313-cp313-macosx_11_0_universal2.whl", hash = "sha256:aa6ac98bdfd716a749b84d4034486863fd81c3abde9aa3cf8eff9127981a4ae4", size = 279120, upload-time = "2026-02-20T20:19:01.9Z" }, { url = "https://files.pythonhosted.org/packages/49/8d/9771d03e7a8b1ee456511961e1b97a6d77ae1dea4a34a5b98eee706689d3/greenlet-3.3.2-cp313-cp313-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ab0c7e7901a00bc0a7284907273dc165b32e0d109a6713babd04471327ff7986", size = 603238, upload-time = "2026-02-20T20:47:32.873Z" }, { url = "https://files.pythonhosted.org/packages/59/0e/4223c2bbb63cd5c97f28ffb2a8aee71bdfb30b323c35d409450f51b91e3e/greenlet-3.3.2-cp313-cp313-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:d248d8c23c67d2291ffd47af766e2a3aa9fa1c6703155c099feb11f526c63a92", size = 614219, upload-time = "2026-02-20T20:55:59.817Z" }, - { url = "https://files.pythonhosted.org/packages/94/2b/4d012a69759ac9d77210b8bfb128bc621125f5b20fc398bce3940d036b1c/greenlet-3.3.2-cp313-cp313-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:ccd21bb86944ca9be6d967cf7691e658e43417782bce90b5d2faeda0ff78a7dd", size = 628268, upload-time = "2026-02-20T21:02:48.024Z" }, { url = "https://files.pythonhosted.org/packages/7a/34/259b28ea7a2a0c904b11cd36c79b8cef8019b26ee5dbe24e73b469dea347/greenlet-3.3.2-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b6997d360a4e6a4e936c0f9625b1c20416b8a0ea18a8e19cabbefc712e7397ab", size = 616774, upload-time = "2026-02-20T20:21:02.454Z" }, { url = "https://files.pythonhosted.org/packages/0a/03/996c2d1689d486a6e199cb0f1cf9e4aa940c500e01bdf201299d7d61fa69/greenlet-3.3.2-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:64970c33a50551c7c50491671265d8954046cb6e8e2999aacdd60e439b70418a", size = 1571277, upload-time = "2026-02-20T20:49:34.795Z" }, { url = "https://files.pythonhosted.org/packages/d9/c4/2570fc07f34a39f2caf0bf9f24b0a1a0a47bc2e8e465b2c2424821389dfc/greenlet-3.3.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:1a9172f5bf6bd88e6ba5a84e0a68afeac9dc7b6b412b245dd64f52d83c81e55b", size = 1640455, upload-time = "2026-02-20T20:21:10.261Z" }, @@ -1306,7 +1306,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/3f/ae/8bffcbd373b57a5992cd077cbe8858fff39110480a9d50697091faea6f39/greenlet-3.3.2-cp314-cp314-macosx_11_0_universal2.whl", hash = "sha256:8d1658d7291f9859beed69a776c10822a0a799bc4bfe1bd4272bb60e62507dab", size = 279650, upload-time = "2026-02-20T20:18:00.783Z" }, { url = "https://files.pythonhosted.org/packages/d1/c0/45f93f348fa49abf32ac8439938726c480bd96b2a3c6f4d949ec0124b69f/greenlet-3.3.2-cp314-cp314-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:18cb1b7337bca281915b3c5d5ae19f4e76d35e1df80f4ad3c1a7be91fadf1082", size = 650295, upload-time = "2026-02-20T20:47:34.036Z" }, { url = "https://files.pythonhosted.org/packages/b3/de/dd7589b3f2b8372069ab3e4763ea5329940fc7ad9dcd3e272a37516d7c9b/greenlet-3.3.2-cp314-cp314-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:c2e47408e8ce1c6f1ceea0dffcdf6ebb85cc09e55c7af407c99f1112016e45e9", size = 662163, upload-time = "2026-02-20T20:56:01.295Z" }, - { url = "https://files.pythonhosted.org/packages/cd/ac/85804f74f1ccea31ba518dcc8ee6f14c79f73fe36fa1beba38930806df09/greenlet-3.3.2-cp314-cp314-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:e3cb43ce200f59483eb82949bf1835a99cf43d7571e900d7c8d5c62cdf25d2f9", size = 675371, upload-time = "2026-02-20T21:02:49.664Z" }, { url = "https://files.pythonhosted.org/packages/d2/d8/09bfa816572a4d83bccd6750df1926f79158b1c36c5f73786e26dbe4ee38/greenlet-3.3.2-cp314-cp314-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:63d10328839d1973e5ba35e98cccbca71b232b14051fd957b6f8b6e8e80d0506", size = 664160, upload-time = "2026-02-20T20:21:04.015Z" }, { url = "https://files.pythonhosted.org/packages/48/cf/56832f0c8255d27f6c35d41b5ec91168d74ec721d85f01a12131eec6b93c/greenlet-3.3.2-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:8e4ab3cfb02993c8cc248ea73d7dae6cec0253e9afa311c9b37e603ca9fad2ce", size = 1619181, upload-time = "2026-02-20T20:49:36.052Z" }, { url = "https://files.pythonhosted.org/packages/0a/23/b90b60a4aabb4cec0796e55f25ffbfb579a907c3898cd2905c8918acaa16/greenlet-3.3.2-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:94ad81f0fd3c0c0681a018a976e5c2bd2ca2d9d94895f23e7bb1af4e8af4e2d5", size = 1687713, upload-time = "2026-02-20T20:21:11.684Z" }, @@ -1315,7 +1314,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/98/6d/8f2ef704e614bcf58ed43cfb8d87afa1c285e98194ab2cfad351bf04f81e/greenlet-3.3.2-cp314-cp314t-macosx_11_0_universal2.whl", hash = "sha256:e26e72bec7ab387ac80caa7496e0f908ff954f31065b0ffc1f8ecb1338b11b54", size = 286617, upload-time = "2026-02-20T20:19:29.856Z" }, { url = "https://files.pythonhosted.org/packages/5e/0d/93894161d307c6ea237a43988f27eba0947b360b99ac5239ad3fe09f0b47/greenlet-3.3.2-cp314-cp314t-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:8b466dff7a4ffda6ca975979bab80bdadde979e29fc947ac3be4451428d8b0e4", size = 655189, upload-time = "2026-02-20T20:47:35.742Z" }, { url = "https://files.pythonhosted.org/packages/f5/2c/d2d506ebd8abcb57386ec4f7ba20f4030cbe56eae541bc6fd6ef399c0b41/greenlet-3.3.2-cp314-cp314t-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:b8bddc5b73c9720bea487b3bffdb1840fe4e3656fba3bd40aa1489e9f37877ff", size = 658225, upload-time = "2026-02-20T20:56:02.527Z" }, - { url = "https://files.pythonhosted.org/packages/d1/67/8197b7e7e602150938049d8e7f30de1660cfb87e4c8ee349b42b67bdb2e1/greenlet-3.3.2-cp314-cp314t-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:59b3e2c40f6706b05a9cd299c836c6aa2378cabe25d021acd80f13abf81181cf", size = 666581, upload-time = "2026-02-20T21:02:51.526Z" }, { url = "https://files.pythonhosted.org/packages/8e/30/3a09155fbf728673a1dea713572d2d31159f824a37c22da82127056c44e4/greenlet-3.3.2-cp314-cp314t-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b26b0f4428b871a751968285a1ac9648944cea09807177ac639b030bddebcea4", size = 657907, upload-time = "2026-02-20T20:21:05.259Z" }, { url = "https://files.pythonhosted.org/packages/f3/fd/d05a4b7acd0154ed758797f0a43b4c0962a843bedfe980115e842c5b2d08/greenlet-3.3.2-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:1fb39a11ee2e4d94be9a76671482be9398560955c9e568550de0224e41104727", size = 1618857, upload-time = "2026-02-20T20:49:37.309Z" }, { url = "https://files.pythonhosted.org/packages/6f/e1/50ee92a5db521de8f35075b5eff060dd43d39ebd46c2181a2042f7070385/greenlet-3.3.2-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:20154044d9085151bc309e7689d6f7ba10027f8f5a8c0676ad398b951913d89e", size = 1680010, upload-time = "2026-02-20T20:21:13.427Z" },