diff --git a/.env.example b/.env.example index 5df79b1..fcf7805 100644 --- a/.env.example +++ b/.env.example @@ -5,6 +5,7 @@ LANGSMITH_API_KEY= LANGSMITH_PROJECT= LANGSMITH_TRACING= +GROQ_API_KEY= GENSEE_API_KEY= # Only necessary for Open Agent Platform diff --git a/src/open_deep_research/configuration.py b/src/open_deep_research/configuration.py index 288302a..4b59bea 100644 --- a/src/open_deep_research/configuration.py +++ b/src/open_deep_research/configuration.py @@ -2,7 +2,7 @@ import os from enum import Enum -from typing import Any, List, Optional +from typing import Any, List, Optional, Literal from langchain_core.runnables import RunnableConfig from pydantic import BaseModel, Field @@ -75,6 +75,34 @@ class Configuration(BaseModel): } } ) + # Parallel Supervisors Configuration + max_concurrent_supervisors: int = Field( + default=3, + metadata={ + "x_oap_ui_config": { + "type": "slider", + "default": 3, + "min": 1, + "max": 10, + "step": 1, + "description": "Maximum number of supervisor subgraphs to run in parallel after the research brief." + } + } + ) + parallel_supervisor_strategy: Literal["early_stop", "aggregate"] = Field( + default="early_stop", + metadata={ + "x_oap_ui_config": { + "type": "select", + "default": "early_stop", + "description": "Whether parallel supervisors receive the identical brief or variants of it.", + "options": [ + {"label": "Early Stop", "value": "early_stop"}, + {"label": "Aggregate", "value": "aggregate"} + ] + } + } + ) # Research Configuration search_api: SearchAPI = Field( default=SearchAPI.TAVILY, diff --git a/src/open_deep_research/deep_researcher.py b/src/open_deep_research/deep_researcher.py index 279dbff..e3a3650 100644 --- a/src/open_deep_research/deep_researcher.py +++ b/src/open_deep_research/deep_researcher.py @@ -1,6 +1,7 @@ """Main LangGraph implementation for the Deep Research agent.""" import asyncio +import json from typing import Literal from langchain.chat_models import init_chat_model @@ -86,12 +87,19 @@ async def clarify_with_user(state: AgentState, config: RunnableConfig) -> Comman } # Configure model with structured output and retry logic - clarification_model = ( - configurable_model - .with_structured_output(ClarifyWithUser) - .with_retry(stop_after_attempt=configurable.max_structured_output_retries) - .with_config(model_config) - ) + if "groq" in configurable.research_model: + clarification_model = ( + configurable_model + .with_retry(stop_after_attempt=configurable.max_structured_output_retries) + .with_config(model_config) + ) + else: + clarification_model = ( + configurable_model + .with_structured_output(ClarifyWithUser) + .with_retry(stop_after_attempt=configurable.max_structured_output_retries) + .with_config(model_config) + ) # Step 3: Analyze whether clarification is needed prompt_content = clarify_with_user_instructions.format( @@ -99,6 +107,9 @@ async def clarify_with_user(state: AgentState, config: RunnableConfig) -> Comman date=get_today_str() ) response = await clarification_model.ainvoke([HumanMessage(content=prompt_content)]) + if "groq" in configurable.research_model: + response = json.loads(response.content) + response = ClarifyWithUser(**response) # Step 4: Route based on clarification analysis if response.need_clarification: @@ -145,6 +156,19 @@ async def write_research_brief(state: AgentState, config: RunnableConfig) -> Com .with_retry(stop_after_attempt=configurable.max_structured_output_retries) .with_config(research_model_config) ) + if "groq" in configurable.research_model: + research_model = ( + configurable_model + .with_retry(stop_after_attempt=configurable.max_structured_output_retries) + .with_config(research_model_config) + ) + else: + research_model = ( + configurable_model + .with_structured_output(ResearchQuestion) + .with_retry(stop_after_attempt=configurable.max_structured_output_retries) + .with_config(research_model_config) + ) # Step 2: Generate structured research brief from user messages prompt_content = transform_messages_into_research_topic_prompt.format( @@ -152,7 +176,9 @@ async def write_research_brief(state: AgentState, config: RunnableConfig) -> Com date=get_today_str() ) response = await research_model.ainvoke([HumanMessage(content=prompt_content)]) - + if "groq" in configurable.research_model: + response = ResearchQuestion(research_brief=response.content) + # Step 3: Initialize supervisor with research brief and instructions supervisor_system_prompt = lead_researcher_prompt.format( date=get_today_str(), @@ -362,6 +388,75 @@ async def supervisor_tools(state: SupervisorState, config: RunnableConfig) -> Co # Compile supervisor subgraph for use in main workflow supervisor_subgraph = supervisor_builder.compile() +# ----------------------------- +# Parallel Supervisors Orchestrator +# ----------------------------- +async def multiple_supervisors(state: AgentState, config: RunnableConfig) -> Command[Literal["final_report_generation"]]: + """Spawn multiple supervisor subgraphs in parallel and aggregate OR early stop the process. + + Each supervisor receives the identical research brief. + Notes(aggregated if chosen) are forwarded to final report. + """ + configurable = Configuration.from_runnable_config(config) + research_brief = state.get("research_brief", "") + + # Determine number of supervisors to launch (apply a hard cap for safety) + hard_cap = 10 + num_supervisors = max(1, min(configurable.max_concurrent_supervisors, hard_cap)) + + briefs_for_supervisors = [research_brief for _ in range(num_supervisors)] + + # Prepare supervisor system prompt once + supervisor_system_prompt = lead_researcher_prompt.format( + date=get_today_str(), + max_concurrent_research_units=configurable.max_concurrent_research_units, + max_researcher_iterations=configurable.max_researcher_iterations, + ) + + # Launch all supervisors in parallel + tasks = [ + asyncio.create_task( + supervisor_subgraph.ainvoke( + { + "supervisor_messages": { + "type": "override", + "value": [ + SystemMessage(content=supervisor_system_prompt), + HumanMessage(content=brief), + ], + }, + "research_brief": research_brief, + }, + config, + ) + ) + for brief in briefs_for_supervisors + ] + + if configurable.parallel_supervisor_strategy == "early_stop": + done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) + winner_result = await next(iter(done)) + # Cancel remaining tasks + for t in pending: + t.cancel() + # Optionally, drain cancellations to avoid warnings + await asyncio.gather(*pending, return_exceptions=True) + + elif configurable.parallel_supervisor_strategy == "aggregate": + results = await asyncio.gather(*tasks) + final_result = { + "notes": [result.get("notes", []) for result in results], + } + + return Command( + goto="final_report_generation", + update={ + "notes": final_result.get("notes", []), + "research_brief": research_brief, + }, + ) + + async def researcher(state: ResearcherState, config: RunnableConfig) -> Command[Literal["researcher_tools"]]: """Individual researcher that conducts focused research on specific topics. @@ -707,7 +802,7 @@ async def final_report_generation(state: AgentState, config: RunnableConfig): # Add main workflow nodes for the complete research process deep_researcher_builder.add_node("clarify_with_user", clarify_with_user) # User clarification phase deep_researcher_builder.add_node("write_research_brief", write_research_brief) # Research planning phase -deep_researcher_builder.add_node("research_supervisor", supervisor_subgraph) # Research execution phase +deep_researcher_builder.add_node("research_supervisor", multiple_supervisors) # Research execution phase deep_researcher_builder.add_node("final_report_generation", final_report_generation) # Report generation phase # Define main workflow edges for sequential execution diff --git a/src/open_deep_research/utils.py b/src/open_deep_research/utils.py index cd551d0..08daffc 100644 --- a/src/open_deep_research/utils.py +++ b/src/open_deep_research/utils.py @@ -1032,6 +1032,8 @@ def get_api_key_for_model(model_name: str, config: RunnableConfig): return api_keys.get("ANTHROPIC_API_KEY") elif model_name.startswith("google"): return api_keys.get("GOOGLE_API_KEY") + elif model_name.startswith("groq"): + return api_keys.get("GROQ_API_KEY") return None else: if model_name.startswith("openai:"): @@ -1040,6 +1042,8 @@ def get_api_key_for_model(model_name: str, config: RunnableConfig): return os.getenv("ANTHROPIC_API_KEY") elif model_name.startswith("google"): return os.getenv("GOOGLE_API_KEY") + elif model_name.startswith("groq"): + return os.getenv("GROQ_API_KEY") return None def get_tavily_api_key(config: RunnableConfig):