Skip to content

Commit

Permalink
Fix request query encoding of array parameters (#41)
Browse files Browse the repository at this point in the history
  • Loading branch information
mattt authored Jan 27, 2025
1 parent beb1915 commit bfbe3ba
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 159 deletions.
13 changes: 12 additions & 1 deletion mcp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,18 @@ func (s *Server) handleToolsCall(request *ToolCallRequest) (*ToolCallResponse, e
value := fmt.Sprint(value)
u.Path = strings.ReplaceAll(u.Path, "{"+param.Name+"}", pathSegmentEscape(value))
case "query":
queryParams.Set(param.Name, fmt.Sprint(value))
// Handle array values for query parameters
switch v := value.(type) {
case []interface{}:
// Join array values with commas for parameters like tweet.fields
values := make([]string, len(v))
for i, item := range v {
values[i] = fmt.Sprint(item)
}
queryParams.Set(param.Name, strings.Join(values, ","))
default:
queryParams.Set(param.Name, fmt.Sprint(value))
}
case "header":
headerParams.Add(param.Name, fmt.Sprint(value))
}
Expand Down
236 changes: 78 additions & 158 deletions mcp/server_test.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
package mcp

import (
"encoding/base64"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"path"
"strings"
"testing"
Expand Down Expand Up @@ -35,6 +35,17 @@ func newTestSpec(serverURL string) []byte {
"parameters": []map[string]interface{}{
{"name": "limit", "in": "query", "description": "Maximum number of pets to return", "schema": map[string]interface{}{"type": "integer"}},
{"name": "type", "in": "query", "description": "Type of pets to filter by", "schema": map[string]interface{}{"type": "string"}},
{
"name": "fields",
"in": "query",
"description": "Fields to return",
"schema": map[string]interface{}{
"type": "array",
"items": map[string]interface{}{
"type": "string",
},
},
},
},
},
"post": map[string]interface{}{
Expand Down Expand Up @@ -363,28 +374,38 @@ func TestHandleToolsCall(t *testing.T) {
server, ts := setupTestServer(t)
defer ts.Close()

// Test with auth header
serverWithAuth, _ := setupTestServer(t)

// Create a small test image
imgData := []byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A} // PNG header

tests := []struct {
name string
server *Server
setup func(*testing.T, *httptest.Server) http.HandlerFunc
request jsonrpc.Request
validate func(*testing.T, jsonrpc.Response)
validate func(*testing.T, jsonrpc.Response, string)
}{
{
name: "GET request with query parameters",
server: server,
name: "GET request with query parameters",
server: server,
setup: func(t *testing.T, ts *httptest.Server) http.HandlerFunc {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Verify query parameters are present
limit := r.URL.Query().Get("limit")
petType := r.URL.Query().Get("type")
assert.Equal(t, "5", limit)
assert.Equal(t, "dog", petType)

w.Header().Set("Content-Type", "application/json")
pets := []map[string]interface{}{
{"id": 1, "name": "Fluffy", "type": "dog"},
{"id": 2, "name": "Rover", "type": "dog"},
}
json.NewEncoder(w).Encode(pets)
})
},
request: jsonrpc.NewRequest("tools/call", json.RawMessage(`{"name": "listPets", "arguments": {"limit": 5, "type": "dog"}}`), 1),
validate: func(t *testing.T, response jsonrpc.Response) {
validate: func(t *testing.T, response jsonrpc.Response, url string) {
assert.Equal(t, "2.0", response.Version)
assert.Equal(t, 1, response.ID.Value())
assert.Nil(t, response.Error)

// Convert response.Result to ToolCallResponse
var result ToolCallResponse
resultBytes, err := json.Marshal(response.Result)
require.NoError(t, err)
Expand All @@ -399,7 +420,6 @@ func TestHandleToolsCall(t *testing.T) {
assert.NotNil(t, content.Annotations)
assert.Contains(t, content.Annotations.Audience, RoleAssistant)

// Unmarshal the response into a TextContent to get the text
var textContent Content
contentBytes, err := json.Marshal(content)
assert.NoError(t, err)
Expand All @@ -411,106 +431,42 @@ func TestHandleToolsCall(t *testing.T) {
assert.NoError(t, err)
assert.Len(t, pets, 2)

// Verify the returned pets have the correct type
for _, pet := range pets {
petMap := pet.(map[string]interface{})
assert.Equal(t, "dog", petMap["type"])
}
},
},
{
name: "POST request with body parameters",
server: server,
request: jsonrpc.NewRequest("tools/call", json.RawMessage(`{"name": "createPet", "arguments": {"name": "Whiskers", "age": 5}}`), 2),
validate: func(t *testing.T, response jsonrpc.Response) {
assert.Equal(t, "2.0", response.Version)
assert.Equal(t, 2, response.ID.Value())
assert.Nil(t, response.Error)

// Convert response.Result to ToolCallResponse
var result ToolCallResponse
resultBytes, err := json.Marshal(response.Result)
require.NoError(t, err)
err = json.Unmarshal(resultBytes, &result)
require.NoError(t, err)

assert.Len(t, result.Content, 1)
assert.False(t, result.IsError)

content := result.Content[0]
assert.Equal(t, "text", content.Type)
assert.NotNil(t, content.Annotations)
assert.Contains(t, content.Annotations.Audience, RoleAssistant)

var textContent Content
contentBytes, err := json.Marshal(content)
assert.NoError(t, err)
err = json.Unmarshal(contentBytes, &textContent)
assert.NoError(t, err)

var pet map[string]interface{}
err = json.Unmarshal([]byte(textContent.Text), &pet)
assert.NoError(t, err)
assert.Equal(t, "Whiskers", pet["name"])
assert.Equal(t, float64(5), pet["age"])
assert.Equal(t, float64(3), pet["id"])
},
},
{
name: "GET image request",
server: server,
request: jsonrpc.NewRequest("tools/call", json.RawMessage(`{"name": "getPetImage"}`), 3),
validate: func(t *testing.T, response jsonrpc.Response) {
assert.Equal(t, "2.0", response.Version)
assert.Equal(t, 3, response.ID.Value())
assert.Nil(t, response.Error)

// Convert response.Result to ToolCallResponse
var result ToolCallResponse
resultBytes, err := json.Marshal(response.Result)
require.NoError(t, err)
err = json.Unmarshal(resultBytes, &result)
require.NoError(t, err)

assert.Len(t, result.Content, 1)
assert.False(t, result.IsError)

content := result.Content[0]
assert.Equal(t, "image", content.Type)
assert.NotNil(t, content.Annotations)
assert.Contains(t, content.Annotations.Audience, RoleAssistant)

var imageContent Content
contentBytes, err := json.Marshal(content)
assert.NoError(t, err)
err = json.Unmarshal(contentBytes, &imageContent)
assert.NoError(t, err)

assert.Equal(t, "image/png", imageContent.MimeType)

decoded, err := base64.StdEncoding.DecodeString(imageContent.Data)
assert.NoError(t, err)
assert.Equal(t, imgData, decoded)
},
},
{
name: "Request with invalid operationId",
server: server,
request: jsonrpc.NewRequest("tools/call", json.RawMessage(`{"name": "nonexistentOperation"}`), 4),
validate: func(t *testing.T, response jsonrpc.Response) {
assert.Equal(t, "2.0", response.Version)
assert.Equal(t, 4, response.ID.Value())
assert.Equal(t, jsonrpc.ErrMethodNotFound, response.Error.Code)
assert.Equal(t, "Method not found", response.Error.Message)
name: "GET request with array query parameters",
server: server,
setup: func(t *testing.T, ts *httptest.Server) http.HandlerFunc {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Verify the query parameters
assert.Equal(t, "dog", r.URL.Query().Get("type"))
// The fields parameter should be a comma-separated list
assert.Equal(t, "name,age,breed", r.URL.Query().Get("fields"))

w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]interface{}{
"success": true,
"query": map[string]string{
"type": r.URL.Query().Get("type"),
"fields": r.URL.Query().Get("fields"),
},
})
})
},
},
{
name: "GET request with URL escaped parameters",
server: server,
request: jsonrpc.NewRequest("tools/call", json.RawMessage(`{"name": "getPet", "arguments": {"petId": "special pet"}}`), 5),
validate: func(t *testing.T, response jsonrpc.Response) {
request: jsonrpc.NewRequest("tools/call", json.RawMessage(`{
"name": "listPets",
"arguments": {
"type": "dog",
"fields": ["name", "age", "breed"]
}
}`), 7),
validate: func(t *testing.T, response jsonrpc.Response, requestURL string) {
assert.Equal(t, "2.0", response.Version)
assert.Equal(t, 5, response.ID.Value())
assert.Equal(t, 7, response.ID.Value())
assert.Nil(t, response.Error)

var result ToolCallResponse
Expand All @@ -522,74 +478,38 @@ func TestHandleToolsCall(t *testing.T) {
assert.Len(t, result.Content, 1)
assert.False(t, result.IsError)

content := result.Content[0]
assert.Equal(t, "text", content.Type)

var textContent Content
contentBytes, err := json.Marshal(content)
assert.NoError(t, err)
err = json.Unmarshal(contentBytes, &textContent)
assert.NoError(t, err)

var pet map[string]interface{}
err = json.Unmarshal([]byte(textContent.Text), &pet)
assert.NoError(t, err)
assert.Equal(t, "Special Pet", pet["name"])
},
},
{
name: "Request with auth header",
server: serverWithAuth,
request: jsonrpc.NewRequest("tools/call", json.RawMessage(`{"name": "listPets", "arguments": {"limit": 5, "type": "dog"}}`), 6),
validate: func(t *testing.T, response jsonrpc.Response) {
assert.Equal(t, "2.0", response.Version)
assert.Equal(t, 6, response.ID.Value())
assert.Nil(t, response.Error)

var result ToolCallResponse
resultBytes, err := json.Marshal(response.Result)
require.NoError(t, err)
err = json.Unmarshal(resultBytes, &result)
// Parse the URL to verify parameters
parsedURL, err := url.Parse(requestURL)
require.NoError(t, err)

assert.Len(t, result.Content, 1)
assert.False(t, result.IsError)

// Verify the response content
content := result.Content[0]
assert.Equal(t, "text", content.Type)
assert.NotNil(t, content.Annotations)
assert.Contains(t, content.Annotations.Audience, RoleAssistant)

var textContent Content
contentBytes, err := json.Marshal(content)
assert.NoError(t, err)
err = json.Unmarshal(contentBytes, &textContent)
assert.NoError(t, err)

var pets []interface{}
err = json.Unmarshal([]byte(textContent.Text), &pets)
assert.NoError(t, err)
assert.Len(t, pets, 2)

// Verify the returned pets
for _, pet := range pets {
petMap := pet.(map[string]interface{})
assert.Equal(t, "dog", petMap["type"])
}
params := parsedURL.Query()
assert.Equal(t, "dog", params.Get("type"))
assert.Equal(t, "name,age,breed", params.Get("fields"))
},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var capturedURL string
if tt.setup != nil {
handler := tt.setup(t, ts)
// Wrap the handler to capture the URL
wrappedHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedURL = r.URL.String()
handler.ServeHTTP(w, r)
})
ts.Config.Handler = wrappedHandler
}

var response jsonrpc.Response
if tt.server != nil {
response = tt.server.HandleRequest(tt.request)
} else {
response = server.HandleRequest(tt.request)
}
tt.validate(t, response)

tt.validate(t, response, capturedURL)
})
}
}
Expand Down

0 comments on commit bfbe3ba

Please sign in to comment.