Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
.idea
65 changes: 65 additions & 0 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,54 @@ func (c *Client) CallTool(ctx context.Context, callToolReq *CallToolRequest) (*C
return parseCallToolResult(rawResp)
}

func (c *Client) CallToolAsync(ctx context.Context, callToolReq *CallToolRequest) (requestID int64, resultChan <-chan *CallToolResult, errChan <-chan error) {
resultCh := make(chan *CallToolResult, 1)
errCh := make(chan error, 1)

if !c.initialized {
errCh <- errors.ErrNotInitialized
return
}

requestID = c.requestID.Add(1)
req := &JSONRPCRequest{
JSONRPC: JSONRPCVersion,
ID: requestID,
Request: Request{
Method: MethodToolsCall,
},
Params: callToolReq.Params,
}

go func() {
rawResp, err := c.transport.sendRequest(ctx, req)
if err != nil {
errCh <- fmt.Errorf("tool call request failed: %w", err)
return
}

if isErrorResponse(rawResp) {
errResp, err := parseRawMessageToError(rawResp)
if err != nil {
errCh <- fmt.Errorf("failed to parse error response: %w", err)
}
errCh <- fmt.Errorf("tool call error: %s (code: %d)",
errResp.Error.Message, errResp.Error.Code)
return
}

result, err := parseCallToolResult(rawResp)
if err != nil {
errCh <- fmt.Errorf("failed to parse tool call result: %w", err)
return
}

resultCh <- result
}()

return requestID, resultCh, errCh
}

// Close closes the client connection and cleans up resources.
func (c *Client) Close() error {
if c.transport != nil {
Expand Down Expand Up @@ -606,6 +654,23 @@ func (c *Client) ReadResource(ctx context.Context, readResourceReq *ReadResource
return parseReadResourceResultFromJSON(rawResp)
}

func (c *Client) CancelRequest(ctx context.Context, requestID int64, reason string) error {
notification := newJSONRPCNotification(Notification{
Method: MethodNotificationsCancelled,
Params: NotificationParams{
AdditionalFields: map[string]interface{}{
"requestID": requestID,
"reason": reason,
},
},
})
return c.transport.sendNotification(ctx, notification)
}

func isZeroStruct(x interface{}) bool {
return reflect.ValueOf(x).IsZero()
}

func (c *Client) GetTransport() httpTransport {
return c.transport
}
6 changes: 3 additions & 3 deletions e2e/scenarios/basic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import (
// Includes connection, initialization, tool listing, tool calling, and disconnection.
func TestBasicWorkflow(t *testing.T) {
// Set up test server.
serverURL, cleanup := e2e.StartTestServer(t, e2e.WithTestTools())
serverURL, cleanup, _ := e2e.StartTestServer(t, e2e.WithTestTools())
defer cleanup()

// Create test client.
Expand Down Expand Up @@ -53,7 +53,7 @@ func TestBasicWorkflow(t *testing.T) {
// TestErrorHandling tests error handling logic.
func TestErrorHandling(t *testing.T) {
// Set up test server.
serverURL, cleanup := e2e.StartTestServer(t, e2e.WithTestTools())
serverURL, cleanup, _ := e2e.StartTestServer(t, e2e.WithTestTools())
defer cleanup()

// Create test client.
Expand Down Expand Up @@ -89,7 +89,7 @@ func TestErrorHandling(t *testing.T) {
// TestMultipleCalls tests multiple consecutive calls.
func TestMultipleCalls(t *testing.T) {
// Set up test server.
serverURL, cleanup := e2e.StartTestServer(t, e2e.WithTestTools())
serverURL, cleanup, _ := e2e.StartTestServer(t, e2e.WithTestTools())
defer cleanup()

// Create test client.
Expand Down
8 changes: 4 additions & 4 deletions e2e/scenarios/session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import (
// TestSessionCreation tests session creation.
func TestSessionCreation(t *testing.T) {
// Set up test server.
serverURL, cleanup := e2e.StartTestServer(t, e2e.WithTestTools())
serverURL, cleanup, _ := e2e.StartTestServer(t, e2e.WithTestTools())
defer cleanup()

// Create test client.
Expand All @@ -39,7 +39,7 @@ func TestSessionCreation(t *testing.T) {
// TestSessionTermination tests session termination.
func TestSessionTermination(t *testing.T) {
// Set up test server.
serverURL, cleanup := e2e.StartTestServer(t, e2e.WithTestTools())
serverURL, cleanup, _ := e2e.StartTestServer(t, e2e.WithTestTools())
defer cleanup()

// Create test client.
Expand All @@ -64,7 +64,7 @@ func TestSessionTermination(t *testing.T) {
// TestMultipleSessions tests multiple sessions.
func TestMultipleSessions(t *testing.T) {
// Set up test server.
serverURL, cleanup := e2e.StartTestServer(t, e2e.WithTestTools())
serverURL, cleanup, _ := e2e.StartTestServer(t, e2e.WithTestTools())
defer cleanup()

// Create and initialize first client.
Expand All @@ -90,7 +90,7 @@ func TestMultipleSessions(t *testing.T) {
// TestSessionReconnect tests reconnection after disconnection.
func TestSessionReconnect(t *testing.T) {
// Set up test server.
serverURL, cleanup := e2e.StartTestServer(t, e2e.WithTestTools())
serverURL, cleanup, _ := e2e.StartTestServer(t, e2e.WithTestTools())
defer cleanup()

// Create and initialize client.
Expand Down
44 changes: 41 additions & 3 deletions e2e/scenarios/streaming_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import (
// TestStreamingContent tests streaming content transmission from server to client.
func TestStreamingContent(t *testing.T) {
// Set up test server.
serverURL, cleanup := e2e.StartTestServer(t, e2e.WithTestTools())
serverURL, cleanup, _ := e2e.StartTestServer(t, e2e.WithTestTools())
defer cleanup()

// Create test client.
Expand Down Expand Up @@ -58,7 +58,7 @@ func TestStreamingContent(t *testing.T) {
// TestDelayedResponse tests delayed response.
func TestDelayedResponse(t *testing.T) {
// Set up test server.
serverURL, cleanup := e2e.StartTestServer(t, e2e.WithTestTools())
serverURL, cleanup, _ := e2e.StartTestServer(t, e2e.WithTestTools())
defer cleanup()

// Create test client.
Expand Down Expand Up @@ -104,7 +104,7 @@ func TestDelayedResponse(t *testing.T) {
// TestContextCancellation tests context cancellation.
func TestContextCancellation(t *testing.T) {
// Set up test server.
serverURL, cleanup := e2e.StartTestServer(t, e2e.WithTestTools())
serverURL, cleanup, _ := e2e.StartTestServer(t, e2e.WithTestTools())
defer cleanup()

// Create test client.
Expand Down Expand Up @@ -134,3 +134,41 @@ func TestContextCancellation(t *testing.T) {
assert.Contains(t, err.Error(), "context", "error should be related to context cancellation")
})
}

func TestRequestCancellation(t *testing.T) {
serverURL, cleanup, _ := e2e.StartTestServer(t, e2e.WithTestTools())
defer cleanup()

client := e2e.CreateTestClient(t, serverURL)
defer e2e.CleanupClient(t, client)

e2e.InitializeClient(t, client)

t.Run("CancelRequest", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

callToolReq := &mcp.CallToolRequest{}
callToolReq.Params.Name = "delay-tool"
callToolReq.Params.Arguments = map[string]interface{}{
"delay_ms": 2000,
"message": "This message should not be received",
}

requestID, resultCh, errCh := client.CallToolAsync(ctx, callToolReq)

time.Sleep(200 * time.Millisecond)

client.CancelRequest(ctx, requestID, "test cancel")

select {
case err := <-errCh:
require.Error(t, err)
assert.Contains(t, err.Error(), "cancel")
case <-resultCh:
t.Fatal("should not receive result")
case <-time.After(3 * time.Second):
t.Fatal("timeout waiting for result or error")
}
})
}
4 changes: 2 additions & 2 deletions e2e/server_helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func WithServerOptions(opts ...mcp.ServerOption) ServerOption {
}

// StartTestServer starts a test server and returns its URL and cleanup function.
func StartTestServer(t *testing.T, opts ...ServerOption) (string, func()) {
func StartTestServer(t *testing.T, opts ...ServerOption) (string, func(), *mcp.Server) {
t.Helper()

// Create basic server config.
Expand Down Expand Up @@ -77,7 +77,7 @@ func StartTestServer(t *testing.T, opts ...ServerOption) (string, func()) {
httpServer.Close()
}

return serverURL, cleanup
return serverURL, cleanup, s
}

// StartSSETestServer starts a test server with SSE enabled, returns its URL and cleanup function.
Expand Down
27 changes: 27 additions & 0 deletions handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package mcp

import (
"context"
"fmt"
)

const (
Expand Down Expand Up @@ -39,6 +40,8 @@ type mcpHandler struct {

// Prompt manager
promptManager *promptManager

cancelRequestFunc func(id interface{},reason interface{}) error
}

// newMCPHandler creates an MCP protocol handler
Expand Down Expand Up @@ -192,6 +195,8 @@ func (h *mcpHandler) handleNotification(ctx context.Context, notification *JSONR
switch notification.Method {
case MethodNotificationsInitialized:
return h.lifecycleManager.handleInitialized(ctx, notification, session)
case MethodNotificationsCancelled:
return h.handleNotificationCancel(ctx, notification, session)
default:
// Ignore unknown notifications
return nil
Expand All @@ -203,3 +208,25 @@ func (h *mcpHandler) onSessionTerminated(sessionID string) {
// Notify lifecycle manager that session has terminated
h.lifecycleManager.onSessionTerminated(sessionID)
}

func (h *mcpHandler) setCancelRequestFunc(cancelRequestFunc func(id interface{},reason interface{}) error) {
h.cancelRequestFunc = cancelRequestFunc
}

func (h *mcpHandler) handleNotificationCancel(ctx context.Context, notification *JSONRPCNotification, session Session) error {
requestID,ok := notification.Params.AdditionalFields["requestID"]
if !ok {
return fmt.Errorf("requestID not found in params")
}
reason,ok := notification.Params.AdditionalFields["reason"]
if !ok {
return fmt.Errorf("reason not found in params")
}
if initialRequestID,ok := session.GetData("initialRequestID");ok && initialRequestID == requestID {
return nil
}
if h.cancelRequestFunc != nil {
return h.cancelRequestFunc(requestID,reason)
}
return nil
}
1 change: 1 addition & 0 deletions jsonrpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ const (
ErrCodeMethodNotFound = -32601
ErrCodeInvalidParams = -32602
ErrCodeInternal = -32603
ErrCodeRequestCancelled = -32000

// MCP custom error code range: -32000 to -32099
)
Expand Down
18 changes: 18 additions & 0 deletions manager_prompt.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,12 @@ func (m *promptManager) getPrompts() []*Prompt {

// handleListPrompts handles listing prompts requests
func (m *promptManager) handleListPrompts(ctx context.Context, req *JSONRPCRequest) (JSONRPCMessage, error) {
select {
case <-ctx.Done():
return newJSONRPCErrorResponse(req.ID, ErrCodeRequestCancelled, "request cancelled", nil), ctx.Err()
default:
}

prompts := m.getPrompts()

// Convert []*mcp.Prompt to []mcp.Prompt for the result
Expand Down Expand Up @@ -156,6 +162,12 @@ func buildPromptMessages(prompt *Prompt, arguments map[string]interface{}) []Pro

// Refactored: handleGetPrompt with logic unchanged, now using helpers
func (m *promptManager) handleGetPrompt(ctx context.Context, req *JSONRPCRequest) (JSONRPCMessage, error) {
select {
case <-ctx.Done():
return newJSONRPCErrorResponse(req.ID, ErrCodeRequestCancelled, "request cancelled", nil), ctx.Err()
default:
}

name, arguments, errResp, ok := parseGetPromptParams(req)
if !ok {
return errResp, nil
Expand Down Expand Up @@ -232,6 +244,12 @@ func parseCompletionCompleteParams(req *JSONRPCRequest) (promptName string, errR

// Refactored: handleCompletionComplete with logic unchanged, now using helpers
func (m *promptManager) handleCompletionComplete(ctx context.Context, req *JSONRPCRequest) (JSONRPCMessage, error) {
select {
case <-ctx.Done():
return newJSONRPCErrorResponse(req.ID, ErrCodeRequestCancelled, "request cancelled", nil), ctx.Err()
default:
}

promptName, errResp, ok := parseCompletionCompleteParams(req)
if !ok {
return errResp, nil
Expand Down
30 changes: 30 additions & 0 deletions manager_resource.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,12 @@ func (m *resourceManager) notifyUpdate(uri string) {

// handleListResources handles listing resources requests
func (m *resourceManager) handleListResources(ctx context.Context, req *JSONRPCRequest) (JSONRPCMessage, error) {
select {
case <-ctx.Done():
return newJSONRPCErrorResponse(req.ID, ErrCodeRequestCancelled, "request cancelled", nil), ctx.Err()
default:
}

resources := m.getResources()

// Convert []*mcp.Resource to []mcp.Resource for the result
Expand All @@ -223,6 +229,12 @@ func (m *resourceManager) handleListResources(ctx context.Context, req *JSONRPCR

// handleReadResource handles reading resource requests
func (m *resourceManager) handleReadResource(ctx context.Context, req *JSONRPCRequest) (JSONRPCMessage, error) {
select {
case <-ctx.Done():
return newJSONRPCErrorResponse(req.ID, ErrCodeRequestCancelled, "request cancelled", nil), ctx.Err()
default:
}

// Convert params to map for easier access
paramsMap, ok := req.Params.(map[string]interface{})
if !ok {
Expand Down Expand Up @@ -272,6 +284,12 @@ func (m *resourceManager) handleReadResource(ctx context.Context, req *JSONRPCRe

// handleListTemplates handles listing templates requests
func (m *resourceManager) handleListTemplates(ctx context.Context, req *JSONRPCRequest) (JSONRPCMessage, error) {
select {
case <-ctx.Done():
return newJSONRPCErrorResponse(req.ID, ErrCodeRequestCancelled, "request cancelled", nil), ctx.Err()
default:
}

templates := m.getTemplates()

// Convert []*mcp.ResourceTemplate to []mcp.ResourceTemplate for the result
Expand All @@ -290,6 +308,12 @@ func (m *resourceManager) handleListTemplates(ctx context.Context, req *JSONRPCR

// handleSubscribe handles subscription requests
func (m *resourceManager) handleSubscribe(ctx context.Context, req *JSONRPCRequest) (JSONRPCMessage, error) {
select {
case <-ctx.Done():
return newJSONRPCErrorResponse(req.ID, ErrCodeRequestCancelled, "request cancelled", nil), ctx.Err()
default:
}

// Convert params to map for easier access
paramsMap, ok := req.Params.(map[string]interface{})
if !ok {
Expand Down Expand Up @@ -322,6 +346,12 @@ func (m *resourceManager) handleSubscribe(ctx context.Context, req *JSONRPCReque

// handleUnsubscribe handles unsubscription requests
func (m *resourceManager) handleUnsubscribe(ctx context.Context, req *JSONRPCRequest) (JSONRPCMessage, error) {
select {
case <-ctx.Done():
return newJSONRPCErrorResponse(req.ID, ErrCodeRequestCancelled, "request cancelled", nil), ctx.Err()
default:
}

// Convert params to map for easier access
paramsMap, ok := req.Params.(map[string]interface{})
if !ok {
Expand Down
Loading