From 00192d6ccc70989366b512186c56e2494b92a027 Mon Sep 17 00:00:00 2001 From: Nakul Ganesh S Date: Fri, 26 Jun 2026 10:58:01 +0200 Subject: [PATCH] feat(trino): support per-query user impersonation Add an opt-in `impersonateUser` field to the `trino-sql` and `trino-execute-sql` tools. When enabled, the tool exposes an additional optional `trino_user` input parameter whose value is forwarded as the `X-Trino-User` header for that statement only, letting a single pooled connection run individual queries as different users. If `trino_user` is omitted (or empty), the query runs as the source's configured user. This is implemented on the source via a new `RunSQLAsUser` method that attaches the user as a `sql.Named("X-Trino-User", ...)` query argument; the trino-go-client forwards `X-Trino-`-prefixed arguments as request headers and excludes them from positional placeholder binding, so the impersonation user never consumes a `?` parameter. The pool's configured principal still authenticates the request, so it must be authorized to impersonate on the Trino side. --- .../trino/tools/trino-execute-sql.md | 17 ++- docs/en/integrations/trino/tools/trino-sql.md | 11 ++ internal/sources/trino/trino.go | 19 +++ .../trino/trinoexecutesql/trinoexecutesql.go | 27 +++- .../trinoexecutesql/trinoexecutesql_test.go | 128 +++++++++++++++++ internal/tools/trino/trinosql/trinosql.go | 29 ++++ .../tools/trino/trinosql/trinosql_test.go | 131 ++++++++++++++++++ tests/trino/trino_integration_test.go | 12 ++ 8 files changed, 369 insertions(+), 5 deletions(-) diff --git a/docs/en/integrations/trino/tools/trino-execute-sql.md b/docs/en/integrations/trino/tools/trino-execute-sql.md index ac15553bd49d..d77239961683 100644 --- a/docs/en/integrations/trino/tools/trino-execute-sql.md +++ b/docs/en/integrations/trino/tools/trino-execute-sql.md @@ -18,6 +18,16 @@ statement against the `source`. > **Note:** This tool is intended for developer assistant workflows with > human-in-the-loop and shouldn't be used for production agents. +### User impersonation + +Set `impersonateUser: true` to run each statement as a specific Trino user. When +enabled, the tool exposes an additional optional input parameter `trino_user` +whose value is forwarded as the `X-Trino-User` header for that statement only. If +`trino_user` is omitted (or empty), the query runs as the source's configured +user. The connection pool's configured principal (DSN `user` / `accessToken`) +still authenticates the request, so that principal must be authorized to +impersonate on the Trino side. + ## Compatible Sources {{< compatible-sources >}} @@ -36,6 +46,7 @@ description: Use this tool to execute sql statement. | **field** | **type** | **required** | **description** | |-------------|:------------------------------------------:|:------------:|--------------------------------------------------------------------------------------------------| -| type | string | true | Must be "trino-execute-sql". | -| source | string | true | Name of the source the SQL should execute on. | -| description | string | true | Description of the tool that is passed to the LLM. | +| type | string | true | Must be "trino-execute-sql". | +| source | string | true | Name of the source the SQL should execute on. | +| description | string | true | Description of the tool that is passed to the LLM. | +| impersonateUser | bool | false | When true, adds an optional `trino_user` input parameter forwarded as the `X-Trino-User` header. | diff --git a/docs/en/integrations/trino/tools/trino-sql.md b/docs/en/integrations/trino/tools/trino-sql.md index 486b92ca0285..1ae450151e77 100644 --- a/docs/en/integrations/trino/tools/trino-sql.md +++ b/docs/en/integrations/trino/tools/trino-sql.md @@ -15,6 +15,16 @@ The specified SQL statement is executed as a [prepared statement][trino-prepare] [trino-prepare]: https://trino.io/docs/current/sql/prepare.html +### User impersonation + +Set `impersonateUser: true` to run the statement as a specific Trino user. When +enabled, the tool exposes an additional optional input parameter `trino_user` +whose value is forwarded as the `X-Trino-User` header for that statement only +(it is not bound into the SQL). If `trino_user` is omitted (or empty), the query +runs as the source's configured user. The connection pool's configured principal +(DSN `user` / `accessToken`) still authenticates the request, so that principal +must be authorized to impersonate on the Trino side. + ## Compatible Sources {{< compatible-sources >}} @@ -101,3 +111,4 @@ templateParameters: | statement | string | true | SQL statement to execute on. | | parameters | [parameters](../../../documentation/configuration/tools/_index.md#specifying-parameters) | false | List of [parameters](../../../documentation/configuration/tools/_index.md#specifying-parameters) that will be inserted into the SQL statement. | | templateParameters | [templateParameters](../../../documentation/configuration/tools/_index.md#template-parameters) | false | List of [templateParameters](../../../documentation/configuration/tools/_index.md#template-parameters) that will be inserted into the SQL statement before executing prepared statement. | +| impersonateUser | bool | false | When true, adds an optional `trino_user` input parameter forwarded as the `X-Trino-User` header. | diff --git a/internal/sources/trino/trino.go b/internal/sources/trino/trino.go index 0f37a835409b..d713669ba83f 100644 --- a/internal/sources/trino/trino.go +++ b/internal/sources/trino/trino.go @@ -32,6 +32,13 @@ import ( const SourceType string = "trino" +// trinoUserHeader is the HTTP header the Trino protocol uses to identify the +// session user. The trino-go-client forwards any query argument whose name +// carries the "X-Trino-" prefix as a request header (and skips it when binding +// positional "?" placeholders), so passing a sql.Named with this name lets a +// single pooled connection run an individual statement as a different user. +const trinoUserHeader string = "X-Trino-User" + // validate interface var _ sources.SourceConfig = Config{} @@ -108,6 +115,18 @@ func (s *Source) TrinoDB() *sql.DB { return s.Pool } +// RunSQLAsUser runs statement while impersonating the given Trino user. When +// user is empty it is equivalent to RunSQL. Impersonation attaches the +// X-Trino-User header to this statement only; the connection pool's configured +// principal (DSN user / access token) still authenticates the request, so that +// principal must be authorized to impersonate on the Trino side. +func (s *Source) RunSQLAsUser(ctx context.Context, user, statement string, params []any) (any, error) { + if user != "" { + params = append(params, sql.Named(trinoUserHeader, user)) + } + return s.RunSQL(ctx, statement, params) +} + func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (any, error) { results, err := s.TrinoDB().QueryContext(ctx, statement, params...) if err != nil { diff --git a/internal/tools/trino/trinoexecutesql/trinoexecutesql.go b/internal/tools/trino/trinoexecutesql/trinoexecutesql.go index 06261fc8fb43..a4177553436d 100644 --- a/internal/tools/trino/trinoexecutesql/trinoexecutesql.go +++ b/internal/tools/trino/trinoexecutesql/trinoexecutesql.go @@ -28,6 +28,11 @@ import ( const resourceType string = "trino-execute-sql" +// impersonationParamName is the tool input parameter that supplies the Trino +// user to impersonate. It is exposed only when impersonateUser is enabled and +// is forwarded as the X-Trino-User header rather than bound into the SQL. +const impersonationParamName string = "trino_user" + func init() { if !tools.Register(resourceType, newConfig) { panic(fmt.Sprintf("tool type %q already registered", resourceType)) @@ -45,12 +50,14 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { TrinoDB() *sql.DB RunSQL(context.Context, string, []any) (any, error) + RunSQLAsUser(context.Context, string, string, []any) (any, error) } type Config struct { tools.ConfigBase `yaml:",inline"` Type string `yaml:"type" validate:"required"` Source string `yaml:"source" validate:"required"` + ImpersonateUser bool `yaml:"impersonateUser"` Annotations *tools.ToolAnnotations `yaml:"annotations,omitempty"` } @@ -68,6 +75,9 @@ func (cfg Config) Initialize(context.Context) (tools.Tool, error) { sqlParameter := parameters.NewStringParameter("sql", "The SQL query to execute against the Trino database.") params := parameters.Parameters{sqlParameter} + if cfg.ImpersonateUser { + params = append(params, parameters.NewStringParameter(impersonationParamName, "The Trino user to impersonate for this query (forwarded as the X-Trino-User header). Optional; if omitted the query runs as the source's configured user.", parameters.WithStringDefault(""))) + } return Tool{ BaseTool: tools.NewBaseTool( @@ -96,11 +106,24 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, util.NewClientServerError("source not compatible with this tool", http.StatusInternalServerError, err) } - sliceParams := params.AsSlice() - sql, ok := sliceParams[0].(string) + paramsMap := params.AsMap() + sql, ok := paramsMap["sql"].(string) if !ok { return nil, util.NewAgentError("unable to cast the `sql` input parameter into string", nil) } + + if t.Cfg.ImpersonateUser { + user, ok := paramsMap[impersonationParamName].(string) + if !ok { + return nil, util.NewAgentError(fmt.Sprintf("unable to cast the `%s` input parameter into string", impersonationParamName), nil) + } + res, err := source.RunSQLAsUser(ctx, user, sql, []any{}) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return res, nil + } + res, err := source.RunSQL(ctx, sql, []any{}) if err != nil { return nil, util.ProcessGeneralError(err) diff --git a/internal/tools/trino/trinoexecutesql/trinoexecutesql_test.go b/internal/tools/trino/trinoexecutesql/trinoexecutesql_test.go index ddcea20cab11..95d9941f38e4 100644 --- a/internal/tools/trino/trinoexecutesql/trinoexecutesql_test.go +++ b/internal/tools/trino/trinoexecutesql/trinoexecutesql_test.go @@ -15,13 +15,17 @@ package trinoexecutesql_test import ( + "context" + "database/sql" "testing" "github.com/google/go-cmp/cmp" "github.com/googleapis/mcp-toolbox/internal/server" + "github.com/googleapis/mcp-toolbox/internal/sources" "github.com/googleapis/mcp-toolbox/internal/testutils" "github.com/googleapis/mcp-toolbox/internal/tools" "github.com/googleapis/mcp-toolbox/internal/tools/trino/trinoexecutesql" + "github.com/googleapis/mcp-toolbox/internal/util/parameters" ) func TestParseFromYamlTrinoExecuteSQL(t *testing.T) { @@ -58,6 +62,29 @@ func TestParseFromYamlTrinoExecuteSQL(t *testing.T) { }, }, }, + { + desc: "with user impersonation", + in: ` + kind: tool + name: example_tool + type: trino-execute-sql + source: my-trino-instance + description: some description + impersonateUser: true + `, + want: server.ToolConfigs{ + "example_tool": trinoexecutesql.Config{ + ConfigBase: tools.ConfigBase{ + Name: "example_tool", + Description: "some description", + AuthRequired: []string{}, + }, + Type: "trino-execute-sql", + Source: "my-trino-instance", + ImpersonateUser: true, + }, + }, + }, } for _, tc := range tcs { t.Run(tc.desc, func(t *testing.T) { @@ -71,3 +98,104 @@ func TestParseFromYamlTrinoExecuteSQL(t *testing.T) { }) } } + +// mockSource implements the tool's compatibleSource interface and records how +// the tool routed the call (plain RunSQL vs. impersonating RunSQLAsUser). +type mockSource struct { + ranPlain bool + ranAsUser bool + gotUser string + gotStmt string + gotParams []any +} + +func (m *mockSource) SourceType() string { return "trino" } +func (m *mockSource) ToConfig() sources.SourceConfig { return nil } +func (m *mockSource) TrinoDB() *sql.DB { return nil } + +func (m *mockSource) RunSQL(_ context.Context, stmt string, params []any) (any, error) { + m.ranPlain = true + m.gotStmt = stmt + m.gotParams = params + return []any{}, nil +} + +func (m *mockSource) RunSQLAsUser(_ context.Context, user, stmt string, params []any) (any, error) { + m.ranAsUser = true + m.gotUser = user + m.gotStmt = stmt + m.gotParams = params + return []any{}, nil +} + +type mockSourceProvider struct { + tools.SourceProvider + source *mockSource +} + +func (m *mockSourceProvider) GetSource(string) (sources.Source, bool) { return m.source, true } + +func TestInvokeImpersonation(t *testing.T) { + tcs := []struct { + desc string + impersonateUser bool + params parameters.ParamValues + wantAsUser bool + wantUser string + }{ + { + desc: "impersonation disabled uses plain RunSQL", + impersonateUser: false, + params: parameters.ParamValues{{Name: "sql", Value: "SELECT 1"}}, + wantAsUser: false, + }, + { + desc: "impersonation forwards trino_user", + impersonateUser: true, + params: parameters.ParamValues{{Name: "sql", Value: "SELECT 1"}, {Name: "trino_user", Value: "alice@seedtag.com"}}, + wantAsUser: true, + wantUser: "alice@seedtag.com", + }, + { + desc: "empty trino_user falls back to source user", + impersonateUser: true, + params: parameters.ParamValues{{Name: "sql", Value: "SELECT 1"}, {Name: "trino_user", Value: ""}}, + wantAsUser: true, + wantUser: "", + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + cfg := trinoexecutesql.Config{ + ConfigBase: tools.ConfigBase{Name: "tool", Description: "d"}, + Type: "trino-execute-sql", + Source: "s", + ImpersonateUser: tc.impersonateUser, + } + tool, err := cfg.Initialize(context.Background()) + if err != nil { + t.Fatalf("initialize: %v", err) + } + src := &mockSource{} + if _, toolErr := tool.Invoke(context.Background(), &mockSourceProvider{source: src}, tc.params, ""); toolErr != nil { + t.Fatalf("invoke: %v", toolErr) + } + if src.ranAsUser != tc.wantAsUser { + t.Errorf("ranAsUser = %v, want %v", src.ranAsUser, tc.wantAsUser) + } + if src.ranPlain == tc.wantAsUser { + t.Errorf("ranPlain = %v, want %v", src.ranPlain, !tc.wantAsUser) + } + if tc.wantAsUser && src.gotUser != tc.wantUser { + t.Errorf("forwarded user = %q, want %q", src.gotUser, tc.wantUser) + } + if src.gotStmt != "SELECT 1" { + t.Errorf("forwarded statement = %q, want %q", src.gotStmt, "SELECT 1") + } + // trino_user must never be bound as a SQL parameter. + if len(src.gotParams) != 0 { + t.Errorf("expected no bind params, got %v", src.gotParams) + } + }) + } +} diff --git a/internal/tools/trino/trinosql/trinosql.go b/internal/tools/trino/trinosql/trinosql.go index 8b4162805284..2bb48d9c8cda 100644 --- a/internal/tools/trino/trinosql/trinosql.go +++ b/internal/tools/trino/trinosql/trinosql.go @@ -28,6 +28,11 @@ import ( const resourceType string = "trino-sql" +// impersonationParamName is the tool input parameter that supplies the Trino +// user to impersonate. It is exposed only when impersonateUser is enabled and +// is forwarded as the X-Trino-User header rather than bound into the statement. +const impersonationParamName string = "trino_user" + func init() { if !tools.Register(resourceType, newConfig) { panic(fmt.Sprintf("tool type %q already registered", resourceType)) @@ -45,6 +50,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T type compatibleSource interface { TrinoDB() *sql.DB RunSQL(context.Context, string, []any) (any, error) + RunSQLAsUser(context.Context, string, string, []any) (any, error) } type Config struct { @@ -54,6 +60,7 @@ type Config struct { Statement string `yaml:"statement" validate:"required"` Parameters parameters.Parameters `yaml:"parameters"` TemplateParameters parameters.Parameters `yaml:"templateParameters"` + ImpersonateUser bool `yaml:"impersonateUser"` Annotations *tools.ToolAnnotations `yaml:"annotations,omitempty"` } @@ -74,6 +81,15 @@ func (cfg Config) Initialize(context.Context) (tools.Tool, error) { return nil, fmt.Errorf("unable to process parameters: %w", err) } + // The impersonation user is exposed as a tool input parameter so the caller + // supplies it, but it is forwarded as the X-Trino-User header in Invoke + // rather than bound into the statement (so it stays out of cfg.Parameters). + if cfg.ImpersonateUser { + userParam := parameters.NewStringParameter(impersonationParamName, "The Trino user to impersonate for this query (forwarded as the X-Trino-User header). Optional; if omitted the query runs as the source's configured user.", parameters.WithStringDefault("")) + allParameters = append(allParameters, userParam) + paramManifest = append(paramManifest, userParam.Manifest()) + } + return Tool{ BaseTool: tools.NewBaseTool( cfg, @@ -111,6 +127,19 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para return nil, util.NewAgentError("unable to extract standard params", err) } sliceParams := newParams.AsSlice() + + if t.Cfg.ImpersonateUser { + user, ok := paramsMap[impersonationParamName].(string) + if !ok { + return nil, util.NewAgentError(fmt.Sprintf("unable to cast the `%s` input parameter into string", impersonationParamName), nil) + } + res, err := source.RunSQLAsUser(ctx, user, newStatement, sliceParams) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return res, nil + } + res, err := source.RunSQL(ctx, newStatement, sliceParams) if err != nil { return nil, util.ProcessGeneralError(err) diff --git a/internal/tools/trino/trinosql/trinosql_test.go b/internal/tools/trino/trinosql/trinosql_test.go index 820fdb79504c..b415048ca37a 100644 --- a/internal/tools/trino/trinosql/trinosql_test.go +++ b/internal/tools/trino/trinosql/trinosql_test.go @@ -15,10 +15,13 @@ package trinosql_test import ( + "context" + "database/sql" "testing" "github.com/google/go-cmp/cmp" "github.com/googleapis/mcp-toolbox/internal/server" + "github.com/googleapis/mcp-toolbox/internal/sources" "github.com/googleapis/mcp-toolbox/internal/testutils" "github.com/googleapis/mcp-toolbox/internal/tools" "github.com/googleapis/mcp-toolbox/internal/tools/trino/trinosql" @@ -76,6 +79,32 @@ func TestParseFromYamlTrino(t *testing.T) { }, }, }, + { + desc: "with user impersonation", + in: ` + kind: tool + name: example_tool + type: trino-sql + source: my-trino-instance + description: some description + statement: | + SELECT 1; + impersonateUser: true + `, + want: server.ToolConfigs{ + "example_tool": trinosql.Config{ + ConfigBase: tools.ConfigBase{ + Name: "example_tool", + Description: "some description", + AuthRequired: []string{}, + }, + Type: "trino-sql", + Source: "my-trino-instance", + Statement: "SELECT 1;\n", + ImpersonateUser: true, + }, + }, + }, } for _, tc := range tcs { t.Run(tc.desc, func(t *testing.T) { @@ -177,3 +206,105 @@ func TestParseFromYamlWithTemplateParamsTrino(t *testing.T) { }) } } + +// mockSource implements the tool's compatibleSource interface and records how +// the tool routed the call (plain RunSQL vs. impersonating RunSQLAsUser). +type mockSource struct { + ranPlain bool + ranAsUser bool + gotUser string + gotStmt string + gotParams []any +} + +func (m *mockSource) SourceType() string { return "trino" } +func (m *mockSource) ToConfig() sources.SourceConfig { return nil } +func (m *mockSource) TrinoDB() *sql.DB { return nil } + +func (m *mockSource) RunSQL(_ context.Context, stmt string, params []any) (any, error) { + m.ranPlain = true + m.gotStmt = stmt + m.gotParams = params + return []any{}, nil +} + +func (m *mockSource) RunSQLAsUser(_ context.Context, user, stmt string, params []any) (any, error) { + m.ranAsUser = true + m.gotUser = user + m.gotStmt = stmt + m.gotParams = params + return []any{}, nil +} + +type mockSourceProvider struct { + tools.SourceProvider + source *mockSource +} + +func (m *mockSourceProvider) GetSource(string) (sources.Source, bool) { return m.source, true } + +func TestInvokeImpersonation(t *testing.T) { + tcs := []struct { + desc string + impersonateUser bool + params parameters.ParamValues + wantAsUser bool + wantUser string + }{ + { + desc: "impersonation disabled uses plain RunSQL", + impersonateUser: false, + params: parameters.ParamValues{}, + wantAsUser: false, + }, + { + desc: "impersonation forwards trino_user", + impersonateUser: true, + params: parameters.ParamValues{{Name: "trino_user", Value: "alice@seedtag.com"}}, + wantAsUser: true, + wantUser: "alice@seedtag.com", + }, + { + desc: "empty trino_user falls back to source user", + impersonateUser: true, + params: parameters.ParamValues{{Name: "trino_user", Value: ""}}, + wantAsUser: true, + wantUser: "", + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + cfg := trinosql.Config{ + ConfigBase: tools.ConfigBase{Name: "tool", Description: "d"}, + Type: "trino-sql", + Source: "s", + Statement: "SELECT 1", + ImpersonateUser: tc.impersonateUser, + } + tool, err := cfg.Initialize(context.Background()) + if err != nil { + t.Fatalf("initialize: %v", err) + } + src := &mockSource{} + if _, toolErr := tool.Invoke(context.Background(), &mockSourceProvider{source: src}, tc.params, ""); toolErr != nil { + t.Fatalf("invoke: %v", toolErr) + } + if src.ranAsUser != tc.wantAsUser { + t.Errorf("ranAsUser = %v, want %v", src.ranAsUser, tc.wantAsUser) + } + if src.ranPlain == tc.wantAsUser { + t.Errorf("ranPlain = %v, want %v", src.ranPlain, !tc.wantAsUser) + } + if tc.wantAsUser && src.gotUser != tc.wantUser { + t.Errorf("forwarded user = %q, want %q", src.gotUser, tc.wantUser) + } + if src.gotStmt != "SELECT 1" { + t.Errorf("forwarded statement = %q, want %q", src.gotStmt, "SELECT 1") + } + // trino_user must never be bound as a SQL parameter. + if len(src.gotParams) != 0 { + t.Errorf("expected no bind params, got %v", src.gotParams) + } + }) + } +} diff --git a/tests/trino/trino_integration_test.go b/tests/trino/trino_integration_test.go index b0acafd959c0..21d9989210b6 100644 --- a/tests/trino/trino_integration_test.go +++ b/tests/trino/trino_integration_test.go @@ -204,6 +204,12 @@ func addTrinoExecuteSqlConfig(t *testing.T, config map[string]any) map[string]an "my-google-auth", }, } + tools["my-impersonate-exec-sql-tool"] = map[string]any{ + "type": "trino-execute-sql", + "source": "my-instance", + "description": "Tool to execute sql impersonating a Trino user", + "impersonateUser": true, + } config["tools"] = tools return config } @@ -264,4 +270,10 @@ func TestTrinoToolEndpoints(t *testing.T) { tests.RunMCPToolCallMethod(t, mcpMyFailToolWant, mcpSelect1Want) tests.RunExecuteSqlToolInvokeTest(t, createTableStatement, select1Want) tests.RunToolInvokeWithTemplateParameters(t, tableNameTemplateParam, tests.WithInsert1Want(`[{"rows":1}]`)) + + // Verify user impersonation: the `trino_user` parameter is forwarded as the + // X-Trino-User header, so the query runs as that user (current_user reflects + // it rather than the source's configured user). + impersonateParams := []byte(`{"sql": "SELECT current_user", "trino_user": "impersonated_user"}`) + tests.RunToolInvokeParametersTest(t, "my-impersonate-exec-sql-tool", impersonateParams, "impersonated_user") }