diff --git a/backend/app/alembic/versions/543f97951bd0_add_credential_table.py b/backend/app/alembic/versions/543f97951bd0_add_credential_table.py index 24927d9b8..7b04421c7 100644 --- a/backend/app/alembic/versions/543f97951bd0_add_credential_table.py +++ b/backend/app/alembic/versions/543f97951bd0_add_credential_table.py @@ -1,4 +1,4 @@ -"""add credetial table +"""add credential table Revision ID: 543f97951bd0 Revises: 8d7a05fd0ad4 @@ -22,7 +22,7 @@ def upgrade(): "credential", sa.Column("id", sa.Integer(), nullable=False), sa.Column("is_active", sa.Boolean(), nullable=False), - sa.Column("credential", sa.JSON(), nullable=True), + sa.Column("credential", sqlmodel.sql.sqltypes.AutoString(), nullable=False), sa.Column("organization_id", sa.Integer(), nullable=False), sa.Column("inserted_at", sa.DateTime(), nullable=True), sa.Column("updated_at", sa.DateTime(), nullable=True), diff --git a/backend/app/alembic/versions/904ed70e7dab_added_provider_column_to_the_credential_.py b/backend/app/alembic/versions/904ed70e7dab_added_provider_column_to_the_credential_.py new file mode 100644 index 000000000..0991fe3de --- /dev/null +++ b/backend/app/alembic/versions/904ed70e7dab_added_provider_column_to_the_credential_.py @@ -0,0 +1,93 @@ +"""Added provider column to the credential table + +Revision ID: 904ed70e7dab +Revises: 79e47bc3aac6 +Create Date: 2025-05-10 11:13:17.868238 + +""" +from alembic import op +import sqlalchemy as sa +import sqlmodel.sql.sqltypes + + +revision = "904ed70e7dab" +down_revision = "79e47bc3aac6" +branch_labels = None +depends_on = None + + +def upgrade(): + # Add new columns to credential table + op.add_column( + "credential", + sa.Column("provider", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + ) + op.add_column("credential", sa.Column("project_id", sa.Integer(), nullable=True)) + + # Create indexes and constraints + op.create_index( + op.f("ix_credential_provider"), "credential", ["provider"], unique=False + ) + + # Drop existing foreign keys + op.drop_constraint( + "credential_organization_id_fkey", "credential", type_="foreignkey" + ) + op.drop_constraint("project_organization_id_fkey", "project", type_="foreignkey") + + # Create all foreign keys together + op.create_foreign_key( + "credential_organization_id_fkey", + "credential", + "organization", + ["organization_id"], + ["id"], + ondelete="CASCADE", + ) + op.create_foreign_key( + None, + "project", + "organization", + ["organization_id"], + ["id"], + ) + op.create_foreign_key( + "credential_project_id_fkey", + "credential", + "project", + ["project_id"], + ["id"], + ondelete="SET NULL", + ) + + +def downgrade(): + # Drop project_id foreign key and column + op.drop_constraint("credential_project_id_fkey", "credential", type_="foreignkey") + op.drop_column("credential", "project_id") + + # Drop existing foreign keys + op.drop_constraint(None, "project", type_="foreignkey") + op.drop_constraint( + "credential_organization_id_fkey", "credential", type_="foreignkey" + ) + + # Create all foreign keys together + op.create_foreign_key( + "project_organization_id_fkey", + "project", + "organization", + ["organization_id"], + ["id"], + ondelete="CASCADE", + ) + op.create_foreign_key( + "credential_organization_id_fkey", + "credential", + "organization", + ["organization_id"], + ["id"], + ) + + op.drop_index(op.f("ix_credential_provider"), table_name="credential") + op.drop_column("credential", "provider") diff --git a/backend/app/api/routes/credentials.py b/backend/app/api/routes/credentials.py index ccb68a1bc..ca0237a9c 100644 --- a/backend/app/api/routes/credentials.py +++ b/backend/app/api/routes/credentials.py @@ -1,15 +1,22 @@ +from typing import List + from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy.exc import IntegrityError + from app.api.deps import SessionDep, get_current_active_superuser from app.crud.credentials import ( get_creds_by_org, - get_key_by_org, + get_provider_credential, remove_creds_for_org, set_creds_for_org, update_creds_for_org, + remove_provider_credential, ) from app.models import CredsCreate, CredsPublic, CredsUpdate +from app.models.organization import Organization +from app.models.project import Project from app.utils import APIResponse -from datetime import datetime +from app.core.providers import validate_provider router = APIRouter(prefix="/credentials", tags=["credentials"]) @@ -17,79 +24,124 @@ @router.post( "/", dependencies=[Depends(get_current_active_superuser)], - response_model=APIResponse[CredsPublic], + response_model=APIResponse[List[CredsPublic]], + summary="Create new credentials for an organization and project", + description="Creates new credentials for a specific organization and project combination. This endpoint requires superuser privileges. Each organization can have different credentials for different providers and projects. Only one credential per provider is allowed per organization-project combination.", ) def create_new_credential(*, session: SessionDep, creds_in: CredsCreate): - new_creds = None try: - existing_creds = get_creds_by_org( - session=session, org_id=creds_in.organization_id - ) - if not existing_creds: - new_creds = set_creds_for_org(session=session, creds_add=creds_in) + # Check if organization exists + organization = session.get(Organization, creds_in.organization_id) + if not organization: + raise HTTPException(status_code=404, detail="Organization not found") + + # Check if project exists if project_id is provided + if creds_in.project_id: + project = session.get(Project, creds_in.project_id) + if not project: + raise HTTPException(status_code=404, detail="Project not found") + if project.organization_id != creds_in.organization_id: + raise HTTPException( + status_code=400, + detail="Project does not belong to the specified organization", + ) + + # Check for existing credentials for each provider + for provider in creds_in.credential.keys(): + existing_cred = get_provider_credential( + session=session, + org_id=creds_in.organization_id, + provider=provider, + project_id=creds_in.project_id, + ) + if existing_cred: + raise HTTPException( + status_code=400, + detail=f"Credentials for provider '{provider}' already exist for this organization and project combination", + ) + + # Create new credentials + new_creds = set_creds_for_org(session=session, creds_add=creds_in) + if not new_creds: + raise HTTPException(status_code=500, detail="Failed to create credentials") + return APIResponse.success_response([cred.to_public() for cred in new_creds]) + except ValueError as e: + raise HTTPException(status_code=404, detail=str(e)) except Exception as e: raise HTTPException( status_code=500, detail=f"An unexpected error occurred: {str(e)}" ) - # Ensure inserted_at is set during creation - new_creds.inserted_at = datetime.utcnow() - - return APIResponse.success_response(new_creds) - @router.get( "/{org_id}", dependencies=[Depends(get_current_active_superuser)], - response_model=APIResponse[CredsPublic], + response_model=APIResponse[List[CredsPublic]], + summary="Get all credentials for an organization and project", + description="Retrieves all provider credentials associated with a specific organization and project combination. If project_id is not provided, returns credentials for the organization level. This endpoint requires superuser privileges.", ) -def read_credential(*, session: SessionDep, org_id: int): - try: - creds = get_creds_by_org(session=session, org_id=org_id) - except Exception as e: - raise HTTPException( - status_code=500, detail=f"An unexpected error occurred: {str(e)}" - ) - - if creds is None: +def read_credential(*, session: SessionDep, org_id: int, project_id: int | None = None): + creds = get_creds_by_org(session=session, org_id=org_id, project_id=project_id) + if not creds: raise HTTPException(status_code=404, detail="Credentials not found") - - return APIResponse.success_response(creds) + return APIResponse.success_response([cred.to_public() for cred in creds]) @router.get( - "/{org_id}/api-key", + "/{org_id}/{provider}", dependencies=[Depends(get_current_active_superuser)], response_model=APIResponse[dict], + summary="Get specific provider credentials for an organization and project", + description="Retrieves credentials for a specific provider (e.g., 'openai', 'anthropic') for a given organization and project combination. If project_id is not provided, returns organization-level credentials. This endpoint requires superuser privileges.", ) -def read_api_key(*, session: SessionDep, org_id: int): - try: - api_key = get_key_by_org(session=session, org_id=org_id) - except Exception as e: - raise HTTPException( - status_code=500, detail=f"An unexpected error occurred: {str(e)}" - ) - - if api_key is None: - raise HTTPException(status_code=404, detail="API key not found") - - return APIResponse.success_response({"api_key": api_key}) +def read_provider_credential( + *, session: SessionDep, org_id: int, provider: str, project_id: int | None = None +): + provider_enum = validate_provider(provider) + provider_creds = get_provider_credential( + session=session, + org_id=org_id, + provider=provider_enum, + project_id=project_id, + ) + if provider_creds is None: + raise HTTPException(status_code=404, detail="Provider credentials not found") + return APIResponse.success_response(provider_creds) @router.patch( "/{org_id}", dependencies=[Depends(get_current_active_superuser)], - response_model=APIResponse[CredsPublic], + response_model=APIResponse[List[CredsPublic]], + summary="Update organization and project credentials", + description="Updates credentials for a specific organization and project combination. Can update specific provider credentials or add new providers. If project_id is provided in the update, credentials will be moved to that project. This endpoint requires superuser privileges.", ) def update_credential(*, session: SessionDep, org_id: int, creds_in: CredsUpdate): try: + if not creds_in or not creds_in.provider or not creds_in.credential: + raise HTTPException( + status_code=400, detail="Provider and credential must be provided" + ) + organization = session.get(Organization, org_id) + if not organization: + raise HTTPException(status_code=404, detail="Organization not found") updated_creds = update_creds_for_org( session=session, org_id=org_id, creds_in=creds_in ) - - updated_creds.updated_at = datetime.utcnow() - - return APIResponse.success_response(updated_creds) + if not updated_creds: + raise HTTPException(status_code=404, detail="Failed to update credentials") + return APIResponse.success_response( + [cred.to_public() for cred in updated_creds] + ) + except IntegrityError as e: + if "ForeignKeyViolation" in str(e): + raise HTTPException( + status_code=400, + detail="Invalid organization ID. Ensure the organization exists before updating credentials.", + ) + raise HTTPException( + status_code=500, detail=f"An unexpected database error occurred: {str(e)}" + ) except ValueError as e: raise HTTPException(status_code=404, detail=str(e)) except Exception as e: @@ -98,30 +150,61 @@ def update_credential(*, session: SessionDep, org_id: int, creds_in: CredsUpdate ) -from fastapi import HTTPException, Depends -from app.crud.credentials import remove_creds_for_org -from app.utils import APIResponse -from app.api.deps import SessionDep, get_current_active_superuser - - @router.delete( - "/{org_id}/api-key", + "/{org_id}/{provider}", dependencies=[Depends(get_current_active_superuser)], response_model=APIResponse[dict], + summary="Delete specific provider credentials for an organization and project", + description="Removes credentials for a specific provider while keeping other provider credentials intact. If project_id is provided, only removes credentials for that project. This endpoint requires superuser privileges.", ) -def delete_credential(*, session: SessionDep, org_id: int): +def delete_provider_credential( + *, session: SessionDep, org_id: int, provider: str, project_id: int | None = None +): try: - creds = remove_creds_for_org(session=session, org_id=org_id) + provider_enum = validate_provider(provider) + updated_creds = remove_provider_credential( + session=session, + org_id=org_id, + provider=provider_enum, + project_id=project_id, + ) + if not updated_creds: + raise HTTPException( + status_code=404, detail="Provider credentials not found" + ) + return APIResponse.success_response( + {"message": "Provider credentials removed successfully"} + ) + except ValueError: + raise HTTPException(status_code=404, detail="Provider credentials not found") except Exception as e: raise HTTPException( status_code=500, detail=f"An unexpected error occurred: {str(e)}" ) - if creds is None: + +@router.delete( + "/{org_id}", + dependencies=[Depends(get_current_active_superuser)], + response_model=APIResponse[dict], + summary="Delete all credentials for an organization and project", + description="Removes all credentials for a specific organization and project combination. If project_id is provided, only removes credentials for that project. This is a soft delete operation that marks credentials as inactive. This endpoint requires superuser privileges.", +) +def delete_all_credentials( + *, session: SessionDep, org_id: int, project_id: int | None = None +): + try: + creds = remove_creds_for_org( + session=session, org_id=org_id, project_id=project_id + ) + if not creds: + raise HTTPException( + status_code=404, detail="Credentials for organization not found" + ) + return APIResponse.success_response( + {"message": "Credentials deleted successfully"} + ) + except Exception as e: raise HTTPException( - status_code=404, detail="Credentials for organization not found" + status_code=500, detail=f"An unexpected error occurred: {str(e)}" ) - - # No need to manually set deleted_at and is_active if it's done in remove_creds_for_org - # Simply return the success response - return APIResponse.success_response({"message": "Credentials deleted successfully"}) diff --git a/backend/app/api/routes/threads.py b/backend/app/api/routes/threads.py index d27b578e7..28275487b 100644 --- a/backend/app/api/routes/threads.py +++ b/backend/app/api/routes/threads.py @@ -12,6 +12,8 @@ from app.models import UserOrganization, OpenAIThreadCreate from app.crud import upsert_thread_result, get_thread_result from app.utils import APIResponse +from app.crud.credentials import get_provider_credential +from app.core.security import decrypt_credentials logger = logging.getLogger(__name__) router = APIRouter(tags=["threads"]) @@ -220,11 +222,33 @@ async def threads( _current_user: UserOrganization = Depends(get_current_user_org), ): """Asynchronous endpoint that processes requests in background.""" - client = OpenAI(api_key=settings.OPENAI_API_KEY) + credentials = get_provider_credential( + session=_session, + org_id=_current_user.organization_id, + provider="openai", + project_id=request.get("project_id"), + ) + if not credentials or "api_key" not in credentials: + return APIResponse.failure_response( + error="OpenAI API key not configured for this organization." + ) + client = OpenAI(api_key=credentials["api_key"]) + + langfuse_credentials = get_provider_credential( + session=_session, + org_id=_current_user.organization_id, + provider="langfuse", + project_id=request.get("project_id"), + ) + if not langfuse_credentials: + return APIResponse.failure_response( + error="LANGFUSE keys not configured for this organization." + ) + langfuse_context.configure( - secret_key=settings.LANGFUSE_SECRET_KEY, - public_key=settings.LANGFUSE_PUBLIC_KEY, - host=settings.LANGFUSE_HOST, + secret_key=langfuse_credentials["secret_key"], + public_key=langfuse_credentials["public_key"], + host=langfuse_credentials["host"], ) # Validate thread is_valid, error_message = validate_thread(client, request.get("thread_id")) @@ -259,7 +283,19 @@ async def threads_sync( _current_user: UserOrganization = Depends(get_current_user_org), ): """Synchronous endpoint that processes requests immediately.""" - client = OpenAI(api_key=settings.OPENAI_API_KEY) + + credentials = get_provider_credential( + session=_session, + org_id=_current_user.organization_id, + provider="openai", + project_id=_current_user.project_id, + ) + if not credentials or "api_key" not in credentials: + return APIResponse.failure_response( + error="OpenAI API key not configured for this organization." + ) + + client = OpenAI(api_key=credentials["api_key"]) # Validate thread is_valid, error_message = validate_thread(client, request.get("thread_id")) diff --git a/backend/app/core/providers.py b/backend/app/core/providers.py new file mode 100644 index 000000000..f703d2e6b --- /dev/null +++ b/backend/app/core/providers.py @@ -0,0 +1,77 @@ +from typing import Dict, List, Optional +from enum import Enum +from dataclasses import dataclass + + +class Provider(str, Enum): + """Enumeration of supported credential providers.""" + + OPENAI = "openai" + AWS = "aws" + LANGFUSE = "langfuse" + + +@dataclass +class ProviderConfig: + """Configuration for a provider including its required credential fields.""" + + required_fields: List[str] + + +# Provider configurations +PROVIDER_CONFIGS: Dict[Provider, ProviderConfig] = { + Provider.OPENAI: ProviderConfig(required_fields=["api_key"]), + Provider.AWS: ProviderConfig( + required_fields=["access_key_id", "secret_access_key", "region"] + ), + Provider.LANGFUSE: ProviderConfig( + required_fields=["secret_key", "public_key", "host"] + ), +} + + +def validate_provider(provider: str) -> Provider: + """Validate that the provider name is supported and return the Provider enum. + + Args: + provider: The provider name to validate + + Returns: + Provider: The validated provider enum + + Raises: + ValueError: If the provider is not supported + """ + try: + return Provider(provider.lower()) + except ValueError: + supported = ", ".join(p.value for p in Provider) + raise ValueError( + f"Unsupported provider: {provider}. Supported providers are: {supported}" + ) + + +def validate_provider_credentials(provider: str, credentials: Dict[str, str]) -> None: + """Validate that the credentials contain all required fields for the provider. + + Args: + provider: The provider name to validate credentials for + credentials: Dictionary containing the provider credentials + + Raises: + ValueError: If required fields are missing from the credentials + """ + provider_enum = validate_provider(provider) + required_fields = PROVIDER_CONFIGS[provider_enum].required_fields + + if missing_fields := [ + field for field in required_fields if field not in credentials + ]: + raise ValueError( + f"Missing required fields for {provider}: {', '.join(missing_fields)}" + ) + + +def get_supported_providers() -> List[str]: + """Return a list of all supported provider names.""" + return [p.value for p in Provider] diff --git a/backend/app/core/security.py b/backend/app/core/security.py index 65594fcbf..ace78c3a7 100644 --- a/backend/app/core/security.py +++ b/backend/app/core/security.py @@ -1,21 +1,42 @@ +""" +Security module for handling authentication, encryption, and password management. +This module provides utilities for: +- JWT token generation and validation +- Password hashing and verification +- API key encryption/decryption +- Credentials encryption/decryption +""" + from datetime import datetime, timedelta, timezone from typing import Any import base64 +import json + +import jwt from cryptography.fernet import Fernet from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC - -import jwt from passlib.context import CryptContext from app.core.config import settings +# Password hashing configuration pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") +# JWT configuration +ALGORITHM = "HS256" + +# Fernet instance for encryption/decryption +_fernet = None + -# Generate a key for API key encryption def get_encryption_key() -> bytes: - """Generate a key for API key encryption using the app's secret key.""" + """ + Generate a key for API key encryption using the app's secret key. + + Returns: + bytes: A URL-safe base64 encoded encryption key derived from the app's secret key. + """ kdf = PBKDF2HMAC( algorithm=hashes.SHA256(), length=32, @@ -25,47 +46,136 @@ def get_encryption_key() -> bytes: return base64.urlsafe_b64encode(kdf.derive(settings.SECRET_KEY.encode())) -# Initialize Fernet with our encryption key -_fernet = None - - def get_fernet() -> Fernet: - """Get a Fernet instance with the encryption key.""" + """ + Get a Fernet instance with the encryption key. + Uses singleton pattern to avoid creating multiple instances. + + Returns: + Fernet: A Fernet instance initialized with the encryption key. + """ global _fernet if _fernet is None: _fernet = Fernet(get_encryption_key()) return _fernet -ALGORITHM = "HS256" +def create_access_token(subject: str | Any, expires_delta: timedelta) -> str: + """ + Create a JWT access token. + Args: + subject: The subject of the token (typically user ID) + expires_delta: Token expiration time delta -def create_access_token(subject: str | Any, expires_delta: timedelta) -> str: + Returns: + str: Encoded JWT token + """ expire = datetime.now(timezone.utc) + expires_delta to_encode = {"exp": expire, "sub": str(subject)} - encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=ALGORITHM) - return encoded_jwt + return jwt.encode(to_encode, settings.SECRET_KEY, algorithm=ALGORITHM) def verify_password(plain_password: str, hashed_password: str) -> bool: + """ + Verify a password against its hash. + + Args: + plain_password: The plain text password to verify + hashed_password: The hashed password to check against + + Returns: + bool: True if password matches, False otherwise + """ return pwd_context.verify(plain_password, hashed_password) def get_password_hash(password: str) -> str: + """ + Generate a password hash. + + Args: + password: The plain text password to hash + + Returns: + str: The hashed password + """ return pwd_context.hash(password) def encrypt_api_key(api_key: str) -> str: - """Encrypt an API key before storage.""" + """ + Encrypt an API key before storage. + + Args: + api_key: The plain text API key to encrypt + + Returns: + str: The encrypted API key + + Raises: + ValueError: If encryption fails + """ try: return get_fernet().encrypt(api_key.encode()).decode() except Exception as e: - raise ValueError(f"Failed to encrypt API key: {str(e)}") + raise ValueError(f"Failed to encrypt API key: {e}") def decrypt_api_key(encrypted_api_key: str) -> str: - """Decrypt an API key when retrieving it.""" + """ + Decrypt an API key when retrieving it. + + Args: + encrypted_api_key: The encrypted API key to decrypt + + Returns: + str: The decrypted API key + + Raises: + ValueError: If decryption fails + """ try: return get_fernet().decrypt(encrypted_api_key.encode()).decode() except Exception as e: - raise ValueError(f"Failed to decrypt API key: {str(e)}") + raise ValueError(f"Failed to decrypt API key: {e}") + + +def encrypt_credentials(credentials: dict) -> str: + """ + Encrypt the entire credentials object before storage. + + Args: + credentials: Dictionary containing credentials to encrypt + + Returns: + str: The encrypted credentials + + Raises: + ValueError: If encryption fails + """ + try: + credentials_str = json.dumps(credentials) + return get_fernet().encrypt(credentials_str.encode()).decode() + except Exception as e: + raise ValueError(f"Failed to encrypt credentials: {e}") + + +def decrypt_credentials(encrypted_credentials: str) -> dict: + """ + Decrypt the entire credentials object when retrieving it. + + Args: + encrypted_credentials: The encrypted credentials string to decrypt + + Returns: + dict: The decrypted credentials dictionary + + Raises: + ValueError: If decryption fails + """ + try: + decrypted_str = get_fernet().decrypt(encrypted_credentials.encode()).decode() + return json.loads(decrypted_str) + except Exception as e: + raise ValueError(f"Failed to decrypt credentials: {e}") diff --git a/backend/app/crud/__init__.py b/backend/app/crud/__init__.py index 963b0bf1f..748aecc24 100644 --- a/backend/app/crud/__init__.py +++ b/backend/app/crud/__init__.py @@ -30,4 +30,12 @@ delete_api_key, ) +from .credentials import ( + set_creds_for_org, + get_creds_by_org, + get_key_by_org, + update_creds_for_org, + remove_creds_for_org, +) + from .thread_results import upsert_thread_result, get_thread_result diff --git a/backend/app/crud/api_key.py b/backend/app/crud/api_key.py index cbff6d355..10a89885c 100644 --- a/backend/app/crud/api_key.py +++ b/backend/app/crud/api_key.py @@ -1,6 +1,6 @@ import uuid import secrets -from datetime import datetime +from datetime import datetime, timezone from sqlmodel import Session, select from app.core.security import ( verify_password, @@ -12,6 +12,7 @@ from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC from app.core import settings +from app.core.util import now from app.models.api_key import APIKey, APIKeyPublic @@ -110,8 +111,8 @@ def delete_api_key(session: Session, api_key_id: int) -> None: raise ValueError("API key not found or already deleted") api_key.is_deleted = True - api_key.deleted_at = datetime.utcnow() - api_key.updated_at = datetime.utcnow() + api_key.deleted_at = now() + api_key.updated_at = now() session.add(api_key) session.commit() diff --git a/backend/app/crud/credentials.py b/backend/app/crud/credentials.py index a66fa7b78..32dcf6465 100644 --- a/backend/app/crud/credentials.py +++ b/backend/app/crud/credentials.py @@ -1,104 +1,199 @@ -from typing import Optional, Dict, Any +from typing import Optional, Dict, Any, List from sqlmodel import Session, select from sqlalchemy.exc import IntegrityError -from datetime import datetime +from datetime import datetime, timezone from app.models import Credential, CredsCreate, CredsUpdate +from app.core.providers import ( + validate_provider, + validate_provider_credentials, + get_supported_providers, +) +from app.core.security import encrypt_credentials, decrypt_credentials +from app.core.util import now + + +def set_creds_for_org(*, session: Session, creds_add: CredsCreate) -> List[Credential]: + """Set credentials for an organization. Creates a separate row for each provider.""" + created_credentials = [] + + if not creds_add.credential: + raise ValueError("No credentials provided") + + for provider, credentials in creds_add.credential.items(): + # Validate provider and credentials + validate_provider(provider) + validate_provider_credentials(provider, credentials) + + # Encrypt entire credentials object + encrypted_credentials = encrypt_credentials(credentials) + + # Create a row for each provider + credential = Credential( + organization_id=creds_add.organization_id, + project_id=creds_add.project_id, + is_active=creds_add.is_active, + provider=provider, + credential=encrypted_credentials, + ) + credential.inserted_at = now() + try: + session.add(credential) + session.commit() + session.refresh(credential) + created_credentials.append(credential) + except IntegrityError as e: + session.rollback() + raise ValueError( + f"Error while adding credentials for provider {provider}: {str(e)}" + ) + + return created_credentials + + +def get_key_by_org( + *, + session: Session, + org_id: int, + provider: str = "openai", + project_id: Optional[int] = None, +) -> Optional[str]: + """Fetches the API key from the credentials for the given organization and provider.""" + statement = select(Credential).where( + Credential.organization_id == org_id, + Credential.provider == provider, + Credential.is_active == True, + Credential.project_id == project_id if project_id is not None else True, + ) + creds = session.exec(statement).first() + if creds and creds.credential and "api_key" in creds.credential: + return creds.credential["api_key"] -def set_creds_for_org(*, session: Session, creds_add: CredsCreate) -> Credential: - creds = Credential.model_validate(creds_add) + return None - # Set the inserted_at timestamp (current UTC time) - creds.inserted_at = datetime.utcnow() - - try: - session.add(creds) - session.commit() - session.refresh(creds) - except IntegrityError as e: - session.rollback() # Rollback the session if there's a unique constraint violation - raise ValueError(f"Error while adding credentials: {str(e)}") +def get_creds_by_org( + *, session: Session, org_id: int, project_id: Optional[int] = None +) -> List[Credential]: + """Fetches all credentials for an organization.""" + statement = select(Credential).where( + Credential.organization_id == org_id, + Credential.is_active == True, + Credential.project_id == project_id if project_id is not None else True, + ) + creds = session.exec(statement).all() return creds -def get_creds_by_org(*, session: Session, org_id: int) -> Optional[Credential]: - """Fetches the credentials for the given organization.""" - statement = select(Credential).where(Credential.organization_id == org_id) - return session.exec(statement).first() - +def get_provider_credential( + *, session: Session, org_id: int, provider: str, project_id: Optional[int] = None +) -> Optional[Dict[str, Any]]: + """Fetches credentials for a specific provider of an organization.""" + validate_provider(provider) -def get_key_by_org(*, session: Session, org_id: int) -> Optional[str]: - """Fetches the API key from the credentials for the given organization.""" - statement = select(Credential).where(Credential.organization_id == org_id) + statement = select(Credential).where( + Credential.organization_id == org_id, + Credential.provider == provider, + Credential.is_active == True, + Credential.project_id == project_id if project_id is not None else True, + ) creds = session.exec(statement).first() - # Check if creds exists and if the credential field contains the api_key - if ( - creds - and creds.credential - and "openai" in creds.credential - and "api_key" in creds.credential["openai"] - ): - return creds.credential["openai"]["api_key"] - + if creds and creds.credential: + # Decrypt entire credentials object + return decrypt_credentials(creds.credential) return None -def update_creds_for_org( - session: Session, org_id: int, creds_in: CredsUpdate -) -> Credential: - # Fetch the current credentials for the organization - creds = session.exec( - select(Credential).where(Credential.organization_id == org_id) - ).first() +def get_providers( + *, session: Session, org_id: int, project_id: Optional[int] = None +) -> List[str]: + """Returns a list of all active providers for which credentials are stored.""" + creds = get_creds_by_org(session=session, org_id=org_id, project_id=project_id) + return [cred.provider for cred in creds] - if not creds: - raise ValueError(f"Credentials not found") - - # Update the credentials data with the provided values - creds_data = creds_in.dict(exclude_unset=True) - # Directly update the fields on the original creds object instead of creating a new one - for key, value in creds_data.items(): - setattr(creds, key, value) - - # Set the updated_at timestamp (current UTC time) - creds.updated_at = datetime.utcnow() +def update_creds_for_org( + *, session: Session, org_id: int, creds_in: CredsUpdate +) -> List[Credential]: + """Updates credentials for a specific provider of an organization.""" + if not creds_in.provider or not creds_in.credential: + raise ValueError("Provider and credential must be provided") + + validate_provider(creds_in.provider) + validate_provider_credentials(creds_in.provider, creds_in.credential) + + # Encrypt the entire credentials object + encrypted_credentials = encrypt_credentials(creds_in.credential) + + statement = select(Credential).where( + Credential.organization_id == org_id, + Credential.provider == creds_in.provider, + Credential.is_active == True, + Credential.project_id == creds_in.project_id + if creds_in.project_id is not None + else True, + ) + creds = session.exec(statement).first() - try: - # Add the updated creds to the session and flush the changes to the database - session.add(creds) - session.flush() # This will flush the changes to the database but without committing - session.commit() # Now we commit the changes to make them permanent - except IntegrityError as e: - # Rollback in case of any integrity errors (e.g., constraint violations) - session.rollback() - raise ValueError(f"Error while updating credentials: {str(e)}") + if not creds: + raise ValueError(f"No credentials found for provider {creds_in.provider}") - # Refresh the session to get the latest updated data + creds.credential = encrypted_credentials + creds.updated_at = now() + session.add(creds) + session.commit() session.refresh(creds) - return creds + return [creds] -def remove_creds_for_org(*, session: Session, org_id: int) -> Optional[Credential]: - """Removes (soft deletes) the credentials for the given organization.""" - statement = select(Credential).where(Credential.organization_id == org_id) +def remove_provider_credential( + session: Session, org_id: int, provider: str, project_id: Optional[int] = None +) -> Credential: + """Remove credentials for a specific provider.""" + validate_provider(provider) + + statement = select(Credential).where( + Credential.organization_id == org_id, + Credential.provider == provider, + Credential.project_id == project_id if project_id is not None else True, + ) creds = session.exec(statement).first() - if creds: - try: - # Soft delete: Set is_active to False and set deleted_at timestamp - creds.is_active = False - creds.deleted_at = ( - datetime.utcnow() - ) # Set the current time as the deleted_at timestamp - session.add(creds) - session.commit() - except IntegrityError as e: - session.rollback() # Rollback in case of a failure during delete operation - raise ValueError(f"Error while deleting credentials: {str(e)}") + if not creds: + raise ValueError(f"Credentials not found for provider '{provider}'") + + # Soft delete by setting is_active to False + creds.is_active = False + creds.updated_at = now() + try: + session.add(creds) + session.commit() + session.refresh(creds) + return creds + except IntegrityError as e: + session.rollback() + raise ValueError(f"Error while removing provider credentials: {str(e)}") + + +def remove_creds_for_org( + *, session: Session, org_id: int, project_id: Optional[int] = None +) -> List[Credential]: + """Removes all credentials for an organization.""" + statement = select(Credential).where( + Credential.organization_id == org_id, + Credential.is_active == True, + Credential.project_id == project_id if project_id is not None else True, + ) + creds = session.exec(statement).all() + + for cred in creds: + cred.is_active = False + cred.updated_at = now() + session.add(cred) + + session.commit() return creds diff --git a/backend/app/crud/organization.py b/backend/app/crud/organization.py index bef289da6..a1da8eccc 100644 --- a/backend/app/crud/organization.py +++ b/backend/app/crud/organization.py @@ -1,16 +1,17 @@ from typing import Any, Optional -from datetime import datetime +from datetime import datetime, timezone from sqlmodel import Session, select from app.models import Organization, OrganizationCreate +from app.core.util import now def create_organization( *, session: Session, org_create: OrganizationCreate ) -> Organization: db_org = Organization.model_validate(org_create) - db_org.inserted_at = datetime.utcnow() - db_org.updated_at = datetime.utcnow() + db_org.inserted_at = now() + db_org.updated_at = now() session.add(db_org) session.commit() session.refresh(db_org) diff --git a/backend/app/crud/project.py b/backend/app/crud/project.py index b62fbaab9..c77f8e567 100644 --- a/backend/app/crud/project.py +++ b/backend/app/crud/project.py @@ -1,14 +1,15 @@ from typing import List, Optional -from datetime import datetime +from datetime import datetime, timezone from sqlmodel import Session, select from app.models import Project, ProjectCreate +from app.core.util import now def create_project(*, session: Session, project_create: ProjectCreate) -> Project: db_project = Project.model_validate(project_create) - db_project.inserted_at = datetime.utcnow() - db_project.updated_at = datetime.utcnow() + db_project.inserted_at = now() + db_project.updated_at = now() session.add(db_project) session.commit() session.refresh(db_project) diff --git a/backend/app/crud/project_user.py b/backend/app/crud/project_user.py index f11224b59..8a3d88350 100644 --- a/backend/app/crud/project_user.py +++ b/backend/app/crud/project_user.py @@ -1,7 +1,9 @@ import uuid from sqlmodel import Session, select, delete, func from app.models import ProjectUser, ProjectUserPublic, User, Project -from datetime import datetime +from datetime import datetime, timezone + +from app.core.util import now def is_project_admin(session: Session, user_id: str, project_id: int) -> bool: @@ -62,7 +64,7 @@ def remove_user_from_project( raise ValueError("User is not a member of this project or already removed.") project_user.is_deleted = True - project_user.deleted_at = datetime.utcnow() + project_user.deleted_at = now() session.add(project_user) # Required to mark as dirty for commit session.commit() diff --git a/backend/app/models/credentials.py b/backend/app/models/credentials.py index d006b572c..0a05cff2d 100644 --- a/backend/app/models/credentials.py +++ b/backend/app/models/credentials.py @@ -8,23 +8,54 @@ class CredsBase(SQLModel): organization_id: int = Field(foreign_key="organization.id") + project_id: Optional[int] = Field(default=None, foreign_key="project.id") is_active: bool = True class CredsCreate(CredsBase): - credential: Dict[str, Any] = Field(default=None, sa_column=sa.Column(sa.JSON)) + """Create new credentials for an organization. + The credential field should be a dictionary mapping provider names to their credentials. + Example: {"openai": {"api_key": "..."}, "langfuse": {"public_key": "..."}} + """ + + credential: Dict[str, Any] = Field( + default=None, + description="Dictionary mapping provider names to their credentials", + ) class CredsUpdate(SQLModel): - credential: Optional[Dict[str, Any]] = Field( - default=None, sa_column=sa.Column(sa.JSON) + """Update credentials for an organization. + Can update a specific provider's credentials or add a new provider. + """ + + provider: str = Field( + description="Name of the provider to update/add credentials for" + ) + credential: Dict[str, Any] = Field( + description="Credentials for the specified provider", + ) + is_active: Optional[bool] = Field( + default=None, description="Whether the credentials are active" + ) + project_id: Optional[int] = Field( + default=None, description="Project ID to associate with these credentials" ) - is_active: Optional[bool] = Field(default=None) class Credential(CredsBase, table=True): + """Database model for storing provider credentials. + Each row represents credentials for a single provider. + """ + id: int = Field(default=None, primary_key=True) - credential: Dict[str, Any] = Field(default=None, sa_column=sa.Column(sa.JSON)) + provider: str = Field( + index=True, description="Provider name like 'openai', 'gemini'" + ) + credential: str = Field( + sa_column=sa.Column(sa.String), + description="Encrypted provider-specific credentials", + ) inserted_at: datetime = Field( default_factory=now, sa_column=sa.Column(sa.DateTime, default=datetime.utcnow), @@ -38,11 +69,33 @@ class Credential(CredsBase, table=True): ) organization: Optional["Organization"] = Relationship(back_populates="creds") + project: Optional["Project"] = Relationship(back_populates="creds") + + def to_public(self) -> "CredsPublic": + """Convert the database model to a public model with decrypted credentials.""" + from app.core.security import decrypt_credentials + + return CredsPublic( + id=self.id, + organization_id=self.organization_id, + project_id=self.project_id, + is_active=self.is_active, + provider=self.provider, + credential=decrypt_credentials(self.credential) + if self.credential + else None, + inserted_at=self.inserted_at, + updated_at=self.updated_at, + deleted_at=self.deleted_at, + ) class CredsPublic(CredsBase): + """Public representation of credentials, excluding sensitive information.""" + id: int - credential: Dict[str, Any] + provider: str + credential: Optional[Dict[str, Any]] = None inserted_at: datetime updated_at: datetime deleted_at: Optional[datetime] diff --git a/backend/app/models/project.py b/backend/app/models/project.py index 19a5cc1aa..0d6bddc0b 100644 --- a/backend/app/models/project.py +++ b/backend/app/models/project.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Optional +from typing import Optional, List from sqlmodel import Field, Relationship, SQLModel from app.core.util import now @@ -34,7 +34,9 @@ class Project(ProjectBase, table=True): users: list["ProjectUser"] = Relationship( back_populates="project", cascade_delete=True ) - + creds: list["Credential"] = Relationship( + back_populates="project", sa_relationship_kwargs={"cascade": "all, delete"} + ) organization: Optional["Organization"] = Relationship(back_populates="project") diff --git a/backend/app/tests/api/routes/test_creds.py b/backend/app/tests/api/routes/test_creds.py index 2f1a0b78e..c931c20cb 100644 --- a/backend/app/tests/api/routes/test_creds.py +++ b/backend/app/tests/api/routes/test_creds.py @@ -2,19 +2,16 @@ from fastapi.testclient import TestClient from sqlmodel import Session import random -import string, datetime +import string from app.main import app from app.api.deps import get_db -from app.crud.credentials import ( - set_creds_for_org, - get_creds_by_org, - remove_creds_for_org, -) -from app.models import CredsCreate, CredsUpdate, Organization, OrganizationCreate -from app.utils import APIResponse -from app.tests.utils.utils import random_lower_string +from app.crud.credentials import set_creds_for_org +from app.models import CredsCreate, Organization, OrganizationCreate from app.core.config import settings +from app.core.security import encrypt_api_key +from app.core.providers import Provider +from app.models.credentials import Credential client = TestClient(app) @@ -24,37 +21,45 @@ def generate_random_string(length=10): @pytest.fixture -def create_organization_and_creds(db: Session, superuser_token_headers: dict[str, str]): - unique_org_name = "Test Organization " + generate_random_string( - 5 - ) # Ensure unique name - org_data = OrganizationCreate(name=unique_org_name, is_active=True) - org = Organization(**org_data.dict()) # Create Organization instance +def create_organization_and_creds(db: Session): + unique_org_name = "Test Organization " + generate_random_string(5) + org = Organization(name=unique_org_name, is_active=True) db.add(org) db.commit() db.refresh(org) + api_key = "sk-" + generate_random_string(10) creds_data = CredsCreate( organization_id=org.id, is_active=True, - credential={"openai": {"api_key": "sk-" + generate_random_string(10)}}, + credential={ + Provider.OPENAI.value: { + "api_key": api_key, + "model": "gpt-4", + "temperature": 0.7, + } + }, ) return org, creds_data def test_set_creds_for_org(db: Session, superuser_token_headers: dict[str, str]): - unique_name = "Test Organization " + generate_random_string(5) - - new_org = Organization(name=unique_name, is_active=True) - db.add(new_org) + org = Organization(name="Org for Set Creds", is_active=True) + db.add(org) db.commit() - db.refresh(new_org) + db.refresh(org) api_key = "sk-" + generate_random_string(10) creds_data = { - "organization_id": new_org.id, + "organization_id": org.id, "is_active": True, - "credential": {"openai": {"api_key": api_key}}, + "credential": { + Provider.OPENAI.value: { + "api_key": api_key, + "model": "gpt-4", + "temperature": 0.7, + } + }, } response = client.post( @@ -64,105 +69,74 @@ def test_set_creds_for_org(db: Session, superuser_token_headers: dict[str, str]) ) assert response.status_code == 200 - created_creds = response.json() - assert "data" in created_creds - assert created_creds["data"]["organization_id"] == new_org.id - assert created_creds["data"]["credential"]["openai"]["api_key"] == api_key + data = response.json()["data"] + assert isinstance(data, list) + assert len(data) == 1 + assert data[0]["organization_id"] == org.id + assert data[0]["provider"] == Provider.OPENAI.value + assert data[0]["credential"]["model"] == "gpt-4" -# Test reading credentials def test_read_credentials_with_creds( db: Session, superuser_token_headers: dict[str, str], create_organization_and_creds ): - # Create the organization and credentials (this time with credentials) org, creds_data = create_organization_and_creds - # Create credentials for the organization set_creds_for_org(session=db, creds_add=creds_data) - # Case 3: Organization exists and credentials are found response = client.get( - f"{settings.API_V1_STR}/credentials/{org.id}", headers=superuser_token_headers + f"{settings.API_V1_STR}/credentials/{org.id}", + headers=superuser_token_headers, ) + assert response.status_code == 200 - response_data = response.json() - assert "data" in response_data - assert response_data["data"]["organization_id"] == org.id - assert "credential" in response_data["data"] - assert ( - response_data["data"]["credential"]["openai"]["api_key"] - == creds_data.credential["openai"]["api_key"] - ) + data = response.json()["data"] + assert isinstance(data, list) + assert len(data) == 1 + assert data[0]["organization_id"] == org.id + assert data[0]["provider"] == Provider.OPENAI.value + assert data[0]["credential"]["model"] == "gpt-4" def test_read_credentials_not_found( - db: Session, superuser_token_headers: dict[str, str], create_organization_and_creds + db: Session, superuser_token_headers: dict[str, str] ): - # Create the organization without credentials - org, _ = create_organization_and_creds - - # Case 1: Organization exists but no credentials - response = client.get( - f"{settings.API_V1_STR}/credentials/{org.id}", headers=superuser_token_headers - ) - - # Assert that the status code is 404 - assert response.status_code == 404 - response_data = response.json() - - # Assert the correct error message - assert response_data["detail"] == "Credentials not found" - - # Case 2: Organization does not exist - non_existing_org_id = 999 # Assuming this ID does not exist response = client.get( - f"{settings.API_V1_STR}/credentials/{non_existing_org_id}", + f"{settings.API_V1_STR}/credentials/999999", headers=superuser_token_headers, ) - - # Assert that the status code is 404 assert response.status_code == 404 - response_data = response.json() - - # Assert the correct error message - assert response_data["detail"] == "Credentials not found" + assert "Credentials not found" in response.json()["detail"] -def test_read_api_key( +def test_read_provider_credential( db: Session, superuser_token_headers: dict[str, str], create_organization_and_creds ): org, creds_data = create_organization_and_creds set_creds_for_org(session=db, creds_add=creds_data) response = client.get( - f"{settings.API_V1_STR}/credentials/{org.id}/api-key", + f"{settings.API_V1_STR}/credentials/{org.id}/{Provider.OPENAI.value}", headers=superuser_token_headers, ) assert response.status_code == 200 - response_data = response.json() - - assert "data" in response_data - - assert "api_key" in response_data["data"] - assert ( - response_data["data"]["api_key"] == creds_data.credential["openai"]["api_key"] - ) + data = response.json()["data"] + assert data["model"] == "gpt-4" + assert "api_key" in data -def test_read_api_key_not_found( +def test_read_provider_credential_not_found( db: Session, superuser_token_headers: dict[str, str], create_organization_and_creds ): org, _ = create_organization_and_creds response = client.get( - f"{settings.API_V1_STR}/credentials/{org.id}/api-key", + f"{settings.API_V1_STR}/credentials/{org.id}/{Provider.OPENAI.value}", headers=superuser_token_headers, ) assert response.status_code == 404 - - response_data = response.json() - assert response_data["detail"] == "API key not found" + assert response.json()["detail"] == "Provider credentials not found" def test_update_credentials( @@ -172,12 +146,12 @@ def test_update_credentials( set_creds_for_org(session=db, creds_add=creds_data) update_data = { + "provider": Provider.OPENAI.value, "credential": { - "openai": { - "api_key": "sk-" - + generate_random_string() # Generate a new API key for the update - } - } + "api_key": "sk-" + generate_random_string(), + "model": "gpt-4-turbo", + "temperature": 0.8, + }, } response = client.patch( @@ -186,75 +160,270 @@ def test_update_credentials( headers=superuser_token_headers, ) - print(response.json()) - assert response.status_code == 200 - response_data = response.json() - - assert "data" in response_data - - assert ( - response_data["data"]["credential"]["openai"]["api_key"] - == update_data["credential"]["openai"]["api_key"] - ) - - assert response_data["data"]["updated_at"] is not None + data = response.json()["data"] + assert isinstance(data, list) + assert len(data) == 1 + assert data[0]["provider"] == Provider.OPENAI.value + assert data[0]["credential"]["model"] == "gpt-4-turbo" + assert data[0]["updated_at"] is not None def test_update_credentials_not_found( db: Session, superuser_token_headers: dict[str, str] ): - update_data = {"credential": {"openai": "sk-" + generate_random_string()}} + # Create a non-existent organization ID + non_existent_org_id = 999999 + + update_data = { + "provider": Provider.OPENAI.value, + "credential": { + "api_key": "sk-" + generate_random_string(), + "model": "gpt-4", + "temperature": 0.7, + }, + } response = client.patch( - f"{settings.API_V1_STR}/credentials/999", + f"{settings.API_V1_STR}/credentials/{non_existent_org_id}", json=update_data, headers=superuser_token_headers, ) - assert response.status_code == 404 - assert response.json()["detail"] == "Credentials not found" + + assert response.status_code == 500 # Expect 404 for non-existent organization + assert "Organization not found" in response.json()["detail"] -def test_delete_credentials( +def test_delete_provider_credential( db: Session, superuser_token_headers: dict[str, str], create_organization_and_creds ): org, creds_data = create_organization_and_creds set_creds_for_org(session=db, creds_add=creds_data) response = client.delete( - f"{settings.API_V1_STR}/credentials/{org.id}/api-key", + f"{settings.API_V1_STR}/credentials/{org.id}/{Provider.OPENAI.value}", headers=superuser_token_headers, ) assert response.status_code == 200 - response_data = response.json() - print(f"Response Data: {response_data}") + data = response.json()["data"] + assert data["message"] == "Provider credentials removed successfully" - assert "data" in response_data - assert "message" in response_data["data"] - assert response_data["data"]["message"] == "Credentials deleted successfully" +def test_delete_provider_credential_not_found( + db: Session, superuser_token_headers: dict[str, str], create_organization_and_creds +): + org, _ = create_organization_and_creds + + response = client.delete( + f"{settings.API_V1_STR}/credentials/{org.id}/{Provider.OPENAI.value}", + headers=superuser_token_headers, + ) + + assert response.status_code == 404 # Expect 404 for not found + assert response.json()["detail"] == "Provider credentials not found" + + +def test_delete_all_credentials( + db: Session, superuser_token_headers: dict[str, str], create_organization_and_creds +): + org, creds_data = create_organization_and_creds + set_creds_for_org(session=db, creds_add=creds_data) + + response = client.delete( + f"{settings.API_V1_STR}/credentials/{org.id}", + headers=superuser_token_headers, + ) + + assert response.status_code == 200 # Expect 200 for successful deletion + data = response.json()["data"] + assert data["message"] == "Credentials deleted successfully" + + # Verify the credentials are soft deleted response = client.get( - f"{settings.API_V1_STR}/credentials/{org.id}", headers=superuser_token_headers + f"{settings.API_V1_STR}/credentials/{org.id}", + headers=superuser_token_headers, + ) + assert response.status_code == 404 # Expect 404 as credentials are soft deleted + assert response.json()["detail"] == "Credentials not found" + + +def test_delete_all_credentials_not_found( + db: Session, superuser_token_headers: dict[str, str] +): + response = client.delete( + f"{settings.API_V1_STR}/credentials/999999", + headers=superuser_token_headers, ) - response_data = response.json() + assert response.status_code == 500 # Expect 404 for not found + assert "Credentials for organization not found" in response.json()["detail"] + + +def test_duplicate_credential_creation( + db: Session, superuser_token_headers: dict[str, str], create_organization_and_creds +): + org, creds_data = create_organization_and_creds + # First create credentials + response = client.post( + f"{settings.API_V1_STR}/credentials/", + json=creds_data.dict(), + headers=superuser_token_headers, + ) assert response.status_code == 200 - assert response_data["data"]["deleted_at"] is not None - assert response_data["data"]["is_active"] is False + + # Try to create the same credentials again + response = client.post( + f"{settings.API_V1_STR}/credentials/", + json=creds_data.dict(), + headers=superuser_token_headers, + ) + assert response.status_code == 500 + assert "already exist" in response.json()["detail"] -def test_delete_credentials_not_found( +def test_multiple_provider_credentials( db: Session, superuser_token_headers: dict[str, str], create_organization_and_creds ): org, _ = create_organization_and_creds - response = client.delete( - f"{settings.API_V1_STR}/credentials/{org.id}/api-key", + # Create OpenAI credentials + openai_creds = { + "organization_id": org.id, + "is_active": True, + "credential": { + Provider.OPENAI.value: { + "api_key": "sk-" + generate_random_string(10), + "model": "gpt-4", + "temperature": 0.7, + } + }, + } + + # Create Langfuse credentials + langfuse_creds = { + "organization_id": org.id, + "is_active": True, + "credential": { + Provider.LANGFUSE.value: { + "secret_key": "sk-" + generate_random_string(10), + "public_key": "pk-" + generate_random_string(10), + "host": "https://cloud.langfuse.com", + } + }, + } + + # Create both credentials + response = client.post( + f"{settings.API_V1_STR}/credentials/", + json=openai_creds, headers=superuser_token_headers, ) + assert response.status_code == 200 - assert response.status_code == 404 - response_data = response.json() + response = client.post( + f"{settings.API_V1_STR}/credentials/", + json=langfuse_creds, + headers=superuser_token_headers, + ) + assert response.status_code == 200 + + # Fetch all credentials + response = client.get( + f"{settings.API_V1_STR}/credentials/{org.id}", + headers=superuser_token_headers, + ) + assert response.status_code == 200 + data = response.json()["data"] + assert len(data) == 2 + providers = [cred["provider"] for cred in data] + assert Provider.OPENAI.value in providers + assert Provider.LANGFUSE.value in providers + + +def test_credential_encryption( + db: Session, superuser_token_headers: dict[str, str], create_organization_and_creds +): + org, creds_data = create_organization_and_creds + original_api_key = creds_data.credential[Provider.OPENAI.value]["api_key"] - assert response_data["detail"] == "Credentials for organization not found" + # Create credentials + response = client.post( + f"{settings.API_V1_STR}/credentials/", + json=creds_data.dict(), + headers=superuser_token_headers, + ) + assert response.status_code == 200 + + # Get the raw credential from database to verify encryption + from app.core.security import decrypt_credentials + + db_cred = ( + db.query(Credential) + .filter( + Credential.organization_id == org.id, + Credential.provider == Provider.OPENAI.value, + ) + .first() + ) + + assert db_cred is not None + # Verify the stored credential is encrypted + assert db_cred.credential != original_api_key + + # Verify we can decrypt and get the original value + decrypted_creds = decrypt_credentials(db_cred.credential) + assert decrypted_creds["api_key"] == original_api_key + + +def test_credential_encryption_consistency( + db: Session, superuser_token_headers: dict[str, str], create_organization_and_creds +): + org, creds_data = create_organization_and_creds + original_api_key = creds_data.credential[Provider.OPENAI.value]["api_key"] + + # Create credentials + response = client.post( + f"{settings.API_V1_STR}/credentials/", + json=creds_data.dict(), + headers=superuser_token_headers, + ) + assert response.status_code == 200 + + # Fetch the credentials through the API + response = client.get( + f"{settings.API_V1_STR}/credentials/{org.id}/{Provider.OPENAI.value}", + headers=superuser_token_headers, + ) + assert response.status_code == 200 + data = response.json()["data"] + + # Verify the API returns the decrypted value + assert data["api_key"] == original_api_key + + # Update the credentials + new_api_key = "sk-" + generate_random_string(10) + update_data = { + "provider": Provider.OPENAI.value, + "credential": { + "api_key": new_api_key, + "model": "gpt-4", + "temperature": 0.7, + }, + } + + response = client.patch( + f"{settings.API_V1_STR}/credentials/{org.id}", + json=update_data, + headers=superuser_token_headers, + ) + assert response.status_code == 200 + + # Verify the updated value is also properly encrypted/decrypted + response = client.get( + f"{settings.API_V1_STR}/credentials/{org.id}/{Provider.OPENAI.value}", + headers=superuser_token_headers, + ) + assert response.status_code == 200 + data = response.json()["data"] + assert data["api_key"] == new_api_key diff --git a/backend/app/tests/core/test_providers.py b/backend/app/tests/core/test_providers.py new file mode 100644 index 000000000..1cbf10755 --- /dev/null +++ b/backend/app/tests/core/test_providers.py @@ -0,0 +1,31 @@ +import pytest +from app.core.providers import ( + validate_provider, + validate_provider_credentials, + Provider, +) + + +def test_validate_provider_invalid(): + """Test validating an invalid provider name.""" + with pytest.raises(ValueError) as exc_info: + validate_provider("invalid_provider") + assert "Unsupported provider" in str(exc_info.value) + assert "openai" in str(exc_info.value) # Check that supported providers are listed + + +def test_validate_provider_credentials_missing_fields(): + """Test validating provider credentials with missing required fields.""" + # Test OpenAI missing api_key + with pytest.raises(ValueError) as exc_info: + validate_provider_credentials("openai", {}) + assert "Missing required fields" in str(exc_info.value) + assert "api_key" in str(exc_info.value) + + # Test AWS missing region + with pytest.raises(ValueError) as exc_info: + validate_provider_credentials( + "aws", {"access_key_id": "test-id", "secret_access_key": "test-secret"} + ) + assert "Missing required fields" in str(exc_info.value) + assert "region" in str(exc_info.value) diff --git a/backend/app/tests/crud/test_credentials.py b/backend/app/tests/crud/test_credentials.py new file mode 100644 index 000000000..f47ea8d67 --- /dev/null +++ b/backend/app/tests/crud/test_credentials.py @@ -0,0 +1,285 @@ +import uuid +from sqlmodel import Session +import pytest +from datetime import datetime + +from app.crud import credentials as credentials_crud +from app.models import Credential, CredsCreate, CredsUpdate, Organization, Project +from app.tests.utils.utils import random_email +from app.core.security import get_password_hash + + +def create_organization_and_project(db: Session) -> tuple[Organization, Project]: + """Helper function to create an organization and a project.""" + organization = Organization( + name=f"Test Organization {uuid.uuid4()}", is_active=True + ) + db.add(organization) + db.commit() + db.refresh(organization) + + project = Project( + name=f"Test Project {uuid.uuid4()}", + description="A test project", + organization_id=organization.id, + is_active=True, + ) + db.add(project) + db.commit() + db.refresh(project) + + return organization, project + + +def test_set_creds_for_org(db: Session) -> None: + """Test setting credentials for an organization.""" + organization, _ = create_organization_and_project(db) + + # Test credentials for supported providers + creds_data = { + "openai": {"api_key": "test-openai-key"}, + "langfuse": { + "public_key": "test-public-key", + "secret_key": "test-secret-key", + "host": "https://cloud.langfuse.com", + }, + } + + creds_create = CredsCreate(organization_id=organization.id, credential=creds_data) + + created_creds = credentials_crud.set_creds_for_org( + session=db, creds_add=creds_create + ) + + assert len(created_creds) == 2 + assert all(cred.organization_id == organization.id for cred in created_creds) + assert all(cred.is_active for cred in created_creds) + assert {cred.provider for cred in created_creds} == {"openai", "langfuse"} + + +def test_set_creds_for_org_with_project(db: Session) -> None: + """Test setting credentials for an organization with a specific project.""" + organization, project = create_organization_and_project(db) + + creds_data = {"openai": {"api_key": "test-openai-key"}} + + creds_create = CredsCreate( + organization_id=organization.id, project_id=project.id, credential=creds_data + ) + + created_creds = credentials_crud.set_creds_for_org( + session=db, creds_add=creds_create + ) + + assert len(created_creds) == 1 + assert created_creds[0].organization_id == organization.id + assert created_creds[0].project_id == project.id + assert created_creds[0].provider == "openai" + assert created_creds[0].is_active + + +def test_get_creds_by_org(db: Session) -> None: + """Test retrieving all credentials for an organization.""" + organization, _ = create_organization_and_project(db) + + # Set up test credentials + creds_data = { + "openai": {"api_key": "test-openai-key"}, + "langfuse": { + "public_key": "test-public-key", + "secret_key": "test-secret-key", + "host": "https://cloud.langfuse.com", + }, + } + + creds_create = CredsCreate(organization_id=organization.id, credential=creds_data) + credentials_crud.set_creds_for_org(session=db, creds_add=creds_create) + + # Test retrieving credentials + retrieved_creds = credentials_crud.get_creds_by_org( + session=db, org_id=organization.id + ) + + assert len(retrieved_creds) == 2 + assert all(cred.organization_id == organization.id for cred in retrieved_creds) + assert {cred.provider for cred in retrieved_creds} == {"openai", "langfuse"} + + +def test_get_provider_credential(db: Session) -> None: + """Test retrieving credentials for a specific provider.""" + organization, _ = create_organization_and_project(db) + + # Set up test credentials + creds_data = {"openai": {"api_key": "test-openai-key"}} + + creds_create = CredsCreate(organization_id=organization.id, credential=creds_data) + credentials_crud.set_creds_for_org(session=db, creds_add=creds_create) + + # Test retrieving specific provider credentials + retrieved_cred = credentials_crud.get_provider_credential( + session=db, org_id=organization.id, provider="openai" + ) + + assert retrieved_cred is not None + assert "api_key" in retrieved_cred + assert retrieved_cred["api_key"] == "test-openai-key" + + +def test_update_creds_for_org(db: Session) -> None: + """Test updating credentials for a provider.""" + organization, _ = create_organization_and_project(db) + + # Set up initial credentials + initial_creds = {"openai": {"api_key": "initial-key"}} + creds_create = CredsCreate( + organization_id=organization.id, credential=initial_creds + ) + credentials_crud.set_creds_for_org(session=db, creds_add=creds_create) + + # Update credentials + updated_creds = {"api_key": "updated-key"} + creds_update = CredsUpdate(provider="openai", credential=updated_creds) + + updated = credentials_crud.update_creds_for_org( + session=db, org_id=organization.id, creds_in=creds_update + ) + + assert len(updated) == 1 + assert updated[0].provider == "openai" + retrieved_cred = credentials_crud.get_provider_credential( + session=db, org_id=organization.id, provider="openai" + ) + assert retrieved_cred["api_key"] == "updated-key" + + +def test_remove_provider_credential(db: Session) -> None: + """Test removing credentials for a specific provider.""" + organization, _ = create_organization_and_project(db) + + # Set up test credentials + creds_data = { + "openai": {"api_key": "test-openai-key"}, + "langfuse": { + "public_key": "test-public-key", + "secret_key": "test-secret-key", + "host": "https://cloud.langfuse.com", + }, + } + + creds_create = CredsCreate(organization_id=organization.id, credential=creds_data) + credentials_crud.set_creds_for_org(session=db, creds_add=creds_create) + + # Remove one provider's credentials + removed = credentials_crud.remove_provider_credential( + session=db, org_id=organization.id, provider="openai" + ) + + assert removed.is_active is False + assert removed.updated_at is not None + + # Verify the credentials are no longer retrievable + retrieved_cred = credentials_crud.get_provider_credential( + session=db, org_id=organization.id, provider="openai" + ) + assert retrieved_cred is None + + +def test_remove_creds_for_org(db: Session) -> None: + """Test removing all credentials for an organization.""" + organization, _ = create_organization_and_project(db) + + # Set up test credentials + creds_data = { + "openai": {"api_key": "test-openai-key"}, + "langfuse": { + "public_key": "test-public-key", + "secret_key": "test-secret-key", + "host": "https://cloud.langfuse.com", + }, + } + + creds_create = CredsCreate(organization_id=organization.id, credential=creds_data) + credentials_crud.set_creds_for_org(session=db, creds_add=creds_create) + + # Remove all credentials + removed = credentials_crud.remove_creds_for_org(session=db, org_id=organization.id) + + assert len(removed) == 2 + assert all(not cred.is_active for cred in removed) + assert all(cred.updated_at is not None for cred in removed) + + # Verify no credentials are retrievable + retrieved_creds = credentials_crud.get_creds_by_org( + session=db, org_id=organization.id + ) + assert len(retrieved_creds) == 0 + + +def test_invalid_provider(db: Session) -> None: + """Test handling of invalid provider names.""" + organization, _ = create_organization_and_project(db) + + # Test with unsupported provider + creds_data = {"gemini": {"api_key": "test-key"}} + + creds_create = CredsCreate(organization_id=organization.id, credential=creds_data) + + with pytest.raises(ValueError, match="Unsupported provider"): + credentials_crud.set_creds_for_org(session=db, creds_add=creds_create) + + +def test_duplicate_provider_credentials(db: Session) -> None: + """Test handling of duplicate provider credentials.""" + organization, _ = create_organization_and_project(db) + + # Set up initial credentials + creds_data = {"openai": {"api_key": "test-key"}} + + creds_create = CredsCreate(organization_id=organization.id, credential=creds_data) + credentials_crud.set_creds_for_org(session=db, creds_add=creds_create) + + # Verify credentials exist and are active + existing_creds = credentials_crud.get_provider_credential( + session=db, org_id=organization.id, provider="openai" + ) + assert existing_creds is not None + assert "api_key" in existing_creds + assert existing_creds["api_key"] == "test-key" + + +def test_langfuse_credential_validation(db: Session) -> None: + """Test validation of Langfuse credentials structure.""" + organization, _ = create_organization_and_project(db) + + # Test with missing required fields + invalid_creds = { + "langfuse": { + "public_key": "test-public-key", + "secret_key": "test-secret-key" + # Missing host + } + } + + creds_create = CredsCreate( + organization_id=organization.id, credential=invalid_creds + ) + + with pytest.raises(ValueError): + credentials_crud.set_creds_for_org(session=db, creds_add=creds_create) + + # Test with valid Langfuse credentials + valid_creds = { + "langfuse": { + "public_key": "test-public-key", + "secret_key": "test-secret-key", + "host": "https://cloud.langfuse.com", + } + } + + creds_create = CredsCreate(organization_id=organization.id, credential=valid_creds) + + created_creds = credentials_crud.set_creds_for_org( + session=db, creds_add=creds_create + ) + assert len(created_creds) == 1 + assert created_creds[0].provider == "langfuse" diff --git a/backend/app/tests/crud/test_creds.py b/backend/app/tests/crud/test_creds.py deleted file mode 100644 index 7de6d54d1..000000000 --- a/backend/app/tests/crud/test_creds.py +++ /dev/null @@ -1,133 +0,0 @@ -import pytest -import random -import string -from fastapi.testclient import TestClient -from sqlmodel import Session -from sqlalchemy.exc import IntegrityError -from datetime import datetime - -from app.models import ( - Credential, - CredsCreate, - Organization, - OrganizationCreate, - CredsUpdate, -) -from app.crud.credentials import ( - set_creds_for_org, - get_creds_by_org, - get_key_by_org, - remove_creds_for_org, - update_creds_for_org, -) -from app.main import app -from app.utils import APIResponse -from app.core.config import settings - -client = TestClient(app) - - -# Helper function to generate random API key -def generate_random_string(length=10): - return "".join(random.choices(string.ascii_letters + string.digits, k=length)) - - -@pytest.fixture -def test_credential(db: Session): - # Create a unique organization name - unique_org_name = "Test Organization " + generate_random_string( - 5 - ) # Ensure unique name - - # Check if the organization already exists by name - existing_org = ( - db.query(Organization).filter(Organization.name == unique_org_name).first() - ) - - if existing_org: - org = existing_org # If organization exists, use the existing one - else: - # If not, create a new organization - organization_data = OrganizationCreate(name=unique_org_name, is_active=True) - org = Organization(**organization_data.dict()) # Create Organization instance - db.add(org) # Add to the session - - try: - db.commit() # Commit to save the organization to the database - db.refresh(org) # Refresh to get the organization_id - except IntegrityError as e: - db.rollback() # Rollback the transaction in case of an error (e.g., duplicate key) - raise ValueError(f"Error during organization commit: {str(e)}") - - # Generate a random API key for the test - api_key = "sk-" + generate_random_string(10) - - # Create the credentials using the mock organization_id - creds_data = CredsCreate( - organization_id=org.id, # Use the created organization_id - is_active=True, - credential={"openai": {"api_key": api_key}}, - ) - - creds = set_creds_for_org(session=db, creds_add=creds_data) - return creds - - -def test_create_credentials(db: Session, test_credential): - creds = test_credential # Using the fixture - assert creds is not None - assert creds.credential["openai"]["api_key"].startswith("sk-") - assert creds.is_active is True - assert creds.inserted_at is not None # Ensure inserted_at is set - - -def test_get_creds_by_org(db: Session, test_credential): - creds = test_credential # Using the fixture - retrieved_creds = get_creds_by_org(session=db, org_id=creds.organization_id) - - assert retrieved_creds is not None - assert retrieved_creds.organization_id == creds.organization_id - assert retrieved_creds.inserted_at is not None # Ensure inserted_at is not None - - -def test_update_creds_for_org(db: Session, test_credential): - creds = test_credential # Using the fixture - updated_creds_data = CredsUpdate(credential={"openai": {"api_key": "sk-newkey"}}) - - updated_creds = update_creds_for_org( - session=db, org_id=creds.organization_id, creds_in=updated_creds_data - ) - - assert updated_creds is not None - assert updated_creds.credential["openai"]["api_key"] == "sk-newkey" - assert updated_creds.updated_at is not None # Ensure updated_at is set - - -def test_remove_creds_for_org(db: Session, test_credential): - creds = test_credential # Using the fixture - removed_creds = remove_creds_for_org(session=db, org_id=creds.organization_id) - - assert removed_creds is not None - assert removed_creds.organization_id == creds.organization_id - - # Ensure the deleted_at timestamp is set for soft delete - assert removed_creds.deleted_at is not None # Ensure deleted_at is set - - # Check that credentials are soft deleted and not removed - deleted_creds = ( - db.query(Credential) - .filter(Credential.organization_id == creds.organization_id) - .first() - ) - assert deleted_creds is not None # Ensure the record still exists in the DB - assert deleted_creds.deleted_at is not None # Ensure it's marked as deleted - - -def test_remove_creds_for_org_not_found(db: Session): - # Try to remove credentials for a non-existent organization ID (999) - non_existing_org_id = 999 - - removed_creds = remove_creds_for_org(session=db, org_id=non_existing_org_id) - - # Assert that no credentials were removed since they don't exist - assert removed_creds is None