From ae74836def651a97ac5bcacf23fc89811c817696 Mon Sep 17 00:00:00 2001 From: AdityaK011 Date: Tue, 14 Apr 2026 00:10:16 +0900 Subject: [PATCH 1/3] Rename to k8scope, add 14 new tools, dynamic client, graceful shutdown, and error handling --- .gitignore | 3 +- Dockerfile | 6 +- README.md | 24 +- cmd/server/main.go | 10 +- go.mod | 5 +- go.sum | 4 + internal/k8s/client.go | 19 + internal/tools/tools.go | 49 +- internal/tools/tools_extended.go | 765 +++++++++++++++++++++++++++++++ 9 files changed, 848 insertions(+), 37 deletions(-) create mode 100644 internal/tools/tools_extended.go diff --git a/.gitignore b/.gitignore index b6aace4..bf00b5d 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ .claude/ -.DS_Store \ No newline at end of file +.DS_Store +.mcp.json \ No newline at end of file diff --git a/Dockerfile b/Dockerfile index f2edcf3..67de326 100644 --- a/Dockerfile +++ b/Dockerfile @@ -3,9 +3,9 @@ WORKDIR /src COPY go.mod go.sum ./ RUN go mod download COPY . . -RUN CGO_ENABLED=0 go build -o /kubelens ./cmd/server +RUN CGO_ENABLED=0 go build -o /k8scope ./cmd/server FROM gcr.io/distroless/static-debian12:nonroot -COPY --from=build /kubelens /kubelens +COPY --from=build /k8scope /k8scope EXPOSE 8080 -ENTRYPOINT ["/kubelens"] +ENTRYPOINT ["/k8scope"] diff --git a/README.md b/README.md index 62e22de..f3c3884 100644 --- a/README.md +++ b/README.md @@ -1,14 +1,14 @@ -# KubeLens +# k8scope A hosted MCP server that lets AI assistants (Claude Code, Cursor, etc.) interact with your GKE clusters using **your own Google identity**. No shared service accounts, no manual token passing — you log in once via browser and the server handles everything. ## How it works -1. You connect Claude Code to the KubeLens server URL +1. You connect Claude Code to the k8scope server URL 2. First time, a browser opens → you log in with Google -3. KubeLens stores your tokens server-side and issues a session ID +3. k8scope stores your tokens server-side and issues a session ID 4. Claude Code sends the session ID on every MCP request -5. KubeLens uses your Google access token to call the GKE API +5. k8scope uses your Google access token to call the GKE API 6. All K8s operations run as **your IAM identity** with your RBAC permissions ## Prerequisites @@ -38,7 +38,7 @@ go run ./cmd/server ## Connect from Claude Code ```bash -claude mcp add --transport http kubelens http://localhost:8080/mcp +claude mcp add --transport http k8scope http://localhost:8080/mcp ``` Then use it: @@ -53,13 +53,13 @@ Then use it: ```bash # Build and push -docker build -t gcr.io/YOUR_PROJECT/kubelens . -docker push gcr.io/YOUR_PROJECT/kubelens +docker build -t gcr.io/YOUR_PROJECT/k8scope . +docker push gcr.io/YOUR_PROJECT/k8scope # Deploy -gcloud run deploy kubelens \ - --image gcr.io/YOUR_PROJECT/kubelens \ - --set-env-vars "GOOGLE_CLIENT_ID=xxx,GOOGLE_CLIENT_SECRET=xxx,REDIRECT_URL=https://kubelens-xxx.run.app/callback" \ +gcloud run deploy k8scope \ + --image gcr.io/YOUR_PROJECT/k8scope \ + --set-env-vars "GOOGLE_CLIENT_ID=xxx,GOOGLE_CLIENT_SECRET=xxx,REDIRECT_URL=https://k8scope-xxx.run.app/callback" \ --allow-unauthenticated \ --port 8080 ``` @@ -80,7 +80,7 @@ Update the OAuth client's redirect URI to match the Cloud Run URL. ## Architecture ``` -Claude Code ──Bearer: session_id──▶ KubeLens MCP Server ──Bearer: ya29.xxx──▶ GKE API Server +Claude Code ──Bearer: session_id──▶ k8scope MCP Server ──Bearer: ya29.xxx──▶ GKE API Server │ ├── OAuth flow (one-time) ├── Session store (in-memory) @@ -90,7 +90,7 @@ Claude Code ──Bearer: session_id──▶ KubeLens MCP Server ──Bearer: ## Project structure ``` -kubelens/ +k8scope/ ├── cmd/server/main.go # Entrypoint, wires OAuth + MCP ├── internal/ │ ├── auth/ diff --git a/cmd/server/main.go b/cmd/server/main.go index a02d90a..67d8982 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -12,15 +12,15 @@ import ( "github.com/mark3labs/mcp-go/server" - "github.com/AdityaK011/kubelens/internal/auth" - "github.com/AdityaK011/kubelens/internal/tools" + "github.com/AdityaK011/k8scope/internal/auth" + "github.com/AdityaK011/k8scope/internal/tools" ) func main() { // Required env vars. clientID := mustEnv("GOOGLE_CLIENT_ID") clientSecret := mustEnv("GOOGLE_CLIENT_SECRET") - redirectURL := mustEnv("REDIRECT_URL") // e.g. https://kubelens.example.com/callback + redirectURL := mustEnv("REDIRECT_URL") // e.g. https://k8scope.example.com/callback port := getEnv("PORT", "8080") // Init OAuth handler. @@ -31,7 +31,7 @@ func main() { // Init MCP server. mcpServer := server.NewMCPServer( - "kubelens", + "k8scope", "0.1.0", server.WithToolCapabilities(true), ) @@ -57,7 +57,7 @@ func main() { mux.Handle("/mcp", oauth.Middleware(mcpHTTP)) addr := fmt.Sprintf(":%s", port) - slog.Info("kubelens MCP server starting", + slog.Info("k8scope MCP server starting", "port", port, "redirect_url", redirectURL, ) diff --git a/go.mod b/go.mod index deaedbe..840b2c0 100644 --- a/go.mod +++ b/go.mod @@ -1,4 +1,4 @@ -module github.com/AdityaK011/kubelens +module github.com/AdityaK011/k8scope go 1.25.0 @@ -39,6 +39,7 @@ require ( github.com/spf13/cast v1.7.1 // indirect github.com/yosida95/uritemplate/v3 v3.0.2 // indirect go.opencensus.io v0.24.0 // indirect + go.yaml.in/yaml/v2 v2.4.2 // indirect golang.org/x/crypto v0.21.0 // indirect golang.org/x/net v0.22.0 // indirect golang.org/x/sync v0.20.0 // indirect @@ -58,5 +59,5 @@ require ( k8s.io/utils v0.0.0-20230726121419-3b25d923346b // indirect sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd // indirect sigs.k8s.io/structured-merge-diff/v4 v4.4.1 // indirect - sigs.k8s.io/yaml v1.3.0 // indirect + sigs.k8s.io/yaml v1.6.0 // indirect ) diff --git a/go.sum b/go.sum index c3935a4..5fd620a 100644 --- a/go.sum +++ b/go.sum @@ -149,6 +149,8 @@ github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5t go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0= go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo= go.opentelemetry.io/proto/otlp v0.7.0/go.mod h1:PqfVotwruBrMGOCsRd/89rSnXhoiJIqeYNgFYFoEGnI= +go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI= +go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= @@ -306,3 +308,5 @@ sigs.k8s.io/structured-merge-diff/v4 v4.4.1 h1:150L+0vs/8DA78h1u02ooW1/fFq/Lwr+s sigs.k8s.io/structured-merge-diff/v4 v4.4.1/go.mod h1:N8hJocpFajUSSeSJ9bOZ77VzejKZaXsTtZo4/u7Io08= sigs.k8s.io/yaml v1.3.0 h1:a2VclLzOGrwOHDiV8EfBGhvjHvP46CtW5j6POvhYGGo= sigs.k8s.io/yaml v1.3.0/go.mod h1:GeOyir5tyXNByN85N/dRIT9es5UQNerPYEKK56eTBm8= +sigs.k8s.io/yaml v1.6.0 h1:G8fkbMSAFqgEFgh4b1wmtzDnioxFCUgTZhlbj5P9QYs= +sigs.k8s.io/yaml v1.6.0/go.mod h1:796bPqUfzR/0jLAl6XjHl3Ck7MiyVv8dbTdyT3/pMf4= diff --git a/internal/k8s/client.go b/internal/k8s/client.go index 9527d7c..1c9e38c 100644 --- a/internal/k8s/client.go +++ b/internal/k8s/client.go @@ -10,6 +10,7 @@ import ( container "google.golang.org/api/container/v1" "google.golang.org/api/option" "golang.org/x/oauth2" + "k8s.io/client-go/dynamic" "k8s.io/client-go/kubernetes" "k8s.io/client-go/rest" ) @@ -57,6 +58,24 @@ func NewClientForUser(ctx context.Context, accessToken string, cluster ClusterIn return kubernetes.NewForConfig(config) } +// NewDynamicClientForUser builds a dynamic Kubernetes client for CRDs and generic resources. +func NewDynamicClientForUser(ctx context.Context, accessToken string, cluster ClusterInfo) (dynamic.Interface, error) { + endpoint, ca, err := getCachedClusterDetails(ctx, accessToken, cluster) + if err != nil { + return nil, fmt.Errorf("failed to get cluster details: %w", err) + } + + config := &rest.Config{ + Host: "https://" + endpoint, + BearerToken: accessToken, + TLSClientConfig: rest.TLSClientConfig{ + CAData: ca, + }, + } + + return dynamic.NewForConfig(config) +} + // getCachedClusterDetails returns the cluster endpoint and CA from cache, // or fetches from GKE API and caches for 10 minutes. func getCachedClusterDetails(ctx context.Context, accessToken string, c ClusterInfo) (string, []byte, error) { diff --git a/internal/tools/tools.go b/internal/tools/tools.go index 1e94117..6348a2f 100644 --- a/internal/tools/tools.go +++ b/internal/tools/tools.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "io" + "log/slog" "regexp" "sort" "strings" @@ -14,8 +15,8 @@ import ( corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "github.com/AdityaK011/kubelens/internal/auth" - k8sClient "github.com/AdityaK011/kubelens/internal/k8s" + "github.com/AdityaK011/k8scope/internal/auth" + k8sClient "github.com/AdityaK011/k8scope/internal/k8s" ) // Register adds all MCP tools to the server. @@ -26,6 +27,20 @@ func Register(s *server.MCPServer) { s.AddTool(getPodLogsTool(), handleGetPodLogs) s.AddTool(getEventsTool(), handleGetEvents) s.AddTool(getNodesTool(), handleGetNodes) + s.AddTool(listNamespacesTool(), handleListNamespaces) + s.AddTool(listDeploymentsTool(), handleListDeployments) + s.AddTool(describeDeploymentTool(), handleDescribeDeployment) + s.AddTool(listServicesTool(), handleListServices) + s.AddTool(listIngressesTool(), handleListIngresses) + s.AddTool(listJobsTool(), handleListJobs) + s.AddTool(listHPATool(), handleListHPA) + s.AddTool(listPVCsTool(), handleListPVCs) + s.AddTool(listConfigMapsTool(), handleListConfigMaps) + s.AddTool(listStatefulSetsTool(), handleListStatefulSets) + s.AddTool(listDaemonSetsTool(), handleListDaemonSets) + s.AddTool(listCRDsTool(), handleListCRDs) + s.AddTool(getCRDInstancesTool(), handleGetCRDInstances) + s.AddTool(getResourceYAMLTool(), handleGetResourceYAML) } // --- Tool definitions --- @@ -138,6 +153,12 @@ func errResult(format string, a ...interface{}) (*mcp.CallToolResult, error) { return mcp.NewToolResultError(fmt.Sprintf(format, a...)), nil } +// safeErr logs the error server-side and returns it to the client. +func safeErr(clientMsg string, err error) (*mcp.CallToolResult, error) { + slog.Error(clientMsg, "error", err) + return mcp.NewToolResultError(fmt.Sprintf("%s: %v", clientMsg, err)), nil +} + // --- Handlers --- func handleListClusters(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { @@ -155,7 +176,7 @@ func handleListClusters(ctx context.Context, req mcp.CallToolRequest) (*mcp.Call clusters, err := k8sClient.ListClusters(ctx, session.AccessToken, project) if err != nil { - return errResult("failed to list clusters: %v", err) + return safeErr("failed to list clusters", err) } var sb strings.Builder @@ -188,13 +209,13 @@ func handleListPods(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallTool } client, err := k8sClient.NewClientForUser(ctx, session.AccessToken, ci) if err != nil { - return errResult("auth/connect failed: %v", err) + return safeErr("failed to connect to cluster", err) } const podLimit = 500 pods, err := client.CoreV1().Pods(namespace).List(ctx, metav1.ListOptions{Limit: podLimit}) if err != nil { - return errResult("k8s error: %v", err) + return safeErr("kubernetes API error", err) } var sb strings.Builder @@ -247,12 +268,12 @@ func handleDescribePod(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallT } client, err := k8sClient.NewClientForUser(ctx, session.AccessToken, ci) if err != nil { - return errResult("auth/connect failed: %v", err) + return safeErr("failed to connect to cluster", err) } pod, err := client.CoreV1().Pods(namespace).Get(ctx, podName, metav1.GetOptions{}) if err != nil { - return errResult("k8s error: %v", err) + return safeErr("kubernetes API error", err) } var sb strings.Builder @@ -328,7 +349,7 @@ func handleGetPodLogs(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallTo } client, err := k8sClient.NewClientForUser(ctx, session.AccessToken, ci) if err != nil { - return errResult("auth/connect failed: %v", err) + return safeErr("failed to connect to cluster", err) } opts := &corev1.PodLogOptions{ @@ -340,14 +361,14 @@ func handleGetPodLogs(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallTo stream, err := client.CoreV1().Pods(namespace).GetLogs(podName, opts).Stream(ctx) if err != nil { - return errResult("failed to get logs: %v", err) + return safeErr("failed to get logs", err) } defer stream.Close() const maxLogBytes = 1 << 20 // 1 MB logs, err := io.ReadAll(io.LimitReader(stream, maxLogBytes)) if err != nil { - return errResult("failed to read log stream: %v", err) + return safeErr("failed to read log stream", err) } if len(logs) == 0 { @@ -373,12 +394,12 @@ func handleGetEvents(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToo } client, err := k8sClient.NewClientForUser(ctx, session.AccessToken, ci) if err != nil { - return errResult("auth/connect failed: %v", err) + return safeErr("failed to connect to cluster", err) } events, err := client.CoreV1().Events(namespace).List(ctx, metav1.ListOptions{Limit: 200}) if err != nil { - return errResult("k8s error: %v", err) + return safeErr("kubernetes API error", err) } // Sort by last timestamp (most recent first) — the API doesn't guarantee order. @@ -428,13 +449,13 @@ func handleGetNodes(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallTool } client, err := k8sClient.NewClientForUser(ctx, session.AccessToken, ci) if err != nil { - return errResult("auth/connect failed: %v", err) + return safeErr("failed to connect to cluster", err) } const nodeLimit = 500 nodes, err := client.CoreV1().Nodes().List(ctx, metav1.ListOptions{Limit: nodeLimit}) if err != nil { - return errResult("k8s error: %v", err) + return safeErr("kubernetes API error", err) } var sb strings.Builder diff --git a/internal/tools/tools_extended.go b/internal/tools/tools_extended.go new file mode 100644 index 0000000..ae2ef55 --- /dev/null +++ b/internal/tools/tools_extended.go @@ -0,0 +1,765 @@ +package tools + +import ( + "context" + "fmt" + "strings" + "time" + + "github.com/mark3labs/mcp-go/mcp" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" + "k8s.io/apimachinery/pkg/runtime/schema" + sigsyaml "sigs.k8s.io/yaml" + + "github.com/AdityaK011/k8scope/internal/auth" + k8sClient "github.com/AdityaK011/k8scope/internal/k8s" +) + +// --- Tool definitions --- + +func listNamespacesTool() mcp.Tool { + return mcp.NewTool("list_namespaces", + mcp.WithDescription("List all namespaces in a GKE cluster"), + mcp.WithString("project", mcp.Required(), mcp.Description("GCP project ID")), + mcp.WithString("location", mcp.Required(), mcp.Description("Cluster region/zone")), + mcp.WithString("cluster", mcp.Required(), mcp.Description("Cluster name")), + ) +} + +func listDeploymentsTool() mcp.Tool { + return mcp.NewTool("list_deployments", + mcp.WithDescription("List deployments with replica status"), + mcp.WithString("project", mcp.Required(), mcp.Description("GCP project ID")), + mcp.WithString("location", mcp.Required(), mcp.Description("Cluster region/zone")), + mcp.WithString("cluster", mcp.Required(), mcp.Description("Cluster name")), + mcp.WithString("namespace", mcp.Description("K8s namespace. Omit for all.")), + ) +} + +func describeDeploymentTool() mcp.Tool { + return mcp.NewTool("describe_deployment", + mcp.WithDescription("Get detailed status, conditions, and containers for a deployment"), + mcp.WithString("project", mcp.Required(), mcp.Description("GCP project ID")), + mcp.WithString("location", mcp.Required(), mcp.Description("Cluster region/zone")), + mcp.WithString("cluster", mcp.Required(), mcp.Description("Cluster name")), + mcp.WithString("namespace", mcp.Required(), mcp.Description("Deployment namespace")), + mcp.WithString("deployment", mcp.Required(), mcp.Description("Deployment name")), + ) +} + +func listServicesTool() mcp.Tool { + return mcp.NewTool("list_services", + mcp.WithDescription("List services with type, cluster IP, and ports"), + mcp.WithString("project", mcp.Required(), mcp.Description("GCP project ID")), + mcp.WithString("location", mcp.Required(), mcp.Description("Cluster region/zone")), + mcp.WithString("cluster", mcp.Required(), mcp.Description("Cluster name")), + mcp.WithString("namespace", mcp.Description("K8s namespace. Omit for all.")), + ) +} + +func listIngressesTool() mcp.Tool { + return mcp.NewTool("list_ingresses", + mcp.WithDescription("List ingresses with hosts and paths"), + mcp.WithString("project", mcp.Required(), mcp.Description("GCP project ID")), + mcp.WithString("location", mcp.Required(), mcp.Description("Cluster region/zone")), + mcp.WithString("cluster", mcp.Required(), mcp.Description("Cluster name")), + mcp.WithString("namespace", mcp.Description("K8s namespace. Omit for all.")), + ) +} + +func listJobsTool() mcp.Tool { + return mcp.NewTool("list_jobs", + mcp.WithDescription("List jobs with completion status"), + mcp.WithString("project", mcp.Required(), mcp.Description("GCP project ID")), + mcp.WithString("location", mcp.Required(), mcp.Description("Cluster region/zone")), + mcp.WithString("cluster", mcp.Required(), mcp.Description("Cluster name")), + mcp.WithString("namespace", mcp.Description("K8s namespace. Omit for all.")), + ) +} + +func listHPATool() mcp.Tool { + return mcp.NewTool("list_hpa", + mcp.WithDescription("List horizontal pod autoscalers with scaling targets"), + mcp.WithString("project", mcp.Required(), mcp.Description("GCP project ID")), + mcp.WithString("location", mcp.Required(), mcp.Description("Cluster region/zone")), + mcp.WithString("cluster", mcp.Required(), mcp.Description("Cluster name")), + mcp.WithString("namespace", mcp.Description("K8s namespace. Omit for all.")), + ) +} + +func listPVCsTool() mcp.Tool { + return mcp.NewTool("list_pvcs", + mcp.WithDescription("List persistent volume claims with status and capacity"), + mcp.WithString("project", mcp.Required(), mcp.Description("GCP project ID")), + mcp.WithString("location", mcp.Required(), mcp.Description("Cluster region/zone")), + mcp.WithString("cluster", mcp.Required(), mcp.Description("Cluster name")), + mcp.WithString("namespace", mcp.Description("K8s namespace. Omit for all.")), + ) +} + +func listConfigMapsTool() mcp.Tool { + return mcp.NewTool("list_configmaps", + mcp.WithDescription("List config maps with key counts"), + mcp.WithString("project", mcp.Required(), mcp.Description("GCP project ID")), + mcp.WithString("location", mcp.Required(), mcp.Description("Cluster region/zone")), + mcp.WithString("cluster", mcp.Required(), mcp.Description("Cluster name")), + mcp.WithString("namespace", mcp.Description("K8s namespace. Omit for all.")), + ) +} + +func listStatefulSetsTool() mcp.Tool { + return mcp.NewTool("list_statefulsets", + mcp.WithDescription("List stateful sets with replica status"), + mcp.WithString("project", mcp.Required(), mcp.Description("GCP project ID")), + mcp.WithString("location", mcp.Required(), mcp.Description("Cluster region/zone")), + mcp.WithString("cluster", mcp.Required(), mcp.Description("Cluster name")), + mcp.WithString("namespace", mcp.Description("K8s namespace. Omit for all.")), + ) +} + +func listDaemonSetsTool() mcp.Tool { + return mcp.NewTool("list_daemonsets", + mcp.WithDescription("List daemon sets with scheduling status"), + mcp.WithString("project", mcp.Required(), mcp.Description("GCP project ID")), + mcp.WithString("location", mcp.Required(), mcp.Description("Cluster region/zone")), + mcp.WithString("cluster", mcp.Required(), mcp.Description("Cluster name")), + mcp.WithString("namespace", mcp.Description("K8s namespace. Omit for all.")), + ) +} + +func listCRDsTool() mcp.Tool { + return mcp.NewTool("list_crds", + mcp.WithDescription("List custom resource definitions installed in the cluster"), + mcp.WithString("project", mcp.Required(), mcp.Description("GCP project ID")), + mcp.WithString("location", mcp.Required(), mcp.Description("Cluster region/zone")), + mcp.WithString("cluster", mcp.Required(), mcp.Description("Cluster name")), + ) +} + +func getCRDInstancesTool() mcp.Tool { + return mcp.NewTool("get_crd_instances", + mcp.WithDescription("List instances of a custom resource by group/version/resource"), + mcp.WithString("project", mcp.Required(), mcp.Description("GCP project ID")), + mcp.WithString("location", mcp.Required(), mcp.Description("Cluster region/zone")), + mcp.WithString("cluster", mcp.Required(), mcp.Description("Cluster name")), + mcp.WithString("group", mcp.Required(), mcp.Description("API group, e.g. networking.istio.io")), + mcp.WithString("version", mcp.Required(), mcp.Description("API version, e.g. v1beta1")), + mcp.WithString("resource", mcp.Required(), mcp.Description("Resource plural, e.g. virtualservices")), + mcp.WithString("namespace", mcp.Description("K8s namespace. Omit for cluster-scoped.")), + ) +} + +func getResourceYAMLTool() mcp.Tool { + return mcp.NewTool("get_resource_yaml", + mcp.WithDescription("Get the full YAML of any Kubernetes resource"), + mcp.WithString("project", mcp.Required(), mcp.Description("GCP project ID")), + mcp.WithString("location", mcp.Required(), mcp.Description("Cluster region/zone")), + mcp.WithString("cluster", mcp.Required(), mcp.Description("Cluster name")), + mcp.WithString("api_version", mcp.Required(), mcp.Description("API version, e.g. apps/v1, v1, networking.k8s.io/v1")), + mcp.WithString("kind", mcp.Required(), mcp.Description("Resource kind, e.g. Deployment, Service, Pod")), + mcp.WithString("name", mcp.Required(), mcp.Description("Resource name")), + mcp.WithString("namespace", mcp.Description("Namespace. Omit for cluster-scoped resources.")), + ) +} + +// --- Handlers --- + +func handleListNamespaces(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + session, err := auth.SessionFromContext(ctx) + if err != nil { + return errResult("auth: %v", err) + } + ci, err := clusterInfo(req.GetArguments()) + if err != nil { + return errResult("validation: %v", err) + } + client, err := k8sClient.NewClientForUser(ctx, session.AccessToken, ci) + if err != nil { + return safeErr("failed to connect to cluster", err) + } + + nsList, err := client.CoreV1().Namespaces().List(ctx, metav1.ListOptions{Limit: 500}) + if err != nil { + return safeErr("kubernetes API error", err) + } + + var sb strings.Builder + sb.WriteString(fmt.Sprintf("Found %d namespaces:\n\n", len(nsList.Items))) + sb.WriteString(fmt.Sprintf("%-40s %-10s %s\n", "NAME", "STATUS", "AGE")) + sb.WriteString(strings.Repeat("-", 65) + "\n") + for _, ns := range nsList.Items { + age := time.Since(ns.CreationTimestamp.Time).Truncate(time.Minute) + sb.WriteString(fmt.Sprintf("%-40s %-10s %s\n", ns.Name, ns.Status.Phase, age)) + } + return mcp.NewToolResultText(sb.String()), nil +} + +func handleListDeployments(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + session, err := auth.SessionFromContext(ctx) + if err != nil { + return errResult("auth: %v", err) + } + args := req.GetArguments() + namespace := str(args, "namespace") + ci, err := clusterInfo(args) + if err != nil { + return errResult("validation: %v", err) + } + client, err := k8sClient.NewClientForUser(ctx, session.AccessToken, ci) + if err != nil { + return safeErr("failed to connect to cluster", err) + } + + deps, err := client.AppsV1().Deployments(namespace).List(ctx, metav1.ListOptions{Limit: 500}) + if err != nil { + return safeErr("kubernetes API error", err) + } + + var sb strings.Builder + sb.WriteString(fmt.Sprintf("Found %d deployments:\n\n", len(deps.Items))) + sb.WriteString(fmt.Sprintf("%-45s %-10s %-12s %-10s %s\n", "NAMESPACE/NAME", "READY", "UP-TO-DATE", "AVAILABLE", "AGE")) + sb.WriteString(strings.Repeat("-", 100) + "\n") + for _, d := range deps.Items { + age := time.Since(d.CreationTimestamp.Time).Truncate(time.Minute) + desired := int32(1) + if d.Spec.Replicas != nil { + desired = *d.Spec.Replicas + } + ready := fmt.Sprintf("%d/%d", d.Status.ReadyReplicas, desired) + sb.WriteString(fmt.Sprintf("%-45s %-10s %-12d %-10d %s\n", + d.Namespace+"/"+d.Name, ready, d.Status.UpdatedReplicas, d.Status.AvailableReplicas, age)) + } + return mcp.NewToolResultText(sb.String()), nil +} + +func handleDescribeDeployment(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + session, err := auth.SessionFromContext(ctx) + if err != nil { + return errResult("auth: %v", err) + } + args := req.GetArguments() + namespace := str(args, "namespace") + depName := str(args, "deployment") + ci, err := clusterInfo(args) + if err != nil { + return errResult("validation: %v", err) + } + client, err := k8sClient.NewClientForUser(ctx, session.AccessToken, ci) + if err != nil { + return safeErr("failed to connect to cluster", err) + } + + d, err := client.AppsV1().Deployments(namespace).Get(ctx, depName, metav1.GetOptions{}) + if err != nil { + return safeErr("kubernetes API error", err) + } + + var sb strings.Builder + sb.WriteString(fmt.Sprintf("Deployment: %s/%s\n", d.Namespace, d.Name)) + replicas := int32(0) + if d.Spec.Replicas != nil { + replicas = *d.Spec.Replicas + } + sb.WriteString(fmt.Sprintf("Replicas: %d desired | %d ready | %d up-to-date | %d available\n", + replicas, d.Status.ReadyReplicas, d.Status.UpdatedReplicas, d.Status.AvailableReplicas)) + sb.WriteString(fmt.Sprintf("Strategy: %s\n", d.Spec.Strategy.Type)) + sb.WriteString(fmt.Sprintf("Created: %s\n\n", d.CreationTimestamp.Format(time.RFC3339))) + + sb.WriteString("Conditions:\n") + for _, c := range d.Status.Conditions { + sb.WriteString(fmt.Sprintf(" %-25s %-8s %s\n", c.Type, c.Status, c.Message)) + } + + sb.WriteString("\nContainers:\n") + for _, c := range d.Spec.Template.Spec.Containers { + sb.WriteString(fmt.Sprintf(" %s:\n", c.Name)) + sb.WriteString(fmt.Sprintf(" Image: %s\n", c.Image)) + if c.Resources.Requests != nil { + sb.WriteString(fmt.Sprintf(" Requests: cpu=%s, memory=%s\n", + c.Resources.Requests.Cpu().String(), c.Resources.Requests.Memory().String())) + } + if c.Resources.Limits != nil { + sb.WriteString(fmt.Sprintf(" Limits: cpu=%s, memory=%s\n", + c.Resources.Limits.Cpu().String(), c.Resources.Limits.Memory().String())) + } + } + return mcp.NewToolResultText(sb.String()), nil +} + +func handleListServices(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + session, err := auth.SessionFromContext(ctx) + if err != nil { + return errResult("auth: %v", err) + } + args := req.GetArguments() + namespace := str(args, "namespace") + ci, err := clusterInfo(args) + if err != nil { + return errResult("validation: %v", err) + } + client, err := k8sClient.NewClientForUser(ctx, session.AccessToken, ci) + if err != nil { + return safeErr("failed to connect to cluster", err) + } + + svcs, err := client.CoreV1().Services(namespace).List(ctx, metav1.ListOptions{Limit: 500}) + if err != nil { + return safeErr("kubernetes API error", err) + } + + var sb strings.Builder + sb.WriteString(fmt.Sprintf("Found %d services:\n\n", len(svcs.Items))) + sb.WriteString(fmt.Sprintf("%-40s %-12s %-16s %-30s %s\n", "NAMESPACE/NAME", "TYPE", "CLUSTER-IP", "PORTS", "AGE")) + sb.WriteString(strings.Repeat("-", 115) + "\n") + for _, s := range svcs.Items { + age := time.Since(s.CreationTimestamp.Time).Truncate(time.Minute) + var ports []string + for _, p := range s.Spec.Ports { + ports = append(ports, fmt.Sprintf("%d/%s", p.Port, p.Protocol)) + } + portStr := strings.Join(ports, ",") + if len(portStr) > 28 { + portStr = portStr[:25] + "..." + } + sb.WriteString(fmt.Sprintf("%-40s %-12s %-16s %-30s %s\n", + s.Namespace+"/"+s.Name, s.Spec.Type, s.Spec.ClusterIP, portStr, age)) + } + return mcp.NewToolResultText(sb.String()), nil +} + +func handleListIngresses(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + session, err := auth.SessionFromContext(ctx) + if err != nil { + return errResult("auth: %v", err) + } + args := req.GetArguments() + namespace := str(args, "namespace") + ci, err := clusterInfo(args) + if err != nil { + return errResult("validation: %v", err) + } + client, err := k8sClient.NewClientForUser(ctx, session.AccessToken, ci) + if err != nil { + return safeErr("failed to connect to cluster", err) + } + + ingList, err := client.NetworkingV1().Ingresses(namespace).List(ctx, metav1.ListOptions{Limit: 500}) + if err != nil { + return safeErr("kubernetes API error", err) + } + + var sb strings.Builder + sb.WriteString(fmt.Sprintf("Found %d ingresses:\n\n", len(ingList.Items))) + sb.WriteString(fmt.Sprintf("%-40s %-35s %-25s %s\n", "NAMESPACE/NAME", "HOSTS", "PATHS", "AGE")) + sb.WriteString(strings.Repeat("-", 110) + "\n") + for _, ing := range ingList.Items { + age := time.Since(ing.CreationTimestamp.Time).Truncate(time.Minute) + var hosts, paths []string + for _, rule := range ing.Spec.Rules { + hosts = append(hosts, rule.Host) + if rule.HTTP != nil { + for _, p := range rule.HTTP.Paths { + paths = append(paths, p.Path) + } + } + } + hostStr := strings.Join(hosts, ",") + pathStr := strings.Join(paths, ",") + if len(hostStr) > 33 { + hostStr = hostStr[:30] + "..." + } + sb.WriteString(fmt.Sprintf("%-40s %-35s %-25s %s\n", + ing.Namespace+"/"+ing.Name, hostStr, pathStr, age)) + } + return mcp.NewToolResultText(sb.String()), nil +} + +func handleListJobs(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + session, err := auth.SessionFromContext(ctx) + if err != nil { + return errResult("auth: %v", err) + } + args := req.GetArguments() + namespace := str(args, "namespace") + ci, err := clusterInfo(args) + if err != nil { + return errResult("validation: %v", err) + } + client, err := k8sClient.NewClientForUser(ctx, session.AccessToken, ci) + if err != nil { + return safeErr("failed to connect to cluster", err) + } + + jobs, err := client.BatchV1().Jobs(namespace).List(ctx, metav1.ListOptions{Limit: 500}) + if err != nil { + return safeErr("kubernetes API error", err) + } + + var sb strings.Builder + sb.WriteString(fmt.Sprintf("Found %d jobs:\n\n", len(jobs.Items))) + sb.WriteString(fmt.Sprintf("%-45s %-14s %-10s %-8s %s\n", "NAMESPACE/NAME", "COMPLETIONS", "SUCCEEDED", "FAILED", "AGE")) + sb.WriteString(strings.Repeat("-", 100) + "\n") + for _, j := range jobs.Items { + age := time.Since(j.CreationTimestamp.Time).Truncate(time.Minute) + desired := int32(1) + if j.Spec.Completions != nil { + desired = *j.Spec.Completions + } + sb.WriteString(fmt.Sprintf("%-45s %d/%-12d %-10d %-8d %s\n", + j.Namespace+"/"+j.Name, j.Status.Succeeded, desired, j.Status.Succeeded, j.Status.Failed, age)) + } + return mcp.NewToolResultText(sb.String()), nil +} + +func handleListHPA(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + session, err := auth.SessionFromContext(ctx) + if err != nil { + return errResult("auth: %v", err) + } + args := req.GetArguments() + namespace := str(args, "namespace") + ci, err := clusterInfo(args) + if err != nil { + return errResult("validation: %v", err) + } + client, err := k8sClient.NewClientForUser(ctx, session.AccessToken, ci) + if err != nil { + return safeErr("failed to connect to cluster", err) + } + + hpas, err := client.AutoscalingV2().HorizontalPodAutoscalers(namespace).List(ctx, metav1.ListOptions{Limit: 500}) + if err != nil { + return safeErr("kubernetes API error", err) + } + + var sb strings.Builder + sb.WriteString(fmt.Sprintf("Found %d HPAs:\n\n", len(hpas.Items))) + sb.WriteString(fmt.Sprintf("%-40s %-20s %-8s %-8s %-8s %s\n", "NAMESPACE/NAME", "REFERENCE", "MIN", "MAX", "CURRENT", "AGE")) + sb.WriteString(strings.Repeat("-", 100) + "\n") + for _, h := range hpas.Items { + age := time.Since(h.CreationTimestamp.Time).Truncate(time.Minute) + ref := fmt.Sprintf("%s/%s", h.Spec.ScaleTargetRef.Kind, h.Spec.ScaleTargetRef.Name) + minReplicas := int32(1) + if h.Spec.MinReplicas != nil { + minReplicas = *h.Spec.MinReplicas + } + sb.WriteString(fmt.Sprintf("%-40s %-20s %-8d %-8d %-8d %s\n", + h.Namespace+"/"+h.Name, ref, minReplicas, h.Spec.MaxReplicas, h.Status.CurrentReplicas, age)) + } + return mcp.NewToolResultText(sb.String()), nil +} + +func handleListPVCs(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + session, err := auth.SessionFromContext(ctx) + if err != nil { + return errResult("auth: %v", err) + } + args := req.GetArguments() + namespace := str(args, "namespace") + ci, err := clusterInfo(args) + if err != nil { + return errResult("validation: %v", err) + } + client, err := k8sClient.NewClientForUser(ctx, session.AccessToken, ci) + if err != nil { + return safeErr("failed to connect to cluster", err) + } + + pvcs, err := client.CoreV1().PersistentVolumeClaims(namespace).List(ctx, metav1.ListOptions{Limit: 500}) + if err != nil { + return safeErr("kubernetes API error", err) + } + + var sb strings.Builder + sb.WriteString(fmt.Sprintf("Found %d PVCs:\n\n", len(pvcs.Items))) + sb.WriteString(fmt.Sprintf("%-40s %-10s %-20s %-10s %-15s %s\n", "NAMESPACE/NAME", "STATUS", "VOLUME", "CAPACITY", "STORAGECLASS", "AGE")) + sb.WriteString(strings.Repeat("-", 115) + "\n") + for _, p := range pvcs.Items { + age := time.Since(p.CreationTimestamp.Time).Truncate(time.Minute) + capacity := "" + if p.Status.Capacity != nil { + if storage, ok := p.Status.Capacity["storage"]; ok { + capacity = storage.String() + } + } + sc := "" + if p.Spec.StorageClassName != nil { + sc = *p.Spec.StorageClassName + } + sb.WriteString(fmt.Sprintf("%-40s %-10s %-20s %-10s %-15s %s\n", + p.Namespace+"/"+p.Name, p.Status.Phase, p.Spec.VolumeName, capacity, sc, age)) + } + return mcp.NewToolResultText(sb.String()), nil +} + +func handleListConfigMaps(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + session, err := auth.SessionFromContext(ctx) + if err != nil { + return errResult("auth: %v", err) + } + args := req.GetArguments() + namespace := str(args, "namespace") + ci, err := clusterInfo(args) + if err != nil { + return errResult("validation: %v", err) + } + client, err := k8sClient.NewClientForUser(ctx, session.AccessToken, ci) + if err != nil { + return safeErr("failed to connect to cluster", err) + } + + cms, err := client.CoreV1().ConfigMaps(namespace).List(ctx, metav1.ListOptions{Limit: 500}) + if err != nil { + return safeErr("kubernetes API error", err) + } + + var sb strings.Builder + sb.WriteString(fmt.Sprintf("Found %d config maps:\n\n", len(cms.Items))) + sb.WriteString(fmt.Sprintf("%-45s %-8s %s\n", "NAMESPACE/NAME", "KEYS", "AGE")) + sb.WriteString(strings.Repeat("-", 70) + "\n") + for _, cm := range cms.Items { + age := time.Since(cm.CreationTimestamp.Time).Truncate(time.Minute) + sb.WriteString(fmt.Sprintf("%-45s %-8d %s\n", + cm.Namespace+"/"+cm.Name, len(cm.Data), age)) + } + return mcp.NewToolResultText(sb.String()), nil +} + +func handleListStatefulSets(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + session, err := auth.SessionFromContext(ctx) + if err != nil { + return errResult("auth: %v", err) + } + args := req.GetArguments() + namespace := str(args, "namespace") + ci, err := clusterInfo(args) + if err != nil { + return errResult("validation: %v", err) + } + client, err := k8sClient.NewClientForUser(ctx, session.AccessToken, ci) + if err != nil { + return safeErr("failed to connect to cluster", err) + } + + ssets, err := client.AppsV1().StatefulSets(namespace).List(ctx, metav1.ListOptions{Limit: 500}) + if err != nil { + return safeErr("kubernetes API error", err) + } + + var sb strings.Builder + sb.WriteString(fmt.Sprintf("Found %d stateful sets:\n\n", len(ssets.Items))) + sb.WriteString(fmt.Sprintf("%-45s %-10s %s\n", "NAMESPACE/NAME", "READY", "AGE")) + sb.WriteString(strings.Repeat("-", 70) + "\n") + for _, s := range ssets.Items { + age := time.Since(s.CreationTimestamp.Time).Truncate(time.Minute) + replicas := int32(0) + if s.Spec.Replicas != nil { + replicas = *s.Spec.Replicas + } + sb.WriteString(fmt.Sprintf("%-45s %d/%-8d %s\n", + s.Namespace+"/"+s.Name, s.Status.ReadyReplicas, replicas, age)) + } + return mcp.NewToolResultText(sb.String()), nil +} + +func handleListDaemonSets(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + session, err := auth.SessionFromContext(ctx) + if err != nil { + return errResult("auth: %v", err) + } + args := req.GetArguments() + namespace := str(args, "namespace") + ci, err := clusterInfo(args) + if err != nil { + return errResult("validation: %v", err) + } + client, err := k8sClient.NewClientForUser(ctx, session.AccessToken, ci) + if err != nil { + return safeErr("failed to connect to cluster", err) + } + + dsets, err := client.AppsV1().DaemonSets(namespace).List(ctx, metav1.ListOptions{Limit: 500}) + if err != nil { + return safeErr("kubernetes API error", err) + } + + var sb strings.Builder + sb.WriteString(fmt.Sprintf("Found %d daemon sets:\n\n", len(dsets.Items))) + sb.WriteString(fmt.Sprintf("%-45s %-10s %-8s %-12s %-10s %s\n", "NAMESPACE/NAME", "DESIRED", "READY", "UP-TO-DATE", "AVAILABLE", "AGE")) + sb.WriteString(strings.Repeat("-", 105) + "\n") + for _, d := range dsets.Items { + age := time.Since(d.CreationTimestamp.Time).Truncate(time.Minute) + sb.WriteString(fmt.Sprintf("%-45s %-10d %-8d %-12d %-10d %s\n", + d.Namespace+"/"+d.Name, + d.Status.DesiredNumberScheduled, d.Status.NumberReady, + d.Status.UpdatedNumberScheduled, d.Status.NumberAvailable, age)) + } + return mcp.NewToolResultText(sb.String()), nil +} + +func handleListCRDs(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + session, err := auth.SessionFromContext(ctx) + if err != nil { + return errResult("auth: %v", err) + } + ci, err := clusterInfo(req.GetArguments()) + if err != nil { + return errResult("validation: %v", err) + } + dynClient, err := k8sClient.NewDynamicClientForUser(ctx, session.AccessToken, ci) + if err != nil { + return safeErr("failed to connect to cluster", err) + } + + crdGVR := schema.GroupVersionResource{ + Group: "apiextensions.k8s.io", + Version: "v1", + Resource: "customresourcedefinitions", + } + crdList, err := dynClient.Resource(crdGVR).List(ctx, metav1.ListOptions{Limit: 500}) + if err != nil { + return safeErr("kubernetes API error", err) + } + + var sb strings.Builder + sb.WriteString(fmt.Sprintf("Found %d CRDs:\n\n", len(crdList.Items))) + sb.WriteString(fmt.Sprintf("%-55s %-30s %-12s %s\n", "NAME", "GROUP", "SCOPE", "AGE")) + sb.WriteString(strings.Repeat("-", 110) + "\n") + for _, item := range crdList.Items { + name := item.GetName() + spec, _ := item.Object["spec"].(map[string]interface{}) + group, _ := spec["group"].(string) + scope, _ := spec["scope"].(string) + age := time.Since(item.GetCreationTimestamp().Time).Truncate(time.Minute) + sb.WriteString(fmt.Sprintf("%-55s %-30s %-12s %s\n", name, group, scope, age)) + } + return mcp.NewToolResultText(sb.String()), nil +} + +func handleGetCRDInstances(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + session, err := auth.SessionFromContext(ctx) + if err != nil { + return errResult("auth: %v", err) + } + args := req.GetArguments() + ci, err := clusterInfo(args) + if err != nil { + return errResult("validation: %v", err) + } + group := str(args, "group") + version := str(args, "version") + resource := str(args, "resource") + namespace := str(args, "namespace") + + dynClient, err := k8sClient.NewDynamicClientForUser(ctx, session.AccessToken, ci) + if err != nil { + return safeErr("failed to connect to cluster", err) + } + + gvr := schema.GroupVersionResource{Group: group, Version: version, Resource: resource} + var list *unstructured.UnstructuredList + if namespace != "" { + list, err = dynClient.Resource(gvr).Namespace(namespace).List(ctx, metav1.ListOptions{Limit: 500}) + } else { + list, err = dynClient.Resource(gvr).List(ctx, metav1.ListOptions{Limit: 500}) + } + if err != nil { + return safeErr("kubernetes API error", err) + } + + var sb strings.Builder + sb.WriteString(fmt.Sprintf("Found %d %s.%s/%s:\n\n", len(list.Items), resource, group, version)) + sb.WriteString(fmt.Sprintf("%-45s %-30s %s\n", "NAME", "NAMESPACE", "AGE")) + sb.WriteString(strings.Repeat("-", 85) + "\n") + for _, item := range list.Items { + age := time.Since(item.GetCreationTimestamp().Time).Truncate(time.Minute) + sb.WriteString(fmt.Sprintf("%-45s %-30s %s\n", item.GetName(), item.GetNamespace(), age)) + } + return mcp.NewToolResultText(sb.String()), nil +} + +func handleGetResourceYAML(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + session, err := auth.SessionFromContext(ctx) + if err != nil { + return errResult("auth: %v", err) + } + args := req.GetArguments() + ci, err := clusterInfo(args) + if err != nil { + return errResult("validation: %v", err) + } + apiVersion := str(args, "api_version") + kind := str(args, "kind") + name := str(args, "name") + namespace := str(args, "namespace") + + if apiVersion == "" || kind == "" || name == "" { + return errResult("api_version, kind, and name are required") + } + + dynClient, err := k8sClient.NewDynamicClientForUser(ctx, session.AccessToken, ci) + if err != nil { + return safeErr("failed to connect to cluster", err) + } + + // Parse apiVersion into group + version. + gv := strings.SplitN(apiVersion, "/", 2) + var group, version string + if len(gv) == 1 { + group = "" + version = gv[0] // core API: "v1" + } else { + group = gv[0] + version = gv[1] + } + + // Convert kind to lowercase plural (simple heuristic). + resource := strings.ToLower(kind) + "s" + + gvr := schema.GroupVersionResource{Group: group, Version: version, Resource: resource} + var obj *unstructured.Unstructured + if namespace != "" { + obj, err = dynClient.Resource(gvr).Namespace(namespace).Get(ctx, name, metav1.GetOptions{}) + } else { + obj, err = dynClient.Resource(gvr).Get(ctx, name, metav1.GetOptions{}) + } + if err != nil { + return safeErr("failed to get resource", err) + } + + yamlBytes, err := sigsyaml.Marshal(obj.Object) + if err != nil { + return safeErr("failed to serialize resource", err) + } + + return mcp.NewToolResultText(string(yamlBytes)), nil +} From 357d84c8ce5f48e298195e024a6bc1110cb338a1 Mon Sep 17 00:00:00 2001 From: AdityaK011 Date: Tue, 14 Apr 2026 00:12:46 +0900 Subject: [PATCH 2/3] Add unit tests and GitHub Actions CI --- .github/workflows/ci.yml | 22 +++ internal/auth/middleware_test.go | 167 +++++++++++++++++++++ internal/auth/oauth_test.go | 249 ++++++++++++++++++++++++++++++ internal/auth/ratelimit_test.go | 69 +++++++++ internal/auth/session_test.go | 250 +++++++++++++++++++++++++++++++ internal/k8s/client_test.go | 79 ++++++++++ internal/tools/tools_test.go | 224 +++++++++++++++++++++++++++ 7 files changed, 1060 insertions(+) create mode 100644 .github/workflows/ci.yml create mode 100644 internal/auth/middleware_test.go create mode 100644 internal/auth/oauth_test.go create mode 100644 internal/auth/ratelimit_test.go create mode 100644 internal/auth/session_test.go create mode 100644 internal/k8s/client_test.go create mode 100644 internal/tools/tools_test.go diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..b6cd196 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,22 @@ +name: CI + +on: + push: + branches: ['*'] + pull_request: + branches: [main] + +jobs: + build-and-test: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-go@v5 + with: + go-version: '1.23' + - name: Build + run: go build ./... + - name: Vet + run: go vet ./... + - name: Test + run: go test ./... -v -race -count=1 diff --git a/internal/auth/middleware_test.go b/internal/auth/middleware_test.go new file mode 100644 index 0000000..0122cba --- /dev/null +++ b/internal/auth/middleware_test.go @@ -0,0 +1,167 @@ +package auth + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + "time" +) + +func TestMiddlewareNoToken(t *testing.T) { + g := newTestGoogleOAuth() + + handler := g.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Error("handler should not be called when no token is provided") + })) + + req := httptest.NewRequest(http.MethodGet, "/protected", nil) + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusUnauthorized { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusUnauthorized) + } +} + +func TestMiddlewareInvalidToken(t *testing.T) { + g := newTestGoogleOAuth() + + handler := g.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Error("handler should not be called with invalid token") + })) + + req := httptest.NewRequest(http.MethodGet, "/protected", nil) + req.Header.Set("Authorization", "Bearer invalid123") + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusUnauthorized { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusUnauthorized) + } +} + +func TestMiddlewareValidToken(t *testing.T) { + g := newTestGoogleOAuth() + + // Create a session with an access token that won't expire soon, + // so EnsureFreshToken does not attempt a Google refresh. + sessionID := g.Store.CreateSession(Session{ + Email: "test@example.com", + AccessToken: "google-access-token", + RefreshToken: "google-refresh-token", + ExpiresAt: time.Now().Add(time.Hour), + }) + + var receivedSession *Session + handler := g.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + sess, err := SessionFromContext(r.Context()) + if err != nil { + t.Errorf("SessionFromContext failed: %v", err) + return + } + receivedSession = sess + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest(http.MethodGet, "/protected", nil) + req.Header.Set("Authorization", "Bearer "+sessionID) + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d; body: %s", rec.Code, http.StatusOK, rec.Body.String()) + } + + if receivedSession == nil { + t.Fatal("handler did not receive session in context") + } + if receivedSession.Email != "test@example.com" { + t.Errorf("session Email = %q, want %q", receivedSession.Email, "test@example.com") + } + if receivedSession.AccessToken != "google-access-token" { + t.Errorf("session AccessToken = %q, want %q", receivedSession.AccessToken, "google-access-token") + } +} + +func TestSessionFromContextWithSession(t *testing.T) { + sess := &Session{ + Email: "ctx@example.com", + AccessToken: "tok-abc", + RefreshToken: "ref-xyz", + ExpiresAt: time.Now().Add(time.Hour), + CreatedAt: time.Now(), + } + + ctx := context.WithValue(context.Background(), sessionCtxKey, sess) + + got, err := SessionFromContext(ctx) + if err != nil { + t.Fatalf("SessionFromContext failed: %v", err) + } + if got.Email != "ctx@example.com" { + t.Errorf("Email = %q, want %q", got.Email, "ctx@example.com") + } + if got.AccessToken != "tok-abc" { + t.Errorf("AccessToken = %q, want %q", got.AccessToken, "tok-abc") + } +} + +func TestSessionFromContextWithoutSession(t *testing.T) { + ctx := context.Background() + + _, err := SessionFromContext(ctx) + if err == nil { + t.Fatal("expected error from empty context, got nil") + } +} + +func TestExtractBearer(t *testing.T) { + tests := []struct { + name string + header string + want string + }{ + { + name: "valid bearer token", + header: "Bearer abc123", + want: "abc123", + }, + { + name: "basic auth scheme", + header: "Basic xyz", + want: "", + }, + { + name: "empty header", + header: "", + want: "", + }, + { + name: "bearer with long token", + header: "Bearer eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9", + want: "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9", + }, + { + name: "lowercase bearer (invalid)", + header: "bearer abc123", + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/", nil) + if tt.header != "" { + req.Header.Set("Authorization", tt.header) + } + got := extractBearer(req) + if got != tt.want { + t.Errorf("extractBearer() = %q, want %q", got, tt.want) + } + }) + } +} diff --git a/internal/auth/oauth_test.go b/internal/auth/oauth_test.go new file mode 100644 index 0000000..45af80a --- /dev/null +++ b/internal/auth/oauth_test.go @@ -0,0 +1,249 @@ +package auth + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func newTestGoogleOAuth() *GoogleOAuth { + return NewGoogleOAuth("test-id", "test-secret", "http://localhost:8080/callback") +} + +func TestHandleMetadata(t *testing.T) { + g := newTestGoogleOAuth() + + req := httptest.NewRequest(http.MethodGet, "/.well-known/oauth-authorization-server", nil) + rec := httptest.NewRecorder() + + g.handleMetadata(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK) + } + + contentType := rec.Header().Get("Content-Type") + if contentType != "application/json" { + t.Errorf("Content-Type = %q, want %q", contentType, "application/json") + } + + var body map[string]interface{} + if err := json.NewDecoder(rec.Body).Decode(&body); err != nil { + t.Fatalf("failed to decode response body: %v", err) + } + + requiredFields := []string{ + "issuer", + "authorization_endpoint", + "token_endpoint", + "registration_endpoint", + "response_types_supported", + "grant_types_supported", + "code_challenge_methods_supported", + "token_endpoint_auth_methods_supported", + } + for _, field := range requiredFields { + if _, ok := body[field]; !ok { + t.Errorf("response missing required field %q", field) + } + } + + // Verify endpoints contain the host. + if authEP, ok := body["authorization_endpoint"].(string); ok { + if !strings.HasSuffix(authEP, "/authorize") { + t.Errorf("authorization_endpoint = %q, want it to end with /authorize", authEP) + } + } + if tokenEP, ok := body["token_endpoint"].(string); ok { + if !strings.HasSuffix(tokenEP, "/token") { + t.Errorf("token_endpoint = %q, want it to end with /token", tokenEP) + } + } + if regEP, ok := body["registration_endpoint"].(string); ok { + if !strings.HasSuffix(regEP, "/register") { + t.Errorf("registration_endpoint = %q, want it to end with /register", regEP) + } + } +} + +func TestHandleRegister(t *testing.T) { + g := newTestGoogleOAuth() + + body := `{ + "client_name": "Test Client", + "redirect_uris": ["http://127.0.0.1:3000/callback"], + "grant_types": ["authorization_code"], + "token_endpoint_auth_method": "none" + }` + req := httptest.NewRequest(http.MethodPost, "/register", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + + g.handleRegister(rec, req) + + if rec.Code != http.StatusCreated { + t.Fatalf("status = %d, want %d; body: %s", rec.Code, http.StatusCreated, rec.Body.String()) + } + + var resp map[string]interface{} + if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil { + t.Fatalf("failed to decode response: %v", err) + } + + clientID, ok := resp["client_id"].(string) + if !ok || clientID == "" { + t.Error("response missing or empty client_id") + } + + if name, _ := resp["client_name"].(string); name != "Test Client" { + t.Errorf("client_name = %q, want %q", name, "Test Client") + } +} + +func TestHandleRegisterInvalidRedirectURI(t *testing.T) { + g := newTestGoogleOAuth() + + body := `{ + "client_name": "Evil Client", + "redirect_uris": ["https://evil.com/callback"], + "grant_types": ["authorization_code"], + "token_endpoint_auth_method": "none" + }` + req := httptest.NewRequest(http.MethodPost, "/register", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + + g.handleRegister(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want %d; body: %s", rec.Code, http.StatusBadRequest, rec.Body.String()) + } + + if !strings.Contains(rec.Body.String(), "invalid_redirect_uri") { + t.Errorf("response body = %q, want it to contain 'invalid_redirect_uri'", rec.Body.String()) + } +} + +func TestHandleRegisterMissingRedirectURIs(t *testing.T) { + g := newTestGoogleOAuth() + + body := `{ + "client_name": "No Redirect Client", + "redirect_uris": [], + "grant_types": ["authorization_code"] + }` + req := httptest.NewRequest(http.MethodPost, "/register", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + + g.handleRegister(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want %d; body: %s", rec.Code, http.StatusBadRequest, rec.Body.String()) + } + + if !strings.Contains(rec.Body.String(), "invalid_redirect_uri") { + t.Errorf("response body = %q, want it to contain 'invalid_redirect_uri'", rec.Body.String()) + } +} + +func TestIsLoopbackURI(t *testing.T) { + tests := []struct { + name string + uri string + want bool + }{ + { + name: "IPv4 loopback", + uri: "http://127.0.0.1:3000/callback", + want: true, + }, + { + name: "IPv6 loopback", + uri: "http://[::1]:3000/callback", + want: true, + }, + { + name: "localhost", + uri: "http://localhost:3000/callback", + want: true, + }, + { + name: "external host", + uri: "http://evil.com/callback", + want: false, + }, + { + name: "empty string", + uri: "", + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := isLoopbackURI(tt.uri) + if got != tt.want { + t.Errorf("isLoopbackURI(%q) = %v, want %v", tt.uri, got, tt.want) + } + }) + } +} + +func TestMatchesRegisteredURI(t *testing.T) { + tests := []struct { + name string + requestURI string + registeredURIs []string + want bool + }{ + { + name: "same scheme+host+path, different port", + requestURI: "http://127.0.0.1:9999/callback", + registeredURIs: []string{"http://127.0.0.1:3000/callback"}, + want: true, + }, + { + name: "different host", + requestURI: "http://192.168.1.1:3000/callback", + registeredURIs: []string{"http://127.0.0.1:3000/callback"}, + want: false, + }, + { + name: "different scheme", + requestURI: "https://127.0.0.1:3000/callback", + registeredURIs: []string{"http://127.0.0.1:3000/callback"}, + want: false, + }, + { + name: "exact match", + requestURI: "http://127.0.0.1:3000/callback", + registeredURIs: []string{"http://127.0.0.1:3000/callback"}, + want: true, + }, + { + name: "different path", + requestURI: "http://127.0.0.1:3000/other", + registeredURIs: []string{"http://127.0.0.1:3000/callback"}, + want: false, + }, + { + name: "no registered URIs", + requestURI: "http://127.0.0.1:3000/callback", + registeredURIs: []string{}, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := matchesRegisteredURI(tt.requestURI, tt.registeredURIs) + if got != tt.want { + t.Errorf("matchesRegisteredURI(%q, %v) = %v, want %v", + tt.requestURI, tt.registeredURIs, got, tt.want) + } + }) + } +} diff --git a/internal/auth/ratelimit_test.go b/internal/auth/ratelimit_test.go new file mode 100644 index 0000000..40c6950 --- /dev/null +++ b/internal/auth/ratelimit_test.go @@ -0,0 +1,69 @@ +package auth + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +func TestRateLimiterAllows(t *testing.T) { + rl, stop := NewRateLimitMiddleware() + defer stop() + + okHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + handler := rl(okHandler) + + // The rate limiter allows a burst of 5. Send 5 requests; all should pass. + for i := 0; i < 5; i++ { + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.RemoteAddr = "192.168.1.1:12345" + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("request %d: status = %d, want %d", i+1, rec.Code, http.StatusOK) + } + } +} + +func TestRateLimiterBlocks(t *testing.T) { + rl, stop := NewRateLimitMiddleware() + defer stop() + + okHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + handler := rl(okHandler) + + // Burst is 5 and rate is 10/min (1 every 6s). Sending 10 rapid requests + // from the same IP should result in some being rate-limited. + var okCount, blockedCount int + for i := 0; i < 10; i++ { + req := httptest.NewRequest(http.MethodGet, "/test", nil) + req.RemoteAddr = "10.0.0.1:12345" + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + + switch rec.Code { + case http.StatusOK: + okCount++ + case http.StatusTooManyRequests: + blockedCount++ + default: + t.Fatalf("request %d: unexpected status %d", i+1, rec.Code) + } + } + + if blockedCount == 0 { + t.Error("expected at least one request to be rate-limited (429), but all passed") + } + if okCount == 0 { + t.Error("expected at least one request to succeed (200), but all were blocked") + } + + t.Logf("out of 10 requests: %d OK, %d blocked", okCount, blockedCount) +} diff --git a/internal/auth/session_test.go b/internal/auth/session_test.go new file mode 100644 index 0000000..277ef8a --- /dev/null +++ b/internal/auth/session_test.go @@ -0,0 +1,250 @@ +package auth + +import ( + "encoding/hex" + "strings" + "testing" + "time" +) + +func TestCreateSession(t *testing.T) { + store := NewStore() + + id := store.CreateSession(Session{ + Email: "user@example.com", + AccessToken: "access-123", + RefreshToken: "refresh-456", + ExpiresAt: time.Now().Add(time.Hour), + }) + + // Session ID must be 64-char hex (32 random bytes). + if len(id) != 64 { + t.Fatalf("expected session ID of length 64, got %d", len(id)) + } + if _, err := hex.DecodeString(id); err != nil { + t.Fatalf("session ID is not valid hex: %v", err) + } + + // Retrieve and verify fields. + sess, err := store.GetSession(id) + if err != nil { + t.Fatalf("GetSession failed: %v", err) + } + if sess.Email != "user@example.com" { + t.Errorf("Email = %q, want %q", sess.Email, "user@example.com") + } + if sess.AccessToken != "access-123" { + t.Errorf("AccessToken = %q, want %q", sess.AccessToken, "access-123") + } + if sess.RefreshToken != "refresh-456" { + t.Errorf("RefreshToken = %q, want %q", sess.RefreshToken, "refresh-456") + } + if sess.CreatedAt.IsZero() { + t.Error("CreatedAt should be set by CreateSession") + } +} + +func TestGetSessionNotFound(t *testing.T) { + store := NewStore() + + _, err := store.GetSession("nonexistent-id") + if err == nil { + t.Fatal("expected error for nonexistent session, got nil") + } + if !strings.Contains(err.Error(), "not found") { + t.Errorf("error = %q, want it to contain 'not found'", err.Error()) + } +} + +func TestGetSessionExpiry(t *testing.T) { + store := NewStore() + + id := store.CreateSession(Session{ + Email: "expired@example.com", + AccessToken: "token", + }) + + // Manually set CreatedAt to 25 hours ago to force expiry. + store.mu.Lock() + store.sessions[id].CreatedAt = time.Now().Add(-25 * time.Hour) + store.mu.Unlock() + + _, err := store.GetSession(id) + if err == nil { + t.Fatal("expected error for expired session, got nil") + } + if !strings.Contains(err.Error(), "expired") { + t.Errorf("error = %q, want it to contain 'expired'", err.Error()) + } +} + +func TestPendingAuthOneTimeUse(t *testing.T) { + store := NewStore() + + store.StorePending("key1", PendingAuth{ + ClientID: "client-1", + CodeChallenge: "challenge", + ClientRedirectURI: "http://127.0.0.1:3000/callback", + State: "state-abc", + }) + + // First retrieval should succeed. + p, err := store.GetPending("key1") + if err != nil { + t.Fatalf("first GetPending failed: %v", err) + } + if p.ClientID != "client-1" { + t.Errorf("ClientID = %q, want %q", p.ClientID, "client-1") + } + if p.State != "state-abc" { + t.Errorf("State = %q, want %q", p.State, "state-abc") + } + + // Second retrieval should fail (one-time use). + _, err = store.GetPending("key1") + if err == nil { + t.Fatal("expected error on second GetPending, got nil") + } + if !strings.Contains(err.Error(), "not found") { + t.Errorf("error = %q, want it to contain 'not found'", err.Error()) + } +} + +func TestPendingAuthExpiry(t *testing.T) { + store := NewStore() + + store.StorePending("key-expired", PendingAuth{ + ClientID: "client-1", + State: "state-xyz", + }) + + // Manually set CreatedAt to 6 minutes ago. + store.pendingMu.Lock() + store.pending["key-expired"].CreatedAt = time.Now().Add(-6 * time.Minute) + store.pendingMu.Unlock() + + _, err := store.GetPending("key-expired") + if err == nil { + t.Fatal("expected error for expired pending auth, got nil") + } + if !strings.Contains(err.Error(), "expired") { + t.Errorf("error = %q, want it to contain 'expired'", err.Error()) + } +} + +func TestAuthCodeOneTimeUse(t *testing.T) { + store := NewStore() + + store.StoreAuthCode("code1", AuthCode{ + ClientID: "client-1", + SessionID: "session-abc", + CodeChallenge: "challenge", + }) + + // First retrieval should succeed. + ac, err := store.GetAuthCode("code1") + if err != nil { + t.Fatalf("first GetAuthCode failed: %v", err) + } + if ac.ClientID != "client-1" { + t.Errorf("ClientID = %q, want %q", ac.ClientID, "client-1") + } + if ac.SessionID != "session-abc" { + t.Errorf("SessionID = %q, want %q", ac.SessionID, "session-abc") + } + + // Second retrieval should fail (one-time use). + _, err = store.GetAuthCode("code1") + if err == nil { + t.Fatal("expected error on second GetAuthCode, got nil") + } + if !strings.Contains(err.Error(), "not found") { + t.Errorf("error = %q, want it to contain 'not found'", err.Error()) + } +} + +func TestAuthCodeExpiry(t *testing.T) { + store := NewStore() + + store.StoreAuthCode("code-expired", AuthCode{ + ClientID: "client-1", + SessionID: "session-abc", + }) + + // Manually set CreatedAt to 6 minutes ago. + store.codesMu.Lock() + store.codes["code-expired"].CreatedAt = time.Now().Add(-6 * time.Minute) + store.codesMu.Unlock() + + _, err := store.GetAuthCode("code-expired") + if err == nil { + t.Fatal("expected error for expired auth code, got nil") + } + if !strings.Contains(err.Error(), "expired") { + t.Errorf("error = %q, want it to contain 'expired'", err.Error()) + } +} + +func TestRegisterClient(t *testing.T) { + store := NewStore() + + clientID := store.RegisterClient(Client{ + ClientName: "Test App", + RedirectURIs: []string{"http://127.0.0.1:3000/callback"}, + GrantTypes: []string{"authorization_code"}, + TokenEndpointAuthMethod: "none", + }) + + if len(clientID) != 32 { // 16 random bytes = 32 hex chars + t.Fatalf("expected client ID of length 32, got %d", len(clientID)) + } + + client, err := store.GetClient(clientID) + if err != nil { + t.Fatalf("GetClient failed: %v", err) + } + if client.ClientName != "Test App" { + t.Errorf("ClientName = %q, want %q", client.ClientName, "Test App") + } + if len(client.RedirectURIs) != 1 || client.RedirectURIs[0] != "http://127.0.0.1:3000/callback" { + t.Errorf("RedirectURIs = %v, want [http://127.0.0.1:3000/callback]", client.RedirectURIs) + } + if client.TokenEndpointAuthMethod != "none" { + t.Errorf("TokenEndpointAuthMethod = %q, want %q", client.TokenEndpointAuthMethod, "none") + } + if client.CreatedAt.IsZero() { + t.Error("CreatedAt should be set by RegisterClient") + } +} + +func TestGetClientNotFound(t *testing.T) { + store := NewStore() + + _, err := store.GetClient("nonexistent-client-id") + if err == nil { + t.Fatal("expected error for nonexistent client, got nil") + } + if !strings.Contains(err.Error(), "not found") { + t.Errorf("error = %q, want it to contain 'not found'", err.Error()) + } +} + +func TestRandomHex(t *testing.T) { + result := randomHex(32) + + // 32 bytes encoded as hex = 64 characters. + if len(result) != 64 { + t.Fatalf("expected length 64, got %d", len(result)) + } + + // All characters must be valid hex. + if _, err := hex.DecodeString(result); err != nil { + t.Fatalf("result is not valid hex: %v", err) + } + + // Two calls should produce different values (probabilistically). + result2 := randomHex(32) + if result == result2 { + t.Error("two calls to randomHex returned identical values, which is extremely unlikely") + } +} diff --git a/internal/k8s/client_test.go b/internal/k8s/client_test.go new file mode 100644 index 0000000..13088bf --- /dev/null +++ b/internal/k8s/client_test.go @@ -0,0 +1,79 @@ +package k8s + +import ( + "context" + "testing" + "time" +) + +func TestClusterCacheHit(t *testing.T) { + // Manually insert an entry into the cluster cache. + key := "test-project/us-central1/test-cluster" + expectedEndpoint := "10.0.0.1" + expectedCA := []byte("fake-ca-data") + + clusterCacheMu.Lock() + clusterCache[key] = &cachedCluster{ + endpoint: expectedEndpoint, + ca: expectedCA, + cachedAt: time.Now(), + } + clusterCacheMu.Unlock() + + // Clean up after test. + defer func() { + clusterCacheMu.Lock() + delete(clusterCache, key) + clusterCacheMu.Unlock() + }() + + ci := ClusterInfo{ + Project: "test-project", + Location: "us-central1", + Name: "test-cluster", + } + + endpoint, ca, err := getCachedClusterDetails(context.Background(), "fake-token", ci) + if err != nil { + t.Fatalf("getCachedClusterDetails returned error: %v", err) + } + if endpoint != expectedEndpoint { + t.Errorf("endpoint = %q, want %q", endpoint, expectedEndpoint) + } + if string(ca) != string(expectedCA) { + t.Errorf("ca = %q, want %q", string(ca), string(expectedCA)) + } +} + +func TestClusterCacheExpiry(t *testing.T) { + // Insert an entry that is older than the TTL (11 minutes ago). + key := "expired-project/us-east1/expired-cluster" + clusterCacheMu.Lock() + clusterCache[key] = &cachedCluster{ + endpoint: "old-endpoint", + ca: []byte("old-ca"), + cachedAt: time.Now().Add(-11 * time.Minute), + } + clusterCacheMu.Unlock() + + // Clean up after test. + defer func() { + clusterCacheMu.Lock() + delete(clusterCache, key) + clusterCacheMu.Unlock() + }() + + ci := ClusterInfo{ + Project: "expired-project", + Location: "us-east1", + Name: "expired-cluster", + } + + // The cache entry is expired, so getCachedClusterDetails will attempt + // to call the GKE API, which will fail because we have a fake token. + // The important thing is that it did NOT return the stale cached data. + _, _, err := getCachedClusterDetails(context.Background(), "fake-token", ci) + if err == nil { + t.Fatal("expected error when cache is expired and GKE API call fails, got nil") + } +} diff --git a/internal/tools/tools_test.go b/internal/tools/tools_test.go new file mode 100644 index 0000000..0f043f2 --- /dev/null +++ b/internal/tools/tools_test.go @@ -0,0 +1,224 @@ +package tools + +import ( + "errors" + "testing" + + "github.com/mark3labs/mcp-go/mcp" +) + +func TestStr(t *testing.T) { + tests := []struct { + name string + args map[string]interface{} + key string + want string + }{ + { + name: "key exists with string value", + args: map[string]interface{}{"project": "my-project"}, + key: "project", + want: "my-project", + }, + { + name: "key exists with non-string value", + args: map[string]interface{}{"count": 42}, + key: "count", + want: "", + }, + { + name: "key missing", + args: map[string]interface{}{}, + key: "project", + want: "", + }, + { + name: "nil map value", + args: map[string]interface{}{"ns": nil}, + key: "ns", + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := str(tt.args, tt.key) + if got != tt.want { + t.Errorf("str(%v, %q) = %q, want %q", tt.args, tt.key, got, tt.want) + } + }) + } +} + +func TestNum(t *testing.T) { + tests := []struct { + name string + args map[string]interface{} + key string + def float64 + want int64 + }{ + { + name: "key exists with float64 value", + args: map[string]interface{}{"tail_lines": float64(50)}, + key: "tail_lines", + def: 100, + want: 50, + }, + { + name: "key missing returns default", + args: map[string]interface{}{}, + key: "tail_lines", + def: 100, + want: 100, + }, + { + name: "key exists with wrong type returns default", + args: map[string]interface{}{"tail_lines": "not-a-number"}, + key: "tail_lines", + def: 200, + want: 200, + }, + { + name: "zero value", + args: map[string]interface{}{"lines": float64(0)}, + key: "lines", + def: 100, + want: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := num(tt.args, tt.key, tt.def) + if got != tt.want { + t.Errorf("num(%v, %q, %v) = %d, want %d", tt.args, tt.key, tt.def, got, tt.want) + } + }) + } +} + +func TestClusterInfoValid(t *testing.T) { + args := map[string]interface{}{ + "project": "my-project-123", + "location": "us-central1", + "cluster": "my-cluster", + } + + ci, err := clusterInfo(args) + if err != nil { + t.Fatalf("clusterInfo returned unexpected error: %v", err) + } + if ci.Project != "my-project-123" { + t.Errorf("Project = %q, want %q", ci.Project, "my-project-123") + } + if ci.Location != "us-central1" { + t.Errorf("Location = %q, want %q", ci.Location, "us-central1") + } + if ci.Name != "my-cluster" { + t.Errorf("Name = %q, want %q", ci.Name, "my-cluster") + } +} + +func TestClusterInfoInvalidProject(t *testing.T) { + args := map[string]interface{}{ + "project": "INVALID", + "location": "us-central1", + "cluster": "my-cluster", + } + + _, err := clusterInfo(args) + if err == nil { + t.Fatal("expected error for invalid project, got nil") + } + if got := err.Error(); !contains(got, "invalid project") { + t.Errorf("error = %q, want it to contain %q", got, "invalid project") + } +} + +func TestClusterInfoInvalidLocation(t *testing.T) { + args := map[string]interface{}{ + "project": "my-project-123", + "location": "bad", + "cluster": "my-cluster", + } + + _, err := clusterInfo(args) + if err == nil { + t.Fatal("expected error for invalid location, got nil") + } + if got := err.Error(); !contains(got, "invalid location") { + t.Errorf("error = %q, want it to contain %q", got, "invalid location") + } +} + +func TestClusterInfoInvalidCluster(t *testing.T) { + args := map[string]interface{}{ + "project": "my-project-123", + "location": "us-central1", + "cluster": "", + } + + _, err := clusterInfo(args) + if err == nil { + t.Fatal("expected error for invalid cluster, got nil") + } + if got := err.Error(); !contains(got, "invalid cluster") { + t.Errorf("error = %q, want it to contain %q", got, "invalid cluster") + } +} + +func TestErrResult(t *testing.T) { + result, err := errResult("something went wrong: %s", "details") + if err != nil { + t.Fatalf("errResult returned non-nil error: %v", err) + } + if result == nil { + t.Fatal("errResult returned nil CallToolResult") + } + if !result.IsError { + t.Error("expected IsError to be true") + } +} + +func TestSafeErr(t *testing.T) { + testErr := errors.New("connection refused") + result, err := safeErr("failed to connect", testErr) + if err != nil { + t.Fatalf("safeErr returned non-nil error: %v", err) + } + if result == nil { + t.Fatal("safeErr returned nil CallToolResult") + } + if !result.IsError { + t.Error("expected IsError to be true") + } + + // Verify the result contains the error message. + found := false + for _, c := range result.Content { + if tc, ok := c.(mcp.TextContent); ok { + if contains(tc.Text, "connection refused") { + found = true + break + } + } + } + if !found { + t.Error("expected result content to contain the error message 'connection refused'") + } +} + +// contains is a small helper to check substring presence. +func contains(s, substr string) bool { + return len(s) >= len(substr) && searchSubstring(s, substr) +} + +func searchSubstring(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +} From c467a6f738186a4e4a8015a22d771e8129f8d13c Mon Sep 17 00:00:00 2001 From: Aditya Kumar Date: Tue, 14 Apr 2026 00:16:25 +0900 Subject: [PATCH 3/3] Test the code in all branches --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b6cd196..f8d3fd2 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -4,7 +4,7 @@ on: push: branches: ['*'] pull_request: - branches: [main] + branches: ['*'] jobs: build-and-test: