diff --git a/internal/slack/agentCallbackHandler.go b/internal/slack/agentCallbackHandler.go index f51615b2..dcdcbe24 100644 --- a/internal/slack/agentCallbackHandler.go +++ b/internal/slack/agentCallbackHandler.go @@ -2,7 +2,13 @@ package slackbot import ( "context" + "encoding/json" "github.com/tmc/langchaingo/callbacks" + "github.com/tuannvm/slack-mcp-client/internal/common/logging" + "regexp" + "strings" + + "github.com/slack-go/slack" ) type sendMessageFunc func(message string) @@ -10,12 +16,48 @@ type sendMessageFunc func(message string) type agentCallbackHandler struct { callbacks.SimpleHandler sendMessage sendMessageFunc + logger *logging.Logger } func (handler *agentCallbackHandler) HandleChainEnd(_ context.Context, outputs map[string]any) { if text, ok := outputs["text"]; ok { if textStr, ok := text.(string); ok { + if isThinkingMessage(textStr) { + textStr = formatContextMessageBlock(textStr, handler.logger) + } else { + textStr = formatFinalResponse(textStr) + } handler.sendMessage(textStr) } } } + +var thinkingPattern = regexp.MustCompile(`Do I need to use a tool\? Yes`) + +func isThinkingMessage(msg string) bool { + return thinkingPattern.MatchString(msg) +} + +// formatFinalResponse removes LLM agent response prefixes. +// The agent response format is defined in internal/llm/langchain.go +// > Thought: Do I need to use a tool? No +// > AI: [your response here] +func formatFinalResponse(msg string) string { + msg = strings.Replace(msg, "Do I need to use a tool? No", "", 1) + msg = strings.Replace(msg, "AI:", "", 1) + return strings.TrimSpace(msg) +} + +func formatContextMessageBlock(message string, logger *logging.Logger) string { + mrkdwnBlock := slack.NewTextBlockObject("mrkdwn", message, false, false) + contextBlock := slack.NewContextBlock("", []slack.MixedElement{mrkdwnBlock}...) + blockMessage := slack.NewBlockMessage(contextBlock) + + jsonByte, err := json.Marshal(blockMessage) + if err != nil { + // Fallback to plain message if marshaling fails + logger.ErrorKV("Failed to marshal block message", "error", err) + return message + } + return string(jsonByte) +} diff --git a/internal/slack/agentCallbackHandler_test.go b/internal/slack/agentCallbackHandler_test.go new file mode 100644 index 00000000..fba02b52 --- /dev/null +++ b/internal/slack/agentCallbackHandler_test.go @@ -0,0 +1,99 @@ +package slackbot + +import ( + "encoding/json" + + "github.com/slack-go/slack" + "testing" +) + +func TestIsThinkingMessage(t *testing.T) { + tests := []struct { + name string + input string + expected bool + }{ + { + name: "Thinking message", + input: ` + Do I need to use a tool? Yes + thinking... + `, + expected: true, + }, + { + name: "Not thinking message", + input: ` + Do I need to use a tool? No + AI: Here is the final response. + `, + expected: false, + }, + // Add more test cases as needed + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := isThinkingMessage(tt.input) + if result != tt.expected { + t.Errorf("isThinkingMessage() = %v, want %v", result, tt.expected) + } + }) + } +} + +func TestFormatFinalResponse(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + { + name: "Final response formatting", + input: `Do I need to use a tool? No + AI: Here is the final response.`, + expected: `Here is the final response.`, + }, + { + name: "Fallback final response formatting", + input: `This is final response without prefixes.`, + expected: `This is final response without prefixes.`, + }, + // Add more test cases as needed + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := formatFinalResponse(tt.input) + if result != tt.expected { + t.Errorf("isThinkingMessage() = \"%s\", want \"%s\"", result, tt.expected) + } + }) + } +} + +func TestFormatContextMessageBlock(t *testing.T) { + tests := []struct { + name string + input string + }{ + { + name: "Simple text context", + input: "Here is the final response.", + }, + // Add more test cases as needed + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := formatContextMessageBlock(tt.input, nil) + + var contextBlock struct { + Elements slack.ContextElements `json:"elements"` + } + if err := json.Unmarshal([]byte(result), &contextBlock); err != nil { + t.Errorf("Failed to unmarshal block message JSON: %v", err) + } + }) + } +} diff --git a/internal/slack/client.go b/internal/slack/client.go index e3ccbe00..354fa695 100644 --- a/internal/slack/client.go +++ b/internal/slack/client.go @@ -538,6 +538,7 @@ func (c *Client) handleUserPrompt(userPrompt, channelID, threadTS string, timest &agentCallbackHandler{ callbacks.SimpleHandler{}, sendMsg, + c.logger, }) duration := time.Since(startTime) diff --git a/internal/slack/formatter/formatter.go b/internal/slack/formatter/formatter.go index 7b15d6f5..ee92290f 100644 --- a/internal/slack/formatter/formatter.go +++ b/internal/slack/formatter/formatter.go @@ -113,9 +113,12 @@ func FormatMessage(text string, options FormatOptions) []slack.MsgOption { case "divider": slackBlock = slack.NewDividerBlock() case "context": - var context slack.ContextBlock - if err := json.Unmarshal(blockJSON, &context); err == nil { - slackBlock = context + var contextBlock struct { + Elements slack.ContextElements `json:"elements"` + } + + if err := json.Unmarshal(blockJSON, &contextBlock); err == nil { + slackBlock = slack.NewContextBlock("", contextBlock.Elements.Elements...) } // Add more block types as needed }