Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions runs/migrations/sql/20260618120000_add_actions_created_by.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
-- Add created_by to actions: the OIDC subject of the identity that created the run.
-- Captured from the auth headers the load balancer forwards (it enforces auth),
-- and used to populate ActionMetadata.executed_by on read.
ALTER TABLE actions ADD COLUMN IF NOT EXISTS created_by VARCHAR(255);

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed: the column is now TEXT (the migration comment notes the OIDC sub length is IdP-dependent and can exceed 255).

6 changes: 3 additions & 3 deletions runs/repository/impl/action.go
Original file line number Diff line number Diff line change
Expand Up @@ -224,13 +224,13 @@ func (r *actionRepo) CreateAction(ctx context.Context, action *models.Action, up
}

result, err := tx.ExecContext(ctx,
`INSERT INTO actions (project, domain, run_name, name, parent_action_name, phase, run_source, action_type, action_group, task_project, task_domain, task_name, task_version, task_type, task_short_name, function_name, environment_name, action_spec, action_details, detailed_info, run_spec, attempts, cache_status, trigger_name, trigger_task_name, trigger_revision, created_at, ended_at, duration_ms)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28, CASE WHEN $28::timestamptz IS NOT NULL THEN EXTRACT(EPOCH FROM (GREATEST($28::timestamptz, $27) - $27)) * 1000 ELSE NULL END)
`INSERT INTO actions (project, domain, run_name, name, parent_action_name, phase, run_source, action_type, action_group, task_project, task_domain, task_name, task_version, task_type, task_short_name, function_name, environment_name, action_spec, action_details, detailed_info, run_spec, attempts, cache_status, trigger_name, trigger_task_name, trigger_revision, created_at, ended_at, created_by, duration_ms)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, CASE WHEN $28::timestamptz IS NOT NULL THEN EXTRACT(EPOCH FROM (GREATEST($28::timestamptz, $27) - $27)) * 1000 ELSE NULL END)
ON CONFLICT DO NOTHING`,
action.Project, action.Domain, action.RunName, action.Name, action.ParentActionName, action.Phase, action.RunSource, action.ActionType, action.ActionGroup,
action.TaskProject, action.TaskDomain, action.TaskName, action.TaskVersion, action.TaskType, action.TaskShortName, action.FunctionName, action.EnvironmentName,
action.ActionSpec, action.ActionDetails, action.DetailedInfo, action.RunSpec, action.Attempts, action.CacheStatus,
action.TriggerName, action.TriggerTaskName, action.TriggerRevision, createdAt, action.EndedAt)
action.TriggerName, action.TriggerTaskName, action.TriggerRevision, createdAt, action.EndedAt, action.CreatedBy)
if err != nil {
return nil, fmt.Errorf("failed to create action: %w", err)
}
Expand Down
5 changes: 5 additions & 0 deletions runs/repository/models/action.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ type Action struct {
// Who initiated this run(web, CLI, scheduler, etc.)
RunSource string `db:"run_source" json:"run_source,omitempty"`

// CreatedBy is the OIDC subject of the identity that created this run, captured
// from the auth headers the load balancer forwards. Surfaced as
// ActionMetadata.executed_by. NULL for runs created without an authenticated identity.
CreatedBy sql.NullString `db:"created_by" json:"created_by,omitempty"`

// Trigger fields — only set for runs created via RUN_SOURCE_SCHEDULE_TRIGGER.
TriggerTaskName sql.NullString `db:"trigger_task_name"`
TriggerName sql.NullString `db:"trigger_name"`
Expand Down
72 changes: 72 additions & 0 deletions runs/service/identity.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package service

import (
"encoding/base64"
"encoding/json"
"net/http"
"strings"

"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/common"
)

const (
// albIdentityHeader is set by ALB authenticate-oidc (browser/cookie path) and
// carries the OIDC subject (`sub`) directly.
albIdentityHeader = "X-Amzn-Oidc-Identity"
// authorizationHeader carries the Bearer token on the JWT-validation path
// (SDK/CLI). The load balancer validates it and forwards it unchanged.
authorizationHeader = "Authorization"
bearerPrefix = "Bearer "
)

// subjectFromHeaders extracts the authenticated subject (OIDC `sub`) from the auth
// headers the load balancer forwards. Auth is enforced upstream (e.g. ALB OIDC /
// JWT validation), so the claims are trusted and only decoded here — not
// re-verified. Returns "" when no authenticated identity is present.
func subjectFromHeaders(h http.Header) string {
// authenticate-oidc (browser/cookie) path: subject is forwarded directly.
if sub := strings.TrimSpace(h.Get(albIdentityHeader)); sub != "" {
return sub
}
// JWT (SDK/CLI) path: read the `sub` claim from the forwarded Bearer token.
return subjectFromBearer(h.Get(authorizationHeader))
}

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in e7dfbf10: gated behind Config.TrustForwardedIdentityHeaders (default true). When false, the forwarded Authorization / X-Amzn-Oidc-* headers are not trusted and executed_by is left unset, so a direct caller can't spoof attribution. The runs service does no authz of its own, so this flag is the explicit "prove the request came through a trusted proxy" guard — operators whose service can be reached without one set it false. See runs/config/config.go and run_service.go (s.trustHeaders).


// subjectFromBearer returns the `sub` claim of a Bearer JWT without verifying its
// signature (the load balancer already validated it). Returns "" on any malformed input.
func subjectFromBearer(authz string) string {
if len(authz) <= len(bearerPrefix) || !strings.EqualFold(authz[:len(bearerPrefix)], bearerPrefix) {
return ""
}
parts := strings.Split(strings.TrimSpace(authz[len(bearerPrefix):]), ".")
if len(parts) != 3 {
return ""
}
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return ""
}
var claims struct {
Sub string `json:"sub"`
}
if err := json.Unmarshal(payload, &claims); err != nil {
return ""
}
return claims.Sub
}

// subjectOnlyIdentity builds a minimal EnrichedIdentity carrying just the subject.
// Mirrors the cloud transformer fallback; the standalone runs service has no
// identity service to enrich the subject into full user details (email, name, groups).
func subjectOnlyIdentity(subject string) *common.EnrichedIdentity {
if subject == "" {
return nil
}
return &common.EnrichedIdentity{
Principal: &common.EnrichedIdentity_User{
User: &common.User{
Id: &common.UserIdentifier{Subject: subject},
},
},
}
}
67 changes: 67 additions & 0 deletions runs/service/identity_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package service

import (
"encoding/base64"
"net/http"
"testing"

"github.com/stretchr/testify/assert"
)

// bearerToken builds a syntactically valid (unsigned) JWT carrying the given sub claim.
func bearerToken(payloadJSON string) string {
enc := func(s string) string { return base64.RawURLEncoding.EncodeToString([]byte(s)) }
return "Bearer " + enc(`{"alg":"RS256"}`) + "." + enc(payloadJSON) + ".sig"
}

func TestSubjectFromHeaders(t *testing.T) {
tests := []struct {
name string
headers map[string]string
want string
}{
{
name: "amzn oidc identity header (cookie path)",
headers: map[string]string{albIdentityHeader: "okta|user-123"},
want: "okta|user-123",
},
{
name: "amzn oidc identity header is trimmed",
headers: map[string]string{albIdentityHeader: " user-456 "},
want: "user-456",
},
{
name: "bearer token sub claim (jwt path)",
headers: map[string]string{authorizationHeader: bearerToken(`{"sub":"sdk-user-789","email":"a@b.com"}`)},
want: "sdk-user-789",
},
{
name: "amzn header takes precedence over bearer",
headers: map[string]string{
albIdentityHeader: "cookie-user",
authorizationHeader: bearerToken(`{"sub":"bearer-user"}`),
},
want: "cookie-user",
},
{name: "no auth headers", headers: map[string]string{}, want: ""},
{name: "non-bearer authorization", headers: map[string]string{authorizationHeader: "Basic abc"}, want: ""},
{name: "malformed bearer (two segments)", headers: map[string]string{authorizationHeader: "Bearer a.b"}, want: ""},
{name: "bearer without sub", headers: map[string]string{authorizationHeader: bearerToken(`{"email":"a@b.com"}`)}, want: ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h := http.Header{}
for k, v := range tt.headers {
h.Set(k, v)
}
assert.Equal(t, tt.want, subjectFromHeaders(h))
})
}
}

func TestSubjectOnlyIdentity(t *testing.T) {
assert.Nil(t, subjectOnlyIdentity(""))

id := subjectOnlyIdentity("user-123")
assert.Equal(t, "user-123", id.GetUser().GetId().GetSubject())
}
12 changes: 11 additions & 1 deletion runs/service/run_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -325,8 +325,12 @@ func (s *RunService) CreateRun(
}
}

// Capture who created the run from the auth headers the load balancer forwards
// (it enforces auth upstream). Empty when there is no authenticated identity.
createdBy := subjectFromHeaders(req.Header())

// Persist task spec and create run model
run, err := s.persistRunModel(ctx, runId, taskID, taskSpec, inputPrefix, runOutputBase, runSpec, request.GetSource(), triggerName, triggerTaskName, triggerRevision, triggerType)
run, err := s.persistRunModel(ctx, runId, taskID, taskSpec, inputPrefix, runOutputBase, runSpec, request.GetSource(), triggerName, triggerTaskName, triggerRevision, triggerType, createdBy)
if err != nil {
logger.Errorf(ctx, "Failed to create run: %v", err)
return nil, connect.NewError(connect.CodeInternal, err)
Expand Down Expand Up @@ -372,6 +376,7 @@ func (s *RunService) persistRunModel(
triggerName, triggerTaskName string,
triggerRevision int64,
triggerType string,
createdBy string,
) (*models.Run, error) {
// Store task spec and compute digest
info := &workflow.RunInfo{InputsUri: inputPrefix + "/inputs.pb"}
Expand Down Expand Up @@ -443,6 +448,7 @@ func (s *RunService) persistRunModel(
RunSpec: runSpecBytes,
Attempts: 1,
RunSource: source.String(),
CreatedBy: nullStr(createdBy),
TriggerTaskName: nullStr(triggerTaskName),
TriggerName: nullStr(triggerName),
TriggerRevision: sql.NullInt64{Int64: triggerRevision, Valid: triggerRevision != 0},
Expand Down Expand Up @@ -1549,6 +1555,10 @@ func actionMetadataFromModel(action *models.Action) *workflow.ActionMetadata {
}
}

if action.CreatedBy.Valid {
metadata.ExecutedBy = subjectOnlyIdentity(action.CreatedBy.String)
}

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed: added TestActionMetadataFromModel_ExecutedBy (runs/service/executed_by_test.go) covering full identity from executed_by, fallback to subject-only from created_by, corrupt executed_by bytes, and the no-identity (nil) case.


if action.TriggerName.Valid {
metadata.TriggerName = action.TriggerName.String
metadata.TriggerId = &common.TriggerIdentifier{
Expand Down
Loading