Skip to content

Commit b63bb56

Browse files
authored
ROB-717 stream (#240)
1 parent 40b6b88 commit b63bb56

File tree

4 files changed

+244
-0
lines changed

4 files changed

+244
-0
lines changed

holmes/core/investigation.py

+96
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,15 @@
55
from holmes.core.models import InvestigateRequest, InvestigationResult
66
from holmes.core.supabase_dal import SupabaseDal
77
from holmes.utils.robusta import load_robusta_api_key
8+
import logging
9+
10+
from holmes.core.investigation_structured_output import (
11+
DEFAULT_SECTIONS,
12+
REQUEST_STRUCTURED_OUTPUT_FROM_LLM,
13+
get_output_format_for_investigation,
14+
)
15+
16+
from holmes.plugins.prompts import load_and_render_prompt
817

918

1019
def investigate_issues(
@@ -49,3 +58,90 @@ def investigate_issues(
4958
tool_calls=investigation.tool_calls or [],
5059
instructions=investigation.instructions,
5160
)
61+
62+
63+
def get_investigation_context(
64+
investigate_request: InvestigateRequest, dal: SupabaseDal, config: Config
65+
):
66+
load_robusta_api_key(dal=dal, config=config)
67+
ai = config.create_issue_investigator(dal=dal)
68+
69+
raw_data = investigate_request.model_dump()
70+
context = dal.get_issue_data(investigate_request.context.get("robusta_issue_id"))
71+
if context:
72+
raw_data["extra_context"] = context
73+
74+
issue = Issue(
75+
id=context["id"] if context else "",
76+
name=investigate_request.title,
77+
source_type=investigate_request.source,
78+
source_instance_id=investigate_request.source_instance_id,
79+
raw=raw_data,
80+
)
81+
82+
runbooks = ai.runbook_manager.get_instructions_for_issue(issue)
83+
84+
instructions = dal.get_resource_instructions(
85+
"alert", investigate_request.context.get("issue_type")
86+
)
87+
if instructions is not None and instructions.instructions:
88+
runbooks.extend(instructions.instructions)
89+
if instructions is not None and len(instructions.documents) > 0:
90+
docPrompts = []
91+
for document in instructions.documents:
92+
docPrompts.append(f"* fetch information from this URL: {document.url}\n")
93+
runbooks.extend(docPrompts)
94+
95+
# This section is about setting vars to request the LLM to return structured output.
96+
# It does not mean that Holmes will not return structured sections for investigation as it is
97+
# capable of splitting the markdown into sections
98+
request_structured_output_from_llm = True
99+
response_format = None
100+
sections = investigate_request.sections
101+
if not sections or len(sections) == 0:
102+
# If no sections are passed, we will not ask the LLM for structured output
103+
sections = DEFAULT_SECTIONS
104+
request_structured_output_from_llm = False
105+
logging.info(
106+
"No section received from the client. Default sections will be used."
107+
)
108+
elif ai.llm.model and ai.llm.model.startswith("bedrock"):
109+
# Structured output does not work well with Bedrock Anthropic Sonnet 3.5 through litellm
110+
request_structured_output_from_llm = False
111+
112+
if not REQUEST_STRUCTURED_OUTPUT_FROM_LLM:
113+
request_structured_output_from_llm = False
114+
115+
if request_structured_output_from_llm:
116+
response_format = get_output_format_for_investigation(sections)
117+
logging.info("Structured output is enabled for this request")
118+
else:
119+
logging.info("Structured output is disabled for this request")
120+
121+
system_prompt = load_and_render_prompt(
122+
investigate_request.prompt_template,
123+
{
124+
"issue": issue,
125+
"sections": sections,
126+
"structured_output": request_structured_output_from_llm,
127+
},
128+
)
129+
130+
user_prompt = ""
131+
if runbooks:
132+
for runbook_str in runbooks:
133+
user_prompt += f"* {runbook_str}\n"
134+
135+
user_prompt = f'My instructions to check \n"""{user_prompt}"""'
136+
137+
global_instructions = dal.get_global_instructions_for_account()
138+
if (
139+
global_instructions
140+
and global_instructions.instructions
141+
and len(global_instructions.instructions[0]) > 0
142+
):
143+
user_prompt += f"\n\nGlobal Instructions (use only if relevant): {global_instructions.instructions[0]}\n"
144+
145+
user_prompt = f"{user_prompt}\n This is context from the issue {issue.raw}"
146+
147+
return ai, system_prompt, user_prompt, response_format, sections, runbooks

holmes/core/llm.py

+6
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from litellm.types.utils import ModelResponse
66
import sentry_sdk
77

8+
from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
89
from holmes.core.tools import Tool
910
from pydantic import BaseModel
1011
import litellm
@@ -49,6 +50,7 @@ def completion(
4950
response_format: Optional[Union[dict, Type[BaseModel]]] = None,
5051
temperature: Optional[float] = None,
5152
drop_params: Optional[bool] = None,
53+
stream: Optional[bool] = None,
5254
) -> ModelResponse:
5355
pass
5456

@@ -170,6 +172,7 @@ def completion(
170172
response_format: Optional[Union[dict, Type[BaseModel]]] = None,
171173
temperature: Optional[float] = None,
172174
drop_params: Optional[bool] = None,
175+
stream: Optional[bool] = None,
173176
) -> ModelResponse:
174177
result = litellm.completion(
175178
model=self.model,
@@ -181,10 +184,13 @@ def completion(
181184
temperature=temperature,
182185
response_format=response_format,
183186
drop_params=drop_params,
187+
stream=stream,
184188
)
185189

186190
if isinstance(result, ModelResponse):
187191
return result
192+
elif isinstance(result, CustomStreamWrapper):
193+
return result
188194
else:
189195
raise Exception(f"Unexpected type returned by the LLM {type(result)}")
190196

holmes/core/tool_calling_llm.py

+123
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
InputSectionsDataType,
1313
get_output_format_for_investigation,
1414
is_response_an_incorrect_tool_call,
15+
process_response_into_sections,
1516
)
1617
from holmes.core.performance_timing import PerformanceTiming
1718
from holmes.utils.tags import format_tags_in_string, parse_messages_tags
@@ -37,6 +38,14 @@ class ToolCallResult(BaseModel):
3738
result: str
3839
size: Optional[int] = None
3940

41+
def as_dict(self):
42+
return {
43+
"tool_call_id": self.tool_call_id,
44+
"role": "tool",
45+
"name": self.tool_name,
46+
"content": self.result,
47+
}
48+
4049

4150
class LLMResult(BaseModel):
4251
tool_calls: Optional[List[ToolCallResult]] = None
@@ -357,6 +366,120 @@ def truncate_messages_to_fit_context(
357366
message["content"] = message["content"][:tool_size]
358367
return messages
359368

369+
def call_stream(
370+
self,
371+
system_prompt: str,
372+
user_prompt: Optional[str] = None,
373+
response_format: Optional[Union[dict, Type[BaseModel]]] = None,
374+
runbooks: List[str] = None,
375+
):
376+
messages = [
377+
{"role": "system", "content": system_prompt},
378+
{"role": "user", "content": user_prompt},
379+
]
380+
perf_timing = PerformanceTiming("tool_calling_llm.call")
381+
tool_calls: List[ToolCallResult] = []
382+
tools = self.tool_executor.get_all_tools_openai_format()
383+
perf_timing.measure("get_all_tools_openai_format")
384+
i = 0
385+
386+
while i < self.max_steps:
387+
i += 1
388+
perf_timing.measure(f"start iteration {i}")
389+
logging.debug(f"running iteration {i}")
390+
391+
tools = [] if i == self.max_steps - 1 else tools
392+
tool_choice = None if tools == [] else "auto"
393+
394+
total_tokens = self.llm.count_tokens_for_message(messages)
395+
max_context_size = self.llm.get_context_window_size()
396+
maximum_output_token = self.llm.get_maximum_output_token()
397+
perf_timing.measure("count tokens")
398+
399+
if (total_tokens + maximum_output_token) > max_context_size:
400+
logging.warning("Token limit exceeded. Truncating tool responses.")
401+
messages = self.truncate_messages_to_fit_context(
402+
messages, max_context_size, maximum_output_token
403+
)
404+
perf_timing.measure("truncate_messages_to_fit_context")
405+
406+
logging.debug(f"sending messages={messages}\n\ntools={tools}")
407+
try:
408+
full_response = self.llm.completion(
409+
messages=parse_messages_tags(messages),
410+
tools=tools,
411+
tool_choice=tool_choice,
412+
temperature=0.00000001,
413+
response_format=response_format,
414+
stream=False,
415+
drop_params=True,
416+
)
417+
perf_timing.measure("llm.completion")
418+
419+
# catch a known error that occurs with Azure and replace the error message with something more obvious to the user
420+
except BadRequestError as e:
421+
if "Unrecognized request arguments supplied: tool_choice, tools" in str(
422+
e
423+
):
424+
yield json.dumps(
425+
{
426+
"type": "error",
427+
"details": {
428+
"msg": "The Azure model you chose is not supported. Model version 1106 and higher required."
429+
},
430+
}
431+
)
432+
return
433+
raise
434+
except Exception:
435+
raise
436+
437+
response_message = full_response.choices[0].message
438+
tools_to_call = getattr(response_message, "tool_calls", None)
439+
if not tools_to_call:
440+
(text_response, _) = process_response_into_sections(
441+
response_message.content
442+
)
443+
yield json.dumps(
444+
{"type": "ai_answer", "details": {"answer": text_response}}
445+
)
446+
if runbooks:
447+
yield json.dumps(
448+
{
449+
"type": "instructions",
450+
"details": {"instructions": json.dumps(runbooks)},
451+
}
452+
)
453+
return
454+
455+
messages.append(
456+
response_message.model_dump(
457+
exclude_defaults=True, exclude_unset=True, exclude_none=True
458+
)
459+
)
460+
461+
perf_timing.measure("pre-tool-calls")
462+
with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor:
463+
futures = []
464+
for t in tools_to_call:
465+
futures.append(executor.submit(self._invoke_tool, t))
466+
yield json.dumps(
467+
{
468+
"type": "start_tool_calling",
469+
"details": {"tool_name": t.function.name, "id": t.id},
470+
}
471+
)
472+
473+
for future in concurrent.futures.as_completed(futures):
474+
tool_call_result: ToolCallResult = future.result()
475+
tool_calls.append(tool_call_result)
476+
tool_call_dict = tool_call_result.as_dict()
477+
messages.append(tool_call_dict)
478+
perf_timing.measure(f"tool completed {tool_call_result.tool_name}")
479+
yield json.dumps(
480+
{"type": "tool_calling_result", "details": tool_call_dict}
481+
)
482+
360483

361484
# TODO: consider getting rid of this entirely and moving templating into the cmds in holmes.py
362485
class IssueInvestigator(ToolCallingLLM):

server.py

+19
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from litellm.exceptions import AuthenticationError
2222
from fastapi import FastAPI, HTTPException, Request
23+
from fastapi.responses import StreamingResponse
2324
from holmes.utils.robusta import load_robusta_api_key
2425

2526
from holmes.common.env_vars import (
@@ -145,6 +146,24 @@ def investigate_issues(investigate_request: InvestigateRequest):
145146
raise HTTPException(status_code=500, detail=str(e))
146147

147148

149+
@app.post("/api/stream/investigate")
150+
def stream_investigate_issues(req: InvestigateRequest):
151+
ai, system_prompt, user_prompt, response_format, sections, runbooks = (
152+
investigation.get_investigation_context(req, dal, config=config)
153+
)
154+
155+
try:
156+
return StreamingResponse(
157+
ai.call_stream(system_prompt, user_prompt, response_format, runbooks),
158+
media_type="text/event-stream",
159+
)
160+
except AuthenticationError as e:
161+
raise HTTPException(status_code=401, detail=e.message)
162+
except Exception as e:
163+
logging.exception(f"Error in /api/stream/investigate: {e}")
164+
raise HTTPException(status_code=500, detail=str(e))
165+
166+
148167
@app.post("/api/workload_health_check")
149168
def workload_health_check(request: WorkloadHealthRequest):
150169
load_robusta_api_key(dal=dal, config=config)

0 commit comments

Comments
 (0)