Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions app/control/account/backends/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,10 @@ def _sync() -> AccountPage:
if query.status:
where_parts.append("status = ?")
params.append(query.status.value)
if query.exclude_statuses:
placeholders = ", ".join("?" for _ in query.exclude_statuses)
where_parts.append(f"status NOT IN ({placeholders})")
params.extend(s.value for s in query.exclude_statuses)

where_sql = ("WHERE " + " AND ".join(where_parts)) if where_parts else ""
order_dir = "DESC" if query.sort_desc else "ASC"
Expand Down Expand Up @@ -512,5 +516,77 @@ def _sync() -> AccountMutationResult:
async def close(self) -> None:
"""No-op for SQLite — connections are opened and closed per operation."""

async def get_stats(self) -> dict:
"""Return aggregated stats via SQL — no Python-side row iteration."""
import json as _json
import time as _time

def _sync() -> dict:
t0 = _time.monotonic()
with closing(self._connect()) as conn:
# 1. Status counts
status_rows = conn.execute(
f"SELECT status, COUNT(*) FROM {_TBL} WHERE deleted_at IS NULL GROUP BY status"
).fetchall()
status_counts: dict[str, int] = {}
for row in status_rows:
status_counts[row[0]] = row[1]
total = sum(status_counts.values())

# 2. Pool counts
pool_rows = conn.execute(
f"SELECT pool, COUNT(*) FROM {_TBL} WHERE deleted_at IS NULL GROUP BY pool"
).fetchall()
pool_counts: dict[str, int] = {}
for row in pool_rows:
pool_counts[row[0]] = row[1]

# 3. Pool × status
ps_rows = conn.execute(
f"SELECT pool, status, COUNT(*) FROM {_TBL} WHERE deleted_at IS NULL GROUP BY pool, status"
).fetchall()
pool_status: dict[str, dict[str, int]] = {}
for row in ps_rows:
pool_status.setdefault(row[0], {})[row[1]] = row[2]

# 4. Usage sums
usage_row = conn.execute(
f"SELECT COALESCE(SUM(usage_use_count),0), COALESCE(SUM(usage_fail_count),0) "
f"FROM {_TBL} WHERE deleted_at IS NULL"
).fetchone()
success = int(usage_row[0])
fail = int(usage_row[1])

# 5. Quota sums — extract "remaining" from JSON without Python parse
quota_sums: dict[str, int] = {"auto": 0, "fast": 0, "expert": 0, "heavy": 0}
for mode in ("auto", "fast", "expert", "heavy"):
row = conn.execute(
f"SELECT COALESCE(SUM(CAST("
f" json_extract(quota_{mode}, '$.remaining') AS INTEGER"
f")), 0) FROM {_TBL} WHERE deleted_at IS NULL"
).fetchone()
quota_sums[mode] = int(row[0])

# 6. NSFW counts — tags is stored as JSON array, check for "nsfw"
nsfw_row = conn.execute(
f"SELECT COUNT(*) FROM {_TBL} WHERE deleted_at IS NULL AND tags LIKE '%\"nsfw\"%'"
).fetchone()
nsfw_enabled = int(nsfw_row[0])
nsfw_disabled = total - nsfw_enabled

elapsed = _time.monotonic() - t0
return {
"total": total,
"status_counts": status_counts,
"pool_counts": pool_counts,
"pool_status": pool_status,
"usage": {"success": success, "fail": fail, "calls": success + fail},
"quota_sums": quota_sums,
"nsfw": {"enabled": nsfw_enabled, "disabled": nsfw_disabled},
"elapsed_ms": round(elapsed * 1000, 1),
}

return await asyncio.to_thread(_sync)


__all__ = ["LocalAccountRepository"]
52 changes: 52 additions & 0 deletions app/control/account/backends/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,8 @@ async def list_accounts(
continue
if query.status and r.status != query.status:
continue
if query.exclude_statuses and r.status in query.exclude_statuses:
continue
all_records.append(r)

# Sort.
Expand Down Expand Up @@ -407,6 +409,56 @@ async def replace_pool(
revision=upserted_result.revision,
)

async def get_stats(self) -> dict:
"""Return aggregated stats — scans all record keys (Redis has no aggregation)."""
import time as _time

t0 = _time.monotonic()
status_counts: dict[str, int] = {}
pool_counts: dict[str, int] = {}
pool_status: dict[str, dict[str, int]] = {}
success = fail = 0
quota_sums: dict[str, int] = {"auto": 0, "fast": 0, "expert": 0, "heavy": 0}
nsfw_enabled = 0
total = 0

async for key in self._r.scan_iter("accounts:record:*"):
token = (key.decode() if isinstance(key, bytes) else key).split(":", 2)[-1]
h = await self._r.hgetall(key)
if not h:
continue
r = self._from_hash(token, h)
if r.is_deleted():
continue
total += 1
st = r.status or "active"
status_counts[st] = status_counts.get(st, 0) + 1
pool = r.pool or "basic"
pool_counts[pool] = pool_counts.get(pool, 0) + 1
ps = pool_status.setdefault(pool, {})
ps[st] = ps.get(st, 0) + 1
success += r.usage_use_count or 0
fail += r.usage_fail_count or 0
if "nsfw" in (r.tags or []):
nsfw_enabled += 1
if isinstance(r.quota, dict):
for mode in ("auto", "fast", "expert", "heavy"):
v = r.quota.get(mode)
if isinstance(v, dict):
quota_sums[mode] += int(v.get("remaining", 0) or 0)

elapsed = _time.monotonic() - t0
return {
"total": total,
"status_counts": status_counts,
"pool_counts": pool_counts,
"pool_status": pool_status,
"usage": {"success": success, "fail": fail, "calls": success + fail},
"quota_sums": quota_sums,
"nsfw": {"enabled": nsfw_enabled, "disabled": total - nsfw_enabled},
"elapsed_ms": round(elapsed * 1000, 1),
}

async def close(self) -> None:
"""Close the underlying Redis connection pool."""
await self._r.aclose()
Expand Down
87 changes: 87 additions & 0 deletions app/control/account/backends/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -785,6 +785,10 @@ async def list_accounts(
stmt = stmt.where(accounts_table.c.pool == query.pool)
if query.status:
stmt = stmt.where(accounts_table.c.status == query.status.value)
if query.exclude_statuses:
stmt = stmt.where(accounts_table.c.status.notin_(
[s.value for s in query.exclude_statuses]
))

total_row = (await conn.execute(
sa.select(sa.func.count()).select_from(stmt.subquery())
Expand Down Expand Up @@ -835,6 +839,89 @@ async def replace_pool(
revision=upserted_result.revision,
)

async def get_stats(self) -> dict:
"""Return aggregated stats via SQL — no Python-side row iteration."""
import time as _time
t = accounts_table

t0 = _time.monotonic()
async with self._session() as session:
# 1. Status counts
stmt = (
sa.select(t.c.status, sa.func.count())
.where(t.c.deleted_at.is_(None))
.group_by(t.c.status)
)
rows = (await session.execute(stmt)).all()
status_counts = {r[0]: r[1] for r in rows}
total = sum(status_counts.values())

# 2. Pool counts
stmt = (
sa.select(t.c.pool, sa.func.count())
.where(t.c.deleted_at.is_(None))
.group_by(t.c.pool)
)
rows = (await session.execute(stmt)).all()
pool_counts = {r[0]: r[1] for r in rows}

# 3. Pool × status
stmt = (
sa.select(t.c.pool, t.c.status, sa.func.count())
.where(t.c.deleted_at.is_(None))
.group_by(t.c.pool, t.c.status)
)
rows = (await session.execute(stmt)).all()
pool_status: dict[str, dict[str, int]] = {}
for r in rows:
pool_status.setdefault(r[0], {})[r[1]] = r[2]

# 4. Usage sums
stmt = sa.select(
sa.func.coalesce(sa.func.sum(t.c.usage_use_count), 0),
sa.func.coalesce(sa.func.sum(t.c.usage_fail_count), 0),
).where(t.c.deleted_at.is_(None))
row = (await session.execute(stmt)).one()
success, fail = int(row[0]), int(row[1])

# 5. Quota sums (dialect-aware JSON extraction)
quota_sums: dict[str, int] = {}
for mode in ("auto", "fast", "expert", "heavy"):
col = getattr(t.c, f"quota_{mode}")
if self._dialect == "mysql":
remaining_expr = sa.cast(
sa.func.json_extract(col, "$.remaining"), sa.Integer
)
else: # postgresql
remaining_expr = sa.cast(
sa.func.json_extract_path_text(col, "remaining"), sa.Integer
)
stmt = sa.select(
sa.func.coalesce(sa.func.sum(remaining_expr), 0)
).where(t.c.deleted_at.is_(None))
r = (await session.execute(stmt)).scalar()
quota_sums[mode] = int(r)

# 6. NSFW
stmt = sa.select(sa.func.count()).where(
t.c.deleted_at.is_(None),
t.c.tags.like('%"nsfw"%'),
)
nsfw_enabled = int((await session.execute(stmt)).scalar())
nsfw_disabled = total - nsfw_enabled

elapsed = _time.monotonic() - t0
return {
"total": total,
"status_counts": status_counts,
"pool_counts": pool_counts,
"pool_status": pool_status,
"usage": {"success": success, "fail": fail, "calls": success + fail},
"quota_sums": quota_sums,
"nsfw": {"enabled": nsfw_enabled, "disabled": nsfw_disabled},
"elapsed_ms": round(elapsed * 1000, 1),
}

async def close(self) -> None:
"""Dispose the SQLAlchemy connection pool."""
if self._dispose_engine:
Expand Down
1 change: 1 addition & 0 deletions app/control/account/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class ListAccountsQuery(BaseModel):
page_size: int = Field(default=50, ge=1, le=2000)
pool: str | None = None
status: AccountStatus | None = None
exclude_statuses: list[AccountStatus] = Field(default_factory=list)
tags: list[str] = Field(default_factory=list)
include_deleted: bool = False
sort_by: str = "updated_at" # field name
Expand Down
4 changes: 4 additions & 0 deletions app/control/account/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ async def replace_pool(
"""Atomically replace all accounts in a pool."""
...

async def get_stats(self) -> dict:
"""Return aggregated stats (status/pool/usage/quota counts)."""
...

async def close(self) -> None:
"""Release database connections / file handles."""
...
Expand Down
35 changes: 29 additions & 6 deletions app/products/openai/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,31 @@


async def _available_pools(request: Request) -> frozenset[str]:
repo = getattr(request.app.state, "repository", None)
if repo is None:
return frozenset()

snapshot = await repo.runtime_snapshot()
pools = {record.pool for record in snapshot.items if is_manageable(record)}
"""Return the set of pool names that have at least one manageable account.

Uses the in-memory AccountRuntimeTable (O(n) array scan, no DB hit)
instead of repo.runtime_snapshot() which would deserialise every row.
"""
from app.dataplane.account import _directory
from app.dataplane.shared.enums import POOL_ID_TO_STR, StatusId

if _directory is None or _directory._table is None:
# Fallback: no runtime table yet — use repo (startup path)
repo = getattr(request.app.state, "repository", None)
if repo is None:
return frozenset()
snapshot = await repo.runtime_snapshot()
pools = {record.pool for record in snapshot.items if is_manageable(record)}
return frozenset(pools)

table = _directory._table
manageable_statuses = {int(StatusId.ACTIVE), int(StatusId.COOLING)}
pools: set[str] = set()
for i in range(len(table.pool_by_idx)):
if table.status_by_idx[i] in manageable_statuses:
pool_name = POOL_ID_TO_STR.get(table.pool_by_idx[i])
if pool_name:
pools.add(pool_name)
return frozenset(pools)


Expand All @@ -64,7 +83,9 @@ def _model_available_for_pools(spec: ModelSpec, pools: frozenset[str]) -> bool:
@router.get("/models", tags=[_TAG_MODELS], dependencies=[Depends(verify_api_key)])
async def list_models(request: Request):
import time
import time as _time

t0 = _time.monotonic()
pools = await _available_pools(request)
models = [
{
Expand All @@ -77,6 +98,8 @@ async def list_models(request: Request):
for m in model_registry.list_enabled()
if _model_available_for_pools(m, pools)
]
elapsed = _time.monotonic() - t0
logger.info("openai list_models: pools={} models={} elapsed_ms={:.1f}", pools, len(models), elapsed * 1000)
return JSONResponse({"object": "list", "data": models})


Expand Down
Loading