Skip to content
Merged
3 changes: 3 additions & 0 deletions backend/app/crud/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@
get_key_by_org,
update_creds_for_org,
remove_creds_for_org,
get_provider_credential,
remove_provider_credential,
get_full_provider_credential,
)

from .thread_results import upsert_thread_result, get_thread_result
Expand Down
20 changes: 20 additions & 0 deletions backend/app/crud/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,26 @@
return None


def get_full_provider_credential(
*, session: Session, org_id: int, provider: str, project_id: Optional[int] = None
) -> Optional[Dict[str, Any]]:
"""Fetches credentials for a specific provider of an organization."""
validate_provider(provider)

statement = select(Credential).where(
Credential.organization_id == org_id,
Credential.provider == provider,
Credential.is_active == True,
Credential.project_id == project_id if project_id is not None else True,
)
Comment thread
nishika26 marked this conversation as resolved.
Comment on lines +116 to +113
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix the WHERE clause logic for project_id filtering.

The current implementation has two critical issues:

  1. The boolean comparison Credential.is_active == True should be simplified to Credential.is_active
  2. The conditional logic Credential.project_id == project_id if project_id is not None else True is incorrect. When project_id is None, this becomes Credential.project_id == True, which doesn't make logical sense and will cause unexpected filtering behavior.

Apply this diff to fix the WHERE clause:

    statement = select(Credential).where(
        Credential.organization_id == org_id,
        Credential.provider == provider,
-        Credential.is_active == True,
-        Credential.project_id == project_id if project_id is not None else True,
+        Credential.is_active,
+        Credential.project_id == project_id if project_id is not None else Credential.project_id.is_not(None),
    )

Note: This assumes that when project_id is None, you want to filter for records that have a non-null project_id. If you want to match records regardless of their project_id value when the parameter is None, you should restructure the query to conditionally add the filter instead.

🧰 Tools
🪛 Ruff (0.11.9)

119-119: Avoid equality comparisons to True; use if Credential.is_active: for truth checks

Replace with Credential.is_active

(E712)

🤖 Prompt for AI Agents
In backend/app/crud/credentials.py around lines 116 to 121, simplify the
condition `Credential.is_active == True` to just `Credential.is_active`. For the
`project_id` filter, do not use a conditional expression inside the `where`
clause; instead, build the query so that if `project_id` is not None, add a
filter `Credential.project_id == project_id`, otherwise either omit this filter
or explicitly filter for non-null `project_id` depending on the intended logic.
This avoids incorrect comparisons and ensures proper filtering behavior.

creds = session.exec(statement).first()

if creds and creds.credential:
# Decrypt entire credentials object
return creds
return None

Check warning on line 127 in backend/app/crud/credentials.py

View check run for this annotation

Codecov / codecov/patch

backend/app/crud/credentials.py#L127

Added line #L127 was not covered by tests

Comment thread
nishika26 marked this conversation as resolved.
Outdated

def get_providers(
*, session: Session, org_id: int, project_id: Optional[int] = None
) -> List[str]:
Expand Down
90 changes: 21 additions & 69 deletions backend/app/tests/api/routes/test_api_key.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,19 @@
import uuid
import pytest
from fastapi.testclient import TestClient
from sqlmodel import Session

from app.main import app
from app.models import APIKey, User, Organization, Project
from app.models import APIKey
from app.core.config import settings
from app.crud.api_key import create_api_key
from app.tests.utils.utils import random_email
from app.core.security import get_password_hash
from app.tests.utils.utils import get_non_existent_id
from app.tests.utils.user import create_random_user
from app.tests.utils.test_data import create_test_api_key, create_test_project

client = TestClient(app)


def create_test_user(db: Session) -> User:
user = User(
email=random_email(),
hashed_password=get_password_hash("password123"),
is_superuser=True,
)
db.add(user)
db.commit()
db.refresh(user)
return user


def create_test_organization(db: Session) -> Organization:
org = Organization(
name=f"Test Organization {uuid.uuid4()}", description="Test Organization"
)
db.add(org)
db.commit()
db.refresh(org)
return org


def create_test_project(db: Session, organization_id: int) -> Project:
project = Project(name="Test Project", organization_id=organization_id)
db.add(project)
db.commit()
db.refresh(project)
return project


def test_create_api_key(db: Session, superuser_token_headers: dict[str, str]):
user = create_test_user(db)
org = create_test_organization(db)
project = create_test_project(db, organization_id=org.id)
user = create_random_user(db)
project = create_test_project(db)

response = client.post(
f"{settings.API_V1_STR}/apikeys",
Expand All @@ -57,14 +25,13 @@ def test_create_api_key(db: Session, superuser_token_headers: dict[str, str]):
assert data["success"] is True
assert "id" in data["data"]
assert "key" in data["data"]
assert data["data"]["organization_id"] == org.id
assert data["data"]["organization_id"] == project.organization_id
assert data["data"]["user_id"] == user.id


def test_create_duplicate_api_key(db: Session, superuser_token_headers: dict[str, str]):
user = create_test_user(db)
org = create_test_organization(db)
project = create_test_project(db, organization_id=org.id)
user = create_random_user(db)
project = create_test_project(db)

client.post(
f"{settings.API_V1_STR}/apikeys",
Expand All @@ -81,16 +48,11 @@ def test_create_duplicate_api_key(db: Session, superuser_token_headers: dict[str


def test_list_api_keys(db: Session, superuser_token_headers: dict[str, str]):
user = create_test_user(db)
org = create_test_organization(db)
project = create_test_project(db, organization_id=org.id)
api_key = create_api_key(
db, organization_id=org.id, user_id=user.id, project_id=project.id
)
api_key = create_test_api_key(db)

response = client.get(
f"{settings.API_V1_STR}/apikeys",
params={"project_id": project.id},
params={"project_id": api_key.project_id},
headers=superuser_token_headers,
)
assert response.status_code == 200
Expand All @@ -100,17 +62,12 @@ def test_list_api_keys(db: Session, superuser_token_headers: dict[str, str]):
assert len(data["data"]) > 0

first_key = data["data"][0]
assert first_key["organization_id"] == org.id
assert first_key["user_id"] == user.id
assert first_key["organization_id"] == api_key.organization_id
assert first_key["user_id"] == api_key.user_id


def test_get_api_key(db: Session, superuser_token_headers: dict[str, str]):
user = create_test_user(db)
org = create_test_organization(db)
project = create_test_project(db, organization_id=org.id)
api_key = create_api_key(
db, organization_id=org.id, user_id=user.id, project_id=project.id
)
api_key = create_test_api_key(db)

response = client.get(
f"{settings.API_V1_STR}/apikeys/{api_key.id}",
Expand All @@ -121,25 +78,21 @@ def test_get_api_key(db: Session, superuser_token_headers: dict[str, str]):
assert data["success"] is True
assert data["data"]["id"] == api_key.id
assert data["data"]["organization_id"] == api_key.organization_id
assert data["data"]["user_id"] == user.id
assert data["data"]["user_id"] == api_key.user_id


def test_get_nonexistent_api_key(db: Session, superuser_token_headers: dict[str, str]):
api_key_id = get_non_existent_id(db, APIKey)
response = client.get(
f"{settings.API_V1_STR}/apikeys/999999",
f"{settings.API_V1_STR}/apikeys/{api_key_id}",
headers=superuser_token_headers,
)
assert response.status_code == 404
assert "API Key does not exist" in response.json()["error"]


def test_revoke_api_key(db: Session, superuser_token_headers: dict[str, str]):
user = create_test_user(db)
org = create_test_organization(db)
project = create_test_project(db, organization_id=org.id)
api_key = create_api_key(
db, organization_id=org.id, user_id=user.id, project_id=project.id
)
api_key = create_test_api_key(db)

response = client.delete(
f"{settings.API_V1_STR}/apikeys/{api_key.id}",
Expand All @@ -154,11 +107,10 @@ def test_revoke_api_key(db: Session, superuser_token_headers: dict[str, str]):
def test_revoke_nonexistent_api_key(
db: Session, superuser_token_headers: dict[str, str]
):
user = create_test_user(db)
org = create_test_organization(db)
api_key_id = get_non_existent_id(db, APIKey)

response = client.delete(
f"{settings.API_V1_STR}/apikeys/999999",
f"{settings.API_V1_STR}/apikeys/{api_key_id}",
headers=superuser_token_headers,
)
assert response.status_code == 404
Expand Down
Loading