diff --git a/pkg/api/v1/router_instance_ec2_metadata.go b/pkg/api/v1/router_instance_ec2_metadata.go index 7325eab..3a42c68 100644 --- a/pkg/api/v1/router_instance_ec2_metadata.go +++ b/pkg/api/v1/router_instance_ec2_metadata.go @@ -91,9 +91,10 @@ func (r *Router) instanceEc2MetadataItemGet(c *gin.Context) { } if subPath, ok := c.Params.Get("subpath"); ok { - // If subPath is empty, we're just hitting the EC2 endpoint with a trailing - // slash, so perform the same operation as in instanceEc2MetadataGet() - if subPath == "" { + // If subPath is only a fwd slash, we're just hitting the EC2 endpoint + // with a trailing slash, so return the ItemNames as we would in + // instanceEc2MetadataGet() + if subPath == "/" { c.String(http.StatusOK, strings.Join(metadata.ItemNames(), "\n")) return } diff --git a/pkg/api/v1/router_instance_ec2_metadata_test.go b/pkg/api/v1/router_instance_ec2_metadata_test.go index 7fc51ac..c328052 100644 --- a/pkg/api/v1/router_instance_ec2_metadata_test.go +++ b/pkg/api/v1/router_instance_ec2_metadata_test.go @@ -14,6 +14,20 @@ import ( v1api "go.hollow.sh/metadataservice/pkg/api/v1" ) +// GetEc2MetadataItemPathWithoutTrim is used to test routing edge cases where +// the trailing '/' is kept +func getEc2MetadataItemPathWithoutTrim(itemPath string) string { + fullpath := v1api.GetEc2MetadataItemPath(itemPath) + + // GetEc2MetadataItemPath() calls path.Join(), which strips trailing slashes. + // So restore a trailing slash if itemPath came with one + if itemPath != "" && itemPath[len(itemPath)-1:] == "/" { + fullpath += "/" + } + + return fullpath +} + func TestGetEc2MetadataByIP(t *testing.T) { router := *testHTTPServer(t) @@ -405,4 +419,25 @@ func TestGetEc2MetadataItemByIP(t *testing.T) { } }) } + + t.Run("check routing works with trailing slash in the url", func(t *testing.T) { + w := httptest.NewRecorder() + + standardFields := "instance-id\nhostname\niqn\nplan\nfacility\ntags\noperating-system\npublic-keys" + + itemName := "/" + instanceIP := "139.178.82.3" + expectedStatus := http.StatusOK + expectedBody := fmt.Sprintf("%s\npublic-ipv4\npublic-ipv6\nlocal-ipv4", standardFields) + + req, _ := http.NewRequestWithContext(context.TODO(), http.MethodGet, getEc2MetadataItemPathWithoutTrim(itemName), nil) + req.RemoteAddr = net.JoinHostPort(instanceIP, "0") + router.ServeHTTP(w, req) + + assert.Equal(t, expectedStatus, w.Code) + + if expectedStatus == http.StatusOK { + assert.Equal(t, expectedBody, w.Body.String()) + } + }) }