From 29d219a9761a688b083997c6553f7bdcfdd1ba80 Mon Sep 17 00:00:00 2001 From: Priyanshu singh <111607560+PriyanSingh@users.noreply.github.com> Date: Sat, 10 May 2025 11:21:42 +0530 Subject: [PATCH 01/10] Add provider column to credential table and update API for provider-specific credentials - Introduced a new column 'provider' in the credential table to support multiple credential providers. - Updated API routes to handle provider-specific credential operations, including creation, retrieval, updating, and deletion. - Enhanced validation for provider credentials and added support for multiple providers in the data model. - Refactored existing credential handling functions to accommodate the new structure and improve error handling. - Ensured backward compatibility by maintaining existing functionality while expanding capabilities. --- ...dded_provider_column_to_the_credential_.py | 39 +++ backend/app/api/routes/credentials.py | 122 ++++++--- backend/app/core/providers.py | 58 ++++ backend/app/crud/credentials.py | 247 ++++++++++++------ backend/app/models/__init__.py | 2 +- backend/app/models/credentials.py | 48 +++- 6 files changed, 398 insertions(+), 118 deletions(-) create mode 100644 backend/app/alembic/versions/904ed70e7dab_added_provider_column_to_the_credential_.py create mode 100644 backend/app/core/providers.py diff --git a/backend/app/alembic/versions/904ed70e7dab_added_provider_column_to_the_credential_.py b/backend/app/alembic/versions/904ed70e7dab_added_provider_column_to_the_credential_.py new file mode 100644 index 000000000..bb7677802 --- /dev/null +++ b/backend/app/alembic/versions/904ed70e7dab_added_provider_column_to_the_credential_.py @@ -0,0 +1,39 @@ +"""Added provider column to the credential table + +Revision ID: 904ed70e7dab +Revises: 543f97951bd0 +Create Date: 2025-05-10 11:13:17.868238 + +""" +from alembic import op +import sqlalchemy as sa +import sqlmodel.sql.sqltypes + + +# revision identifiers, used by Alembic. +revision = '904ed70e7dab' +down_revision = '543f97951bd0' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('credential', sa.Column('provider', sqlmodel.sql.sqltypes.AutoString(), nullable=False)) + op.create_index(op.f('ix_credential_provider'), 'credential', ['provider'], unique=False) + op.drop_constraint('credential_organization_id_fkey', 'credential', type_='foreignkey') + op.create_foreign_key(None, 'credential', 'organization', ['organization_id'], ['id']) + op.drop_constraint('project_organization_id_fkey', 'project', type_='foreignkey') + op.create_foreign_key(None, 'project', 'organization', ['organization_id'], ['id']) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint(None, 'project', type_='foreignkey') + op.create_foreign_key('project_organization_id_fkey', 'project', 'organization', ['organization_id'], ['id'], ondelete='CASCADE') + op.drop_constraint(None, 'credential', type_='foreignkey') + op.create_foreign_key('credential_organization_id_fkey', 'credential', 'organization', ['organization_id'], ['id'], ondelete='CASCADE') + op.drop_index(op.f('ix_credential_provider'), table_name='credential') + op.drop_column('credential', 'provider') + # ### end Alembic commands ### diff --git a/backend/app/api/routes/credentials.py b/backend/app/api/routes/credentials.py index ccb68a1bc..0edbde56f 100644 --- a/backend/app/api/routes/credentials.py +++ b/backend/app/api/routes/credentials.py @@ -2,14 +2,17 @@ from app.api.deps import SessionDep, get_current_active_superuser from app.crud.credentials import ( get_creds_by_org, - get_key_by_org, + get_provider_credential, remove_creds_for_org, set_creds_for_org, update_creds_for_org, + remove_provider_credential, ) from app.models import CredsCreate, CredsPublic, CredsUpdate from app.utils import APIResponse from datetime import datetime +from app.core.providers import validate_provider +from typing import List router = APIRouter(prefix="/credentials", tags=["credentials"]) @@ -17,80 +20,107 @@ @router.post( "/", dependencies=[Depends(get_current_active_superuser)], - response_model=APIResponse[CredsPublic], + response_model=APIResponse[List[CredsPublic]], + summary="Create new credentials for an organization", + description="Creates new credentials for a specific organization. This endpoint requires superuser privileges. If credentials already exist for the organization, it will return an error.", ) def create_new_credential(*, session: SessionDep, creds_in: CredsCreate): - new_creds = None try: existing_creds = get_creds_by_org( session=session, org_id=creds_in.organization_id ) - if not existing_creds: - new_creds = set_creds_for_org(session=session, creds_add=creds_in) + if existing_creds: + raise HTTPException( + status_code=400, + detail="Credentials already exist for this organization" + ) + + new_creds = set_creds_for_org(session=session, creds_add=creds_in) + if not new_creds: + raise HTTPException( + status_code=500, + detail="Failed to create credentials" + ) + + # Return all created credentials + return APIResponse.success_response(new_creds) + + except ValueError as e: + if "Unsupported provider" in str(e): + raise HTTPException(status_code=400, detail=str(e)) + raise HTTPException(status_code=404, detail=str(e)) except Exception as e: raise HTTPException( status_code=500, detail=f"An unexpected error occurred: {str(e)}" ) - # Ensure inserted_at is set during creation - new_creds.inserted_at = datetime.utcnow() - - return APIResponse.success_response(new_creds) - @router.get( "/{org_id}", dependencies=[Depends(get_current_active_superuser)], - response_model=APIResponse[CredsPublic], + response_model=APIResponse[List[CredsPublic]], + summary="Get all credentials for an organization", + description="Retrieves all provider credentials associated with a specific organization. This endpoint requires superuser privileges.", ) def read_credential(*, session: SessionDep, org_id: int): try: creds = get_creds_by_org(session=session, org_id=org_id) + if not creds: + raise HTTPException(status_code=404, detail="Credentials not found") + return APIResponse.success_response(creds) except Exception as e: raise HTTPException( status_code=500, detail=f"An unexpected error occurred: {str(e)}" ) - if creds is None: - raise HTTPException(status_code=404, detail="Credentials not found") - - return APIResponse.success_response(creds) - @router.get( - "/{org_id}/api-key", + "/{org_id}/{provider}", dependencies=[Depends(get_current_active_superuser)], response_model=APIResponse[dict], + summary="Get specific provider credentials", + description="Retrieves credentials for a specific provider (e.g., 'openai', 'anthropic') for a given organization. This endpoint requires superuser privileges.", ) -def read_api_key(*, session: SessionDep, org_id: int): +def read_provider_credential(*, session: SessionDep, org_id: int, provider: str): try: - api_key = get_key_by_org(session=session, org_id=org_id) + provider_enum = validate_provider(provider) + provider_creds = get_provider_credential( + session=session, org_id=org_id, provider=provider_enum + ) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) except Exception as e: raise HTTPException( status_code=500, detail=f"An unexpected error occurred: {str(e)}" ) - if api_key is None: - raise HTTPException(status_code=404, detail="API key not found") + if provider_creds is None: + raise HTTPException(status_code=404, detail="Provider credentials not found") - return APIResponse.success_response({"api_key": api_key}) + return APIResponse.success_response(provider_creds) @router.patch( "/{org_id}", dependencies=[Depends(get_current_active_superuser)], - response_model=APIResponse[CredsPublic], + response_model=APIResponse[List[CredsPublic]], + summary="Update organization credentials", + description="Updates credentials for a specific organization. Can update specific provider credentials or add new providers. This endpoint requires superuser privileges.", ) def update_credential(*, session: SessionDep, org_id: int, creds_in: CredsUpdate): try: updated_creds = update_creds_for_org( session=session, org_id=org_id, creds_in=creds_in ) - - updated_creds.updated_at = datetime.utcnow() - + if not updated_creds: + raise HTTPException( + status_code=404, + detail="Failed to update credentials" + ) return APIResponse.success_response(updated_creds) except ValueError as e: + if "Unsupported provider" in str(e): + raise HTTPException(status_code=400, detail=str(e)) raise HTTPException(status_code=404, detail=str(e)) except Exception as e: raise HTTPException( @@ -98,18 +128,42 @@ def update_credential(*, session: SessionDep, org_id: int, creds_in: CredsUpdate ) -from fastapi import HTTPException, Depends -from app.crud.credentials import remove_creds_for_org -from app.utils import APIResponse -from app.api.deps import SessionDep, get_current_active_superuser +@router.delete( + "/{org_id}/{provider}", + dependencies=[Depends(get_current_active_superuser)], + response_model=APIResponse[dict], + summary="Delete specific provider credentials", + description="Removes credentials for a specific provider while keeping other provider credentials intact. This endpoint requires superuser privileges.", +) +def delete_provider_credential(*, session: SessionDep, org_id: int, provider: str): + try: + provider_enum = validate_provider(provider) + updated_creds = remove_provider_credential( + session=session, org_id=org_id, provider=provider_enum + ) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + raise HTTPException( + status_code=500, detail=f"An unexpected error occurred: {str(e)}" + ) + + if updated_creds is None: + raise HTTPException(status_code=404, detail="Provider credentials not found") + + return APIResponse.success_response( + {"message": "Provider credentials removed successfully"} + ) @router.delete( - "/{org_id}/api-key", + "/{org_id}", dependencies=[Depends(get_current_active_superuser)], response_model=APIResponse[dict], + summary="Delete all organization credentials", + description="Removes all credentials for a specific organization. This is a soft delete operation that marks credentials as inactive. This endpoint requires superuser privileges.", ) -def delete_credential(*, session: SessionDep, org_id: int): +def delete_all_credentials(*, session: SessionDep, org_id: int): try: creds = remove_creds_for_org(session=session, org_id=org_id) except Exception as e: @@ -122,6 +176,4 @@ def delete_credential(*, session: SessionDep, org_id: int): status_code=404, detail="Credentials for organization not found" ) - # No need to manually set deleted_at and is_active if it's done in remove_creds_for_org - # Simply return the success response - return APIResponse.success_response({"message": "Credentials deleted successfully"}) + return APIResponse.success_response({"message": "Credentials deleted successfully"}) \ No newline at end of file diff --git a/backend/app/core/providers.py b/backend/app/core/providers.py new file mode 100644 index 000000000..1f572e6d4 --- /dev/null +++ b/backend/app/core/providers.py @@ -0,0 +1,58 @@ +from typing import Dict, List, Optional +from enum import Enum + + +class Provider(str, Enum): + """Enumeration of supported credential providers.""" + + OPENAI = "openai" + GEMINI = "gemini" + ANTHROPIC = "anthropic" + MISTRAL = "mistral" + COHERE = "cohere" + HUGGINGFACE = "huggingface" + AZURE = "azure" + AWS = "aws" + GOOGLE = "google" + + +# Required fields for each provider's credentials +PROVIDER_REQUIRED_FIELDS: Dict[str, List[str]] = { + Provider.OPENAI: ["api_key"], + Provider.GEMINI: ["api_key"], + Provider.ANTHROPIC: ["api_key"], + Provider.MISTRAL: ["api_key"], + Provider.COHERE: ["api_key"], + Provider.HUGGINGFACE: ["api_key"], + Provider.AZURE: ["api_key", "endpoint"], + Provider.AWS: ["access_key_id", "secret_access_key", "region"], + Provider.GOOGLE: ["api_key"], +} + + +def validate_provider(provider: str) -> Provider: + """Validate that the provider name is supported and return the Provider enum.""" + try: + return Provider(provider.lower()) + except ValueError: + supported = ", ".join([p.value for p in Provider]) + raise ValueError( + f"Unsupported provider: {provider}. Supported providers are: {supported}" + ) + + +def validate_provider_credentials(provider: str, credentials: dict) -> None: + """Validate that the credentials contain all required fields for the provider.""" + provider_enum = validate_provider(provider) + required_fields = PROVIDER_REQUIRED_FIELDS[provider_enum] + + missing_fields = [field for field in required_fields if field not in credentials] + if missing_fields: + raise ValueError( + f"Missing required fields for {provider}: {', '.join(missing_fields)}" + ) + + +def get_supported_providers() -> List[str]: + """Return a list of all supported provider names.""" + return [p.value for p in Provider] \ No newline at end of file diff --git a/backend/app/crud/credentials.py b/backend/app/crud/credentials.py index a66fa7b78..828ba66ff 100644 --- a/backend/app/crud/credentials.py +++ b/backend/app/crud/credentials.py @@ -1,104 +1,201 @@ -from typing import Optional, Dict, Any +from typing import Optional, Dict, Any, List from sqlmodel import Session, select from sqlalchemy.exc import IntegrityError from datetime import datetime from app.models import Credential, CredsCreate, CredsUpdate +from app.core.providers import ( + validate_provider, + validate_provider_credentials, + get_supported_providers, +) +from app.core.security import encrypt_api_key + + +def set_creds_for_org(*, session: Session, creds_add: CredsCreate) -> List[Credential]: + """Set credentials for an organization. Creates a separate row for each provider.""" + created_credentials = [] + + if not creds_add.credential: + raise ValueError("No credentials provided") + + for provider, credentials in creds_add.credential.items(): + # Validate provider and credentials + validate_provider(provider) + validate_provider_credentials(provider, credentials) + + # Encrypt API key if present + if isinstance(credentials, dict) and "api_key" in credentials: + credentials["api_key"] = encrypt_api_key(credentials["api_key"]) + + # Create a row for each provider + credential = Credential( + organization_id=creds_add.organization_id, + is_active=creds_add.is_active, + provider=provider, + credential=credentials, + ) + credential.inserted_at = datetime.utcnow() + try: + session.add(credential) + session.commit() + session.refresh(credential) + created_credentials.append(credential) + except IntegrityError as e: + session.rollback() + raise ValueError(f"Error while adding credentials for provider {provider}: {str(e)}") + return created_credentials -def set_creds_for_org(*, session: Session, creds_add: CredsCreate) -> Credential: - creds = Credential.model_validate(creds_add) - - # Set the inserted_at timestamp (current UTC time) - creds.inserted_at = datetime.utcnow() - try: - session.add(creds) - session.commit() - session.refresh(creds) - except IntegrityError as e: - session.rollback() # Rollback the session if there's a unique constraint violation - raise ValueError(f"Error while adding credentials: {str(e)}") - - return creds +def get_key_by_org( + *, session: Session, org_id: int, provider: str = "openai" +) -> Optional[str]: + """Fetches the API key from the credentials for the given organization and provider.""" + statement = select(Credential).where( + Credential.organization_id == org_id, + Credential.provider == provider, + Credential.is_active == True + ) + creds = session.exec(statement).first() + if creds and creds.credential and "api_key" in creds.credential: + return creds.credential["api_key"] -def get_creds_by_org(*, session: Session, org_id: int) -> Optional[Credential]: - """Fetches the credentials for the given organization.""" - statement = select(Credential).where(Credential.organization_id == org_id) - return session.exec(statement).first() + return None -def get_key_by_org(*, session: Session, org_id: int) -> Optional[str]: - """Fetches the API key from the credentials for the given organization.""" - statement = select(Credential).where(Credential.organization_id == org_id) +def get_creds_by_org(*, session: Session, org_id: int) -> List[Credential]: + """Fetches all active credentials for the given organization.""" + statement = select(Credential).where( + Credential.organization_id == org_id, + Credential.is_active == True + ) + return session.exec(statement).all() + + +def get_provider_credential( + *, session: Session, org_id: int, provider: str +) -> Optional[Dict[str, Any]]: + """Fetches credentials for a specific provider of an organization.""" + validate_provider(provider) + + statement = select(Credential).where( + Credential.organization_id == org_id, + Credential.provider == provider, + Credential.is_active == True + ) creds = session.exec(statement).first() + + return creds.credential if creds else None - # Check if creds exists and if the credential field contains the api_key - if ( - creds - and creds.credential - and "openai" in creds.credential - and "api_key" in creds.credential["openai"] - ): - return creds.credential["openai"]["api_key"] - return None +def get_providers(*, session: Session, org_id: int) -> List[str]: + """Returns a list of all active providers for which credentials are stored.""" + creds = get_creds_by_org(session=session, org_id=org_id) + return [cred.provider for cred in creds] def update_creds_for_org( session: Session, org_id: int, creds_in: CredsUpdate -) -> Credential: - # Fetch the current credentials for the organization - creds = session.exec( - select(Credential).where(Credential.organization_id == org_id) - ).first() +) -> List[Credential]: + """Update credentials for an organization. Can update specific provider or add new provider.""" + if not creds_in.provider or not creds_in.credential: + raise ValueError("Provider and credential information must be provided") + + # Validate provider and credentials + validate_provider(creds_in.provider) + validate_provider_credentials(creds_in.provider, creds_in.credential) + + # Encrypt API key if present + if isinstance(creds_in.credential, dict) and "api_key" in creds_in.credential: + creds_in.credential["api_key"] = encrypt_api_key(creds_in.credential["api_key"]) + + # Check if credentials exist for this provider + statement = select(Credential).where( + Credential.organization_id == org_id, + Credential.provider == creds_in.provider + ) + existing_cred = session.exec(statement).first() + + if existing_cred: + # Update existing credentials + existing_cred.credential = creds_in.credential + existing_cred.is_active = creds_in.is_active if creds_in.is_active is not None else True + existing_cred.updated_at = datetime.utcnow() + try: + session.add(existing_cred) + session.commit() + session.refresh(existing_cred) + return [existing_cred] + except IntegrityError as e: + session.rollback() + raise ValueError(f"Error while updating credentials: {str(e)}") + else: + # Create new credentials + new_cred = Credential( + organization_id=org_id, + provider=creds_in.provider, + credential=creds_in.credential, + is_active=creds_in.is_active if creds_in.is_active is not None else True + ) + try: + session.add(new_cred) + session.commit() + session.refresh(new_cred) + return [new_cred] + except IntegrityError as e: + session.rollback() + raise ValueError(f"Error while creating credentials: {str(e)}") - if not creds: - raise ValueError(f"Credentials not found") - # Update the credentials data with the provided values - creds_data = creds_in.dict(exclude_unset=True) +def remove_provider_credential( + session: Session, org_id: int, provider: str +) -> Credential: + """Remove credentials for a specific provider.""" + validate_provider(provider) - # Directly update the fields on the original creds object instead of creating a new one - for key, value in creds_data.items(): - setattr(creds, key, value) + statement = select(Credential).where( + Credential.organization_id == org_id, + Credential.provider == provider + ) + creds = session.exec(statement).first() - # Set the updated_at timestamp (current UTC time) + if not creds: + raise ValueError(f"Credentials not found for provider '{provider}'") + + # Soft delete by setting is_active to False + creds.is_active = False creds.updated_at = datetime.utcnow() try: - # Add the updated creds to the session and flush the changes to the database session.add(creds) - session.flush() # This will flush the changes to the database but without committing - session.commit() # Now we commit the changes to make them permanent + session.commit() + session.refresh(creds) + return creds except IntegrityError as e: - # Rollback in case of any integrity errors (e.g., constraint violations) session.rollback() - raise ValueError(f"Error while updating credentials: {str(e)}") - - # Refresh the session to get the latest updated data - session.refresh(creds) - - return creds - - -def remove_creds_for_org(*, session: Session, org_id: int) -> Optional[Credential]: - """Removes (soft deletes) the credentials for the given organization.""" - statement = select(Credential).where(Credential.organization_id == org_id) - creds = session.exec(statement).first() - - if creds: - try: - # Soft delete: Set is_active to False and set deleted_at timestamp - creds.is_active = False - creds.deleted_at = ( - datetime.utcnow() - ) # Set the current time as the deleted_at timestamp - session.add(creds) - session.commit() - except IntegrityError as e: - session.rollback() # Rollback in case of a failure during delete operation - raise ValueError(f"Error while deleting credentials: {str(e)}") - - return creds + raise ValueError(f"Error while removing provider credentials: {str(e)}") + + +def remove_creds_for_org(*, session: Session, org_id: int) -> List[Credential]: + """Removes (soft deletes) all credentials for the given organization.""" + statement = select(Credential).where( + Credential.organization_id == org_id, + Credential.is_active == True + ) + creds = session.exec(statement).all() + + for cred in creds: + cred.is_active = False + cred.updated_at = datetime.utcnow() + session.add(cred) + + try: + session.commit() + for cred in creds: + session.refresh(cred) + return creds + except IntegrityError as e: + session.rollback() + raise ValueError(f"Error while removing organization credentials: {str(e)}") \ No newline at end of file diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index eaa9507c6..ca1f07f9b 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -48,5 +48,5 @@ CredsBase, CredsCreate, CredsPublic, - CredsUpdate, + CredsUpdate ) diff --git a/backend/app/models/credentials.py b/backend/app/models/credentials.py index d006b572c..c4531ff78 100644 --- a/backend/app/models/credentials.py +++ b/backend/app/models/credentials.py @@ -1,5 +1,6 @@ from typing import Dict, Any, Optional import sqlalchemy as sa +from sqlalchemy.ext.mutable import MutableDict from sqlmodel import Field, Relationship, SQLModel from datetime import datetime @@ -12,19 +13,48 @@ class CredsBase(SQLModel): class CredsCreate(CredsBase): - credential: Dict[str, Any] = Field(default=None, sa_column=sa.Column(sa.JSON)) + """Create new credentials for an organization. + The credential field should be a dictionary mapping provider names to their credentials. + Example: {"openai": {"api_key": "..."}, "gemini": {"api_key": "..."}} + """ + + credential: Dict[str, Any] = Field( + default=None, + sa_column=sa.Column(MutableDict.as_mutable(sa.JSON)), + description="Dictionary mapping provider names to their credentials", + ) class CredsUpdate(SQLModel): - credential: Optional[Dict[str, Any]] = Field( - default=None, sa_column=sa.Column(sa.JSON) + """Update credentials for an organization. + Can update a specific provider's credentials or add a new provider. + """ + + provider: str = Field( + description="Name of the provider to update/add credentials for" + ) + credential: Dict[str, Any] = Field( + sa_column=sa.Column(MutableDict.as_mutable(sa.JSON)), + description="Credentials for the specified provider", + ) + is_active: Optional[bool] = Field( + default=None, description="Whether the credentials are active" ) - is_active: Optional[bool] = Field(default=None) class Credential(CredsBase, table=True): + """Database model for storing provider credentials. + Each row represents credentials for a single provider. + """ id: int = Field(default=None, primary_key=True) - credential: Dict[str, Any] = Field(default=None, sa_column=sa.Column(sa.JSON)) + provider: str = Field( + index=True, + description="Provider name like 'openai', 'gemini'" + ) + credential: Dict[str, Any] = Field( + sa_column=sa.Column(MutableDict.as_mutable(sa.JSON)), + description="Provider-specific credentials (e.g., API keys)", + ) inserted_at: datetime = Field( default_factory=now, sa_column=sa.Column(sa.DateTime, default=datetime.utcnow), @@ -34,15 +64,19 @@ class Credential(CredsBase, table=True): sa_column=sa.Column(sa.DateTime, onupdate=datetime.utcnow), ) deleted_at: Optional[datetime] = Field( - default=None, sa_column=sa.Column(sa.DateTime, nullable=True) + default=None, + sa_column=sa.Column(sa.DateTime, nullable=True) ) organization: Optional["Organization"] = Relationship(back_populates="creds") class CredsPublic(CredsBase): + """Public representation of credentials, excluding sensitive information.""" + id: int + provider: str credential: Dict[str, Any] inserted_at: datetime updated_at: datetime - deleted_at: Optional[datetime] + deleted_at: Optional[datetime] \ No newline at end of file From b91b3fca9a50c503bc49ef52110b9955337a6b44 Mon Sep 17 00:00:00 2001 From: Priyanshu singh <111607560+PriyanSingh@users.noreply.github.com> Date: Sat, 10 May 2025 12:43:52 +0530 Subject: [PATCH 02/10] Refactor credential tests to streamline organization and credential creation, enhance readability, and ensure proper handling of provider-specific data. --- backend/app/tests/api/routes/test_creds.py | 261 ++++++++++----------- backend/app/tests/crud/test_creds.py | 153 ++++++------ 2 files changed, 188 insertions(+), 226 deletions(-) diff --git a/backend/app/tests/api/routes/test_creds.py b/backend/app/tests/api/routes/test_creds.py index bdba0d8e0..b1a04c712 100644 --- a/backend/app/tests/api/routes/test_creds.py +++ b/backend/app/tests/api/routes/test_creds.py @@ -2,19 +2,15 @@ from fastapi.testclient import TestClient from sqlmodel import Session import random -import string, datetime +import string from app.main import app from app.api.deps import get_db -from app.crud.credentials import ( - set_creds_for_org, - get_creds_by_org, - remove_creds_for_org, -) -from app.models import CredsCreate, CredsUpdate, Organization, OrganizationCreate -from app.utils import APIResponse -from app.tests.utils.utils import random_lower_string +from app.crud.credentials import set_creds_for_org +from app.models import CredsCreate, Organization, OrganizationCreate from app.core.config import settings +from app.core.security import encrypt_api_key +from app.core.providers import Provider client = TestClient(app) @@ -24,42 +20,45 @@ def generate_random_string(length=10): @pytest.fixture -def create_organization_and_creds(db: Session, superuser_token_headers: dict[str, str]): - unique_org_name = "Test Organization " + generate_random_string( - 5 - ) # Ensure unique name - org_data = OrganizationCreate(name=unique_org_name, is_active=True) - org = Organization(**org_data.dict()) # Create Organization instance +def create_organization_and_creds(db: Session): + unique_org_name = "Test Organization " + generate_random_string(5) + org = Organization(name=unique_org_name, is_active=True) db.add(org) db.commit() db.refresh(org) + api_key = "sk-" + generate_random_string(10) creds_data = CredsCreate( organization_id=org.id, is_active=True, - credential={"openai": {"api_key": "sk-" + generate_random_string(10)}}, + credential={ + Provider.OPENAI.value: { + "api_key": api_key, + "model": "gpt-4", + "temperature": 0.7 + } + }, ) return org, creds_data def test_set_creds_for_org(db: Session, superuser_token_headers: dict[str, str]): - unique_org_id = 2 - existing_org = ( - db.query(Organization).filter(Organization.id == unique_org_id).first() - ) - - if not existing_org: - new_org = Organization( - id=unique_org_id, name="Test Organization", is_active=True - ) - db.add(new_org) - db.commit() + org = Organization(name="Org for Set Creds", is_active=True) + db.add(org) + db.commit() + db.refresh(org) api_key = "sk-" + generate_random_string(10) creds_data = { - "organization_id": unique_org_id, + "organization_id": org.id, "is_active": True, - "credential": {"openai": {"api_key": api_key}}, + "credential": { + Provider.OPENAI.value: { + "api_key": api_key, + "model": "gpt-4", + "temperature": 0.7 + } + }, } response = client.post( @@ -69,120 +68,78 @@ def test_set_creds_for_org(db: Session, superuser_token_headers: dict[str, str]) ) assert response.status_code == 200 - - created_creds = response.json() - assert "data" in created_creds - assert created_creds["data"]["organization_id"] == unique_org_id - assert created_creds["data"]["credential"]["openai"]["api_key"] == api_key + data = response.json()["data"] + assert isinstance(data, list) + assert len(data) == 1 + assert data[0]["organization_id"] == org.id + assert data[0]["provider"] == Provider.OPENAI.value + assert data[0]["credential"]["model"] == "gpt-4" -# Test reading credentials -def test_read_credentials_with_creds( - db: Session, superuser_token_headers: dict[str, str], create_organization_and_creds -): - # Create the organization and credentials (this time with credentials) +def test_read_credentials_with_creds(db: Session, superuser_token_headers: dict[str, str], create_organization_and_creds): org, creds_data = create_organization_and_creds - # Create credentials for the organization set_creds_for_org(session=db, creds_add=creds_data) - # Case 3: Organization exists and credentials are found - response = client.get( - f"{settings.API_V1_STR}/credentials/{org.id}", headers=superuser_token_headers - ) - assert response.status_code == 200 - response_data = response.json() - assert "data" in response_data - assert response_data["data"]["organization_id"] == org.id - assert "credential" in response_data["data"] - assert ( - response_data["data"]["credential"]["openai"]["api_key"] - == creds_data.credential["openai"]["api_key"] - ) - - -def test_read_credentials_not_found( - db: Session, superuser_token_headers: dict[str, str], create_organization_and_creds -): - # Create the organization without credentials - org, _ = create_organization_and_creds - - # Case 1: Organization exists but no credentials response = client.get( - f"{settings.API_V1_STR}/credentials/{org.id}", headers=superuser_token_headers + f"{settings.API_V1_STR}/credentials/{org.id}", + headers=superuser_token_headers, ) - # Assert that the status code is 404 - assert response.status_code == 404 - response_data = response.json() + assert response.status_code == 200 + data = response.json()["data"] + assert isinstance(data, list) + assert len(data) == 1 + assert data[0]["organization_id"] == org.id + assert data[0]["provider"] == Provider.OPENAI.value + assert data[0]["credential"]["model"] == "gpt-4" - # Assert the correct error message - assert response_data["detail"] == "Credentials not found" - # Case 2: Organization does not exist - non_existing_org_id = 999 # Assuming this ID does not exist +def test_read_credentials_not_found(db: Session, superuser_token_headers: dict[str, str]): response = client.get( - f"{settings.API_V1_STR}/credentials/{non_existing_org_id}", + f"{settings.API_V1_STR}/credentials/999999", headers=superuser_token_headers, ) - - # Assert that the status code is 404 assert response.status_code == 404 - response_data = response.json() - - # Assert the correct error message - assert response_data["detail"] == "Credentials not found" + assert response.json()["detail"] == "Credentials not found" -def test_read_api_key( - db: Session, superuser_token_headers: dict[str, str], create_organization_and_creds -): +def test_read_provider_credential(db: Session, superuser_token_headers: dict[str, str], create_organization_and_creds): org, creds_data = create_organization_and_creds set_creds_for_org(session=db, creds_add=creds_data) response = client.get( - f"{settings.API_V1_STR}/credentials/{org.id}/api-key", + f"{settings.API_V1_STR}/credentials/{org.id}/{Provider.OPENAI.value}", headers=superuser_token_headers, ) assert response.status_code == 200 - response_data = response.json() - - assert "data" in response_data - - assert "api_key" in response_data["data"] - assert ( - response_data["data"]["api_key"] == creds_data.credential["openai"]["api_key"] - ) + data = response.json()["data"] + assert data["model"] == "gpt-4" + assert "api_key" in data -def test_read_api_key_not_found( - db: Session, superuser_token_headers: dict[str, str], create_organization_and_creds -): +def test_read_provider_credential_not_found(db: Session, superuser_token_headers: dict[str, str], create_organization_and_creds): org, _ = create_organization_and_creds response = client.get( - f"{settings.API_V1_STR}/credentials/{org.id}/api-key", + f"{settings.API_V1_STR}/credentials/{org.id}/{Provider.OPENAI.value}", headers=superuser_token_headers, ) assert response.status_code == 404 + assert response.json()["detail"] == "Provider credentials not found" - response_data = response.json() - assert response_data["detail"] == "API key not found" - -def test_update_credentials( - db: Session, superuser_token_headers: dict[str, str], create_organization_and_creds -): +def test_update_credentials(db: Session, superuser_token_headers: dict[str, str], create_organization_and_creds): org, creds_data = create_organization_and_creds set_creds_for_org(session=db, creds_add=creds_data) update_data = { + "provider": Provider.OPENAI.value, "credential": { - "openai": { - "api_key": "sk-" - + generate_random_string() # Generate a new API key for the update - } + "api_key": "sk-" + generate_random_string(), + "model": "gpt-4-turbo", + "temperature": 0.8 } } @@ -192,75 +149,95 @@ def test_update_credentials( headers=superuser_token_headers, ) - print(response.json()) - assert response.status_code == 200 - response_data = response.json() + data = response.json()["data"] + assert isinstance(data, list) + assert len(data) == 1 + assert data[0]["provider"] == Provider.OPENAI.value + assert data[0]["credential"]["model"] == "gpt-4-turbo" + assert data[0]["updated_at"] is not None - assert "data" in response_data - assert ( - response_data["data"]["credential"]["openai"]["api_key"] - == update_data["credential"]["openai"]["api_key"] - ) +def test_update_credentials_not_found(db: Session, superuser_token_headers: dict[str, str]): + # Create a non-existent organization ID + non_existent_org_id = 999999 - assert response_data["data"]["updated_at"] is not None - - -def test_update_credentials_not_found( - db: Session, superuser_token_headers: dict[str, str] -): - update_data = {"credential": {"openai": "sk-" + generate_random_string()}} + update_data = { + "provider": Provider.OPENAI.value, + "credential": { + "api_key": "sk-" + generate_random_string(), + "model": "gpt-4", + "temperature": 0.7 + } + } response = client.patch( - f"{settings.API_V1_STR}/credentials/999", + f"{settings.API_V1_STR}/credentials/{non_existent_org_id}", json=update_data, headers=superuser_token_headers, ) + assert response.status_code == 404 - assert response.json()["detail"] == "Credentials not found" + assert response.json()["detail"] == "Failed to update credentials" -def test_delete_credentials( - db: Session, superuser_token_headers: dict[str, str], create_organization_and_creds -): +def test_delete_provider_credential(db: Session, superuser_token_headers: dict[str, str], create_organization_and_creds): org, creds_data = create_organization_and_creds set_creds_for_org(session=db, creds_add=creds_data) response = client.delete( - f"{settings.API_V1_STR}/credentials/{org.id}/api-key", + f"{settings.API_V1_STR}/credentials/{org.id}/{Provider.OPENAI.value}", headers=superuser_token_headers, ) assert response.status_code == 200 - response_data = response.json() - print(f"Response Data: {response_data}") + data = response.json()["data"] + assert data["message"] == "Provider credentials removed successfully" - assert "data" in response_data - assert "message" in response_data["data"] - assert response_data["data"]["message"] == "Credentials deleted successfully" - response = client.get( - f"{settings.API_V1_STR}/credentials/{org.id}", headers=superuser_token_headers +def test_delete_provider_credential_not_found(db: Session, superuser_token_headers: dict[str, str], create_organization_and_creds): + org, _ = create_organization_and_creds + + response = client.delete( + f"{settings.API_V1_STR}/credentials/{org.id}/{Provider.OPENAI.value}", + headers=superuser_token_headers, + ) + + assert response.status_code == 404 + assert response.json()["detail"] == "Provider credentials not found" + + +def test_delete_all_credentials(db: Session, superuser_token_headers: dict[str, str], create_organization_and_creds): + org, creds_data = create_organization_and_creds + set_creds_for_org(session=db, creds_add=creds_data) + + response = client.delete( + f"{settings.API_V1_STR}/credentials/{org.id}", + headers=superuser_token_headers, ) - response_data = response.json() assert response.status_code == 200 - assert response_data["data"]["deleted_at"] is not None - assert response_data["data"]["is_active"] is False + data = response.json()["data"] + assert data["message"] == "Credentials deleted successfully" + # Verify the credentials are soft deleted + response = client.get( + f"{settings.API_V1_STR}/credentials/{org.id}", + headers=superuser_token_headers, + ) + assert response.status_code == 200 + data = response.json()["data"] + assert isinstance(data, list) + assert len(data) == 1 + assert data[0]["deleted_at"] is not None + assert data[0]["is_active"] is False -def test_delete_credentials_not_found( - db: Session, superuser_token_headers: dict[str, str], create_organization_and_creds -): - org, _ = create_organization_and_creds +def test_delete_all_credentials_not_found(db: Session, superuser_token_headers: dict[str, str]): response = client.delete( - f"{settings.API_V1_STR}/credentials/{org.id}/api-key", + f"{settings.API_V1_STR}/credentials/999999", headers=superuser_token_headers, ) assert response.status_code == 404 - response_data = response.json() - - assert response_data["detail"] == "Credentials for organization not found" + assert response.json()["detail"] == "Credentials for organization not found" diff --git a/backend/app/tests/crud/test_creds.py b/backend/app/tests/crud/test_creds.py index 7de6d54d1..6c8b3ad20 100644 --- a/backend/app/tests/crud/test_creds.py +++ b/backend/app/tests/crud/test_creds.py @@ -4,130 +4,115 @@ from fastapi.testclient import TestClient from sqlmodel import Session from sqlalchemy.exc import IntegrityError -from datetime import datetime from app.models import ( Credential, CredsCreate, + CredsUpdate, Organization, OrganizationCreate, - CredsUpdate, ) from app.crud.credentials import ( set_creds_for_org, get_creds_by_org, get_key_by_org, - remove_creds_for_org, update_creds_for_org, + remove_creds_for_org, ) +from app.core.providers import Provider +from app.core.security import encrypt_api_key from app.main import app -from app.utils import APIResponse -from app.core.config import settings client = TestClient(app) -# Helper function to generate random API key def generate_random_string(length=10): return "".join(random.choices(string.ascii_letters + string.digits, k=length)) @pytest.fixture -def test_credential(db: Session): - # Create a unique organization name - unique_org_name = "Test Organization " + generate_random_string( - 5 - ) # Ensure unique name - - # Check if the organization already exists by name - existing_org = ( - db.query(Organization).filter(Organization.name == unique_org_name).first() - ) +def org_with_creds(db: Session): + org = Organization(name=f"Test Org {generate_random_string(5)}", is_active=True) + db.add(org) + db.commit() + db.refresh(org) - if existing_org: - org = existing_org # If organization exists, use the existing one - else: - # If not, create a new organization - organization_data = OrganizationCreate(name=unique_org_name, is_active=True) - org = Organization(**organization_data.dict()) # Create Organization instance - db.add(org) # Add to the session - - try: - db.commit() # Commit to save the organization to the database - db.refresh(org) # Refresh to get the organization_id - except IntegrityError as e: - db.rollback() # Rollback the transaction in case of an error (e.g., duplicate key) - raise ValueError(f"Error during organization commit: {str(e)}") - - # Generate a random API key for the test api_key = "sk-" + generate_random_string(10) - - # Create the credentials using the mock organization_id creds_data = CredsCreate( - organization_id=org.id, # Use the created organization_id + organization_id=org.id, is_active=True, - credential={"openai": {"api_key": api_key}}, + credential={ + Provider.OPENAI.value: { + "api_key": api_key, + "model": "gpt-4", + "temperature": 0.7 + } + }, ) - creds = set_creds_for_org(session=db, creds_add=creds_data) - return creds + return org, creds -def test_create_credentials(db: Session, test_credential): - creds = test_credential # Using the fixture +def test_create_credentials(db: Session, org_with_creds): + org, creds = org_with_creds assert creds is not None - assert creds.credential["openai"]["api_key"].startswith("sk-") - assert creds.is_active is True - assert creds.inserted_at is not None # Ensure inserted_at is set - - -def test_get_creds_by_org(db: Session, test_credential): - creds = test_credential # Using the fixture - retrieved_creds = get_creds_by_org(session=db, org_id=creds.organization_id) - - assert retrieved_creds is not None - assert retrieved_creds.organization_id == creds.organization_id - assert retrieved_creds.inserted_at is not None # Ensure inserted_at is not None - - -def test_update_creds_for_org(db: Session, test_credential): - creds = test_credential # Using the fixture - updated_creds_data = CredsUpdate(credential={"openai": {"api_key": "sk-newkey"}}) - - updated_creds = update_creds_for_org( - session=db, org_id=creds.organization_id, creds_in=updated_creds_data + assert len(creds) == 1 + assert creds[0].provider == Provider.OPENAI.value + assert "api_key" in creds[0].credential + assert creds[0].credential["api_key"].startswith("encrypted_") + assert creds[0].is_active + assert creds[0].inserted_at is not None + + +def test_get_creds_by_org(db: Session, org_with_creds): + org, creds = org_with_creds + retrieved = get_creds_by_org(session=db, org_id=org.id) + assert retrieved is not None + assert len(retrieved) == 1 + assert retrieved[0].organization_id == org.id + assert retrieved[0].provider == Provider.OPENAI.value + assert "api_key" in retrieved[0].credential + assert retrieved[0].inserted_at is not None + + +def test_update_creds_for_org(db: Session, org_with_creds): + org, _ = org_with_creds + new_api_key = "sk-" + generate_random_string(12) + update_data = CredsUpdate( + provider=Provider.OPENAI.value, + credential={ + "api_key": new_api_key, + "model": "gpt-4-turbo", + "temperature": 0.8 + } ) - assert updated_creds is not None - assert updated_creds.credential["openai"]["api_key"] == "sk-newkey" - assert updated_creds.updated_at is not None # Ensure updated_at is set + updated = update_creds_for_org(session=db, org_id=org.id, creds_in=update_data) + assert updated is not None + assert len(updated) == 1 + assert updated[0].credential["api_key"] == encrypt_api_key(new_api_key) + assert updated[0].credential["model"] == "gpt-4-turbo" + assert updated[0].updated_at is not None -def test_remove_creds_for_org(db: Session, test_credential): - creds = test_credential # Using the fixture - removed_creds = remove_creds_for_org(session=db, org_id=creds.organization_id) - assert removed_creds is not None - assert removed_creds.organization_id == creds.organization_id +def test_remove_creds_for_org(db: Session, org_with_creds): + org, creds = org_with_creds + removed = remove_creds_for_org(session=db, org_id=org.id) - # Ensure the deleted_at timestamp is set for soft delete - assert removed_creds.deleted_at is not None # Ensure deleted_at is set + assert removed is not None + assert len(removed) == 1 + assert removed[0].organization_id == org.id + assert removed[0].deleted_at is not None - # Check that credentials are soft deleted and not removed - deleted_creds = ( - db.query(Credential) - .filter(Credential.organization_id == creds.organization_id) - .first() + # Ensure the record is still present but soft-deleted + still_exists = ( + db.query(Credential).filter(Credential.organization_id == org.id).first() ) - assert deleted_creds is not None # Ensure the record still exists in the DB - assert deleted_creds.deleted_at is not None # Ensure it's marked as deleted + assert still_exists is not None + assert still_exists.deleted_at is not None def test_remove_creds_for_org_not_found(db: Session): - # Try to remove credentials for a non-existent organization ID (999) - non_existing_org_id = 999 - - removed_creds = remove_creds_for_org(session=db, org_id=non_existing_org_id) - - # Assert that no credentials were removed since they don't exist - assert removed_creds is None + removed = remove_creds_for_org(session=db, org_id=999999) + assert removed is None From a7fb8882b7d73ea98eac96b2d0cda827d66c0374 Mon Sep 17 00:00:00 2001 From: Priyanshu singh <111607560+PriyanSingh@users.noreply.github.com> Date: Sat, 10 May 2025 16:02:23 +0530 Subject: [PATCH 03/10] Add soft delete functionality for credentials and update tests to verify deletion --- backend/app/crud/credentials.py | 1 + backend/app/tests/api/routes/test_creds.py | 1 + 2 files changed, 2 insertions(+) diff --git a/backend/app/crud/credentials.py b/backend/app/crud/credentials.py index 828ba66ff..b06668459 100644 --- a/backend/app/crud/credentials.py +++ b/backend/app/crud/credentials.py @@ -188,6 +188,7 @@ def remove_creds_for_org(*, session: Session, org_id: int) -> List[Credential]: for cred in creds: cred.is_active = False + cred.deleted_at = datetime.utcnow() cred.updated_at = datetime.utcnow() session.add(cred) diff --git a/backend/app/tests/api/routes/test_creds.py b/backend/app/tests/api/routes/test_creds.py index b1a04c712..2a148911b 100644 --- a/backend/app/tests/api/routes/test_creds.py +++ b/backend/app/tests/api/routes/test_creds.py @@ -231,6 +231,7 @@ def test_delete_all_credentials(db: Session, superuser_token_headers: dict[str, assert len(data) == 1 assert data[0]["deleted_at"] is not None assert data[0]["is_active"] is False + assert "credential" not in data[0] def test_delete_all_credentials_not_found(db: Session, superuser_token_headers: dict[str, str]): From 206fd9b6d62fdaa4fdccfd331b9705c296680e09 Mon Sep 17 00:00:00 2001 From: Priyanshu singh <111607560+PriyanSingh@users.noreply.github.com> Date: Sat, 10 May 2025 16:27:22 +0530 Subject: [PATCH 04/10] Refactor credential handling to improve error management and ensure proper checks for credential existence; update tests for accuracy in response validation. --- backend/app/api/routes/credentials.py | 6 ++-- backend/app/crud/credentials.py | 33 ++++++++++------------ backend/app/tests/api/routes/test_creds.py | 15 ++++------ backend/app/tests/crud/test_creds.py | 9 ++++-- 4 files changed, 30 insertions(+), 33 deletions(-) diff --git a/backend/app/api/routes/credentials.py b/backend/app/api/routes/credentials.py index 0edbde56f..61e63b56a 100644 --- a/backend/app/api/routes/credentials.py +++ b/backend/app/api/routes/credentials.py @@ -68,6 +68,8 @@ def read_credential(*, session: SessionDep, org_id: int): if not creds: raise HTTPException(status_code=404, detail="Credentials not found") return APIResponse.success_response(creds) + except HTTPException as e: + raise e # Ensure HTTPException is not wrapped again except Exception as e: raise HTTPException( status_code=500, detail=f"An unexpected error occurred: {str(e)}" @@ -148,7 +150,7 @@ def delete_provider_credential(*, session: SessionDep, org_id: int, provider: st status_code=500, detail=f"An unexpected error occurred: {str(e)}" ) - if updated_creds is None: + if not updated_creds: # Ensure proper check for no credentials found raise HTTPException(status_code=404, detail="Provider credentials not found") return APIResponse.success_response( @@ -171,7 +173,7 @@ def delete_all_credentials(*, session: SessionDep, org_id: int): status_code=500, detail=f"An unexpected error occurred: {str(e)}" ) - if creds is None: + if not creds: # Ensure proper check for no credentials found raise HTTPException( status_code=404, detail="Credentials for organization not found" ) diff --git a/backend/app/crud/credentials.py b/backend/app/crud/credentials.py index b06668459..747cefc1b 100644 --- a/backend/app/crud/credentials.py +++ b/backend/app/crud/credentials.py @@ -178,25 +178,22 @@ def remove_provider_credential( raise ValueError(f"Error while removing provider credentials: {str(e)}") -def remove_creds_for_org(*, session: Session, org_id: int) -> List[Credential]: - """Removes (soft deletes) all credentials for the given organization.""" - statement = select(Credential).where( - Credential.organization_id == org_id, - Credential.is_active == True - ) - creds = session.exec(statement).all() - +def remove_creds_for_org(session: Session, org_id: int): + """ + Removes all credentials for a specific organization by marking them as inactive. + Returns the list of updated credentials or None if no credentials were found. + """ + creds = session.exec( + select(Credential).where(Credential.organization_id == org_id, Credential.is_active == True) + ).all() + + if not creds: + return None # Return None if no credentials are found + for cred in creds: cred.is_active = False cred.deleted_at = datetime.utcnow() - cred.updated_at = datetime.utcnow() session.add(cred) - - try: - session.commit() - for cred in creds: - session.refresh(cred) - return creds - except IntegrityError as e: - session.rollback() - raise ValueError(f"Error while removing organization credentials: {str(e)}") \ No newline at end of file + + session.commit() + return creds \ No newline at end of file diff --git a/backend/app/tests/api/routes/test_creds.py b/backend/app/tests/api/routes/test_creds.py index 2a148911b..9ef802374 100644 --- a/backend/app/tests/api/routes/test_creds.py +++ b/backend/app/tests/api/routes/test_creds.py @@ -203,7 +203,7 @@ def test_delete_provider_credential_not_found(db: Session, superuser_token_heade headers=superuser_token_headers, ) - assert response.status_code == 404 + assert response.status_code == 404 # Expect 404 for not found assert response.json()["detail"] == "Provider credentials not found" @@ -216,7 +216,7 @@ def test_delete_all_credentials(db: Session, superuser_token_headers: dict[str, headers=superuser_token_headers, ) - assert response.status_code == 200 + assert response.status_code == 200 # Expect 200 for successful deletion data = response.json()["data"] assert data["message"] == "Credentials deleted successfully" @@ -225,13 +225,8 @@ def test_delete_all_credentials(db: Session, superuser_token_headers: dict[str, f"{settings.API_V1_STR}/credentials/{org.id}", headers=superuser_token_headers, ) - assert response.status_code == 200 - data = response.json()["data"] - assert isinstance(data, list) - assert len(data) == 1 - assert data[0]["deleted_at"] is not None - assert data[0]["is_active"] is False - assert "credential" not in data[0] + assert response.status_code == 404 # Expect 404 as credentials are soft deleted + assert response.json()["detail"] == "Credentials not found" def test_delete_all_credentials_not_found(db: Session, superuser_token_headers: dict[str, str]): @@ -240,5 +235,5 @@ def test_delete_all_credentials_not_found(db: Session, superuser_token_headers: headers=superuser_token_headers, ) - assert response.status_code == 404 + assert response.status_code == 404 # Expect 404 for not found assert response.json()["detail"] == "Credentials for organization not found" diff --git a/backend/app/tests/crud/test_creds.py b/backend/app/tests/crud/test_creds.py index 6c8b3ad20..c7a0b3e63 100644 --- a/backend/app/tests/crud/test_creds.py +++ b/backend/app/tests/crud/test_creds.py @@ -20,7 +20,7 @@ remove_creds_for_org, ) from app.core.providers import Provider -from app.core.security import encrypt_api_key +from app.core.security import encrypt_api_key, decrypt_api_key from app.main import app client = TestClient(app) @@ -59,7 +59,9 @@ def test_create_credentials(db: Session, org_with_creds): assert len(creds) == 1 assert creds[0].provider == Provider.OPENAI.value assert "api_key" in creds[0].credential - assert creds[0].credential["api_key"].startswith("encrypted_") + # Decrypt the stored API key before asserting + decrypted_api_key = decrypt_api_key(creds[0].credential["api_key"]) + assert decrypted_api_key.startswith("sk-") assert creds[0].is_active assert creds[0].inserted_at is not None @@ -91,7 +93,8 @@ def test_update_creds_for_org(db: Session, org_with_creds): assert updated is not None assert len(updated) == 1 - assert updated[0].credential["api_key"] == encrypt_api_key(new_api_key) + # Decrypt the stored API key before asserting equality + assert decrypt_api_key(updated[0].credential["api_key"]) == new_api_key assert updated[0].credential["model"] == "gpt-4-turbo" assert updated[0].updated_at is not None From e935a1538cf6793efd59c720c33b134afab8eef9 Mon Sep 17 00:00:00 2001 From: Priyanshu singh <111607560+PriyanSingh@users.noreply.github.com> Date: Sat, 10 May 2025 16:54:26 +0530 Subject: [PATCH 05/10] Enhance credential update handling with organization existence checks and improved error responses; update tests for accurate status codes and messages. --- backend/app/api/routes/credentials.py | 21 ++++++++++++++++++++- backend/app/crud/credentials.py | 3 +++ backend/app/tests/api/routes/test_creds.py | 4 ++-- 3 files changed, 25 insertions(+), 3 deletions(-) diff --git a/backend/app/api/routes/credentials.py b/backend/app/api/routes/credentials.py index 61e63b56a..dd7d1a7bd 100644 --- a/backend/app/api/routes/credentials.py +++ b/backend/app/api/routes/credentials.py @@ -13,6 +13,8 @@ from datetime import datetime from app.core.providers import validate_provider from typing import List +from sqlalchemy.exc import IntegrityError +from app.models.organization import Organization router = APIRouter(prefix="/credentials", tags=["credentials"]) @@ -111,6 +113,14 @@ def read_provider_credential(*, session: SessionDep, org_id: int, provider: str) ) def update_credential(*, session: SessionDep, org_id: int, creds_in: CredsUpdate): try: + # Check if the organization exists + organization = session.get(Organization, org_id) + if not organization: + raise HTTPException( + status_code=404, + detail="Organization not found" + ) + updated_creds = update_creds_for_org( session=session, org_id=org_id, creds_in=creds_in ) @@ -120,6 +130,15 @@ def update_credential(*, session: SessionDep, org_id: int, creds_in: CredsUpdate detail="Failed to update credentials" ) return APIResponse.success_response(updated_creds) + except IntegrityError as e: + if "ForeignKeyViolation" in str(e): + raise HTTPException( + status_code=400, + detail="Invalid organization ID. Ensure the organization exists before updating credentials." + ) + raise HTTPException( + status_code=500, detail=f"An unexpected database error occurred: {str(e)}" + ) except ValueError as e: if "Unsupported provider" in str(e): raise HTTPException(status_code=400, detail=str(e)) @@ -144,7 +163,7 @@ def delete_provider_credential(*, session: SessionDep, org_id: int, provider: st session=session, org_id=org_id, provider=provider_enum ) except ValueError as e: - raise HTTPException(status_code=400, detail=str(e)) + raise HTTPException(status_code=404, detail="Provider credentials not found") # Updated to return 404 except Exception as e: raise HTTPException( status_code=500, detail=f"An unexpected error occurred: {str(e)}" diff --git a/backend/app/crud/credentials.py b/backend/app/crud/credentials.py index 747cefc1b..16ade641c 100644 --- a/backend/app/crud/credentials.py +++ b/backend/app/crud/credentials.py @@ -99,6 +99,9 @@ def get_providers(*, session: Session, org_id: int) -> List[str]: def update_creds_for_org( session: Session, org_id: int, creds_in: CredsUpdate ) -> List[Credential]: + if not creds_in: + raise ValueError("Missing request body or failed to parse JSON into CredsUpdate") + """Update credentials for an organization. Can update specific provider or add new provider.""" if not creds_in.provider or not creds_in.credential: raise ValueError("Provider and credential information must be provided") diff --git a/backend/app/tests/api/routes/test_creds.py b/backend/app/tests/api/routes/test_creds.py index 9ef802374..6a133de2e 100644 --- a/backend/app/tests/api/routes/test_creds.py +++ b/backend/app/tests/api/routes/test_creds.py @@ -177,8 +177,8 @@ def test_update_credentials_not_found(db: Session, superuser_token_headers: dict headers=superuser_token_headers, ) - assert response.status_code == 404 - assert response.json()["detail"] == "Failed to update credentials" + assert response.status_code == 404 # Expect 404 for non-existent organization + assert response.json()["detail"] == "Organization not found" def test_delete_provider_credential(db: Session, superuser_token_headers: dict[str, str], create_organization_and_creds): From dfd17b476f83a5cf1ec4571d7571e722bb97fffd Mon Sep 17 00:00:00 2001 From: Priyanshu singh <111607560+PriyanSingh@users.noreply.github.com> Date: Sat, 10 May 2025 21:06:09 +0530 Subject: [PATCH 06/10] Enhance credential update handling with validation for provider and credential fields; improve error responses for organization checks and unexpected exceptions. --- ...dded_provider_column_to_the_credential_.py | 19 ++++++++++++--- backend/app/api/routes/credentials.py | 24 +++++++++++++++---- 2 files changed, 35 insertions(+), 8 deletions(-) diff --git a/backend/app/alembic/versions/904ed70e7dab_added_provider_column_to_the_credential_.py b/backend/app/alembic/versions/904ed70e7dab_added_provider_column_to_the_credential_.py index bb7677802..2269a74d1 100644 --- a/backend/app/alembic/versions/904ed70e7dab_added_provider_column_to_the_credential_.py +++ b/backend/app/alembic/versions/904ed70e7dab_added_provider_column_to_the_credential_.py @@ -22,7 +22,14 @@ def upgrade(): op.add_column('credential', sa.Column('provider', sqlmodel.sql.sqltypes.AutoString(), nullable=False)) op.create_index(op.f('ix_credential_provider'), 'credential', ['provider'], unique=False) op.drop_constraint('credential_organization_id_fkey', 'credential', type_='foreignkey') - op.create_foreign_key(None, 'credential', 'organization', ['organization_id'], ['id']) + op.create_foreign_key( + 'credential_organization_id_fkey', + 'credential', + 'organization', + ['organization_id'], + ['id'], + ondelete='CASCADE' + ) op.drop_constraint('project_organization_id_fkey', 'project', type_='foreignkey') op.create_foreign_key(None, 'project', 'organization', ['organization_id'], ['id']) # ### end Alembic commands ### @@ -32,8 +39,14 @@ def downgrade(): # ### commands auto generated by Alembic - please adjust! ### op.drop_constraint(None, 'project', type_='foreignkey') op.create_foreign_key('project_organization_id_fkey', 'project', 'organization', ['organization_id'], ['id'], ondelete='CASCADE') - op.drop_constraint(None, 'credential', type_='foreignkey') - op.create_foreign_key('credential_organization_id_fkey', 'credential', 'organization', ['organization_id'], ['id'], ondelete='CASCADE') + op.drop_constraint('credential_organization_id_fkey', 'credential', type_='foreignkey') + op.create_foreign_key( + 'credential_organization_id_fkey', + 'credential', + 'organization', + ['organization_id'], + ['id'] + ) op.drop_index(op.f('ix_credential_provider'), table_name='credential') op.drop_column('credential', 'provider') # ### end Alembic commands ### diff --git a/backend/app/api/routes/credentials.py b/backend/app/api/routes/credentials.py index dd7d1a7bd..90d1fdb3a 100644 --- a/backend/app/api/routes/credentials.py +++ b/backend/app/api/routes/credentials.py @@ -113,8 +113,19 @@ def read_provider_credential(*, session: SessionDep, org_id: int, provider: str) ) def update_credential(*, session: SessionDep, org_id: int, creds_in: CredsUpdate): try: - # Check if the organization exists - organization = session.get(Organization, org_id) + # Validate incoming payload + if not creds_in or not creds_in.provider or not creds_in.credential: + raise HTTPException( + status_code=400, + detail="Provider and credential must be provided" + ) + + # Defensive check to ensure organization exists + try: + organization = session.get(Organization, org_id) + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to fetch organization: {str(e)}") + if not organization: raise HTTPException( status_code=404, @@ -129,7 +140,9 @@ def update_credential(*, session: SessionDep, org_id: int, creds_in: CredsUpdate status_code=404, detail="Failed to update credentials" ) + return APIResponse.success_response(updated_creds) + except IntegrityError as e: if "ForeignKeyViolation" in str(e): raise HTTPException( @@ -143,12 +156,12 @@ def update_credential(*, session: SessionDep, org_id: int, creds_in: CredsUpdate if "Unsupported provider" in str(e): raise HTTPException(status_code=400, detail=str(e)) raise HTTPException(status_code=404, detail=str(e)) + except HTTPException: + raise except Exception as e: raise HTTPException( status_code=500, detail=f"An unexpected error occurred: {str(e)}" ) - - @router.delete( "/{org_id}/{provider}", dependencies=[Depends(get_current_active_superuser)], @@ -194,7 +207,8 @@ def delete_all_credentials(*, session: SessionDep, org_id: int): if not creds: # Ensure proper check for no credentials found raise HTTPException( - status_code=404, detail="Credentials for organization not found" + status_code=404, + detail="Credentials for organization not found" ) return APIResponse.success_response({"message": "Credentials deleted successfully"}) \ No newline at end of file From c1a54e88df35063761ce405da4d2e77cdb2d2598 Mon Sep 17 00:00:00 2001 From: Priyanshu singh <111607560+PriyanSingh@users.noreply.github.com> Date: Sat, 10 May 2025 21:19:31 +0530 Subject: [PATCH 07/10] Refactor credential --- ...dded_provider_column_to_the_credential_.py | 62 ++++++++++++------- backend/app/api/routes/credentials.py | 43 ++++++------- backend/app/core/providers.py | 2 +- backend/app/crud/credentials.py | 37 ++++++----- backend/app/models/__init__.py | 8 +-- backend/app/models/credentials.py | 9 ++- backend/app/tests/api/routes/test_creds.py | 52 +++++++++++----- backend/app/tests/crud/test_creds.py | 8 +-- 8 files changed, 123 insertions(+), 98 deletions(-) diff --git a/backend/app/alembic/versions/904ed70e7dab_added_provider_column_to_the_credential_.py b/backend/app/alembic/versions/904ed70e7dab_added_provider_column_to_the_credential_.py index 2269a74d1..d368bcdbf 100644 --- a/backend/app/alembic/versions/904ed70e7dab_added_provider_column_to_the_credential_.py +++ b/backend/app/alembic/versions/904ed70e7dab_added_provider_column_to_the_credential_.py @@ -11,42 +11,58 @@ # revision identifiers, used by Alembic. -revision = '904ed70e7dab' -down_revision = '543f97951bd0' +revision = "904ed70e7dab" +down_revision = "543f97951bd0" branch_labels = None depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### - op.add_column('credential', sa.Column('provider', sqlmodel.sql.sqltypes.AutoString(), nullable=False)) - op.create_index(op.f('ix_credential_provider'), 'credential', ['provider'], unique=False) - op.drop_constraint('credential_organization_id_fkey', 'credential', type_='foreignkey') + op.add_column( + "credential", + sa.Column("provider", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + ) + op.create_index( + op.f("ix_credential_provider"), "credential", ["provider"], unique=False + ) + op.drop_constraint( + "credential_organization_id_fkey", "credential", type_="foreignkey" + ) op.create_foreign_key( - 'credential_organization_id_fkey', - 'credential', - 'organization', - ['organization_id'], - ['id'], - ondelete='CASCADE' + "credential_organization_id_fkey", + "credential", + "organization", + ["organization_id"], + ["id"], + ondelete="CASCADE", ) - op.drop_constraint('project_organization_id_fkey', 'project', type_='foreignkey') - op.create_foreign_key(None, 'project', 'organization', ['organization_id'], ['id']) + op.drop_constraint("project_organization_id_fkey", "project", type_="foreignkey") + op.create_foreign_key(None, "project", "organization", ["organization_id"], ["id"]) # ### end Alembic commands ### def downgrade(): # ### commands auto generated by Alembic - please adjust! ### - op.drop_constraint(None, 'project', type_='foreignkey') - op.create_foreign_key('project_organization_id_fkey', 'project', 'organization', ['organization_id'], ['id'], ondelete='CASCADE') - op.drop_constraint('credential_organization_id_fkey', 'credential', type_='foreignkey') + op.drop_constraint(None, "project", type_="foreignkey") + op.create_foreign_key( + "project_organization_id_fkey", + "project", + "organization", + ["organization_id"], + ["id"], + ondelete="CASCADE", + ) + op.drop_constraint( + "credential_organization_id_fkey", "credential", type_="foreignkey" + ) op.create_foreign_key( - 'credential_organization_id_fkey', - 'credential', - 'organization', - ['organization_id'], - ['id'] + "credential_organization_id_fkey", + "credential", + "organization", + ["organization_id"], + ["id"], ) - op.drop_index(op.f('ix_credential_provider'), table_name='credential') - op.drop_column('credential', 'provider') + op.drop_index(op.f("ix_credential_provider"), table_name="credential") + op.drop_column("credential", "provider") # ### end Alembic commands ### diff --git a/backend/app/api/routes/credentials.py b/backend/app/api/routes/credentials.py index 90d1fdb3a..5a2da39b0 100644 --- a/backend/app/api/routes/credentials.py +++ b/backend/app/api/routes/credentials.py @@ -34,19 +34,16 @@ def create_new_credential(*, session: SessionDep, creds_in: CredsCreate): if existing_creds: raise HTTPException( status_code=400, - detail="Credentials already exist for this organization" + detail="Credentials already exist for this organization", ) - + new_creds = set_creds_for_org(session=session, creds_add=creds_in) if not new_creds: - raise HTTPException( - status_code=500, - detail="Failed to create credentials" - ) - + raise HTTPException(status_code=500, detail="Failed to create credentials") + # Return all created credentials return APIResponse.success_response(new_creds) - + except ValueError as e: if "Unsupported provider" in str(e): raise HTTPException(status_code=400, detail=str(e)) @@ -116,30 +113,25 @@ def update_credential(*, session: SessionDep, org_id: int, creds_in: CredsUpdate # Validate incoming payload if not creds_in or not creds_in.provider or not creds_in.credential: raise HTTPException( - status_code=400, - detail="Provider and credential must be provided" + status_code=400, detail="Provider and credential must be provided" ) # Defensive check to ensure organization exists try: organization = session.get(Organization, org_id) except Exception as e: - raise HTTPException(status_code=500, detail=f"Failed to fetch organization: {str(e)}") - - if not organization: raise HTTPException( - status_code=404, - detail="Organization not found" + status_code=500, detail=f"Failed to fetch organization: {str(e)}" ) + if not organization: + raise HTTPException(status_code=404, detail="Organization not found") + updated_creds = update_creds_for_org( session=session, org_id=org_id, creds_in=creds_in ) if not updated_creds: - raise HTTPException( - status_code=404, - detail="Failed to update credentials" - ) + raise HTTPException(status_code=404, detail="Failed to update credentials") return APIResponse.success_response(updated_creds) @@ -147,7 +139,7 @@ def update_credential(*, session: SessionDep, org_id: int, creds_in: CredsUpdate if "ForeignKeyViolation" in str(e): raise HTTPException( status_code=400, - detail="Invalid organization ID. Ensure the organization exists before updating credentials." + detail="Invalid organization ID. Ensure the organization exists before updating credentials.", ) raise HTTPException( status_code=500, detail=f"An unexpected database error occurred: {str(e)}" @@ -162,6 +154,8 @@ def update_credential(*, session: SessionDep, org_id: int, creds_in: CredsUpdate raise HTTPException( status_code=500, detail=f"An unexpected error occurred: {str(e)}" ) + + @router.delete( "/{org_id}/{provider}", dependencies=[Depends(get_current_active_superuser)], @@ -176,7 +170,9 @@ def delete_provider_credential(*, session: SessionDep, org_id: int, provider: st session=session, org_id=org_id, provider=provider_enum ) except ValueError as e: - raise HTTPException(status_code=404, detail="Provider credentials not found") # Updated to return 404 + raise HTTPException( + status_code=404, detail="Provider credentials not found" + ) # Updated to return 404 except Exception as e: raise HTTPException( status_code=500, detail=f"An unexpected error occurred: {str(e)}" @@ -207,8 +203,7 @@ def delete_all_credentials(*, session: SessionDep, org_id: int): if not creds: # Ensure proper check for no credentials found raise HTTPException( - status_code=404, - detail="Credentials for organization not found" + status_code=404, detail="Credentials for organization not found" ) - return APIResponse.success_response({"message": "Credentials deleted successfully"}) \ No newline at end of file + return APIResponse.success_response({"message": "Credentials deleted successfully"}) diff --git a/backend/app/core/providers.py b/backend/app/core/providers.py index 1f572e6d4..1f513fea9 100644 --- a/backend/app/core/providers.py +++ b/backend/app/core/providers.py @@ -55,4 +55,4 @@ def validate_provider_credentials(provider: str, credentials: dict) -> None: def get_supported_providers() -> List[str]: """Return a list of all supported provider names.""" - return [p.value for p in Provider] \ No newline at end of file + return [p.value for p in Provider] diff --git a/backend/app/crud/credentials.py b/backend/app/crud/credentials.py index 16ade641c..ff3c01b6f 100644 --- a/backend/app/crud/credentials.py +++ b/backend/app/crud/credentials.py @@ -43,7 +43,9 @@ def set_creds_for_org(*, session: Session, creds_add: CredsCreate) -> List[Crede created_credentials.append(credential) except IntegrityError as e: session.rollback() - raise ValueError(f"Error while adding credentials for provider {provider}: {str(e)}") + raise ValueError( + f"Error while adding credentials for provider {provider}: {str(e)}" + ) return created_credentials @@ -55,7 +57,7 @@ def get_key_by_org( statement = select(Credential).where( Credential.organization_id == org_id, Credential.provider == provider, - Credential.is_active == True + Credential.is_active == True, ) creds = session.exec(statement).first() @@ -68,8 +70,7 @@ def get_key_by_org( def get_creds_by_org(*, session: Session, org_id: int) -> List[Credential]: """Fetches all active credentials for the given organization.""" statement = select(Credential).where( - Credential.organization_id == org_id, - Credential.is_active == True + Credential.organization_id == org_id, Credential.is_active == True ) return session.exec(statement).all() @@ -79,14 +80,14 @@ def get_provider_credential( ) -> Optional[Dict[str, Any]]: """Fetches credentials for a specific provider of an organization.""" validate_provider(provider) - + statement = select(Credential).where( Credential.organization_id == org_id, Credential.provider == provider, - Credential.is_active == True + Credential.is_active == True, ) creds = session.exec(statement).first() - + return creds.credential if creds else None @@ -100,7 +101,9 @@ def update_creds_for_org( session: Session, org_id: int, creds_in: CredsUpdate ) -> List[Credential]: if not creds_in: - raise ValueError("Missing request body or failed to parse JSON into CredsUpdate") + raise ValueError( + "Missing request body or failed to parse JSON into CredsUpdate" + ) """Update credentials for an organization. Can update specific provider or add new provider.""" if not creds_in.provider or not creds_in.credential: @@ -116,15 +119,16 @@ def update_creds_for_org( # Check if credentials exist for this provider statement = select(Credential).where( - Credential.organization_id == org_id, - Credential.provider == creds_in.provider + Credential.organization_id == org_id, Credential.provider == creds_in.provider ) existing_cred = session.exec(statement).first() if existing_cred: # Update existing credentials existing_cred.credential = creds_in.credential - existing_cred.is_active = creds_in.is_active if creds_in.is_active is not None else True + existing_cred.is_active = ( + creds_in.is_active if creds_in.is_active is not None else True + ) existing_cred.updated_at = datetime.utcnow() try: session.add(existing_cred) @@ -140,7 +144,7 @@ def update_creds_for_org( organization_id=org_id, provider=creds_in.provider, credential=creds_in.credential, - is_active=creds_in.is_active if creds_in.is_active is not None else True + is_active=creds_in.is_active if creds_in.is_active is not None else True, ) try: session.add(new_cred) @@ -159,8 +163,7 @@ def remove_provider_credential( validate_provider(provider) statement = select(Credential).where( - Credential.organization_id == org_id, - Credential.provider == provider + Credential.organization_id == org_id, Credential.provider == provider ) creds = session.exec(statement).first() @@ -187,7 +190,9 @@ def remove_creds_for_org(session: Session, org_id: int): Returns the list of updated credentials or None if no credentials were found. """ creds = session.exec( - select(Credential).where(Credential.organization_id == org_id, Credential.is_active == True) + select(Credential).where( + Credential.organization_id == org_id, Credential.is_active == True + ) ).all() if not creds: @@ -199,4 +204,4 @@ def remove_creds_for_org(session: Session, org_id: int): session.add(cred) session.commit() - return creds \ No newline at end of file + return creds diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index ca1f07f9b..874764c5d 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -43,10 +43,4 @@ UpdatePassword, ) -from .credentials import ( - Credential, - CredsBase, - CredsCreate, - CredsPublic, - CredsUpdate -) +from .credentials import Credential, CredsBase, CredsCreate, CredsPublic, CredsUpdate diff --git a/backend/app/models/credentials.py b/backend/app/models/credentials.py index c4531ff78..6c7146a80 100644 --- a/backend/app/models/credentials.py +++ b/backend/app/models/credentials.py @@ -46,10 +46,10 @@ class Credential(CredsBase, table=True): """Database model for storing provider credentials. Each row represents credentials for a single provider. """ + id: int = Field(default=None, primary_key=True) provider: str = Field( - index=True, - description="Provider name like 'openai', 'gemini'" + index=True, description="Provider name like 'openai', 'gemini'" ) credential: Dict[str, Any] = Field( sa_column=sa.Column(MutableDict.as_mutable(sa.JSON)), @@ -64,8 +64,7 @@ class Credential(CredsBase, table=True): sa_column=sa.Column(sa.DateTime, onupdate=datetime.utcnow), ) deleted_at: Optional[datetime] = Field( - default=None, - sa_column=sa.Column(sa.DateTime, nullable=True) + default=None, sa_column=sa.Column(sa.DateTime, nullable=True) ) organization: Optional["Organization"] = Relationship(back_populates="creds") @@ -79,4 +78,4 @@ class CredsPublic(CredsBase): credential: Dict[str, Any] inserted_at: datetime updated_at: datetime - deleted_at: Optional[datetime] \ No newline at end of file + deleted_at: Optional[datetime] diff --git a/backend/app/tests/api/routes/test_creds.py b/backend/app/tests/api/routes/test_creds.py index 6a133de2e..9d09c5b19 100644 --- a/backend/app/tests/api/routes/test_creds.py +++ b/backend/app/tests/api/routes/test_creds.py @@ -35,7 +35,7 @@ def create_organization_and_creds(db: Session): Provider.OPENAI.value: { "api_key": api_key, "model": "gpt-4", - "temperature": 0.7 + "temperature": 0.7, } }, ) @@ -56,7 +56,7 @@ def test_set_creds_for_org(db: Session, superuser_token_headers: dict[str, str]) Provider.OPENAI.value: { "api_key": api_key, "model": "gpt-4", - "temperature": 0.7 + "temperature": 0.7, } }, } @@ -76,7 +76,9 @@ def test_set_creds_for_org(db: Session, superuser_token_headers: dict[str, str]) assert data[0]["credential"]["model"] == "gpt-4" -def test_read_credentials_with_creds(db: Session, superuser_token_headers: dict[str, str], create_organization_and_creds): +def test_read_credentials_with_creds( + db: Session, superuser_token_headers: dict[str, str], create_organization_and_creds +): org, creds_data = create_organization_and_creds set_creds_for_org(session=db, creds_add=creds_data) @@ -94,7 +96,9 @@ def test_read_credentials_with_creds(db: Session, superuser_token_headers: dict[ assert data[0]["credential"]["model"] == "gpt-4" -def test_read_credentials_not_found(db: Session, superuser_token_headers: dict[str, str]): +def test_read_credentials_not_found( + db: Session, superuser_token_headers: dict[str, str] +): response = client.get( f"{settings.API_V1_STR}/credentials/999999", headers=superuser_token_headers, @@ -103,7 +107,9 @@ def test_read_credentials_not_found(db: Session, superuser_token_headers: dict[s assert response.json()["detail"] == "Credentials not found" -def test_read_provider_credential(db: Session, superuser_token_headers: dict[str, str], create_organization_and_creds): +def test_read_provider_credential( + db: Session, superuser_token_headers: dict[str, str], create_organization_and_creds +): org, creds_data = create_organization_and_creds set_creds_for_org(session=db, creds_add=creds_data) @@ -118,7 +124,9 @@ def test_read_provider_credential(db: Session, superuser_token_headers: dict[str assert "api_key" in data -def test_read_provider_credential_not_found(db: Session, superuser_token_headers: dict[str, str], create_organization_and_creds): +def test_read_provider_credential_not_found( + db: Session, superuser_token_headers: dict[str, str], create_organization_and_creds +): org, _ = create_organization_and_creds response = client.get( @@ -130,7 +138,9 @@ def test_read_provider_credential_not_found(db: Session, superuser_token_headers assert response.json()["detail"] == "Provider credentials not found" -def test_update_credentials(db: Session, superuser_token_headers: dict[str, str], create_organization_and_creds): +def test_update_credentials( + db: Session, superuser_token_headers: dict[str, str], create_organization_and_creds +): org, creds_data = create_organization_and_creds set_creds_for_org(session=db, creds_add=creds_data) @@ -139,8 +149,8 @@ def test_update_credentials(db: Session, superuser_token_headers: dict[str, str] "credential": { "api_key": "sk-" + generate_random_string(), "model": "gpt-4-turbo", - "temperature": 0.8 - } + "temperature": 0.8, + }, } response = client.patch( @@ -158,7 +168,9 @@ def test_update_credentials(db: Session, superuser_token_headers: dict[str, str] assert data[0]["updated_at"] is not None -def test_update_credentials_not_found(db: Session, superuser_token_headers: dict[str, str]): +def test_update_credentials_not_found( + db: Session, superuser_token_headers: dict[str, str] +): # Create a non-existent organization ID non_existent_org_id = 999999 @@ -167,8 +179,8 @@ def test_update_credentials_not_found(db: Session, superuser_token_headers: dict "credential": { "api_key": "sk-" + generate_random_string(), "model": "gpt-4", - "temperature": 0.7 - } + "temperature": 0.7, + }, } response = client.patch( @@ -181,7 +193,9 @@ def test_update_credentials_not_found(db: Session, superuser_token_headers: dict assert response.json()["detail"] == "Organization not found" -def test_delete_provider_credential(db: Session, superuser_token_headers: dict[str, str], create_organization_and_creds): +def test_delete_provider_credential( + db: Session, superuser_token_headers: dict[str, str], create_organization_and_creds +): org, creds_data = create_organization_and_creds set_creds_for_org(session=db, creds_add=creds_data) @@ -195,7 +209,9 @@ def test_delete_provider_credential(db: Session, superuser_token_headers: dict[s assert data["message"] == "Provider credentials removed successfully" -def test_delete_provider_credential_not_found(db: Session, superuser_token_headers: dict[str, str], create_organization_and_creds): +def test_delete_provider_credential_not_found( + db: Session, superuser_token_headers: dict[str, str], create_organization_and_creds +): org, _ = create_organization_and_creds response = client.delete( @@ -207,7 +223,9 @@ def test_delete_provider_credential_not_found(db: Session, superuser_token_heade assert response.json()["detail"] == "Provider credentials not found" -def test_delete_all_credentials(db: Session, superuser_token_headers: dict[str, str], create_organization_and_creds): +def test_delete_all_credentials( + db: Session, superuser_token_headers: dict[str, str], create_organization_and_creds +): org, creds_data = create_organization_and_creds set_creds_for_org(session=db, creds_add=creds_data) @@ -229,7 +247,9 @@ def test_delete_all_credentials(db: Session, superuser_token_headers: dict[str, assert response.json()["detail"] == "Credentials not found" -def test_delete_all_credentials_not_found(db: Session, superuser_token_headers: dict[str, str]): +def test_delete_all_credentials_not_found( + db: Session, superuser_token_headers: dict[str, str] +): response = client.delete( f"{settings.API_V1_STR}/credentials/999999", headers=superuser_token_headers, diff --git a/backend/app/tests/crud/test_creds.py b/backend/app/tests/crud/test_creds.py index c7a0b3e63..2f4667c98 100644 --- a/backend/app/tests/crud/test_creds.py +++ b/backend/app/tests/crud/test_creds.py @@ -45,7 +45,7 @@ def org_with_creds(db: Session): Provider.OPENAI.value: { "api_key": api_key, "model": "gpt-4", - "temperature": 0.7 + "temperature": 0.7, } }, ) @@ -82,11 +82,7 @@ def test_update_creds_for_org(db: Session, org_with_creds): new_api_key = "sk-" + generate_random_string(12) update_data = CredsUpdate( provider=Provider.OPENAI.value, - credential={ - "api_key": new_api_key, - "model": "gpt-4-turbo", - "temperature": 0.8 - } + credential={"api_key": new_api_key, "model": "gpt-4-turbo", "temperature": 0.8}, ) updated = update_creds_for_org(session=db, org_id=org.id, creds_in=update_data) From df539d098acd972754e75021c6ae97b6694e29f6 Mon Sep 17 00:00:00 2001 From: Priyanshu singh <111607560+PriyanSingh@users.noreply.github.com> Date: Sat, 10 May 2025 21:42:21 +0530 Subject: [PATCH 08/10] Fix down_revision reference in migration script for provider column addition --- .../904ed70e7dab_added_provider_column_to_the_credential_.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/app/alembic/versions/904ed70e7dab_added_provider_column_to_the_credential_.py b/backend/app/alembic/versions/904ed70e7dab_added_provider_column_to_the_credential_.py index d368bcdbf..2864516b1 100644 --- a/backend/app/alembic/versions/904ed70e7dab_added_provider_column_to_the_credential_.py +++ b/backend/app/alembic/versions/904ed70e7dab_added_provider_column_to_the_credential_.py @@ -12,7 +12,7 @@ # revision identifiers, used by Alembic. revision = "904ed70e7dab" -down_revision = "543f97951bd0" +down_revision = "f23675767ed2" branch_labels = None depends_on = None From 35329e246212705aa6a62f0535613dd3870914d6 Mon Sep 17 00:00:00 2001 From: Priyanshu singh <111607560+PriyanSingh@users.noreply.github.com> Date: Sat, 10 May 2025 23:49:22 +0530 Subject: [PATCH 09/10] added provider test in core --- backend/app/tests/core/test_providers.py | 37 ++++++++++++++++++++++++ 1 file changed, 37 insertions(+) create mode 100644 backend/app/tests/core/test_providers.py diff --git a/backend/app/tests/core/test_providers.py b/backend/app/tests/core/test_providers.py new file mode 100644 index 000000000..647f9751b --- /dev/null +++ b/backend/app/tests/core/test_providers.py @@ -0,0 +1,37 @@ +import pytest +from app.core.providers import ( + validate_provider, + validate_provider_credentials, + Provider, +) + + +def test_validate_provider_invalid(): + """Test validating an invalid provider name.""" + with pytest.raises(ValueError) as exc_info: + validate_provider("invalid_provider") + assert "Unsupported provider" in str(exc_info.value) + assert "openai" in str(exc_info.value) # Check that supported providers are listed + + +def test_validate_provider_credentials_missing_fields(): + """Test validating provider credentials with missing required fields.""" + # Test OpenAI missing api_key + with pytest.raises(ValueError) as exc_info: + validate_provider_credentials("openai", {}) + assert "Missing required fields" in str(exc_info.value) + assert "api_key" in str(exc_info.value) + + # Test Azure missing endpoint + with pytest.raises(ValueError) as exc_info: + validate_provider_credentials("azure", {"api_key": "test-key"}) + assert "Missing required fields" in str(exc_info.value) + assert "endpoint" in str(exc_info.value) + + # Test AWS missing region + with pytest.raises(ValueError) as exc_info: + validate_provider_credentials( + "aws", {"access_key_id": "test-id", "secret_access_key": "test-secret"} + ) + assert "Missing required fields" in str(exc_info.value) + assert "region" in str(exc_info.value) \ No newline at end of file From d4fa0ccd956d4fb005b5572cde70dbde6a7250f0 Mon Sep 17 00:00:00 2001 From: Priyanshu singh <111607560+PriyanSingh@users.noreply.github.com> Date: Sat, 10 May 2025 23:51:54 +0530 Subject: [PATCH 10/10] refactor test_provider.py --- backend/app/tests/core/test_providers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/app/tests/core/test_providers.py b/backend/app/tests/core/test_providers.py index 647f9751b..6d54e6c12 100644 --- a/backend/app/tests/core/test_providers.py +++ b/backend/app/tests/core/test_providers.py @@ -34,4 +34,4 @@ def test_validate_provider_credentials_missing_fields(): "aws", {"access_key_id": "test-id", "secret_access_key": "test-secret"} ) assert "Missing required fields" in str(exc_info.value) - assert "region" in str(exc_info.value) \ No newline at end of file + assert "region" in str(exc_info.value)