Skip to content

Commit ee3057e

Browse files
committed
attach file to user message
1 parent 095badc commit ee3057e

File tree

8 files changed

+65
-64
lines changed

8 files changed

+65
-64
lines changed

packages/server/next/app/api/files/helpers.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@ export async function storeFile(name: string, fileBuffer: Buffer) {
88
const parts = name.split(".");
99
const fileName = parts[0];
1010
const fileExt = parts[1];
11-
11+
if (!fileName) {
12+
throw new Error("File name is required");
13+
}
1214
if (!fileExt) {
1315
throw new Error("File extension is required");
1416
}

python/llama-index-server/examples/private_file/agent-workflow.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def create_file_tool(chat_request: ChatRequest) -> Optional[FunctionTool]:
1818
Create a tool to read file if the user uploads a file.
1919
"""
2020
file_ids = []
21-
for file in get_file_attachments(chat_request):
21+
for file in get_file_attachments(chat_request.messages):
2222
file_ids.append(file.id)
2323
if len(file_ids) == 0:
2424
return None
@@ -29,7 +29,7 @@ def create_file_tool(chat_request: ChatRequest) -> Optional[FunctionTool]:
2929
)
3030

3131
def read_file(file_id: str) -> str:
32-
file_path = FileService.get_private_file_path(file_id)
32+
file_path = FileService.get_file_path(file_id)
3333
try:
3434
with open(file_path, "r") as file:
3535
return file.read()
@@ -57,7 +57,7 @@ def create_app() -> FastAPI:
5757
workflow_factory=create_workflow,
5858
suggest_next_questions=False,
5959
ui_config=UIConfig(
60-
file_upload_enabled=True,
60+
enable_file_upload=True,
6161
component_dir="components",
6262
),
6363
)

python/llama-index-server/examples/private_file/custom-workflow.py

Lines changed: 14 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,19 @@
22

33
from fastapi import FastAPI
44
from llama_index.core.agent.workflow.workflow_events import AgentStream
5-
from llama_index.core.llms import LLM
5+
from llama_index.core.llms import LLM, ChatMessage, DocumentBlock
66
from llama_index.core.prompts import PromptTemplate
77
from llama_index.core.workflow import (
88
Context,
99
Event,
1010
StartEvent,
1111
StopEvent,
1212
Workflow,
13+
WorkflowRuntimeError,
1314
step,
1415
)
1516
from llama_index.llms.openai import OpenAI
1617
from llama_index.server import LlamaIndexServer, UIConfig
17-
from llama_index.server.models import ChatRequest
18-
from llama_index.server.services.file import FileService
19-
from llama_index.server.utils.chat_attachments import get_file_attachments
2018

2119

2220
class FileHelpEvent(Event):
@@ -25,7 +23,7 @@ class FileHelpEvent(Event):
2523
"""
2624

2725
file_content: str
28-
user_msg: str
26+
user_request: str
2927

3028

3129
class FileHelpWorkflow(Workflow):
@@ -38,53 +36,25 @@ class FileHelpWorkflow(Workflow):
3836
def __init__(
3937
self,
4038
llm: LLM,
41-
chat_request: ChatRequest, # Initial the workflow with the chat request
4239
**kwargs: Any,
4340
):
4441
super().__init__(**kwargs)
4542
self.llm = llm
46-
# Get the uploaded files from the chat request and stores them in the workflow instance for accessing later
47-
self.uploaded_files = get_file_attachments(chat_request)
48-
if len(self.uploaded_files) == 0:
49-
raise ValueError("No uploaded files found. Please upload a file to start")
5043

5144
@step
5245
async def read_files(self, ctx: Context, ev: StartEvent) -> FileHelpEvent:
53-
user_msg = ev.user_msg
46+
user_msg: ChatMessage = ev.user_msg
47+
# All the uploaded files are included in the user_msg.blocks as DocumentBlock
48+
files = [block for block in user_msg.blocks if isinstance(block, DocumentBlock)]
49+
if len(files) != 1:
50+
raise WorkflowRuntimeError("Please upload only one file")
5451

55-
# 1. Access through workflow instance as is
56-
# last_file = self.uploaded_files[-1]
57-
58-
59-
# 2. Access through user_msg (if it's a ChatMessage)
60-
# llama_index support ChatMessage with DocumentBlock which mostly the same as our FileServer.
61-
# (but I guess we'll get back to dealing with other problems
62-
# that we need to pass other data to the workflow later)
63-
# e.g:
64-
# files = [
65-
# ServerFile.from_document_block(block)
66-
# for block in user_msg.blocks
67-
# if isinstance(block, DocumentBlock)
68-
# ]
69-
#
70-
# or they can just use files: List[DocumentBlock] as is.
71-
72-
73-
# 3. Introduce server start event with additional fields
74-
# e.g:
75-
# class ChatStartEvent(StartEvent):
76-
# user_msg: Union[str, ChatMessage]
77-
# chat_history: list[ChatMessage]
78-
# attachments: list[ServerFile]
79-
# Then the user can clearly know what do they have with the StartEvent
80-
81-
file_path = FileService.get_private_file_path(last_file.id)
82-
with open(file_path, "r", encoding="utf-8") as f:
83-
file_content = f.read()
52+
# Simply call resolve_document() to get the file content
53+
file_content = files[0].resolve_document().read().decode("utf-8")
8454

8555
return FileHelpEvent(
8656
file_content=file_content,
87-
user_msg=user_msg,
57+
user_request=ev.user_msg.content,
8858
)
8959

9060
@step
@@ -100,7 +70,7 @@ async def help_user(self, ctx: Context, ev: FileHelpEvent) -> StopEvent:
10070
{file_content}
10171
""")
10272
prompt = default_prompt.format(
103-
user_msg=ev.user_msg,
73+
user_msg=ev.user_request,
10474
file_content=ev.file_content,
10575
)
10676
stream = await self.llm.astream_complete(prompt)
@@ -120,10 +90,9 @@ async def help_user(self, ctx: Context, ev: FileHelpEvent) -> StopEvent:
12090
)
12191

12292

123-
def create_workflow(chat_request: ChatRequest) -> Workflow:
93+
def create_workflow() -> Workflow:
12494
return FileHelpWorkflow(
12595
llm=OpenAI(model="gpt-4.1-mini"),
126-
chat_request=chat_request,
12796
)
12897

12998

@@ -132,7 +101,7 @@ def create_app() -> FastAPI:
132101
workflow_factory=create_workflow,
133102
suggest_next_questions=False,
134103
ui_config=UIConfig(
135-
file_upload_enabled=True,
104+
enable_file_upload=True,
136105
component_dir="components",
137106
),
138107
)

python/llama-index-server/llama_index/server/api/routers/chat.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
from fastapi import APIRouter, BackgroundTasks, HTTPException
88
from fastapi.responses import StreamingResponse
9-
109
from llama_index.core.agent.workflow.workflow_events import (
1110
AgentInput,
1211
AgentSetup,
@@ -24,6 +23,7 @@
2423
SuggestNextQuestions,
2524
)
2625
from llama_index.server.api.callbacks.stream_handler import StreamHandler
26+
from llama_index.server.api.utils.chat_request import prepare_user_message
2727
from llama_index.server.api.utils.vercel_stream import VercelStreamResponse
2828
from llama_index.server.models.chat import ChatFile, ChatRequest
2929
from llama_index.server.models.hitl import HumanInputEvent
@@ -46,7 +46,7 @@ async def chat(
4646
) -> StreamingResponse:
4747
try:
4848
last_message = request.messages[-1]
49-
user_message = last_message.to_llamaindex_message()
49+
user_message = prepare_user_message(request)
5050
chat_history = [
5151
message.to_llamaindex_message() for message in request.messages[:-1]
5252
]
@@ -68,7 +68,7 @@ async def chat(
6868
workflow_handler = workflow.run(ctx=ctx)
6969
else:
7070
workflow_handler = workflow.run(
71-
user_msg=user_message.content,
71+
user_msg=user_message,
7272
chat_history=chat_history,
7373
)
7474

python/llama-index-server/llama_index/server/api/utils/chat_request.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
from typing import List, Optional
22

3+
from llama_index.core.llms import DocumentBlock
4+
from llama_index.core.types import ChatMessage, MessageRole
35
from llama_index.server.models.artifacts import Artifact
46
from llama_index.server.models.chat import ChatRequest
7+
from llama_index.server.services.file import FileService
8+
from llama_index.server.utils.chat_attachments import get_file_attachments
59

610

711
def get_artifacts(chat_request: ChatRequest) -> List[Artifact]:
@@ -22,3 +26,25 @@ def get_artifacts(chat_request: ChatRequest) -> List[Artifact]:
2226
def get_last_artifact(chat_request: ChatRequest) -> Optional[Artifact]:
2327
artifacts = get_artifacts(chat_request)
2428
return artifacts[-1] if len(artifacts) > 0 else None
29+
30+
31+
def prepare_user_message(chat_request: ChatRequest) -> ChatMessage:
32+
"""
33+
Prepare the user message from the chat request.
34+
"""
35+
last_message: ChatMessage = chat_request.messages[-1].to_llamaindex_message()
36+
if last_message.role != MessageRole.USER:
37+
raise ValueError("Last message must be from user")
38+
39+
# Add attached files to the user message
40+
attachment_files = get_file_attachments(chat_request.messages)
41+
last_message.blocks += [
42+
DocumentBlock(
43+
path=file.path or FileService.get_file_path(file.id),
44+
url=file.url,
45+
document_mimetype=file.type,
46+
)
47+
for file in attachment_files
48+
]
49+
50+
return last_message

python/llama-index-server/llama_index/server/models/chat.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
import re
22
from typing import Any, List, Literal, Optional, Union
33

4-
from pydantic import BaseModel, Field, field_validator
5-
64
from llama_index.core.types import ChatMessage, MessageRole
5+
from pydantic import BaseModel, Field, field_validator
76

87

98
class ServerFile(BaseModel):
@@ -74,6 +73,10 @@ class ChatAPIMessage(BaseModel):
7473
annotations: Optional[List[Union[FileAnnotation, Any]]] = None
7574

7675
def to_llamaindex_message(self) -> ChatMessage:
76+
"""
77+
Simply convert text content of API message to llama_index's ChatMessage.
78+
Annotations are not included.
79+
"""
7780
return ChatMessage(role=self.role, content=self.content)
7881

7982
@property

python/llama-index-server/llama_index/server/utils/chat_attachments.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,26 @@
11
from typing import List
22

3-
from llama_index.server.models.chat import ChatRequest, FileAnnotation, ServerFile
3+
from llama_index.core.types import MessageRole
4+
from llama_index.server.models.chat import ChatAPIMessage, FileAnnotation, ServerFile
45

56

6-
def get_file_attachments(chat_request: ChatRequest) -> List[ServerFile]:
7+
def get_file_attachments(messages: List[ChatAPIMessage]) -> List[ServerFile]:
78
"""
8-
Extract all file attachments from the chat request.
9+
Extract all file attachments from user messages.
910
1011
Args:
11-
chat_request (ChatRequest): The chat request.
12+
messages (List[ChatAPIMessage]): The list of messages.
1213
1314
Returns:
14-
List[PrivateFile]: The list of private files.
15+
List[ServerFile]: The list of private files.
1516
"""
16-
message_annotations = [
17-
message.annotations for message in chat_request.messages if message.annotations
17+
user_message_annotations = [
18+
message.annotations
19+
for message in messages
20+
if message.annotations and message.role == MessageRole.USER
1821
]
1922
files: List[ServerFile] = []
20-
for annotation in message_annotations:
23+
for annotation in user_message_annotations:
2124
if isinstance(annotation, list):
2225
for item in annotation:
2326
if isinstance(item, FileAnnotation):

python/llama-index-server/tests/api/test_chat_api.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,6 @@ async def test_chat_router(
9090

9191
# Verify the workflow was called with the correct arguments
9292
call_args = mock_workflow.run.call_args[1]
93-
assert call_args["user_msg"] == "Hello, how are you?"
9493
assert isinstance(call_args["chat_history"], list)
9594
assert len(call_args["chat_history"]) == 0 # No history for first message
9695

@@ -153,6 +152,5 @@ def workflow_factory(verbose: bool = False) -> MagicMock:
153152

154153
# Verify the workflow was called with the correct arguments
155154
call_args = mock_workflow.run.call_args[1]
156-
assert call_args["user_msg"] == "What's the weather in New York?"
157155
assert isinstance(call_args["chat_history"], list)
158156
assert len(call_args["chat_history"]) == 0 # No history for first message

0 commit comments

Comments
 (0)