|
6 | 6 | import sentry_sdk
|
7 | 7 |
|
8 | 8 | from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
|
9 |
| -from holmes.core.tools import Tool |
10 | 9 | from pydantic import BaseModel
|
11 | 10 | import litellm
|
12 | 11 | import os
|
@@ -45,13 +44,13 @@ def count_tokens_for_message(self, messages: list[dict]) -> int:
|
45 | 44 | def completion(
|
46 | 45 | self,
|
47 | 46 | messages: List[Dict[str, Any]],
|
48 |
| - tools: Optional[List[Tool]] = [], |
| 47 | + tools: Optional[List[Dict[str, Any]]] = [], |
49 | 48 | tool_choice: Optional[Union[str, dict]] = None,
|
50 | 49 | response_format: Optional[Union[dict, Type[BaseModel]]] = None,
|
51 | 50 | temperature: Optional[float] = None,
|
52 | 51 | drop_params: Optional[bool] = None,
|
53 | 52 | stream: Optional[bool] = None,
|
54 |
| - ) -> ModelResponse: |
| 53 | + ) -> Union[ModelResponse, CustomStreamWrapper]: |
55 | 54 | pass
|
56 | 55 |
|
57 | 56 |
|
@@ -167,24 +166,28 @@ def count_tokens_for_message(self, messages: list[dict]) -> int:
|
167 | 166 | def completion(
|
168 | 167 | self,
|
169 | 168 | messages: List[Dict[str, Any]],
|
170 |
| - tools: Optional[List[Tool]] = [], |
| 169 | + tools: Optional[List[Dict[str, Any]]] = None, |
171 | 170 | tool_choice: Optional[Union[str, dict]] = None,
|
172 | 171 | response_format: Optional[Union[dict, Type[BaseModel]]] = None,
|
173 | 172 | temperature: Optional[float] = None,
|
174 | 173 | drop_params: Optional[bool] = None,
|
175 | 174 | stream: Optional[bool] = None,
|
176 |
| - ) -> ModelResponse: |
| 175 | + ) -> Union[ModelResponse, CustomStreamWrapper]: |
| 176 | + tools_args = {} |
| 177 | + if tools and tool_choice: |
| 178 | + tools_args["tools"] = tools |
| 179 | + tools_args["tool_choice"] = tool_choice |
| 180 | + |
177 | 181 | result = litellm.completion(
|
178 | 182 | model=self.model,
|
179 | 183 | api_key=self.api_key,
|
180 | 184 | messages=messages,
|
181 |
| - tools=tools, |
182 |
| - tool_choice=tool_choice, |
183 | 185 | base_url=self.base_url,
|
184 | 186 | temperature=temperature,
|
185 | 187 | response_format=response_format,
|
186 | 188 | drop_params=drop_params,
|
187 | 189 | stream=stream,
|
| 190 | + **tools_args, |
188 | 191 | )
|
189 | 192 |
|
190 | 193 | if isinstance(result, ModelResponse):
|
|
0 commit comments