diff --git a/backend/app/alembic/versions/904ed70e7dab_added_provider_column_to_the_credential_.py b/backend/app/alembic/versions/904ed70e7dab_added_provider_column_to_the_credential_.py new file mode 100644 index 000000000..2864516b1 --- /dev/null +++ b/backend/app/alembic/versions/904ed70e7dab_added_provider_column_to_the_credential_.py @@ -0,0 +1,68 @@ +"""Added provider column to the credential table + +Revision ID: 904ed70e7dab +Revises: 543f97951bd0 +Create Date: 2025-05-10 11:13:17.868238 + +""" +from alembic import op +import sqlalchemy as sa +import sqlmodel.sql.sqltypes + + +# revision identifiers, used by Alembic. +revision = "904ed70e7dab" +down_revision = "f23675767ed2" +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.add_column( + "credential", + sa.Column("provider", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + ) + op.create_index( + op.f("ix_credential_provider"), "credential", ["provider"], unique=False + ) + op.drop_constraint( + "credential_organization_id_fkey", "credential", type_="foreignkey" + ) + op.create_foreign_key( + "credential_organization_id_fkey", + "credential", + "organization", + ["organization_id"], + ["id"], + ondelete="CASCADE", + ) + op.drop_constraint("project_organization_id_fkey", "project", type_="foreignkey") + op.create_foreign_key(None, "project", "organization", ["organization_id"], ["id"]) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint(None, "project", type_="foreignkey") + op.create_foreign_key( + "project_organization_id_fkey", + "project", + "organization", + ["organization_id"], + ["id"], + ondelete="CASCADE", + ) + op.drop_constraint( + "credential_organization_id_fkey", "credential", type_="foreignkey" + ) + op.create_foreign_key( + "credential_organization_id_fkey", + "credential", + "organization", + ["organization_id"], + ["id"], + ) + op.drop_index(op.f("ix_credential_provider"), table_name="credential") + op.drop_column("credential", "provider") + # ### end Alembic commands ### diff --git a/backend/app/api/routes/credentials.py b/backend/app/api/routes/credentials.py index ccb68a1bc..5a2da39b0 100644 --- a/backend/app/api/routes/credentials.py +++ b/backend/app/api/routes/credentials.py @@ -2,14 +2,19 @@ from app.api.deps import SessionDep, get_current_active_superuser from app.crud.credentials import ( get_creds_by_org, - get_key_by_org, + get_provider_credential, remove_creds_for_org, set_creds_for_org, update_creds_for_org, + remove_provider_credential, ) from app.models import CredsCreate, CredsPublic, CredsUpdate from app.utils import APIResponse from datetime import datetime +from app.core.providers import validate_provider +from typing import List +from sqlalchemy.exc import IntegrityError +from app.models.organization import Organization router = APIRouter(prefix="/credentials", tags=["credentials"]) @@ -17,99 +22,178 @@ @router.post( "/", dependencies=[Depends(get_current_active_superuser)], - response_model=APIResponse[CredsPublic], + response_model=APIResponse[List[CredsPublic]], + summary="Create new credentials for an organization", + description="Creates new credentials for a specific organization. This endpoint requires superuser privileges. If credentials already exist for the organization, it will return an error.", ) def create_new_credential(*, session: SessionDep, creds_in: CredsCreate): - new_creds = None try: existing_creds = get_creds_by_org( session=session, org_id=creds_in.organization_id ) - if not existing_creds: - new_creds = set_creds_for_org(session=session, creds_add=creds_in) + if existing_creds: + raise HTTPException( + status_code=400, + detail="Credentials already exist for this organization", + ) + + new_creds = set_creds_for_org(session=session, creds_add=creds_in) + if not new_creds: + raise HTTPException(status_code=500, detail="Failed to create credentials") + + # Return all created credentials + return APIResponse.success_response(new_creds) + + except ValueError as e: + if "Unsupported provider" in str(e): + raise HTTPException(status_code=400, detail=str(e)) + raise HTTPException(status_code=404, detail=str(e)) except Exception as e: raise HTTPException( status_code=500, detail=f"An unexpected error occurred: {str(e)}" ) - # Ensure inserted_at is set during creation - new_creds.inserted_at = datetime.utcnow() - - return APIResponse.success_response(new_creds) - @router.get( "/{org_id}", dependencies=[Depends(get_current_active_superuser)], - response_model=APIResponse[CredsPublic], + response_model=APIResponse[List[CredsPublic]], + summary="Get all credentials for an organization", + description="Retrieves all provider credentials associated with a specific organization. This endpoint requires superuser privileges.", ) def read_credential(*, session: SessionDep, org_id: int): try: creds = get_creds_by_org(session=session, org_id=org_id) + if not creds: + raise HTTPException(status_code=404, detail="Credentials not found") + return APIResponse.success_response(creds) + except HTTPException as e: + raise e # Ensure HTTPException is not wrapped again except Exception as e: raise HTTPException( status_code=500, detail=f"An unexpected error occurred: {str(e)}" ) - if creds is None: - raise HTTPException(status_code=404, detail="Credentials not found") - - return APIResponse.success_response(creds) - @router.get( - "/{org_id}/api-key", + "/{org_id}/{provider}", dependencies=[Depends(get_current_active_superuser)], response_model=APIResponse[dict], + summary="Get specific provider credentials", + description="Retrieves credentials for a specific provider (e.g., 'openai', 'anthropic') for a given organization. This endpoint requires superuser privileges.", ) -def read_api_key(*, session: SessionDep, org_id: int): +def read_provider_credential(*, session: SessionDep, org_id: int, provider: str): try: - api_key = get_key_by_org(session=session, org_id=org_id) + provider_enum = validate_provider(provider) + provider_creds = get_provider_credential( + session=session, org_id=org_id, provider=provider_enum + ) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) except Exception as e: raise HTTPException( status_code=500, detail=f"An unexpected error occurred: {str(e)}" ) - if api_key is None: - raise HTTPException(status_code=404, detail="API key not found") + if provider_creds is None: + raise HTTPException(status_code=404, detail="Provider credentials not found") - return APIResponse.success_response({"api_key": api_key}) + return APIResponse.success_response(provider_creds) @router.patch( "/{org_id}", dependencies=[Depends(get_current_active_superuser)], - response_model=APIResponse[CredsPublic], + response_model=APIResponse[List[CredsPublic]], + summary="Update organization credentials", + description="Updates credentials for a specific organization. Can update specific provider credentials or add new providers. This endpoint requires superuser privileges.", ) def update_credential(*, session: SessionDep, org_id: int, creds_in: CredsUpdate): try: + # Validate incoming payload + if not creds_in or not creds_in.provider or not creds_in.credential: + raise HTTPException( + status_code=400, detail="Provider and credential must be provided" + ) + + # Defensive check to ensure organization exists + try: + organization = session.get(Organization, org_id) + except Exception as e: + raise HTTPException( + status_code=500, detail=f"Failed to fetch organization: {str(e)}" + ) + + if not organization: + raise HTTPException(status_code=404, detail="Organization not found") + updated_creds = update_creds_for_org( session=session, org_id=org_id, creds_in=creds_in ) - - updated_creds.updated_at = datetime.utcnow() + if not updated_creds: + raise HTTPException(status_code=404, detail="Failed to update credentials") return APIResponse.success_response(updated_creds) + + except IntegrityError as e: + if "ForeignKeyViolation" in str(e): + raise HTTPException( + status_code=400, + detail="Invalid organization ID. Ensure the organization exists before updating credentials.", + ) + raise HTTPException( + status_code=500, detail=f"An unexpected database error occurred: {str(e)}" + ) except ValueError as e: + if "Unsupported provider" in str(e): + raise HTTPException(status_code=400, detail=str(e)) raise HTTPException(status_code=404, detail=str(e)) + except HTTPException: + raise except Exception as e: raise HTTPException( status_code=500, detail=f"An unexpected error occurred: {str(e)}" ) -from fastapi import HTTPException, Depends -from app.crud.credentials import remove_creds_for_org -from app.utils import APIResponse -from app.api.deps import SessionDep, get_current_active_superuser +@router.delete( + "/{org_id}/{provider}", + dependencies=[Depends(get_current_active_superuser)], + response_model=APIResponse[dict], + summary="Delete specific provider credentials", + description="Removes credentials for a specific provider while keeping other provider credentials intact. This endpoint requires superuser privileges.", +) +def delete_provider_credential(*, session: SessionDep, org_id: int, provider: str): + try: + provider_enum = validate_provider(provider) + updated_creds = remove_provider_credential( + session=session, org_id=org_id, provider=provider_enum + ) + except ValueError as e: + raise HTTPException( + status_code=404, detail="Provider credentials not found" + ) # Updated to return 404 + except Exception as e: + raise HTTPException( + status_code=500, detail=f"An unexpected error occurred: {str(e)}" + ) + + if not updated_creds: # Ensure proper check for no credentials found + raise HTTPException(status_code=404, detail="Provider credentials not found") + + return APIResponse.success_response( + {"message": "Provider credentials removed successfully"} + ) @router.delete( - "/{org_id}/api-key", + "/{org_id}", dependencies=[Depends(get_current_active_superuser)], response_model=APIResponse[dict], + summary="Delete all organization credentials", + description="Removes all credentials for a specific organization. This is a soft delete operation that marks credentials as inactive. This endpoint requires superuser privileges.", ) -def delete_credential(*, session: SessionDep, org_id: int): +def delete_all_credentials(*, session: SessionDep, org_id: int): try: creds = remove_creds_for_org(session=session, org_id=org_id) except Exception as e: @@ -117,11 +201,9 @@ def delete_credential(*, session: SessionDep, org_id: int): status_code=500, detail=f"An unexpected error occurred: {str(e)}" ) - if creds is None: + if not creds: # Ensure proper check for no credentials found raise HTTPException( status_code=404, detail="Credentials for organization not found" ) - # No need to manually set deleted_at and is_active if it's done in remove_creds_for_org - # Simply return the success response return APIResponse.success_response({"message": "Credentials deleted successfully"}) diff --git a/backend/app/core/providers.py b/backend/app/core/providers.py new file mode 100644 index 000000000..1f513fea9 --- /dev/null +++ b/backend/app/core/providers.py @@ -0,0 +1,58 @@ +from typing import Dict, List, Optional +from enum import Enum + + +class Provider(str, Enum): + """Enumeration of supported credential providers.""" + + OPENAI = "openai" + GEMINI = "gemini" + ANTHROPIC = "anthropic" + MISTRAL = "mistral" + COHERE = "cohere" + HUGGINGFACE = "huggingface" + AZURE = "azure" + AWS = "aws" + GOOGLE = "google" + + +# Required fields for each provider's credentials +PROVIDER_REQUIRED_FIELDS: Dict[str, List[str]] = { + Provider.OPENAI: ["api_key"], + Provider.GEMINI: ["api_key"], + Provider.ANTHROPIC: ["api_key"], + Provider.MISTRAL: ["api_key"], + Provider.COHERE: ["api_key"], + Provider.HUGGINGFACE: ["api_key"], + Provider.AZURE: ["api_key", "endpoint"], + Provider.AWS: ["access_key_id", "secret_access_key", "region"], + Provider.GOOGLE: ["api_key"], +} + + +def validate_provider(provider: str) -> Provider: + """Validate that the provider name is supported and return the Provider enum.""" + try: + return Provider(provider.lower()) + except ValueError: + supported = ", ".join([p.value for p in Provider]) + raise ValueError( + f"Unsupported provider: {provider}. Supported providers are: {supported}" + ) + + +def validate_provider_credentials(provider: str, credentials: dict) -> None: + """Validate that the credentials contain all required fields for the provider.""" + provider_enum = validate_provider(provider) + required_fields = PROVIDER_REQUIRED_FIELDS[provider_enum] + + missing_fields = [field for field in required_fields if field not in credentials] + if missing_fields: + raise ValueError( + f"Missing required fields for {provider}: {', '.join(missing_fields)}" + ) + + +def get_supported_providers() -> List[str]: + """Return a list of all supported provider names.""" + return [p.value for p in Provider] diff --git a/backend/app/crud/credentials.py b/backend/app/crud/credentials.py index a66fa7b78..ff3c01b6f 100644 --- a/backend/app/crud/credentials.py +++ b/backend/app/crud/credentials.py @@ -1,104 +1,207 @@ -from typing import Optional, Dict, Any +from typing import Optional, Dict, Any, List from sqlmodel import Session, select from sqlalchemy.exc import IntegrityError from datetime import datetime from app.models import Credential, CredsCreate, CredsUpdate +from app.core.providers import ( + validate_provider, + validate_provider_credentials, + get_supported_providers, +) +from app.core.security import encrypt_api_key + + +def set_creds_for_org(*, session: Session, creds_add: CredsCreate) -> List[Credential]: + """Set credentials for an organization. Creates a separate row for each provider.""" + created_credentials = [] + + if not creds_add.credential: + raise ValueError("No credentials provided") + + for provider, credentials in creds_add.credential.items(): + # Validate provider and credentials + validate_provider(provider) + validate_provider_credentials(provider, credentials) + + # Encrypt API key if present + if isinstance(credentials, dict) and "api_key" in credentials: + credentials["api_key"] = encrypt_api_key(credentials["api_key"]) + + # Create a row for each provider + credential = Credential( + organization_id=creds_add.organization_id, + is_active=creds_add.is_active, + provider=provider, + credential=credentials, + ) + credential.inserted_at = datetime.utcnow() + try: + session.add(credential) + session.commit() + session.refresh(credential) + created_credentials.append(credential) + except IntegrityError as e: + session.rollback() + raise ValueError( + f"Error while adding credentials for provider {provider}: {str(e)}" + ) + + return created_credentials + + +def get_key_by_org( + *, session: Session, org_id: int, provider: str = "openai" +) -> Optional[str]: + """Fetches the API key from the credentials for the given organization and provider.""" + statement = select(Credential).where( + Credential.organization_id == org_id, + Credential.provider == provider, + Credential.is_active == True, + ) + creds = session.exec(statement).first() + if creds and creds.credential and "api_key" in creds.credential: + return creds.credential["api_key"] -def set_creds_for_org(*, session: Session, creds_add: CredsCreate) -> Credential: - creds = Credential.model_validate(creds_add) - - # Set the inserted_at timestamp (current UTC time) - creds.inserted_at = datetime.utcnow() - - try: - session.add(creds) - session.commit() - session.refresh(creds) - except IntegrityError as e: - session.rollback() # Rollback the session if there's a unique constraint violation - raise ValueError(f"Error while adding credentials: {str(e)}") + return None - return creds +def get_creds_by_org(*, session: Session, org_id: int) -> List[Credential]: + """Fetches all active credentials for the given organization.""" + statement = select(Credential).where( + Credential.organization_id == org_id, Credential.is_active == True + ) + return session.exec(statement).all() -def get_creds_by_org(*, session: Session, org_id: int) -> Optional[Credential]: - """Fetches the credentials for the given organization.""" - statement = select(Credential).where(Credential.organization_id == org_id) - return session.exec(statement).first() +def get_provider_credential( + *, session: Session, org_id: int, provider: str +) -> Optional[Dict[str, Any]]: + """Fetches credentials for a specific provider of an organization.""" + validate_provider(provider) -def get_key_by_org(*, session: Session, org_id: int) -> Optional[str]: - """Fetches the API key from the credentials for the given organization.""" - statement = select(Credential).where(Credential.organization_id == org_id) + statement = select(Credential).where( + Credential.organization_id == org_id, + Credential.provider == provider, + Credential.is_active == True, + ) creds = session.exec(statement).first() - # Check if creds exists and if the credential field contains the api_key - if ( - creds - and creds.credential - and "openai" in creds.credential - and "api_key" in creds.credential["openai"] - ): - return creds.credential["openai"]["api_key"] + return creds.credential if creds else None - return None + +def get_providers(*, session: Session, org_id: int) -> List[str]: + """Returns a list of all active providers for which credentials are stored.""" + creds = get_creds_by_org(session=session, org_id=org_id) + return [cred.provider for cred in creds] def update_creds_for_org( session: Session, org_id: int, creds_in: CredsUpdate -) -> Credential: - # Fetch the current credentials for the organization - creds = session.exec( - select(Credential).where(Credential.organization_id == org_id) - ).first() +) -> List[Credential]: + if not creds_in: + raise ValueError( + "Missing request body or failed to parse JSON into CredsUpdate" + ) + + """Update credentials for an organization. Can update specific provider or add new provider.""" + if not creds_in.provider or not creds_in.credential: + raise ValueError("Provider and credential information must be provided") + + # Validate provider and credentials + validate_provider(creds_in.provider) + validate_provider_credentials(creds_in.provider, creds_in.credential) + + # Encrypt API key if present + if isinstance(creds_in.credential, dict) and "api_key" in creds_in.credential: + creds_in.credential["api_key"] = encrypt_api_key(creds_in.credential["api_key"]) + + # Check if credentials exist for this provider + statement = select(Credential).where( + Credential.organization_id == org_id, Credential.provider == creds_in.provider + ) + existing_cred = session.exec(statement).first() + + if existing_cred: + # Update existing credentials + existing_cred.credential = creds_in.credential + existing_cred.is_active = ( + creds_in.is_active if creds_in.is_active is not None else True + ) + existing_cred.updated_at = datetime.utcnow() + try: + session.add(existing_cred) + session.commit() + session.refresh(existing_cred) + return [existing_cred] + except IntegrityError as e: + session.rollback() + raise ValueError(f"Error while updating credentials: {str(e)}") + else: + # Create new credentials + new_cred = Credential( + organization_id=org_id, + provider=creds_in.provider, + credential=creds_in.credential, + is_active=creds_in.is_active if creds_in.is_active is not None else True, + ) + try: + session.add(new_cred) + session.commit() + session.refresh(new_cred) + return [new_cred] + except IntegrityError as e: + session.rollback() + raise ValueError(f"Error while creating credentials: {str(e)}") - if not creds: - raise ValueError(f"Credentials not found") - # Update the credentials data with the provided values - creds_data = creds_in.dict(exclude_unset=True) +def remove_provider_credential( + session: Session, org_id: int, provider: str +) -> Credential: + """Remove credentials for a specific provider.""" + validate_provider(provider) - # Directly update the fields on the original creds object instead of creating a new one - for key, value in creds_data.items(): - setattr(creds, key, value) + statement = select(Credential).where( + Credential.organization_id == org_id, Credential.provider == provider + ) + creds = session.exec(statement).first() - # Set the updated_at timestamp (current UTC time) + if not creds: + raise ValueError(f"Credentials not found for provider '{provider}'") + + # Soft delete by setting is_active to False + creds.is_active = False creds.updated_at = datetime.utcnow() try: - # Add the updated creds to the session and flush the changes to the database session.add(creds) - session.flush() # This will flush the changes to the database but without committing - session.commit() # Now we commit the changes to make them permanent + session.commit() + session.refresh(creds) + return creds except IntegrityError as e: - # Rollback in case of any integrity errors (e.g., constraint violations) session.rollback() - raise ValueError(f"Error while updating credentials: {str(e)}") - - # Refresh the session to get the latest updated data - session.refresh(creds) + raise ValueError(f"Error while removing provider credentials: {str(e)}") - return creds +def remove_creds_for_org(session: Session, org_id: int): + """ + Removes all credentials for a specific organization by marking them as inactive. + Returns the list of updated credentials or None if no credentials were found. + """ + creds = session.exec( + select(Credential).where( + Credential.organization_id == org_id, Credential.is_active == True + ) + ).all() -def remove_creds_for_org(*, session: Session, org_id: int) -> Optional[Credential]: - """Removes (soft deletes) the credentials for the given organization.""" - statement = select(Credential).where(Credential.organization_id == org_id) - creds = session.exec(statement).first() + if not creds: + return None # Return None if no credentials are found - if creds: - try: - # Soft delete: Set is_active to False and set deleted_at timestamp - creds.is_active = False - creds.deleted_at = ( - datetime.utcnow() - ) # Set the current time as the deleted_at timestamp - session.add(creds) - session.commit() - except IntegrityError as e: - session.rollback() # Rollback in case of a failure during delete operation - raise ValueError(f"Error while deleting credentials: {str(e)}") + for cred in creds: + cred.is_active = False + cred.deleted_at = datetime.utcnow() + session.add(cred) + session.commit() return creds diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index 693e73bbf..5a6c0b638 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -44,10 +44,4 @@ UpdatePassword, ) -from .credentials import ( - Credential, - CredsBase, - CredsCreate, - CredsPublic, - CredsUpdate, -) +from .credentials import Credential, CredsBase, CredsCreate, CredsPublic, CredsUpdate diff --git a/backend/app/models/credentials.py b/backend/app/models/credentials.py index d006b572c..6c7146a80 100644 --- a/backend/app/models/credentials.py +++ b/backend/app/models/credentials.py @@ -1,5 +1,6 @@ from typing import Dict, Any, Optional import sqlalchemy as sa +from sqlalchemy.ext.mutable import MutableDict from sqlmodel import Field, Relationship, SQLModel from datetime import datetime @@ -12,19 +13,48 @@ class CredsBase(SQLModel): class CredsCreate(CredsBase): - credential: Dict[str, Any] = Field(default=None, sa_column=sa.Column(sa.JSON)) + """Create new credentials for an organization. + The credential field should be a dictionary mapping provider names to their credentials. + Example: {"openai": {"api_key": "..."}, "gemini": {"api_key": "..."}} + """ + + credential: Dict[str, Any] = Field( + default=None, + sa_column=sa.Column(MutableDict.as_mutable(sa.JSON)), + description="Dictionary mapping provider names to their credentials", + ) class CredsUpdate(SQLModel): - credential: Optional[Dict[str, Any]] = Field( - default=None, sa_column=sa.Column(sa.JSON) + """Update credentials for an organization. + Can update a specific provider's credentials or add a new provider. + """ + + provider: str = Field( + description="Name of the provider to update/add credentials for" + ) + credential: Dict[str, Any] = Field( + sa_column=sa.Column(MutableDict.as_mutable(sa.JSON)), + description="Credentials for the specified provider", + ) + is_active: Optional[bool] = Field( + default=None, description="Whether the credentials are active" ) - is_active: Optional[bool] = Field(default=None) class Credential(CredsBase, table=True): + """Database model for storing provider credentials. + Each row represents credentials for a single provider. + """ + id: int = Field(default=None, primary_key=True) - credential: Dict[str, Any] = Field(default=None, sa_column=sa.Column(sa.JSON)) + provider: str = Field( + index=True, description="Provider name like 'openai', 'gemini'" + ) + credential: Dict[str, Any] = Field( + sa_column=sa.Column(MutableDict.as_mutable(sa.JSON)), + description="Provider-specific credentials (e.g., API keys)", + ) inserted_at: datetime = Field( default_factory=now, sa_column=sa.Column(sa.DateTime, default=datetime.utcnow), @@ -41,7 +71,10 @@ class Credential(CredsBase, table=True): class CredsPublic(CredsBase): + """Public representation of credentials, excluding sensitive information.""" + id: int + provider: str credential: Dict[str, Any] inserted_at: datetime updated_at: datetime diff --git a/backend/app/tests/api/routes/test_creds.py b/backend/app/tests/api/routes/test_creds.py index bdba0d8e0..9d09c5b19 100644 --- a/backend/app/tests/api/routes/test_creds.py +++ b/backend/app/tests/api/routes/test_creds.py @@ -2,19 +2,15 @@ from fastapi.testclient import TestClient from sqlmodel import Session import random -import string, datetime +import string from app.main import app from app.api.deps import get_db -from app.crud.credentials import ( - set_creds_for_org, - get_creds_by_org, - remove_creds_for_org, -) -from app.models import CredsCreate, CredsUpdate, Organization, OrganizationCreate -from app.utils import APIResponse -from app.tests.utils.utils import random_lower_string +from app.crud.credentials import set_creds_for_org +from app.models import CredsCreate, Organization, OrganizationCreate from app.core.config import settings +from app.core.security import encrypt_api_key +from app.core.providers import Provider client = TestClient(app) @@ -24,42 +20,45 @@ def generate_random_string(length=10): @pytest.fixture -def create_organization_and_creds(db: Session, superuser_token_headers: dict[str, str]): - unique_org_name = "Test Organization " + generate_random_string( - 5 - ) # Ensure unique name - org_data = OrganizationCreate(name=unique_org_name, is_active=True) - org = Organization(**org_data.dict()) # Create Organization instance +def create_organization_and_creds(db: Session): + unique_org_name = "Test Organization " + generate_random_string(5) + org = Organization(name=unique_org_name, is_active=True) db.add(org) db.commit() db.refresh(org) + api_key = "sk-" + generate_random_string(10) creds_data = CredsCreate( organization_id=org.id, is_active=True, - credential={"openai": {"api_key": "sk-" + generate_random_string(10)}}, + credential={ + Provider.OPENAI.value: { + "api_key": api_key, + "model": "gpt-4", + "temperature": 0.7, + } + }, ) return org, creds_data def test_set_creds_for_org(db: Session, superuser_token_headers: dict[str, str]): - unique_org_id = 2 - existing_org = ( - db.query(Organization).filter(Organization.id == unique_org_id).first() - ) - - if not existing_org: - new_org = Organization( - id=unique_org_id, name="Test Organization", is_active=True - ) - db.add(new_org) - db.commit() + org = Organization(name="Org for Set Creds", is_active=True) + db.add(org) + db.commit() + db.refresh(org) api_key = "sk-" + generate_random_string(10) creds_data = { - "organization_id": unique_org_id, + "organization_id": org.id, "is_active": True, - "credential": {"openai": {"api_key": api_key}}, + "credential": { + Provider.OPENAI.value: { + "api_key": api_key, + "model": "gpt-4", + "temperature": 0.7, + } + }, } response = client.post( @@ -69,106 +68,74 @@ def test_set_creds_for_org(db: Session, superuser_token_headers: dict[str, str]) ) assert response.status_code == 200 + data = response.json()["data"] + assert isinstance(data, list) + assert len(data) == 1 + assert data[0]["organization_id"] == org.id + assert data[0]["provider"] == Provider.OPENAI.value + assert data[0]["credential"]["model"] == "gpt-4" - created_creds = response.json() - assert "data" in created_creds - assert created_creds["data"]["organization_id"] == unique_org_id - assert created_creds["data"]["credential"]["openai"]["api_key"] == api_key - -# Test reading credentials def test_read_credentials_with_creds( db: Session, superuser_token_headers: dict[str, str], create_organization_and_creds ): - # Create the organization and credentials (this time with credentials) org, creds_data = create_organization_and_creds - # Create credentials for the organization set_creds_for_org(session=db, creds_add=creds_data) - # Case 3: Organization exists and credentials are found response = client.get( - f"{settings.API_V1_STR}/credentials/{org.id}", headers=superuser_token_headers + f"{settings.API_V1_STR}/credentials/{org.id}", + headers=superuser_token_headers, ) + assert response.status_code == 200 - response_data = response.json() - assert "data" in response_data - assert response_data["data"]["organization_id"] == org.id - assert "credential" in response_data["data"] - assert ( - response_data["data"]["credential"]["openai"]["api_key"] - == creds_data.credential["openai"]["api_key"] - ) + data = response.json()["data"] + assert isinstance(data, list) + assert len(data) == 1 + assert data[0]["organization_id"] == org.id + assert data[0]["provider"] == Provider.OPENAI.value + assert data[0]["credential"]["model"] == "gpt-4" def test_read_credentials_not_found( - db: Session, superuser_token_headers: dict[str, str], create_organization_and_creds + db: Session, superuser_token_headers: dict[str, str] ): - # Create the organization without credentials - org, _ = create_organization_and_creds - - # Case 1: Organization exists but no credentials response = client.get( - f"{settings.API_V1_STR}/credentials/{org.id}", headers=superuser_token_headers - ) - - # Assert that the status code is 404 - assert response.status_code == 404 - response_data = response.json() - - # Assert the correct error message - assert response_data["detail"] == "Credentials not found" - - # Case 2: Organization does not exist - non_existing_org_id = 999 # Assuming this ID does not exist - response = client.get( - f"{settings.API_V1_STR}/credentials/{non_existing_org_id}", + f"{settings.API_V1_STR}/credentials/999999", headers=superuser_token_headers, ) - - # Assert that the status code is 404 assert response.status_code == 404 - response_data = response.json() - - # Assert the correct error message - assert response_data["detail"] == "Credentials not found" + assert response.json()["detail"] == "Credentials not found" -def test_read_api_key( +def test_read_provider_credential( db: Session, superuser_token_headers: dict[str, str], create_organization_and_creds ): org, creds_data = create_organization_and_creds set_creds_for_org(session=db, creds_add=creds_data) response = client.get( - f"{settings.API_V1_STR}/credentials/{org.id}/api-key", + f"{settings.API_V1_STR}/credentials/{org.id}/{Provider.OPENAI.value}", headers=superuser_token_headers, ) assert response.status_code == 200 - response_data = response.json() + data = response.json()["data"] + assert data["model"] == "gpt-4" + assert "api_key" in data - assert "data" in response_data - assert "api_key" in response_data["data"] - assert ( - response_data["data"]["api_key"] == creds_data.credential["openai"]["api_key"] - ) - - -def test_read_api_key_not_found( +def test_read_provider_credential_not_found( db: Session, superuser_token_headers: dict[str, str], create_organization_and_creds ): org, _ = create_organization_and_creds response = client.get( - f"{settings.API_V1_STR}/credentials/{org.id}/api-key", + f"{settings.API_V1_STR}/credentials/{org.id}/{Provider.OPENAI.value}", headers=superuser_token_headers, ) assert response.status_code == 404 - - response_data = response.json() - assert response_data["detail"] == "API key not found" + assert response.json()["detail"] == "Provider credentials not found" def test_update_credentials( @@ -178,12 +145,12 @@ def test_update_credentials( set_creds_for_org(session=db, creds_add=creds_data) update_data = { + "provider": Provider.OPENAI.value, "credential": { - "openai": { - "api_key": "sk-" - + generate_random_string() # Generate a new API key for the update - } - } + "api_key": "sk-" + generate_random_string(), + "model": "gpt-4-turbo", + "temperature": 0.8, + }, } response = client.patch( @@ -192,75 +159,101 @@ def test_update_credentials( headers=superuser_token_headers, ) - print(response.json()) - assert response.status_code == 200 - response_data = response.json() - - assert "data" in response_data - - assert ( - response_data["data"]["credential"]["openai"]["api_key"] - == update_data["credential"]["openai"]["api_key"] - ) - - assert response_data["data"]["updated_at"] is not None + data = response.json()["data"] + assert isinstance(data, list) + assert len(data) == 1 + assert data[0]["provider"] == Provider.OPENAI.value + assert data[0]["credential"]["model"] == "gpt-4-turbo" + assert data[0]["updated_at"] is not None def test_update_credentials_not_found( db: Session, superuser_token_headers: dict[str, str] ): - update_data = {"credential": {"openai": "sk-" + generate_random_string()}} + # Create a non-existent organization ID + non_existent_org_id = 999999 + + update_data = { + "provider": Provider.OPENAI.value, + "credential": { + "api_key": "sk-" + generate_random_string(), + "model": "gpt-4", + "temperature": 0.7, + }, + } response = client.patch( - f"{settings.API_V1_STR}/credentials/999", + f"{settings.API_V1_STR}/credentials/{non_existent_org_id}", json=update_data, headers=superuser_token_headers, ) - assert response.status_code == 404 - assert response.json()["detail"] == "Credentials not found" + + assert response.status_code == 404 # Expect 404 for non-existent organization + assert response.json()["detail"] == "Organization not found" -def test_delete_credentials( +def test_delete_provider_credential( db: Session, superuser_token_headers: dict[str, str], create_organization_and_creds ): org, creds_data = create_organization_and_creds set_creds_for_org(session=db, creds_add=creds_data) response = client.delete( - f"{settings.API_V1_STR}/credentials/{org.id}/api-key", + f"{settings.API_V1_STR}/credentials/{org.id}/{Provider.OPENAI.value}", headers=superuser_token_headers, ) assert response.status_code == 200 - response_data = response.json() - print(f"Response Data: {response_data}") + data = response.json()["data"] + assert data["message"] == "Provider credentials removed successfully" - assert "data" in response_data - assert "message" in response_data["data"] - assert response_data["data"]["message"] == "Credentials deleted successfully" - response = client.get( - f"{settings.API_V1_STR}/credentials/{org.id}", headers=superuser_token_headers +def test_delete_provider_credential_not_found( + db: Session, superuser_token_headers: dict[str, str], create_organization_and_creds +): + org, _ = create_organization_and_creds + + response = client.delete( + f"{settings.API_V1_STR}/credentials/{org.id}/{Provider.OPENAI.value}", + headers=superuser_token_headers, ) - response_data = response.json() - assert response.status_code == 200 - assert response_data["data"]["deleted_at"] is not None - assert response_data["data"]["is_active"] is False + assert response.status_code == 404 # Expect 404 for not found + assert response.json()["detail"] == "Provider credentials not found" -def test_delete_credentials_not_found( +def test_delete_all_credentials( db: Session, superuser_token_headers: dict[str, str], create_organization_and_creds ): - org, _ = create_organization_and_creds + org, creds_data = create_organization_and_creds + set_creds_for_org(session=db, creds_add=creds_data) response = client.delete( - f"{settings.API_V1_STR}/credentials/{org.id}/api-key", + f"{settings.API_V1_STR}/credentials/{org.id}", headers=superuser_token_headers, ) - assert response.status_code == 404 - response_data = response.json() + assert response.status_code == 200 # Expect 200 for successful deletion + data = response.json()["data"] + assert data["message"] == "Credentials deleted successfully" + + # Verify the credentials are soft deleted + response = client.get( + f"{settings.API_V1_STR}/credentials/{org.id}", + headers=superuser_token_headers, + ) + assert response.status_code == 404 # Expect 404 as credentials are soft deleted + assert response.json()["detail"] == "Credentials not found" + + +def test_delete_all_credentials_not_found( + db: Session, superuser_token_headers: dict[str, str] +): + response = client.delete( + f"{settings.API_V1_STR}/credentials/999999", + headers=superuser_token_headers, + ) - assert response_data["detail"] == "Credentials for organization not found" + assert response.status_code == 404 # Expect 404 for not found + assert response.json()["detail"] == "Credentials for organization not found" diff --git a/backend/app/tests/core/test_providers.py b/backend/app/tests/core/test_providers.py new file mode 100644 index 000000000..6d54e6c12 --- /dev/null +++ b/backend/app/tests/core/test_providers.py @@ -0,0 +1,37 @@ +import pytest +from app.core.providers import ( + validate_provider, + validate_provider_credentials, + Provider, +) + + +def test_validate_provider_invalid(): + """Test validating an invalid provider name.""" + with pytest.raises(ValueError) as exc_info: + validate_provider("invalid_provider") + assert "Unsupported provider" in str(exc_info.value) + assert "openai" in str(exc_info.value) # Check that supported providers are listed + + +def test_validate_provider_credentials_missing_fields(): + """Test validating provider credentials with missing required fields.""" + # Test OpenAI missing api_key + with pytest.raises(ValueError) as exc_info: + validate_provider_credentials("openai", {}) + assert "Missing required fields" in str(exc_info.value) + assert "api_key" in str(exc_info.value) + + # Test Azure missing endpoint + with pytest.raises(ValueError) as exc_info: + validate_provider_credentials("azure", {"api_key": "test-key"}) + assert "Missing required fields" in str(exc_info.value) + assert "endpoint" in str(exc_info.value) + + # Test AWS missing region + with pytest.raises(ValueError) as exc_info: + validate_provider_credentials( + "aws", {"access_key_id": "test-id", "secret_access_key": "test-secret"} + ) + assert "Missing required fields" in str(exc_info.value) + assert "region" in str(exc_info.value) diff --git a/backend/app/tests/crud/test_creds.py b/backend/app/tests/crud/test_creds.py index 7de6d54d1..2f4667c98 100644 --- a/backend/app/tests/crud/test_creds.py +++ b/backend/app/tests/crud/test_creds.py @@ -4,130 +4,114 @@ from fastapi.testclient import TestClient from sqlmodel import Session from sqlalchemy.exc import IntegrityError -from datetime import datetime from app.models import ( Credential, CredsCreate, + CredsUpdate, Organization, OrganizationCreate, - CredsUpdate, ) from app.crud.credentials import ( set_creds_for_org, get_creds_by_org, get_key_by_org, - remove_creds_for_org, update_creds_for_org, + remove_creds_for_org, ) +from app.core.providers import Provider +from app.core.security import encrypt_api_key, decrypt_api_key from app.main import app -from app.utils import APIResponse -from app.core.config import settings client = TestClient(app) -# Helper function to generate random API key def generate_random_string(length=10): return "".join(random.choices(string.ascii_letters + string.digits, k=length)) @pytest.fixture -def test_credential(db: Session): - # Create a unique organization name - unique_org_name = "Test Organization " + generate_random_string( - 5 - ) # Ensure unique name - - # Check if the organization already exists by name - existing_org = ( - db.query(Organization).filter(Organization.name == unique_org_name).first() - ) +def org_with_creds(db: Session): + org = Organization(name=f"Test Org {generate_random_string(5)}", is_active=True) + db.add(org) + db.commit() + db.refresh(org) - if existing_org: - org = existing_org # If organization exists, use the existing one - else: - # If not, create a new organization - organization_data = OrganizationCreate(name=unique_org_name, is_active=True) - org = Organization(**organization_data.dict()) # Create Organization instance - db.add(org) # Add to the session - - try: - db.commit() # Commit to save the organization to the database - db.refresh(org) # Refresh to get the organization_id - except IntegrityError as e: - db.rollback() # Rollback the transaction in case of an error (e.g., duplicate key) - raise ValueError(f"Error during organization commit: {str(e)}") - - # Generate a random API key for the test api_key = "sk-" + generate_random_string(10) - - # Create the credentials using the mock organization_id creds_data = CredsCreate( - organization_id=org.id, # Use the created organization_id + organization_id=org.id, is_active=True, - credential={"openai": {"api_key": api_key}}, + credential={ + Provider.OPENAI.value: { + "api_key": api_key, + "model": "gpt-4", + "temperature": 0.7, + } + }, ) - creds = set_creds_for_org(session=db, creds_add=creds_data) - return creds + return org, creds -def test_create_credentials(db: Session, test_credential): - creds = test_credential # Using the fixture +def test_create_credentials(db: Session, org_with_creds): + org, creds = org_with_creds assert creds is not None - assert creds.credential["openai"]["api_key"].startswith("sk-") - assert creds.is_active is True - assert creds.inserted_at is not None # Ensure inserted_at is set - - -def test_get_creds_by_org(db: Session, test_credential): - creds = test_credential # Using the fixture - retrieved_creds = get_creds_by_org(session=db, org_id=creds.organization_id) - - assert retrieved_creds is not None - assert retrieved_creds.organization_id == creds.organization_id - assert retrieved_creds.inserted_at is not None # Ensure inserted_at is not None - - -def test_update_creds_for_org(db: Session, test_credential): - creds = test_credential # Using the fixture - updated_creds_data = CredsUpdate(credential={"openai": {"api_key": "sk-newkey"}}) - - updated_creds = update_creds_for_org( - session=db, org_id=creds.organization_id, creds_in=updated_creds_data + assert len(creds) == 1 + assert creds[0].provider == Provider.OPENAI.value + assert "api_key" in creds[0].credential + # Decrypt the stored API key before asserting + decrypted_api_key = decrypt_api_key(creds[0].credential["api_key"]) + assert decrypted_api_key.startswith("sk-") + assert creds[0].is_active + assert creds[0].inserted_at is not None + + +def test_get_creds_by_org(db: Session, org_with_creds): + org, creds = org_with_creds + retrieved = get_creds_by_org(session=db, org_id=org.id) + assert retrieved is not None + assert len(retrieved) == 1 + assert retrieved[0].organization_id == org.id + assert retrieved[0].provider == Provider.OPENAI.value + assert "api_key" in retrieved[0].credential + assert retrieved[0].inserted_at is not None + + +def test_update_creds_for_org(db: Session, org_with_creds): + org, _ = org_with_creds + new_api_key = "sk-" + generate_random_string(12) + update_data = CredsUpdate( + provider=Provider.OPENAI.value, + credential={"api_key": new_api_key, "model": "gpt-4-turbo", "temperature": 0.8}, ) - assert updated_creds is not None - assert updated_creds.credential["openai"]["api_key"] == "sk-newkey" - assert updated_creds.updated_at is not None # Ensure updated_at is set + updated = update_creds_for_org(session=db, org_id=org.id, creds_in=update_data) + assert updated is not None + assert len(updated) == 1 + # Decrypt the stored API key before asserting equality + assert decrypt_api_key(updated[0].credential["api_key"]) == new_api_key + assert updated[0].credential["model"] == "gpt-4-turbo" + assert updated[0].updated_at is not None -def test_remove_creds_for_org(db: Session, test_credential): - creds = test_credential # Using the fixture - removed_creds = remove_creds_for_org(session=db, org_id=creds.organization_id) - assert removed_creds is not None - assert removed_creds.organization_id == creds.organization_id +def test_remove_creds_for_org(db: Session, org_with_creds): + org, creds = org_with_creds + removed = remove_creds_for_org(session=db, org_id=org.id) - # Ensure the deleted_at timestamp is set for soft delete - assert removed_creds.deleted_at is not None # Ensure deleted_at is set + assert removed is not None + assert len(removed) == 1 + assert removed[0].organization_id == org.id + assert removed[0].deleted_at is not None - # Check that credentials are soft deleted and not removed - deleted_creds = ( - db.query(Credential) - .filter(Credential.organization_id == creds.organization_id) - .first() + # Ensure the record is still present but soft-deleted + still_exists = ( + db.query(Credential).filter(Credential.organization_id == org.id).first() ) - assert deleted_creds is not None # Ensure the record still exists in the DB - assert deleted_creds.deleted_at is not None # Ensure it's marked as deleted + assert still_exists is not None + assert still_exists.deleted_at is not None def test_remove_creds_for_org_not_found(db: Session): - # Try to remove credentials for a non-existent organization ID (999) - non_existing_org_id = 999 - - removed_creds = remove_creds_for_org(session=db, org_id=non_existing_org_id) - - # Assert that no credentials were removed since they don't exist - assert removed_creds is None + removed = remove_creds_for_org(session=db, org_id=999999) + assert removed is None