11from collections .abc import Generator
2- from typing import Annotated
2+ from typing import Annotated , Optional
33
44import jwt
5- from fastapi import Depends , HTTPException , status , Request
5+ from fastapi import Depends , HTTPException , status , Request , Header , Security
66from fastapi .responses import JSONResponse
7- from fastapi .security import OAuth2PasswordBearer
7+ from fastapi .security import OAuth2PasswordBearer , APIKeyHeader
88from jwt .exceptions import InvalidTokenError
99from pydantic import ValidationError
1010from sqlmodel import Session , select
1313from app .core .config import settings
1414from app .core .db import engine
1515from app .utils import APIResponse
16- from app .models import TokenPayload , User , UserProjectOrg , ProjectUser , Project , Organization
16+ from app .crud .organization import validate_organization
17+ from app .crud .api_key import get_api_key_by_value
18+ from app .models import TokenPayload , User , UserProjectOrg , UserOrganization , ProjectUser , Project , Organization
1719
18- reusable_oauth2 = OAuth2PasswordBearer (
19- tokenUrl = f"{ settings .API_V1_STR } /login/access-token"
20- )
20+ # reusable_oauth2 = OAuth2PasswordBearer(
21+ # tokenUrl=f"{settings.API_V1_STR}/login/access-token"
22+ # )
2123
2224
2325def get_db () -> Generator [Session , None , None ]:
2426 with Session (engine ) as session :
2527 yield session
2628
27-
29+ api_key_header = APIKeyHeader ( name = "Authorization" , auto_error = False )
2830SessionDep = Annotated [Session , Depends (get_db )]
29- TokenDep = Annotated [str , Depends (reusable_oauth2 )]
30-
31-
32- def get_current_user (session : SessionDep , token : TokenDep ) -> User :
33- try :
34- payload = jwt .decode (
35- token , settings .SECRET_KEY , algorithms = [security .ALGORITHM ]
36- )
37- token_data = TokenPayload (** payload )
38- except (InvalidTokenError , ValidationError ):
39- raise HTTPException (
40- status_code = status .HTTP_403_FORBIDDEN ,
41- detail = "Could not validate credentials" ,
42- )
43- user = session .get (User , token_data .sub )
44- if not user :
45- raise HTTPException (status_code = 404 , detail = "User not found" )
46- if not user .is_active :
47- raise HTTPException (status_code = 400 , detail = "Inactive user" )
48- return user
49-
50-
51- CurrentUser = Annotated [User , Depends (get_current_user )]
31+ # TokenDep = Annotated[str, Depends(reusable_oauth2)]
32+
33+ def get_current_user (
34+ session : SessionDep ,
35+ auth_header : str = Security (api_key_header ),
36+ ) -> UserOrganization :
37+ """Authenticate user via API Key first, fallback to JWT token."""
38+
39+ if auth_header .startswith ("ApiKey " ):
40+ api_key = auth_header .split (" " , 1 )[1 ]
41+ api_key_record = get_api_key_by_value (session , api_key )
42+ if not api_key_record :
43+ raise HTTPException (status_code = 401 , detail = "Invalid API Key" )
44+
45+ user = session .get (User , api_key_record .user_id )
46+ if not user :
47+ raise HTTPException (status_code = 404 , detail = "User linked to API Key not found" )
48+
49+ validate_organization (session , api_key_record .organization_id )
50+
51+ # Return UserOrganization model with organization ID
52+ return UserOrganization (** user .model_dump (), organization_id = api_key_record .organization_id )
53+
54+ if auth_header .startswith ("Bearer " ):
55+ try :
56+ token = auth_header .split (" " , 1 )[1 ]
57+ payload = jwt .decode (
58+ token , settings .SECRET_KEY , algorithms = [security .ALGORITHM ]
59+ )
60+ token_data = TokenPayload (** payload )
61+ except (InvalidTokenError , ValidationError ):
62+ raise HTTPException (
63+ status_code = status .HTTP_403_FORBIDDEN ,
64+ detail = "Could not validate credentials" ,
65+ )
66+ user = session .get (User , token_data .sub )
67+ if not user :
68+ raise HTTPException (status_code = 404 , detail = "User not found" )
69+ if not user .is_active :
70+ raise HTTPException (status_code = 400 , detail = "Inactive user" )
71+
72+ return UserOrganization (** user .model_dump (), organization_id = None )
73+ raise HTTPException (status_code = 401 , detail = "Invalid Authorization header format" )
74+
75+ CurrentUser = Annotated [UserOrganization , Depends (get_current_user )]
5276
5377
5478def get_current_active_superuser (current_user : CurrentUser ) -> User :
@@ -78,6 +102,8 @@ def verify_user_project_organization(
78102 Verify that the authenticated user is part of the project
79103 and that the project belongs to the organization.
80104 """
105+ if current_user .organization_id and current_user .organization_id != organization_id :
106+ raise HTTPException (status_code = 403 , detail = "User does not belong to the specified organization" )
81107
82108 project_organization = db .exec (
83109 select (Project , Organization )
@@ -105,9 +131,11 @@ def verify_user_project_organization(
105131 raise HTTPException (status_code = 403 , detail = "Project does not belong to the organization" )
106132
107133
134+ current_user .organization_id = organization_id
135+
108136 # Superuser bypasses all checks
109137 if current_user .is_superuser :
110- return UserProjectOrg (** current_user .model_dump (), project_id = project_id , organization_id = organization_id )
138+ return UserProjectOrg (** current_user .model_dump (), project_id = project_id )
111139
112140 # Check if the user is part of the project
113141 user_in_project = db .exec (
@@ -121,4 +149,4 @@ def verify_user_project_organization(
121149 if not user_in_project :
122150 raise HTTPException (status_code = 403 , detail = "User is not part of the project" )
123151
124- return UserProjectOrg (** current_user .model_dump (), project_id = project_id , organization_id = organization_id )
152+ return UserProjectOrg (** current_user .model_dump (), project_id = project_id )
0 commit comments