Skip to content

Commit df4c296

Browse files
author
Pablo Marin
committed
improve SSE /stream endpoint and clients
1 parent 6c30fed commit df4c296

File tree

3 files changed

+109
-57
lines changed

3 files changed

+109
-57
lines changed

15-FastAPI-API.ipynb

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -224,18 +224,35 @@
224224
" resp.raise_for_status()\n",
225225
" client = SSEClient(resp)\n",
226226
" for event in client.events():\n",
227-
" if event.event == 'tool_start':\n",
228-
" print(\"\\n[Tool Start]\", event.data)\n",
229-
" elif event.event == 'tool_end':\n",
230-
" print(\"\\n[Tool End]\", event.data)\n",
231-
" elif event.event == 'partial':\n",
232-
" process_partial_text(event.data)\n",
233-
" elif event.event == 'end':\n",
234-
" print(f\"\\n[Done Streaming] Final text: {event.data}\")\n",
235-
" elif event.event == 'error':\n",
236-
" print(f\"\\n[SSE Error] {event.data}\")\n",
227+
" evt_type = event.event\n",
228+
" evt_data = event.data\n",
229+
"\n",
230+
" if evt_type == 'metadata':\n",
231+
" info = json.loads(evt_data)\n",
232+
" print(f\"\\n[Metadata] run_id={info.get('run_id', '')}\")\n",
233+
"\n",
234+
" elif evt_type == 'data':\n",
235+
" # The server is sending partial chunk(s) as a \"data\" event\n",
236+
" # so we treat it as partial text to display:\n",
237+
" process_partial_text(evt_data)\n",
238+
"\n",
239+
" elif evt_type == 'on_tool_start':\n",
240+
" print(f\"\\n[Tool Start] {evt_data}\")\n",
241+
"\n",
242+
" elif evt_type == 'on_tool_end':\n",
243+
" print(f\"\\n[Tool End] {evt_data}\")\n",
244+
"\n",
245+
" elif evt_type == 'end':\n",
246+
" # This signals no more data will follow.\n",
247+
" print(f\"\\n[Done Streaming] Final text: {evt_data}\")\n",
248+
"\n",
249+
" elif evt_type == 'error':\n",
250+
" # The server encountered an exception or error\n",
251+
" print(f\"\\n[SSE Error] {evt_data}\")\n",
252+
"\n",
237253
" else:\n",
238-
" print(f\"\\n[Unrecognized event: {event.event}]\", event.data)\n",
254+
" # Some unexpected event name\n",
255+
" print(f\"\\n[Unrecognized event={evt_type}] {evt_data}\")\n",
239256
"\n"
240257
]
241258
},

apps/backend/fastapi/app/server.py

Lines changed: 48 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,7 @@
1919
csv_file_path = "data/all-states-history.csv"
2020
api_file_path = "data/openapi_kraken.json"
2121

22-
##########################################################
23-
## Uncomment this section to run locally
22+
######### Uncomment this section to run locally##########
2423
# current_file = Path(__file__).resolve()
2524
# library_path = current_file.parents[4]
2625
# data_path = library_path / "data"
@@ -181,39 +180,75 @@ async def batch(req: BatchRequest):
181180
# Streaming Endpoint
182181
# -----------------------------------------------------------------------------
183182
@app.post("/stream")
184-
async def stream(req: AskRequest):
183+
async def stream_endpoint(req: AskRequest):
184+
"""
185+
Stream partial chunks from the chain in SSE format.
186+
187+
SSE event structure:
188+
- event: "metadata" (OPTIONAL) – any run-specific metadata
189+
- event: "data" – a chunk of text
190+
- event: "end" – signals no more data
191+
- event: "on_tool_start" - signals the begin of use of a tool
192+
- event: "on_tool_end" - signals the end of use of a tool
193+
- event: "error" – signals an error
194+
"""
185195
logger.info("[/stream] Called with user_input=%s, thread_id=%s", req.user_input, req.thread_id)
186196

187197
if not graph_async:
188198
logger.error("Graph not compiled yet.")
189199
raise HTTPException(status_code=500, detail="Graph not compiled yet.")
190200

191-
config = {"configurable": {"thread_id": req.thread_id or str(uuid.uuid4())}}
201+
run_id = req.thread_id or str(uuid.uuid4())
202+
config = {"configurable": {"thread_id": run_id}}
192203
inputs = {"messages": [("human", req.user_input)]}
193204

194205
async def event_generator():
195-
accumulated_text = ""
196206
try:
207+
yield {
208+
"event": "metadata",
209+
"data": json.dumps({"run_id": run_id})
210+
}
211+
212+
accumulated_text = ""
197213
async for event in graph_async.astream_events(inputs, config, version="v2"):
198-
if event["event"] == "on_chat_model_stream" and event["metadata"].get("langgraph_node") == "agent":
199-
chunk_text = event["data"]["chunk"].content
200-
accumulated_text += chunk_text
201-
yield {"event": "partial", "data": chunk_text}
214+
if event["event"] == "on_chat_model_stream":
215+
if event["metadata"].get("langgraph_node") == "agent":
216+
chunk_text = event["data"]["chunk"].content
217+
accumulated_text += chunk_text
218+
219+
yield {
220+
"event": "data",
221+
"data": chunk_text # partial chunk
222+
}
223+
202224
elif event["event"] == "on_tool_start":
203-
yield {"event": "tool_start", "data": f"Starting {event.get('name','')}"}
225+
yield {"event": "on_tool_start", "data": f"Tool Start: {event.get('name', '')}"}
226+
204227
elif event["event"] == "on_tool_end":
205-
yield {"event": "tool_end", "data": f"Done {event.get('name','')}"}
228+
yield {"event": "on_tool_end", "data": f"Tool End: {event.get('name', '')}"}
229+
206230
elif event["event"] == "on_chain_end" and event.get("name") == "LangGraph":
231+
# If "FINISH" is the next step
207232
if event["data"]["output"].get("next") == "FINISH":
208233
yield {"event": "end", "data": accumulated_text}
209-
return
234+
return # Stop iteration
235+
210236
except Exception as ex:
211237
logger.exception("[/stream] Error streaming events")
212-
yield {"event": "error", "data": str(ex)}
238+
# SSE "error" event
239+
yield {
240+
"event": "error",
241+
"data": json.dumps({
242+
"status_code": 500,
243+
"message": str(ex)
244+
})
245+
}
246+
raise
213247

214248
return EventSourceResponse(event_generator(), media_type="text/event-stream")
215249

216250

251+
217252
# -----------------------------------------------------------------------------
218253
# Main Entrypoint
219254
# -----------------------------------------------------------------------------

apps/frontend/app/helpers/streamlit_helpers.py

Lines changed: 33 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -71,28 +71,21 @@ def get_or_create_ids():
7171

7272
def consume_api(url, user_query, session_id, user_id):
7373
"""
74-
Send a POST request to the FastAPI backend at `url` (the SSE /stream endpoint),
74+
Send a POST request to the FastAPI backend at `url` (/stream endpoint),
7575
and consume the SSE stream using sseclient-py.
7676
7777
The server is expected to return events like:
78-
{"event": "partial", "data": "..."}
78+
{"event": "metadata", "data": "..."}
79+
{"event": "data", "data": "..."}
80+
{"event": "on_tool_start", "data": "..."}
81+
{"event": "on_tool_end", "data": "..."}
7982
{"event": "end", "data": "..."}
80-
{"event": "tool_start", "data": "..."}
81-
{"event": "tool_end", "data": "..."}
82-
{"event": "error", "data": "some error text"}
83-
84-
This function yields text chunks (e.g., partial content) as they arrive.
85-
86-
:param url: SSE /stream endpoint
87-
:param user_query: The user query text to send
88-
:param session_id: A unique ID representing the conversation
89-
:param user_id: The user ID (not strictly used in this minimal example)
90-
:yield: Text chunks that Streamlit can display incrementally
83+
{"event": "error", "data": "..."}
9184
"""
9285
headers = {"Content-Type": "application/json"}
9386
payload = {
9487
"user_input": user_query,
95-
"thread_id": session_id # Typically reusing session_id as conversation ID
88+
"thread_id": session_id
9689
}
9790

9891
logger.info(
@@ -104,43 +97,50 @@ def consume_api(url, user_query, session_id, user_id):
10497
try:
10598
with requests.post(url, json=payload, headers=headers, stream=True) as resp:
10699
resp.raise_for_status()
107-
logger.info("SSE stream opened successfully with status code %d.", resp.status_code)
100+
logger.info("SSE stream opened with status code: %d", resp.status_code)
108101

109-
# Use SSEClient to parse the stream
110102
client = SSEClient(resp)
111103
for event in client.events():
112-
if not event.data.strip():
113-
# Skip keep-alive messages or empty lines
104+
if not event.data:
105+
# Skip empty lines
114106
continue
115107

116108
evt_type = event.event
117109
evt_data = event.data
118110
logger.debug("Received SSE event: %s, data: %s", evt_type, evt_data)
119111

120-
# Switch on event type from the server
121-
if evt_type == "partial":
122-
# Yield partial text; can be streamed in real-time
112+
if evt_type == "metadata":
113+
# Possibly parse run_id from the JSON
114+
# e.g. { "run_id": "...some uuid..." }
115+
info = json.loads(evt_data)
116+
run_id = info.get("run_id", "")
117+
# For streamlit, you might store it as session state, etc.
118+
# st.write(f"New run_id: {run_id}")
119+
120+
elif evt_type == "data":
121+
# The server is sending partial tokens as "data"
122+
# We can yield them so Streamlit can display incrementally
123123
yield evt_data
124-
125-
elif evt_type == "tool_start":
126-
# Display tool start
127-
# e.g. [Tool Start] Starting documents_retrieval
128-
# yield f"\n[Tool Start] {evt_data}\n"
124+
125+
elif evt_type == "on_tool_start":
126+
# Optionally display: yield or do a Streamlit update
127+
# yield f"[Tool Start] {evt_data}"
129128
pass
130129

131-
elif evt_type == "tool_end":
132-
# Display tool end
133-
# e.g. [Tool End] Done documents_retrieval
134-
# yield f"\n[Tool End] {evt_data}\n"
130+
elif evt_type == "on_tool_end":
131+
# yield f"[Tool End] {evt_data}"
135132
pass
136-
133+
137134
elif evt_type == "end":
138-
# Yield final accumulated text without leading newline
135+
# This is the final text.
136+
# Typically you might do a final display or update the UI
139137
yield evt_data
138+
140139
elif evt_type == "error":
140+
# The server had an error
141141
yield f"[SSE Error] {evt_data}"
142+
142143
else:
143-
# Unrecognized event
144144
yield f"[Unrecognized event: {evt_type}] {evt_data}"
145145

146146
except requests.exceptions.HTTPError as err:

0 commit comments

Comments
 (0)