Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions runs/migrations/sql/20260618120000_add_actions_created_by.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
-- Add created_by to actions: the OIDC subject of the identity that created the run.
-- Captured from the auth headers the load balancer forwards (it enforces auth),
-- and used to populate ActionMetadata.executed_by on read.
ALTER TABLE actions ADD COLUMN IF NOT EXISTS created_by VARCHAR(255);

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed: the column is now TEXT (the migration comment notes the OIDC sub length is IdP-dependent and can exceed 255).

Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
-- Add executed_by to actions: the serialized common.EnrichedIdentity of the run's
-- creator, captured from the OIDC claims the load balancer forwards (subject plus
-- name/email when present). created_by keeps the bare subject for querying; this
-- column carries the full identity surfaced as ActionMetadata.executed_by.
ALTER TABLE actions ADD COLUMN IF NOT EXISTS executed_by BYTEA;
6 changes: 3 additions & 3 deletions runs/repository/impl/action.go
Original file line number Diff line number Diff line change
Expand Up @@ -224,13 +224,13 @@ func (r *actionRepo) CreateAction(ctx context.Context, action *models.Action, up
}

result, err := tx.ExecContext(ctx,
`INSERT INTO actions (project, domain, run_name, name, parent_action_name, phase, run_source, action_type, action_group, task_project, task_domain, task_name, task_version, task_type, task_short_name, function_name, environment_name, action_spec, action_details, detailed_info, run_spec, attempts, cache_status, trigger_name, trigger_task_name, trigger_revision, created_at, ended_at, duration_ms)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28, CASE WHEN $28::timestamptz IS NOT NULL THEN EXTRACT(EPOCH FROM (GREATEST($28::timestamptz, $27) - $27)) * 1000 ELSE NULL END)
`INSERT INTO actions (project, domain, run_name, name, parent_action_name, phase, run_source, action_type, action_group, task_project, task_domain, task_name, task_version, task_type, task_short_name, function_name, environment_name, action_spec, action_details, detailed_info, run_spec, attempts, cache_status, trigger_name, trigger_task_name, trigger_revision, created_at, ended_at, created_by, executed_by, duration_ms)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, CASE WHEN $28::timestamptz IS NOT NULL THEN EXTRACT(EPOCH FROM (GREATEST($28::timestamptz, $27) - $27)) * 1000 ELSE NULL END)
ON CONFLICT DO NOTHING`,
action.Project, action.Domain, action.RunName, action.Name, action.ParentActionName, action.Phase, action.RunSource, action.ActionType, action.ActionGroup,
action.TaskProject, action.TaskDomain, action.TaskName, action.TaskVersion, action.TaskType, action.TaskShortName, action.FunctionName, action.EnvironmentName,
action.ActionSpec, action.ActionDetails, action.DetailedInfo, action.RunSpec, action.Attempts, action.CacheStatus,
action.TriggerName, action.TriggerTaskName, action.TriggerRevision, createdAt, action.EndedAt)
action.TriggerName, action.TriggerTaskName, action.TriggerRevision, createdAt, action.EndedAt, action.CreatedBy, action.ExecutedBy)
if err != nil {
return nil, fmt.Errorf("failed to create action: %w", err)
}
Expand Down
10 changes: 10 additions & 0 deletions runs/repository/models/action.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,16 @@ type Action struct {
// Who initiated this run(web, CLI, scheduler, etc.)
RunSource string `db:"run_source" json:"run_source,omitempty"`

// CreatedBy is the OIDC subject of the identity that created this run, captured
// from the auth headers the load balancer forwards. Kept for querying/filtering.
// NULL for runs created without an authenticated identity.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clarified the field comment: created_by is intentionally not in ActionColumnsSet. It's persisted for internal querying/attribution only, not user-supplied list filters/sorts — left out to avoid widening the API surface. The comment now says to add it there only if API-level filtering is desired.

CreatedBy sql.NullString `db:"created_by" json:"created_by,omitempty"`

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Intentional: created_by is persisted for internal querying/attribution, not API-level filtering, so it's deliberately omitted from ActionColumnsSet. I clarified the field comment to say so (add it there only if API filtering is wanted).


// ExecutedBy is the serialized common.EnrichedIdentity of the run's creator
// (subject plus name/email when the forwarded OIDC claims include them).
// Surfaced directly as ActionMetadata.executed_by. NULL for unauthenticated runs.
ExecutedBy []byte `db:"executed_by" json:"executed_by,omitempty"`

// Trigger fields — only set for runs created via RUN_SOURCE_SCHEDULE_TRIGGER.
TriggerTaskName sql.NullString `db:"trigger_task_name"`
TriggerName sql.NullString `db:"trigger_name"`
Expand Down
99 changes: 99 additions & 0 deletions runs/service/identity.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
package service

import (
"encoding/base64"
"encoding/json"
"net/http"
"strings"

"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/common"
)

const (
// albDataHeader is the signed JWT of user claims set by ALB authenticate-oidc
// (browser/cookie path). Its payload carries sub, email, given_name, family_name.
albDataHeader = "X-Amzn-Oidc-Data"
// albIdentityHeader is also set by ALB authenticate-oidc and carries the OIDC
// subject (`sub`) directly — used as a fallback when the data header is absent.
albIdentityHeader = "X-Amzn-Oidc-Identity"
// authorizationHeader carries the Bearer token on the JWT-validation path
// (SDK/CLI). The load balancer validates it and forwards it unchanged.
authorizationHeader = "Authorization"
bearerPrefix = "Bearer "
)

// oidcClaims is the subset of OIDC claims we surface as the executing identity.
type oidcClaims struct {
Sub string `json:"sub"`
Email string `json:"email"`
GivenName string `json:"given_name"`
FamilyName string `json:"family_name"`
}

// identityFromHeaders builds the EnrichedIdentity of the caller from the auth headers
// the load balancer forwards. Auth is enforced upstream (e.g. ALB OIDC / JWT
// validation), so the claims are trusted and only decoded here — not re-verified.
// Returns nil when no authenticated identity is present.
func identityFromHeaders(h http.Header) *common.EnrichedIdentity {
Comment on lines +33 to +37

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in e7dfbf10: added Config.TrustForwardedIdentityHeaders (default true). When false, the decoded-but-unverified forwarded JWT claims are not trusted and executed_by is left unset. We rely on the upstream proxy/LB to validate tokens and strip client-supplied headers rather than verifying against the issuer's JWKS here.

// authenticate-oidc (browser/cookie) path: full claims in the signed data JWT.
if id := identityFromJWT(h.Get(albDataHeader)); id != nil {
return id
}
// Same path, subject only — when the data header is unavailable.
if sub := strings.TrimSpace(h.Get(albIdentityHeader)); sub != "" {
return subjectOnlyIdentity(sub)
}
// JWT (SDK/CLI) path: decode the forwarded Bearer token's claims.
if authz := h.Get(authorizationHeader); len(authz) > len(bearerPrefix) &&
strings.EqualFold(authz[:len(bearerPrefix)], bearerPrefix) {
return identityFromJWT(strings.TrimSpace(authz[len(bearerPrefix):]))
}
return nil
}

// identityFromJWT decodes a JWT's claims payload (without verifying the signature —
// the load balancer already validated it) into an EnrichedIdentity. Returns nil on
// any malformed input or when no subject claim is present.
func identityFromJWT(token string) *common.EnrichedIdentity {
parts := strings.Split(token, ".")
if len(parts) != 3 {
return nil
}
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return nil
}
var c oidcClaims
if err := json.Unmarshal(payload, &c); err != nil || c.Sub == "" {
return nil
}
id := subjectOnlyIdentity(c.Sub)
if c.Email != "" || c.GivenName != "" || c.FamilyName != "" {
id.GetUser().Spec = &common.UserSpec{
FirstName: c.GivenName,
LastName: c.FamilyName,
Email: c.Email,
}
}
return id
}

// subjectOnlyIdentity builds a minimal EnrichedIdentity carrying just the subject.
// Mirrors the cloud transformer fallback; used when only the subject is available.
func subjectOnlyIdentity(subject string) *common.EnrichedIdentity {
if subject == "" {
return nil
}
return &common.EnrichedIdentity{
Principal: &common.EnrichedIdentity_User{
User: &common.User{
Id: &common.UserIdentifier{Subject: subject},
},
},
}
}

// identitySubject returns the subject of an EnrichedIdentity, or "" if absent.
func identitySubject(id *common.EnrichedIdentity) string {
return id.GetUser().GetId().GetSubject()
}
182 changes: 182 additions & 0 deletions runs/service/identity_enricher.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
package service

import (
"context"
"encoding/json"
"fmt"
"net/http"
"strings"
"sync"
"time"

"github.com/flyteorg/flyte/v2/flytestdlib/logger"
"github.com/flyteorg/flyte/v2/gen/go/flyteidl2/common"
)

const (
userinfoHTTPTimeout = 3 * time.Second
identityCacheTTL = 10 * time.Minute
oidcDiscoveryPath = "/.well-known/openid-configuration"
albAccessTokenHdr = "X-Amzn-Oidc-Accesstoken"
)

// identityEnricher fills in a caller's profile (email, first/last name) by calling
// the OIDC userinfo endpoint with their access token. It is needed on the Bearer
// path, where the access token carries only the subject — the profile claims live
// in userinfo, not the token. Results are cached by subject. Every failure mode is
// best-effort: the caller's unenriched (subject-only) identity is returned instead.
type identityEnricher struct {
authServerBaseURL string
httpClient *http.Client

mu sync.Mutex
userinfoURL string // resolved lazily from OIDC discovery, then cached
cache map[string]cachedIdentity
}

type cachedIdentity struct {
id *common.EnrichedIdentity
expires time.Time
}

// newIdentityEnricher returns an enricher for the given OAuth2 authorization-server
// base URL (e.g. https://signin.example.com/oauth2/default), or nil when unset —
// in which case enrich is a no-op and identities stay subject-only.
func newIdentityEnricher(authServerBaseURL string) *identityEnricher {
if authServerBaseURL == "" {
return nil
}
return &identityEnricher{
authServerBaseURL: strings.TrimRight(authServerBaseURL, "/"),
httpClient: &http.Client{Timeout: userinfoHTTPTimeout},
cache: map[string]cachedIdentity{},
}
}

// enrich augments base with profile claims fetched from userinfo when base lacks
// them and an access token is available. base is returned unchanged on any miss,
// cache hit without profile, or error — enrichment never blocks or fails run creation.
func (e *identityEnricher) enrich(ctx context.Context, accessToken string, base *common.EnrichedIdentity) *common.EnrichedIdentity {
if e == nil || base.GetUser() == nil || hasProfile(base) || accessToken == "" {
return base
}
subject := base.GetUser().GetId().GetSubject()
if cached := e.cachedFor(subject); cached != nil {
return cached
}
claims, err := e.fetchUserinfo(ctx, accessToken)
if err != nil {
logger.Warnf(ctx, "identity enrichment: userinfo fetch failed for subject %q: %v", subject, err)
return base
}
enriched := mergeClaims(base, claims)
e.store(subject, enriched)
return enriched
}

func (e *identityEnricher) cachedFor(subject string) *common.EnrichedIdentity {
e.mu.Lock()
defer e.mu.Unlock()
if c, ok := e.cache[subject]; ok && time.Now().Before(c.expires) {
return c.id
}
return nil
}

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in e7dfbf10: cachedFor deletes the stale entry when it detects expiry, and store sweeps expired entries — dead keys no longer accumulate.

Comment on lines +92 to +105

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in e7dfbf10: expired entries are now deleted in cachedFor (on detection) and swept in store, so the map doesn't accumulate dead keys.

Comment on lines +92 to +105

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in e7dfbf10: cachedFor deletes the expired entry on detection and store sweeps expired entries, so the map no longer grows with dead keys.


func (e *identityEnricher) store(subject string, id *common.EnrichedIdentity) {
e.mu.Lock()
defer e.mu.Unlock()
e.cache[subject] = cachedIdentity{id: id, expires: time.Now().Add(identityCacheTTL)}
}
Comment on lines +107 to +118

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

e7dfbf10 evicts expired entries (in cachedFor and store), so dead keys don't accumulate. I didn't add a hard max-size / LRU: the cache only holds distinct run-creators within the 10m TTL (bounded in practice) and entries are swept on expiry. A bounded/LRU cache is a reasonable follow-up if we see high-cardinality bursts within the TTL — can add if you think it's warranted.


func (e *identityEnricher) fetchUserinfo(ctx context.Context, accessToken string) (*oidcClaims, error) {
url, err := e.resolveUserinfoURL(ctx)
if err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, err
}
req.Header.Set(authorizationHeader, bearerPrefix+accessToken)
resp, err := e.httpClient.Do(req)
if err != nil {
return nil, err
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("userinfo returned status %d", resp.StatusCode)
}
var c oidcClaims
if err := json.NewDecoder(resp.Body).Decode(&c); err != nil {
return nil, fmt.Errorf("decode userinfo: %w", err)
}
return &c, nil
}

// resolveUserinfoURL reads userinfo_endpoint from the OIDC discovery document once,
// then caches it for the life of the process.
func (e *identityEnricher) resolveUserinfoURL(ctx context.Context) (string, error) {
e.mu.Lock()
cached := e.userinfoURL
e.mu.Unlock()
if cached != "" {
return cached, nil
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, e.authServerBaseURL+oidcDiscoveryPath, nil)
if err != nil {
return "", err
}
resp, err := e.httpClient.Do(req)
if err != nil {
return "", err
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("oidc discovery returned status %d", resp.StatusCode)
}
var doc struct {
UserinfoEndpoint string `json:"userinfo_endpoint"`
}
if err := json.NewDecoder(resp.Body).Decode(&doc); err != nil {
return "", fmt.Errorf("decode oidc discovery: %w", err)
}
if doc.UserinfoEndpoint == "" {
return "", fmt.Errorf("oidc discovery has no userinfo_endpoint")
}
e.mu.Lock()
e.userinfoURL = doc.UserinfoEndpoint
e.mu.Unlock()
return doc.UserinfoEndpoint, nil
}

// hasProfile reports whether the identity already carries any profile field.
func hasProfile(id *common.EnrichedIdentity) bool {
s := id.GetUser().GetSpec()
return s.GetEmail() != "" || s.GetFirstName() != "" || s.GetLastName() != ""
}

// mergeClaims sets base's user spec from the fetched claims when any are present.
func mergeClaims(base *common.EnrichedIdentity, c *oidcClaims) *common.EnrichedIdentity {
if c == nil || base.GetUser() == nil {
return base
}
if c.Email != "" || c.GivenName != "" || c.FamilyName != "" {
base.GetUser().Spec = &common.UserSpec{
FirstName: c.GivenName,
LastName: c.FamilyName,
Email: c.Email,
}
}
return base
}

// accessTokenFromHeaders returns the caller's access token: the forwarded Bearer
// token (SDK/JWT path) or the ALB-provided access token (cookie path).
func accessTokenFromHeaders(h http.Header) string {
if authz := h.Get(authorizationHeader); len(authz) > len(bearerPrefix) &&
strings.EqualFold(authz[:len(bearerPrefix)], bearerPrefix) {
return strings.TrimSpace(authz[len(bearerPrefix):])
}
return strings.TrimSpace(h.Get(albAccessTokenHdr))
}

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By design accessTokenFromHeaders reads only the Bearer token — the cookie/ALB path isn't userinfo-enriched (its forwarded access token is already expired by request time; it uses the claims ALB injects into x-amzn-oidc-data, which is a complete profile, so enrich short-circuits on isCompleteProfile). And if a mismatched Bearer were ever present, its userinfo sub wouldn't match the caller's subject and is now rejected by the subject-mismatch guard (e7dfbf10). So the wrong profile can't be associated; I kept the single-source token read rather than adding ALB-vs-Authorization precedence.

Loading
Loading