Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions docs/en/integrations/trino/tools/trino-execute-sql.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,16 @@ statement against the `source`.
> **Note:** This tool is intended for developer assistant workflows with
> human-in-the-loop and shouldn't be used for production agents.

### User impersonation

Set `impersonateUser: true` to run each statement as a specific Trino user. When
enabled, the tool exposes an additional optional input parameter `trino_user`
whose value is forwarded as the `X-Trino-User` header for that statement only. If
`trino_user` is omitted (or empty), the query runs as the source's configured
user. The connection pool's configured principal (DSN `user` / `accessToken`)
still authenticates the request, so that principal must be authorized to
impersonate on the Trino side.

## Compatible Sources

{{< compatible-sources >}}
Expand All @@ -36,6 +46,7 @@ description: Use this tool to execute sql statement.

| **field** | **type** | **required** | **description** |
|-------------|:------------------------------------------:|:------------:|--------------------------------------------------------------------------------------------------|
| type | string | true | Must be "trino-execute-sql". |
| source | string | true | Name of the source the SQL should execute on. |
| description | string | true | Description of the tool that is passed to the LLM. |
| type | string | true | Must be "trino-execute-sql". |
| source | string | true | Name of the source the SQL should execute on. |
| description | string | true | Description of the tool that is passed to the LLM. |
| impersonateUser | bool | false | When true, adds an optional `trino_user` input parameter forwarded as the `X-Trino-User` header. |
11 changes: 11 additions & 0 deletions docs/en/integrations/trino/tools/trino-sql.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,16 @@ The specified SQL statement is executed as a [prepared statement][trino-prepare]

[trino-prepare]: https://trino.io/docs/current/sql/prepare.html

### User impersonation

Set `impersonateUser: true` to run the statement as a specific Trino user. When
enabled, the tool exposes an additional optional input parameter `trino_user`
whose value is forwarded as the `X-Trino-User` header for that statement only
(it is not bound into the SQL). If `trino_user` is omitted (or empty), the query
runs as the source's configured user. The connection pool's configured principal
(DSN `user` / `accessToken`) still authenticates the request, so that principal
must be authorized to impersonate on the Trino side.

## Compatible Sources

{{< compatible-sources >}}
Expand Down Expand Up @@ -101,3 +111,4 @@ templateParameters:
| statement | string | true | SQL statement to execute on. |
| parameters | [parameters](../../../documentation/configuration/tools/_index.md#specifying-parameters) | false | List of [parameters](../../../documentation/configuration/tools/_index.md#specifying-parameters) that will be inserted into the SQL statement. |
| templateParameters | [templateParameters](../../../documentation/configuration/tools/_index.md#template-parameters) | false | List of [templateParameters](../../../documentation/configuration/tools/_index.md#template-parameters) that will be inserted into the SQL statement before executing prepared statement. |
| impersonateUser | bool | false | When true, adds an optional `trino_user` input parameter forwarded as the `X-Trino-User` header. |
19 changes: 19 additions & 0 deletions internal/sources/trino/trino.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@ import (

const SourceType string = "trino"

// trinoUserHeader is the HTTP header the Trino protocol uses to identify the
// session user. The trino-go-client forwards any query argument whose name
// carries the "X-Trino-" prefix as a request header (and skips it when binding
// positional "?" placeholders), so passing a sql.Named with this name lets a
// single pooled connection run an individual statement as a different user.
const trinoUserHeader string = "X-Trino-User"

// validate interface
var _ sources.SourceConfig = Config{}

Expand Down Expand Up @@ -108,6 +115,18 @@ func (s *Source) TrinoDB() *sql.DB {
return s.Pool
}

// RunSQLAsUser runs statement while impersonating the given Trino user. When
// user is empty it is equivalent to RunSQL. Impersonation attaches the
// X-Trino-User header to this statement only; the connection pool's configured
// principal (DSN user / access token) still authenticates the request, so that
// principal must be authorized to impersonate on the Trino side.
func (s *Source) RunSQLAsUser(ctx context.Context, user, statement string, params []any) (any, error) {
if user != "" {
params = append(params, sql.Named(trinoUserHeader, user))
}
return s.RunSQL(ctx, statement, params)
}

func (s *Source) RunSQL(ctx context.Context, statement string, params []any) (any, error) {
results, err := s.TrinoDB().QueryContext(ctx, statement, params...)
if err != nil {
Expand Down
27 changes: 25 additions & 2 deletions internal/tools/trino/trinoexecutesql/trinoexecutesql.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ import (

const resourceType string = "trino-execute-sql"

// impersonationParamName is the tool input parameter that supplies the Trino
// user to impersonate. It is exposed only when impersonateUser is enabled and
// is forwarded as the X-Trino-User header rather than bound into the SQL.
const impersonationParamName string = "trino_user"

func init() {
if !tools.Register(resourceType, newConfig) {
panic(fmt.Sprintf("tool type %q already registered", resourceType))
Expand All @@ -45,12 +50,14 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
type compatibleSource interface {
TrinoDB() *sql.DB
RunSQL(context.Context, string, []any) (any, error)
RunSQLAsUser(context.Context, string, string, []any) (any, error)
}

type Config struct {
tools.ConfigBase `yaml:",inline"`
Type string `yaml:"type" validate:"required"`
Source string `yaml:"source" validate:"required"`
ImpersonateUser bool `yaml:"impersonateUser"`
Annotations *tools.ToolAnnotations `yaml:"annotations,omitempty"`
}

Expand All @@ -68,6 +75,9 @@ func (cfg Config) Initialize(context.Context) (tools.Tool, error) {

sqlParameter := parameters.NewStringParameter("sql", "The SQL query to execute against the Trino database.")
params := parameters.Parameters{sqlParameter}
if cfg.ImpersonateUser {
params = append(params, parameters.NewStringParameter(impersonationParamName, "The Trino user to impersonate for this query (forwarded as the X-Trino-User header). Optional; if omitted the query runs as the source's configured user.", parameters.WithStringDefault("")))
}

return Tool{
BaseTool: tools.NewBaseTool(
Expand Down Expand Up @@ -96,11 +106,24 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, util.NewClientServerError("source not compatible with this tool", http.StatusInternalServerError, err)
}

sliceParams := params.AsSlice()
sql, ok := sliceParams[0].(string)
paramsMap := params.AsMap()
sql, ok := paramsMap["sql"].(string)
if !ok {
return nil, util.NewAgentError("unable to cast the `sql` input parameter into string", nil)
}

if t.Cfg.ImpersonateUser {
user, ok := paramsMap[impersonationParamName].(string)
if !ok {
return nil, util.NewAgentError(fmt.Sprintf("unable to cast the `%s` input parameter into string", impersonationParamName), nil)
}
res, err := source.RunSQLAsUser(ctx, user, sql, []any{})
if err != nil {
return nil, util.ProcessGeneralError(err)
}
return res, nil
}

res, err := source.RunSQL(ctx, sql, []any{})
if err != nil {
return nil, util.ProcessGeneralError(err)
Expand Down
128 changes: 128 additions & 0 deletions internal/tools/trino/trinoexecutesql/trinoexecutesql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,17 @@
package trinoexecutesql_test

import (
"context"
"database/sql"
"testing"

"github.com/google/go-cmp/cmp"
"github.com/googleapis/mcp-toolbox/internal/server"
"github.com/googleapis/mcp-toolbox/internal/sources"
"github.com/googleapis/mcp-toolbox/internal/testutils"
"github.com/googleapis/mcp-toolbox/internal/tools"
"github.com/googleapis/mcp-toolbox/internal/tools/trino/trinoexecutesql"
"github.com/googleapis/mcp-toolbox/internal/util/parameters"
)

func TestParseFromYamlTrinoExecuteSQL(t *testing.T) {
Expand Down Expand Up @@ -58,6 +62,29 @@ func TestParseFromYamlTrinoExecuteSQL(t *testing.T) {
},
},
},
{
desc: "with user impersonation",
in: `
kind: tool
name: example_tool
type: trino-execute-sql
source: my-trino-instance
description: some description
impersonateUser: true
`,
want: server.ToolConfigs{
"example_tool": trinoexecutesql.Config{
ConfigBase: tools.ConfigBase{
Name: "example_tool",
Description: "some description",
AuthRequired: []string{},
},
Type: "trino-execute-sql",
Source: "my-trino-instance",
ImpersonateUser: true,
},
},
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
Expand All @@ -71,3 +98,104 @@ func TestParseFromYamlTrinoExecuteSQL(t *testing.T) {
})
}
}

// mockSource implements the tool's compatibleSource interface and records how
// the tool routed the call (plain RunSQL vs. impersonating RunSQLAsUser).
type mockSource struct {
ranPlain bool
ranAsUser bool
gotUser string
gotStmt string
gotParams []any
}
Comment thread
nakulgan marked this conversation as resolved.

func (m *mockSource) SourceType() string { return "trino" }
func (m *mockSource) ToConfig() sources.SourceConfig { return nil }
func (m *mockSource) TrinoDB() *sql.DB { return nil }

func (m *mockSource) RunSQL(_ context.Context, stmt string, params []any) (any, error) {
m.ranPlain = true
m.gotStmt = stmt
m.gotParams = params
return []any{}, nil
}

func (m *mockSource) RunSQLAsUser(_ context.Context, user, stmt string, params []any) (any, error) {
m.ranAsUser = true
m.gotUser = user
m.gotStmt = stmt
m.gotParams = params
return []any{}, nil
}

type mockSourceProvider struct {
tools.SourceProvider
source *mockSource
}

func (m *mockSourceProvider) GetSource(string) (sources.Source, bool) { return m.source, true }

func TestInvokeImpersonation(t *testing.T) {
tcs := []struct {
desc string
impersonateUser bool
params parameters.ParamValues
wantAsUser bool
wantUser string
}{
{
desc: "impersonation disabled uses plain RunSQL",
impersonateUser: false,
params: parameters.ParamValues{{Name: "sql", Value: "SELECT 1"}},
wantAsUser: false,
},
{
desc: "impersonation forwards trino_user",
impersonateUser: true,
params: parameters.ParamValues{{Name: "sql", Value: "SELECT 1"}, {Name: "trino_user", Value: "alice@seedtag.com"}},
wantAsUser: true,
wantUser: "alice@seedtag.com",
},
{
desc: "empty trino_user falls back to source user",
impersonateUser: true,
params: parameters.ParamValues{{Name: "sql", Value: "SELECT 1"}, {Name: "trino_user", Value: ""}},
wantAsUser: true,
wantUser: "",
},
}
for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
cfg := trinoexecutesql.Config{
ConfigBase: tools.ConfigBase{Name: "tool", Description: "d"},
Type: "trino-execute-sql",
Source: "s",
ImpersonateUser: tc.impersonateUser,
}
tool, err := cfg.Initialize(context.Background())
if err != nil {
t.Fatalf("initialize: %v", err)
}
src := &mockSource{}
if _, toolErr := tool.Invoke(context.Background(), &mockSourceProvider{source: src}, tc.params, ""); toolErr != nil {
t.Fatalf("invoke: %v", toolErr)
}
if src.ranAsUser != tc.wantAsUser {
t.Errorf("ranAsUser = %v, want %v", src.ranAsUser, tc.wantAsUser)
}
if src.ranPlain == tc.wantAsUser {
t.Errorf("ranPlain = %v, want %v", src.ranPlain, !tc.wantAsUser)
}
if tc.wantAsUser && src.gotUser != tc.wantUser {
t.Errorf("forwarded user = %q, want %q", src.gotUser, tc.wantUser)
}
if src.gotStmt != "SELECT 1" {
t.Errorf("forwarded statement = %q, want %q", src.gotStmt, "SELECT 1")
}
// trino_user must never be bound as a SQL parameter.
if len(src.gotParams) != 0 {
t.Errorf("expected no bind params, got %v", src.gotParams)
}
})
}
}
29 changes: 29 additions & 0 deletions internal/tools/trino/trinosql/trinosql.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ import (

const resourceType string = "trino-sql"

// impersonationParamName is the tool input parameter that supplies the Trino
// user to impersonate. It is exposed only when impersonateUser is enabled and
// is forwarded as the X-Trino-User header rather than bound into the statement.
const impersonationParamName string = "trino_user"

func init() {
if !tools.Register(resourceType, newConfig) {
panic(fmt.Sprintf("tool type %q already registered", resourceType))
Expand All @@ -45,6 +50,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.T
type compatibleSource interface {
TrinoDB() *sql.DB
RunSQL(context.Context, string, []any) (any, error)
RunSQLAsUser(context.Context, string, string, []any) (any, error)
}

type Config struct {
Expand All @@ -54,6 +60,7 @@ type Config struct {
Statement string `yaml:"statement" validate:"required"`
Parameters parameters.Parameters `yaml:"parameters"`
TemplateParameters parameters.Parameters `yaml:"templateParameters"`
ImpersonateUser bool `yaml:"impersonateUser"`
Annotations *tools.ToolAnnotations `yaml:"annotations,omitempty"`
}

Expand All @@ -74,6 +81,15 @@ func (cfg Config) Initialize(context.Context) (tools.Tool, error) {
return nil, fmt.Errorf("unable to process parameters: %w", err)
}

// The impersonation user is exposed as a tool input parameter so the caller
// supplies it, but it is forwarded as the X-Trino-User header in Invoke
// rather than bound into the statement (so it stays out of cfg.Parameters).
if cfg.ImpersonateUser {
userParam := parameters.NewStringParameter(impersonationParamName, "The Trino user to impersonate for this query (forwarded as the X-Trino-User header). Optional; if omitted the query runs as the source's configured user.", parameters.WithStringDefault(""))
allParameters = append(allParameters, userParam)
paramManifest = append(paramManifest, userParam.Manifest())
}

return Tool{
BaseTool: tools.NewBaseTool(
cfg,
Expand Down Expand Up @@ -111,6 +127,19 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
return nil, util.NewAgentError("unable to extract standard params", err)
}
sliceParams := newParams.AsSlice()

if t.Cfg.ImpersonateUser {
user, ok := paramsMap[impersonationParamName].(string)
if !ok {
return nil, util.NewAgentError(fmt.Sprintf("unable to cast the `%s` input parameter into string", impersonationParamName), nil)
}
res, err := source.RunSQLAsUser(ctx, user, newStatement, sliceParams)
if err != nil {
return nil, util.ProcessGeneralError(err)
}
return res, nil
}

res, err := source.RunSQL(ctx, newStatement, sliceParams)
if err != nil {
return nil, util.ProcessGeneralError(err)
Expand Down
Loading