Skip to content

Commit d2c27cd

Browse files
committed
standardize params
1 parent 071d5cf commit d2c27cd

File tree

2 files changed

+25
-6
lines changed
  • llama-index-server/llama_index/server

2 files changed

+25
-6
lines changed

llama-index-server/llama_index/server/api/routers/chat.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import inspect
23
import logging
34
import os
45
from typing import AsyncGenerator, Callable, Union
@@ -35,7 +36,12 @@ async def chat(
3536
chat_history = [
3637
message.to_llamaindex_message() for message in request.messages[:-1]
3738
]
38-
workflow = workflow_factory()
39+
# detect if the workflow factory has chat_request as a parameter
40+
factory_sig = inspect.signature(workflow_factory)
41+
if "chat_request" in factory_sig.parameters:
42+
workflow = workflow_factory(chat_request=request)
43+
else:
44+
workflow = workflow_factory()
3945
workflow_handler = workflow.run(
4046
user_msg=user_message.content,
4147
chat_history=chat_history,

llama-index-server/llama_index/server/services/llamacloud/index.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,14 @@
33
from typing import TYPE_CHECKING, Any, Optional
44

55
from llama_cloud import PipelineType
6-
from pydantic import BaseModel, Field, field_validator
7-
86
from llama_index.core.callbacks import CallbackManager
97
from llama_index.core.ingestion.api_utils import (
108
get_client as llama_cloud_get_client,
119
)
1210
from llama_index.core.settings import Settings
1311
from llama_index.indices.managed.llama_cloud import LlamaCloudIndex
12+
from llama_index.server.api.models import ChatRequest
13+
from pydantic import BaseModel, Field, field_validator
1414

1515
if TYPE_CHECKING:
1616
from llama_cloud.client import LlamaCloud
@@ -87,13 +87,26 @@ def to_index_kwargs(self) -> dict:
8787
"callback_manager": self.callback_manager,
8888
}
8989

90+
@classmethod
91+
def from_chat_request(cls, chat_request: ChatRequest) -> "IndexConfig":
92+
default_config = cls()
93+
if chat_request is not None:
94+
llamacloud_config = chat_request.data["llamaCloudPipeline"]
95+
if llamacloud_config is not None:
96+
default_config.llama_cloud_pipeline_config.pipeline = llamacloud_config[
97+
"pipeline"
98+
]
99+
default_config.llama_cloud_pipeline_config.project = llamacloud_config[
100+
"project"
101+
]
102+
return default_config
103+
90104

91105
def get_index(
92-
config: Optional[IndexConfig] = None,
106+
chat_request: Optional[ChatRequest] = None,
93107
create_if_missing: bool = False,
94108
) -> Optional[LlamaCloudIndex]:
95-
if config is None:
96-
config = IndexConfig()
109+
config = IndexConfig.from_chat_request(chat_request)
97110
# Check whether the index exists
98111
try:
99112
index = LlamaCloudIndex(**config.to_index_kwargs())

0 commit comments

Comments
 (0)