diff --git a/app/api/routes/provider_keys.py b/app/api/routes/provider_keys.py index 5041c55..15fc863 100644 --- a/app/api/routes/provider_keys.py +++ b/app/api/routes/provider_keys.py @@ -89,22 +89,6 @@ async def _create_provider_key_internal( """ Internal logic to create a new provider key for the current user. """ - # Check if provider already exists for user - result = await db.execute( - select(ProviderKeyModel).filter( - ProviderKeyModel.user_id == current_user.id, - ProviderKeyModel.provider_name == provider_key_create.provider_name, - ProviderKeyModel.deleted_at == None, - ) - ) - existing_key = result.scalar_one_or_none() - - if existing_key: - raise HTTPException( - status_code=400, - detail=f"Provider key for {provider_key_create.provider_name} already exists", - ) - db_provider_key = await _process_provider_key_create_data(db, provider_key_create, current_user.id) await db.commit() await db.refresh(db_provider_key) @@ -140,7 +124,7 @@ async def _process_provider_key_update_data( async def _update_provider_key_internal( - provider_name: str, + provider_key_id: int, provider_key_update: ProviderKeyUpdate, db: AsyncSession, current_user: UserModel, @@ -150,7 +134,7 @@ async def _update_provider_key_internal( """ result = await db.execute( select(ProviderKeyModel).filter( - ProviderKeyModel.provider_name == provider_name, + ProviderKeyModel.id == provider_key_id, ProviderKeyModel.user_id == current_user.id, ProviderKeyModel.deleted_at == None, ) @@ -173,12 +157,12 @@ async def _update_provider_key_internal( async def _process_provider_key_delete_data( db: AsyncSession, - provider_name: str, + provider_key_id: int, user_id: int, ) -> ProviderKeyModel: result = await db.execute( select(ProviderKeyModel).filter( - ProviderKeyModel.provider_name == provider_name, + ProviderKeyModel.id == provider_key_id, ProviderKeyModel.user_id == user_id, ProviderKeyModel.deleted_at == None, ) @@ -206,12 +190,12 @@ async def _process_provider_key_delete_data( async def _delete_provider_key_internal( - provider_name: str, db: AsyncSession, current_user: UserModel + provider_key_id: int, db: AsyncSession, current_user: UserModel ) -> ProviderKey: """ Internal logic to delete a provider key for the current user. """ - provider_key_data, scoped_forge_api_keys = await _process_provider_key_delete_data(db, provider_name, current_user.id) + provider_key_data, scoped_forge_api_keys = await _process_provider_key_delete_data(db, provider_key_id, current_user.id) await db.commit() # Invalidate caches after deleting a provider key @@ -241,25 +225,25 @@ async def create_provider_key( return await _create_provider_key_internal(provider_key_create, db, current_user) -@router.put("/{provider_name}", response_model=ProviderKey) +@router.put("/{provider_key_id}", response_model=ProviderKey) async def update_provider_key( - provider_name: str, + provider_key_id: int, provider_key_update: ProviderKeyUpdate, db: AsyncSession = Depends(get_async_db), current_user: UserModel = Depends(get_current_active_user), ) -> Any: return await _update_provider_key_internal( - provider_name, provider_key_update, db, current_user + provider_key_id, provider_key_update, db, current_user ) -@router.delete("/{provider_name}", response_model=ProviderKey) +@router.delete("/{provider_key_id}", response_model=ProviderKey) async def delete_provider_key( - provider_name: str, + provider_key_id: int, db: AsyncSession = Depends(get_async_db), current_user: UserModel = Depends(get_current_active_user), ) -> Any: - return await _delete_provider_key_internal(provider_name, db, current_user) + return await _delete_provider_key_internal(provider_key_id, db, current_user) # --- Clerk API Routes --- @@ -282,25 +266,25 @@ async def create_provider_key_clerk( return await _create_provider_key_internal(provider_key_create, db, current_user) -@router.put("/clerk/{provider_name}", response_model=ProviderKey) +@router.put("/clerk/{provider_key_id}", response_model=ProviderKey) async def update_provider_key_clerk( - provider_name: str, + provider_key_id: int, provider_key_update: ProviderKeyUpdate, db: AsyncSession = Depends(get_async_db), current_user: UserModel = Depends(get_current_active_user_from_clerk), ) -> Any: return await _update_provider_key_internal( - provider_name, provider_key_update, db, current_user + provider_key_id, provider_key_update, db, current_user ) -@router.delete("/clerk/{provider_name}", response_model=ProviderKey) +@router.delete("/clerk/{provider_key_id}", response_model=ProviderKey) async def delete_provider_key_clerk( - provider_name: str, + provider_key_id: int, db: AsyncSession = Depends(get_async_db), current_user: UserModel = Depends(get_current_active_user_from_clerk), ) -> Any: - return await _delete_provider_key_internal(provider_name, db, current_user) + return await _delete_provider_key_internal(provider_key_id, db, current_user) # --- Batch Upsert API Endpoint --- @@ -323,19 +307,19 @@ async def _batch_upsert_provider_keys_internal( ) existing_keys_query = result.scalars().all() # 2. Map them by provider_name for efficient lookup - existing_keys_map: dict[str, ProviderKeyModel] = { - key.provider_name: key for key in existing_keys_query + existing_keys_map: dict[int, ProviderKeyModel] = { + key.id: key for key in existing_keys_query } invalidated_forge_api_keys = set() for item in items: try: - existing_provider_key: ProviderKeyModel | None = existing_keys_map.get(item.provider_name) + existing_provider_key: ProviderKeyModel | None = existing_keys_map.get(item.id) # Handle deletion if api_key is "DELETE" if item.api_key == "DELETE": if existing_provider_key: - _, scoped_forge_api_keys = await _process_provider_key_delete_data(db, item.provider_name, current_user.id) + _, scoped_forge_api_keys = await _process_provider_key_delete_data(db, item.id, current_user.id) invalidated_forge_api_keys.update(scoped_forge_api_keys) processed = True elif existing_provider_key: # Update existing key diff --git a/app/api/schemas/provider_key.py b/app/api/schemas/provider_key.py index 1a44588..a8751d1 100644 --- a/app/api/schemas/provider_key.py +++ b/app/api/schemas/provider_key.py @@ -17,6 +17,15 @@ class ProviderKeyBase(BaseModel): model_mapping: dict[str, str] | None = None config: dict[str, str] | None = None + @field_validator("provider_name") + @classmethod + def strip_provider_name(cls, v): + """Strip whitespace from provider_name.""" + stripped = v.strip() + if len(stripped) < 1: + raise ValueError("provider_name must have at least 1 character after stripping whitespace") + return stripped + class ProviderKeyCreate(ProviderKeyBase): pass @@ -100,8 +109,18 @@ def config(self) -> dict[str, str] | None: class ProviderKeyUpsertItem(BaseModel): + id: int | None = None provider_name: str = Field(..., min_length=1) api_key: str | None = None base_url: str | None = None model_mapping: dict[str, str] | None = None config: dict[str, str] | None = None + + @field_validator("provider_name") + @classmethod + def strip_provider_name(cls, v): + """Strip whitespace from provider_name.""" + stripped = v.strip() + if len(stripped) < 1: + raise ValueError("provider_name must have at least 1 character after stripping whitespace") + return stripped diff --git a/cli_tools/forge-cli.py b/cli_tools/forge-cli.py index 05f1c1b..731a26c 100755 --- a/cli_tools/forge-cli.py +++ b/cli_tools/forge-cli.py @@ -178,7 +178,8 @@ def add_provider_key( response = requests.post(url, headers=headers, json=data) if response.status_code == HTTPStatus.OK: - print(f"✅ Successfully added {provider_name} API key!") + resp_json = response.json() + print(f"✅ Successfully added {resp_json['provider_name']} API key!") return True else: print(f"❌ Error adding provider key: {response.status_code}") @@ -220,13 +221,13 @@ def list_provider_keys(self) -> list[dict[str, Any]]: print(f"❌ Error listing provider keys: {str(e)}") return [] - def update_provider_key(self, provider_name: str, api_key: str | None = None, base_url: str | None = None, model_mapping: str | None = None, config: str | None = None) -> bool: + def update_provider_key(self, provider_key_id: int, api_key: str | None = None, base_url: str | None = None, model_mapping: str | None = None, config: str | None = None) -> bool: """Update a provider key""" if not self.token: print("❌ Not authenticated. Please login first.") return False - url = f"{self.api_url}/provider-keys/{provider_name}" + url = f"{self.api_url}/provider-keys/{provider_key_id}" headers = { "Authorization": f"Bearer {self.token}", "Content-Type": "application/json", @@ -241,7 +242,8 @@ def update_provider_key(self, provider_name: str, api_key: str | None = None, ba response = requests.put(url, headers=headers, json=data) if response.status_code == HTTPStatus.OK: - print(f"✅ Successfully updated {provider_name} API key!") + resp_json = response.json() + print(f"✅ Successfully updated {provider_key_id}:{resp_json['provider_name']} API key!") return True else: print(f"❌ Error updated provider key: {response.status_code}") @@ -252,20 +254,21 @@ def update_provider_key(self, provider_name: str, api_key: str | None = None, ba return False - def delete_provider_key(self, provider_name: str) -> bool: + def delete_provider_key(self, provider_key_id: int) -> bool: """Delete a provider key""" if not self.token: print("❌ Not authenticated. Please login first.") return False - url = f"{self.api_url}/provider-keys/{provider_name}" + url = f"{self.api_url}/provider-keys/{provider_key_id}" headers = {"Authorization": f"Bearer {self.token}"} try: response = requests.delete(url, headers=headers) if response.status_code == HTTPStatus.OK: - print(f"✅ Successfully deleted provider key {provider_name}!") + resp_json = response.json() + print(f"✅ Successfully deleted provider key {provider_key_id}:{resp_json['provider_name']}!") return True else: print(f"❌ Error deleting provider key: {response.status_code}") @@ -660,19 +663,19 @@ def main(): if not forge.token: token = input("Enter JWT token: ") forge.token = token - provider_name = input("Enter provider name to delete: ") - forge.delete_provider_key(provider_name) + provider_key_id = int(input("Enter provider key id to delete: ")) + forge.delete_provider_key(provider_key_id) elif choice == "11": if not forge.token: token = input("Enter JWT token: ") forge.token = token - provider_name = input("Enter provider name to update: ") + provider_key_id = int(input("Enter provider key id to update: ")) api_key = getpass("Enter provider API key: ") base_url = input("Enter provider base URL (optional, press Enter to skip): ") config = input("Enter provider config in json string format (optional, press Enter to skip): ") model_mapping = input("Enter model mapping config in json string format (optional, press Enter to skip): ") - forge.update_provider_key(provider_name, api_key, base_url=base_url, config=config, model_mapping=model_mapping) + forge.update_provider_key(provider_key_id, api_key, base_url=base_url, config=config, model_mapping=model_mapping) elif choice == "12": model = input("Enter model ID: ")