Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 22 additions & 38 deletions app/api/routes/provider_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,22 +89,6 @@ async def _create_provider_key_internal(
"""
Internal logic to create a new provider key for the current user.
"""
# Check if provider already exists for user
result = await db.execute(
select(ProviderKeyModel).filter(
ProviderKeyModel.user_id == current_user.id,
ProviderKeyModel.provider_name == provider_key_create.provider_name,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't we need this check? I think we do not allow providers with same name right?

ProviderKeyModel.deleted_at == None,
)
)
existing_key = result.scalar_one_or_none()

if existing_key:
raise HTTPException(
status_code=400,
detail=f"Provider key for {provider_key_create.provider_name} already exists",
)

db_provider_key = await _process_provider_key_create_data(db, provider_key_create, current_user.id)
await db.commit()
await db.refresh(db_provider_key)
Expand Down Expand Up @@ -140,7 +124,7 @@ async def _process_provider_key_update_data(


async def _update_provider_key_internal(
provider_name: str,
provider_key_id: int,
provider_key_update: ProviderKeyUpdate,
db: AsyncSession,
current_user: UserModel,
Expand All @@ -150,7 +134,7 @@ async def _update_provider_key_internal(
"""
result = await db.execute(
select(ProviderKeyModel).filter(
ProviderKeyModel.provider_name == provider_name,
ProviderKeyModel.id == provider_key_id,
ProviderKeyModel.user_id == current_user.id,
ProviderKeyModel.deleted_at == None,
)
Expand All @@ -173,12 +157,12 @@ async def _update_provider_key_internal(

async def _process_provider_key_delete_data(
db: AsyncSession,
provider_name: str,
provider_key_id: int,
user_id: int,
) -> ProviderKeyModel:
result = await db.execute(
select(ProviderKeyModel).filter(
ProviderKeyModel.provider_name == provider_name,
ProviderKeyModel.id == provider_key_id,
ProviderKeyModel.user_id == user_id,
ProviderKeyModel.deleted_at == None,
)
Expand Down Expand Up @@ -206,12 +190,12 @@ async def _process_provider_key_delete_data(


async def _delete_provider_key_internal(
provider_name: str, db: AsyncSession, current_user: UserModel
provider_key_id: int, db: AsyncSession, current_user: UserModel
) -> ProviderKey:
"""
Internal logic to delete a provider key for the current user.
"""
provider_key_data, scoped_forge_api_keys = await _process_provider_key_delete_data(db, provider_name, current_user.id)
provider_key_data, scoped_forge_api_keys = await _process_provider_key_delete_data(db, provider_key_id, current_user.id)
await db.commit()

# Invalidate caches after deleting a provider key
Expand Down Expand Up @@ -241,25 +225,25 @@ async def create_provider_key(
return await _create_provider_key_internal(provider_key_create, db, current_user)


@router.put("/{provider_name}", response_model=ProviderKey)
@router.put("/{provider_key_id}", response_model=ProviderKey)
async def update_provider_key(
provider_name: str,
provider_key_id: int,
provider_key_update: ProviderKeyUpdate,
db: AsyncSession = Depends(get_async_db),
current_user: UserModel = Depends(get_current_active_user),
) -> Any:
return await _update_provider_key_internal(
provider_name, provider_key_update, db, current_user
provider_key_id, provider_key_update, db, current_user
)


@router.delete("/{provider_name}", response_model=ProviderKey)
@router.delete("/{provider_key_id}", response_model=ProviderKey)
async def delete_provider_key(
provider_name: str,
provider_key_id: int,
db: AsyncSession = Depends(get_async_db),
current_user: UserModel = Depends(get_current_active_user),
) -> Any:
return await _delete_provider_key_internal(provider_name, db, current_user)
return await _delete_provider_key_internal(provider_key_id, db, current_user)


# --- Clerk API Routes ---
Expand All @@ -282,25 +266,25 @@ async def create_provider_key_clerk(
return await _create_provider_key_internal(provider_key_create, db, current_user)


@router.put("/clerk/{provider_name}", response_model=ProviderKey)
@router.put("/clerk/{provider_key_id}", response_model=ProviderKey)
async def update_provider_key_clerk(
provider_name: str,
provider_key_id: int,
provider_key_update: ProviderKeyUpdate,
db: AsyncSession = Depends(get_async_db),
current_user: UserModel = Depends(get_current_active_user_from_clerk),
) -> Any:
return await _update_provider_key_internal(
provider_name, provider_key_update, db, current_user
provider_key_id, provider_key_update, db, current_user
)


@router.delete("/clerk/{provider_name}", response_model=ProviderKey)
@router.delete("/clerk/{provider_key_id}", response_model=ProviderKey)
async def delete_provider_key_clerk(
provider_name: str,
provider_key_id: int,
db: AsyncSession = Depends(get_async_db),
current_user: UserModel = Depends(get_current_active_user_from_clerk),
) -> Any:
return await _delete_provider_key_internal(provider_name, db, current_user)
return await _delete_provider_key_internal(provider_key_id, db, current_user)


# --- Batch Upsert API Endpoint ---
Expand All @@ -323,19 +307,19 @@ async def _batch_upsert_provider_keys_internal(
)
existing_keys_query = result.scalars().all()
# 2. Map them by provider_name for efficient lookup
existing_keys_map: dict[str, ProviderKeyModel] = {
key.provider_name: key for key in existing_keys_query
existing_keys_map: dict[int, ProviderKeyModel] = {
key.id: key for key in existing_keys_query
}
invalidated_forge_api_keys = set()

for item in items:
try:
existing_provider_key: ProviderKeyModel | None = existing_keys_map.get(item.provider_name)
existing_provider_key: ProviderKeyModel | None = existing_keys_map.get(item.id)

# Handle deletion if api_key is "DELETE"
if item.api_key == "DELETE":
if existing_provider_key:
_, scoped_forge_api_keys = await _process_provider_key_delete_data(db, item.provider_name, current_user.id)
_, scoped_forge_api_keys = await _process_provider_key_delete_data(db, item.id, current_user.id)
invalidated_forge_api_keys.update(scoped_forge_api_keys)
processed = True
elif existing_provider_key: # Update existing key
Expand Down
19 changes: 19 additions & 0 deletions app/api/schemas/provider_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,15 @@ class ProviderKeyBase(BaseModel):
model_mapping: dict[str, str] | None = None
config: dict[str, str] | None = None

@field_validator("provider_name")
@classmethod
def strip_provider_name(cls, v):
"""Strip whitespace from provider_name."""
stripped = v.strip()
if len(stripped) < 1:
raise ValueError("provider_name must have at least 1 character after stripping whitespace")
return stripped


class ProviderKeyCreate(ProviderKeyBase):
pass
Expand Down Expand Up @@ -100,8 +109,18 @@ def config(self) -> dict[str, str] | None:


class ProviderKeyUpsertItem(BaseModel):
id: int | None = None
provider_name: str = Field(..., min_length=1)
api_key: str | None = None
base_url: str | None = None
model_mapping: dict[str, str] | None = None
config: dict[str, str] | None = None

@field_validator("provider_name")
@classmethod
def strip_provider_name(cls, v):
"""Strip whitespace from provider_name."""
stripped = v.strip()
if len(stripped) < 1:
raise ValueError("provider_name must have at least 1 character after stripping whitespace")
return stripped
25 changes: 14 additions & 11 deletions cli_tools/forge-cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,8 @@ def add_provider_key(
response = requests.post(url, headers=headers, json=data)

if response.status_code == HTTPStatus.OK:
print(f"✅ Successfully added {provider_name} API key!")
resp_json = response.json()
print(f"✅ Successfully added {resp_json['provider_name']} API key!")
return True
else:
print(f"❌ Error adding provider key: {response.status_code}")
Expand Down Expand Up @@ -220,13 +221,13 @@ def list_provider_keys(self) -> list[dict[str, Any]]:
print(f"❌ Error listing provider keys: {str(e)}")
return []

def update_provider_key(self, provider_name: str, api_key: str | None = None, base_url: str | None = None, model_mapping: str | None = None, config: str | None = None) -> bool:
def update_provider_key(self, provider_key_id: int, api_key: str | None = None, base_url: str | None = None, model_mapping: str | None = None, config: str | None = None) -> bool:
"""Update a provider key"""
if not self.token:
print("❌ Not authenticated. Please login first.")
return False

url = f"{self.api_url}/provider-keys/{provider_name}"
url = f"{self.api_url}/provider-keys/{provider_key_id}"
headers = {
"Authorization": f"Bearer {self.token}",
"Content-Type": "application/json",
Expand All @@ -241,7 +242,8 @@ def update_provider_key(self, provider_name: str, api_key: str | None = None, ba
response = requests.put(url, headers=headers, json=data)

if response.status_code == HTTPStatus.OK:
print(f"✅ Successfully updated {provider_name} API key!")
resp_json = response.json()
print(f"✅ Successfully updated {provider_key_id}:{resp_json['provider_name']} API key!")
return True
else:
print(f"❌ Error updated provider key: {response.status_code}")
Expand All @@ -252,20 +254,21 @@ def update_provider_key(self, provider_name: str, api_key: str | None = None, ba
return False


def delete_provider_key(self, provider_name: str) -> bool:
def delete_provider_key(self, provider_key_id: int) -> bool:
"""Delete a provider key"""
if not self.token:
print("❌ Not authenticated. Please login first.")
return False

url = f"{self.api_url}/provider-keys/{provider_name}"
url = f"{self.api_url}/provider-keys/{provider_key_id}"
headers = {"Authorization": f"Bearer {self.token}"}

try:
response = requests.delete(url, headers=headers)

if response.status_code == HTTPStatus.OK:
print(f"✅ Successfully deleted provider key {provider_name}!")
resp_json = response.json()
print(f"✅ Successfully deleted provider key {provider_key_id}:{resp_json['provider_name']}!")
return True
else:
print(f"❌ Error deleting provider key: {response.status_code}")
Expand Down Expand Up @@ -660,19 +663,19 @@ def main():
if not forge.token:
token = input("Enter JWT token: ")
forge.token = token
provider_name = input("Enter provider name to delete: ")
forge.delete_provider_key(provider_name)
provider_key_id = int(input("Enter provider key id to delete: "))
forge.delete_provider_key(provider_key_id)

elif choice == "11":
if not forge.token:
token = input("Enter JWT token: ")
forge.token = token
provider_name = input("Enter provider name to update: ")
provider_key_id = int(input("Enter provider key id to update: "))
api_key = getpass("Enter provider API key: ")
base_url = input("Enter provider base URL (optional, press Enter to skip): ")
config = input("Enter provider config in json string format (optional, press Enter to skip): ")
model_mapping = input("Enter model mapping config in json string format (optional, press Enter to skip): ")
forge.update_provider_key(provider_name, api_key, base_url=base_url, config=config, model_mapping=model_mapping)
forge.update_provider_key(provider_key_id, api_key, base_url=base_url, config=config, model_mapping=model_mapping)

elif choice == "12":
model = input("Enter model ID: ")
Expand Down
Loading