diff --git a/internal/handlers/llm_mcp_bridge.go b/internal/handlers/llm_mcp_bridge.go index b1d2407e..b543d977 100644 --- a/internal/handlers/llm_mcp_bridge.go +++ b/internal/handlers/llm_mcp_bridge.go @@ -120,11 +120,11 @@ func NewLLMMCPBridgeFromClientsWithLogLevel(mcpClients interface{}, stdLogger *l // ProcessLLMResponse processes an LLM response, expecting a specific JSON tool call format. // It no longer uses natural language detection. -func (b *LLMMCPBridge) ProcessLLMResponse(ctx context.Context, llmResponse, _ string) (string, error) { +func (b *LLMMCPBridge) ProcessLLMResponse(ctx context.Context, llmResponse, _ string, userProfile interface{}) (string, error) { // Check for a tool call in JSON format if toolCall := b.detectSpecificJSONToolCall(llmResponse); toolCall != nil { // Execute the tool call - result, err := b.executeToolCall(ctx, toolCall) + result, err := b.executeToolCall(ctx, toolCall, userProfile) if err != nil { // Check if it's already a domain error var errorMessage string @@ -306,7 +306,7 @@ func (b *LLMMCPBridge) getClientForTool(toolName string) MCPClientInterface { } // executeToolCall executes a detected tool call (using the new ToolCall struct) -func (b *LLMMCPBridge) executeToolCall(ctx context.Context, toolCall *ToolCall) (string, error) { +func (b *LLMMCPBridge) executeToolCall(ctx context.Context, toolCall *ToolCall, userProfile interface{}) (string, error) { client := b.getClientForTool(toolCall.Tool) if client == nil { b.logger.ErrorKV("No MCP client available", "tool", toolCall.Tool) @@ -319,6 +319,11 @@ func (b *LLMMCPBridge) executeToolCall(ctx context.Context, toolCall *ToolCall) "server", serverName, "args", fmt.Sprintf("%v", toolCall.Args)) + if userProfile != nil { + toolCall.Args["user_profile"] = userProfile + b.logger.DebugKV("Added user profile to tool call", "tool", toolCall.Tool) + } + // Call the tool directly with the tool name and args result, err := client.CallTool(ctx, toolCall.Tool, toolCall.Args) if err != nil { diff --git a/internal/slack/client.go b/internal/slack/client.go index f73d5438..6f85a94b 100644 --- a/internal/slack/client.go +++ b/internal/slack/client.go @@ -212,7 +212,7 @@ func (c *Client) handleEventMessage(event slackevents.EventsAPIEvent) { // Add to message history c.addToHistory(ev.Channel, "user", messageText) // Use handleUserPrompt for app mentions too, for consistency - go c.handleUserPrompt(strings.TrimSpace(messageText), ev.Channel, ev.TimeStamp) + go c.handleUserPrompt(strings.TrimSpace(messageText), ev.Channel, ev.TimeStamp, ev.User) case *slackevents.MessageEvent: isDirectMessage := strings.HasPrefix(ev.Channel, "D") @@ -224,7 +224,7 @@ func (c *Client) handleEventMessage(event slackevents.EventsAPIEvent) { c.logger.InfoKV("Received direct message in channel", "channel", ev.Channel, "user", ev.User, "text", ev.Text) // Add to message history c.addToHistory(ev.Channel, "user", ev.Text) - go c.handleUserPrompt(ev.Text, ev.Channel, ev.ThreadTimeStamp) // Use goroutine to avoid blocking event loop + go c.handleUserPrompt(ev.Text, ev.Channel, ev.ThreadTimeStamp, ev.User) // Use goroutine to avoid blocking event loop } default: @@ -294,7 +294,7 @@ func (c *Client) getContextFromHistory(channelID string) string { } // handleUserPrompt sends the user's text to the configured LLM provider. -func (c *Client) handleUserPrompt(userPrompt, channelID, threadTS string) { +func (c *Client) handleUserPrompt(userPrompt, channelID, threadTS string, userID string) { // Determine the provider to use from config providerName := c.cfg.LLMProvider // Get the primary provider name from config c.logger.DebugKV("Routing prompt via configured provider", "provider", providerName) @@ -321,7 +321,7 @@ func (c *Client) handleUserPrompt(userPrompt, channelID, threadTS string) { c.logger.InfoKV("Received response from LLM", "provider", providerName, "length", len(llmResponse)) // Process the LLM response through the MCP pipeline - c.processLLMResponseAndReply(llmResponse, userPrompt, channelID, threadTS) + c.processLLMResponseAndReply(llmResponse, userPrompt, channelID, threadTS, userID) } // generateToolPrompt generates the prompt string for available tools @@ -437,7 +437,7 @@ func (c *Client) callLLM(providerName, prompt, contextHistory string) (string, e // processLLMResponseAndReply processes the LLM response, handles tool results with re-prompting, and sends the final reply. // Incorporates logic previously in LLMClient.ProcessToolResponse. -func (c *Client) processLLMResponseAndReply(llmResponse, userPrompt, channelID, threadTS string) { +func (c *Client) processLLMResponseAndReply(llmResponse, userPrompt, channelID, threadTS string, userID string) { // Log the raw LLM response for debugging c.logger.DebugKV("Raw LLM response", "response", truncateForLog(llmResponse, 500)) @@ -445,6 +445,17 @@ func (c *Client) processLLMResponseAndReply(llmResponse, userPrompt, channelID, ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute) defer cancel() + var userProfile *slack.UserProfile + if userID != "" { + user, err := c.api.GetUserInfoContext(ctx, userID) + if err != nil { + c.logger.WarnKV("Failed to get user profile", "error", err, "user_id", userID) + } else { + userProfile = &user.Profile + c.logger.DebugKV("Retrieved user profile", "user_id", userID, "email", userProfile.Email) + } + } + // --- Process Tool Response (Logic from LLMClient.ProcessToolResponse) --- var finalResponse string var isToolResult bool @@ -458,7 +469,7 @@ func (c *Client) processLLMResponseAndReply(llmResponse, userPrompt, channelID, c.logger.Warn("LLMMCPBridge is nil, skipping tool processing") } else { // Process the response through the bridge - processedResponse, err := c.llmMCPBridge.ProcessLLMResponse(ctx, llmResponse, userPrompt) + processedResponse, err := c.llmMCPBridge.ProcessLLMResponse(ctx, llmResponse, userPrompt, userProfile) if err != nil { finalResponse = fmt.Sprintf("Sorry, I encountered an error while trying to use a tool: %v", err) isToolResult = false