Skip to content

Commit c73704e

Browse files
authored
llm call: polling llm db endpoint (#726)
1 parent e84122b commit c73704e

File tree

19 files changed

+504
-56
lines changed

19 files changed

+504
-56
lines changed
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
"""add project id to job table
2+
3+
Revision ID: 051
4+
Revises: 050
5+
Create Date: 2026-04-07 14:23:00.938901
6+
7+
"""
8+
from alembic import op
9+
import sqlalchemy as sa
10+
import sqlmodel.sql.sqltypes
11+
from sqlalchemy.dialects import postgresql
12+
13+
# revision identifiers, used by Alembic.
14+
revision = "051"
15+
down_revision = "050"
16+
branch_labels = None
17+
depends_on = None
18+
19+
chain_status_enum = postgresql.ENUM(
20+
"PENDING",
21+
"RUNNING",
22+
"FAILED",
23+
"COMPLETED",
24+
name="chainstatus",
25+
create_type=False,
26+
)
27+
28+
29+
def upgrade():
30+
chain_status_enum.create(op.get_bind())
31+
op.add_column(
32+
"job",
33+
sa.Column(
34+
"project_id",
35+
sa.Integer(),
36+
nullable=True,
37+
comment="Project ID of the job's project",
38+
),
39+
)
40+
op.alter_column(
41+
"llm_call",
42+
"chain_id",
43+
existing_type=sa.UUID(),
44+
comment="Reference to the parent chain (NULL for standalone llm_call requests)",
45+
existing_comment="Reference to the parent chain (NULL for standalone /llm/call requests)",
46+
existing_nullable=True,
47+
)
48+
op.alter_column(
49+
"llm_call",
50+
"input_type",
51+
existing_type=sa.VARCHAR(),
52+
comment="Input type: text, audio, image, pdf, multimodal",
53+
existing_comment="Input type: text, audio, image",
54+
existing_nullable=False,
55+
)
56+
op.execute("ALTER TABLE llm_chain ALTER COLUMN status DROP DEFAULT")
57+
op.alter_column(
58+
"llm_chain",
59+
"status",
60+
existing_type=sa.VARCHAR(),
61+
type_=chain_status_enum,
62+
existing_comment="Chain execution status (pending, running, failed, completed)",
63+
existing_nullable=False,
64+
postgresql_using="UPPER(status)::chainstatus",
65+
)
66+
op.execute(
67+
"ALTER TABLE llm_chain ALTER COLUMN status SET DEFAULT 'PENDING'::chainstatus"
68+
)
69+
op.alter_column(
70+
"llm_chain",
71+
"error",
72+
existing_type=sa.TEXT(),
73+
type_=sqlmodel.sql.sqltypes.AutoString(),
74+
existing_comment="Error message if the chain execution failed",
75+
existing_nullable=True,
76+
)
77+
78+
79+
def downgrade():
80+
op.alter_column(
81+
"llm_chain",
82+
"error",
83+
existing_type=sqlmodel.sql.sqltypes.AutoString(),
84+
type_=sa.TEXT(),
85+
existing_comment="Error message if the chain execution failed",
86+
existing_nullable=True,
87+
)
88+
op.execute("ALTER TABLE llm_chain ALTER COLUMN status DROP DEFAULT")
89+
op.alter_column(
90+
"llm_chain",
91+
"status",
92+
existing_type=sa.Enum(
93+
"PENDING", "RUNNING", "FAILED", "COMPLETED", name="chainstatus"
94+
),
95+
type_=sa.VARCHAR(),
96+
existing_comment="Chain execution status (pending, running, failed, completed)",
97+
existing_nullable=False,
98+
)
99+
op.execute("ALTER TABLE llm_chain ALTER COLUMN status SET DEFAULT 'pending'")
100+
op.execute("DROP TYPE IF EXISTS chainstatus")
101+
op.alter_column(
102+
"llm_call",
103+
"input_type",
104+
existing_type=sa.VARCHAR(),
105+
comment="Input type: text, audio, image",
106+
existing_comment="Input type: text, audio, image, pdf, multimodal",
107+
existing_nullable=False,
108+
)
109+
op.alter_column(
110+
"llm_call",
111+
"chain_id",
112+
existing_type=sa.UUID(),
113+
comment="Reference to the parent chain (NULL for standalone /llm/call requests)",
114+
existing_comment="Reference to the parent chain (NULL for standalone llm_call requests)",
115+
existing_nullable=True,
116+
)
117+
op.drop_column("job", "project_id")
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
Retrieve the status and results of an LLM call job by job ID.
2+
3+
This endpoint allows you to poll for the status and results of an asynchronous LLM call job that was previously initiated via the POST `/llm/call` endpoint.
4+
5+
6+
### Notes
7+
8+
- This endpoint returns both the job status AND the actual LLM response when complete
9+
- LLM responses are also delivered asynchronously via the callback URL (if provided)
10+
- Jobs can be queried at any time after creation

backend/app/api/routes/llm.py

Lines changed: 98 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,20 @@
11
import logging
2+
from uuid import UUID
23

3-
from fastapi import APIRouter, Depends
4+
from fastapi import APIRouter, Depends, HTTPException
45

56
from app.api.deps import AuthContextDep, SessionDep
67
from app.api.permissions import Permission, require_permission
7-
from app.models import LLMCallRequest, LLMCallResponse, Message
8+
from app.crud.jobs import JobCrud
9+
from app.crud.llm import get_llm_calls_by_job_id
10+
from app.models import (
11+
LLMCallRequest,
12+
LLMCallResponse,
13+
LLMJobImmediatePublic,
14+
LLMJobPublic,
15+
JobStatus,
16+
)
17+
from app.models.llm.response import LLMResponse, Usage
818
from app.services.llm.jobs import start_job
919
from app.utils import APIResponse, validate_callback_url, load_description
1020

@@ -34,7 +44,7 @@ def llm_callback_notification(body: APIResponse[LLMCallResponse]):
3444
@router.post(
3545
"/llm/call",
3646
description=load_description("llm/llm_call.md"),
37-
response_model=APIResponse[Message],
47+
response_model=APIResponse[LLMJobImmediatePublic],
3848
callbacks=llm_callback_router.routes,
3949
dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))],
4050
)
@@ -43,22 +53,102 @@ def llm_call(
4353
):
4454
"""
4555
Endpoint to initiate an LLM call as a background job.
56+
Returns job information for polling.
4657
"""
4758
project_id = _current_user.project_.id
4859
organization_id = _current_user.organization_.id
4960

5061
if request.callback_url:
5162
validate_callback_url(str(request.callback_url))
5263

53-
start_job(
64+
job_id = start_job(
5465
db=session,
5566
request=request,
5667
project_id=project_id,
5768
organization_id=organization_id,
5869
)
5970

60-
return APIResponse.success_response(
61-
data=Message(
62-
message=f"Your response is being generated and will be delivered via callback."
63-
),
71+
# Fetch job details to return immediate response
72+
job_crud = JobCrud(session=session)
73+
job = job_crud.get(job_id=job_id, project_id=project_id)
74+
75+
if not job:
76+
raise HTTPException(status_code=404, detail="Job not found")
77+
78+
message = "Your response is being generated and will be delivered via callback."
79+
if not request.callback_url:
80+
message = "Your response is being generated"
81+
82+
job_response = LLMJobImmediatePublic(
83+
job_id=job.id,
84+
status=job.status.value,
85+
message=message,
86+
job_inserted_at=job.created_at,
87+
job_updated_at=job.updated_at,
6488
)
89+
90+
return APIResponse.success_response(data=job_response)
91+
92+
93+
@router.get(
94+
"/llm/call/{job_id}",
95+
description=load_description("llm/get_llm_call.md"),
96+
response_model=APIResponse[LLMJobPublic],
97+
dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))],
98+
)
99+
def get_llm_call_status(
100+
_current_user: AuthContextDep,
101+
session: SessionDep,
102+
job_id: UUID,
103+
) -> APIResponse[LLMJobPublic]:
104+
"""
105+
Poll for LLM call job status and results.
106+
Returns job information with nested LLM response when complete.
107+
"""
108+
109+
project_id = _current_user.project_.id
110+
111+
job_crud = JobCrud(session=session)
112+
job = job_crud.get(job_id=job_id, project_id=project_id)
113+
114+
if not job:
115+
raise HTTPException(status_code=404, detail="Job not found")
116+
117+
llm_call_response = None
118+
if job.status.value == JobStatus.SUCCESS:
119+
llm_calls = get_llm_calls_by_job_id(
120+
session=session, job_id=job_id, project_id=project_id
121+
)
122+
123+
if llm_calls:
124+
# Get the first LLM call from the list which will be the only call for the job id
125+
# since we initially won't be using this endpoint for llm chains
126+
llm_call = llm_calls[0]
127+
128+
llm_response = LLMResponse(
129+
provider_response_id=llm_call.provider_response_id or "",
130+
conversation_id=llm_call.conversation_id,
131+
provider=llm_call.provider,
132+
model=llm_call.model,
133+
output=llm_call.content,
134+
)
135+
136+
if not llm_call.usage:
137+
logger.warning(
138+
f"[get_llm_call] Missing usage data for llm_call job_id={job_id}, project_id={project_id}"
139+
)
140+
141+
llm_call_response = LLMCallResponse(
142+
response=llm_response,
143+
usage=Usage(**llm_call.usage),
144+
provider_raw_response=None,
145+
)
146+
147+
job_response = LLMJobPublic(
148+
job_id=job.id,
149+
status=job.status.value,
150+
llm_response=llm_call_response,
151+
error_message=job.error_message,
152+
)
153+
154+
return APIResponse.success_response(data=job_response)

backend/app/crud/jobs.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,13 @@ class JobCrud:
1212
def __init__(self, session: Session):
1313
self.session = session
1414

15-
def create(self, job_type: JobType, trace_id: str | None = None) -> Job:
16-
new_job = Job(
17-
job_type=job_type,
18-
trace_id=trace_id,
19-
)
15+
def create(
16+
self,
17+
job_type: JobType,
18+
trace_id: str | None = None,
19+
project_id: int | None = None,
20+
) -> Job:
21+
new_job = Job(job_type=job_type, trace_id=trace_id, project_id=project_id)
2022
self.session.add(new_job)
2123
self.session.commit()
2224
self.session.refresh(new_job)
@@ -38,5 +40,10 @@ def update(self, job_id: UUID, job_update: JobUpdate) -> Job:
3840

3941
return job
4042

41-
def get(self, job_id: UUID) -> Job | None:
42-
return self.session.get(Job, job_id)
43+
def get(self, job_id: UUID, project_id: int) -> Job | None:
44+
job = self.session.get(Job, job_id)
45+
if job is None:
46+
return None
47+
if job.project_id not in (None, project_id):
48+
return None
49+
return job

backend/app/crud/llm.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import logging
2+
import base64
3+
import json
4+
from uuid import UUID
25
from typing import Any, Literal
36

4-
from uuid import UUID
57
from sqlmodel import Session, select
8+
69
from app.core.util import now
7-
import base64
8-
import json
910
from app.models.llm import LlmCall, LLMCallRequest, ConfigBlob
1011
from app.models.llm.request import (
1112
TextInput,
@@ -234,13 +235,13 @@ def get_llm_call_by_id(
234235

235236

236237
def get_llm_calls_by_job_id(
237-
session: Session,
238-
job_id: UUID,
238+
session: Session, job_id: UUID, project_id: int
239239
) -> list[LlmCall]:
240240
statement = (
241241
select(LlmCall)
242242
.where(
243243
LlmCall.job_id == job_id,
244+
LlmCall.project_id == project_id,
244245
LlmCall.deleted_at.is_(None),
245246
)
246247
.order_by(LlmCall.created_at.desc())

backend/app/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,8 @@
122122
LLMChainRequest,
123123
LLMChainResponse,
124124
LlmChain,
125+
LLMJobImmediatePublic,
126+
LLMJobPublic,
125127
)
126128

127129
from .message import Message

backend/app/models/job.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,11 @@ class Job(SQLModel, table=True):
4040
description="Tracing ID for correlating logs and traces.",
4141
sa_column_kwargs={"comment": "Tracing ID for correlating logs and traces"},
4242
)
43+
project_id: int | None = Field(
44+
default=None,
45+
description="Project ID of the project the job belongs to.",
46+
sa_column_kwargs={"comment": "Project ID of the job's project"},
47+
)
4348
error_message: str | None = Field(
4449
default=None,
4550
description="Error details if the job fails.",

backend/app/models/llm/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,6 @@
3030
AudioOutput,
3131
LLMChainResponse,
3232
IntermediateChainResponse,
33+
LLMJobImmediatePublic,
34+
LLMJobPublic,
3335
)

backend/app/models/llm/request.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from pydantic import HttpUrl, model_validator
88
from sqlalchemy.dialects.postgresql import JSONB
99
from sqlmodel import Field, Index, SQLModel, text
10+
1011
from app.core.util import now
1112
from app.models.llm.constants import (
1213
DEFAULT_STT_MODEL,
@@ -450,6 +451,11 @@ class LlmCall(SQLModel, table=True):
450451
"conversation_id",
451452
postgresql_where=text("conversation_id IS NOT NULL AND deleted_at IS NULL"),
452453
),
454+
Index(
455+
"idx_llm_call_chain_id",
456+
"chain_id",
457+
postgresql_where=text("chain_id IS NOT NULL"),
458+
),
453459
)
454460

455461
id: UUID = Field(
@@ -666,10 +672,10 @@ class LLMChainRequest(SQLModel):
666672
class ChainStatus(str, Enum):
667673
"""Status of an LLM chain execution."""
668674

669-
PENDING = "pending"
670-
RUNNING = "running"
671-
FAILED = "failed"
672-
COMPLETED = "completed"
675+
PENDING = "PENDING"
676+
RUNNING = "RUNNING"
677+
FAILED = "FAILED"
678+
COMPLETED = "COMPLETED"
673679

674680

675681
class LlmChain(SQLModel, table=True):

0 commit comments

Comments
 (0)