diff --git a/openshift-kube-apiserver/admission/customresourcevalidation/authentication/validate_authentication.go b/openshift-kube-apiserver/admission/customresourcevalidation/authentication/validate_authentication.go index 26506e47019cf..0d109374a415b 100644 --- a/openshift-kube-apiserver/admission/customresourcevalidation/authentication/validate_authentication.go +++ b/openshift-kube-apiserver/admission/customresourcevalidation/authentication/validate_authentication.go @@ -1,22 +1,51 @@ package authentication import ( + "cmp" "context" "fmt" "io" + "math" + "slices" + "time" + "golang.org/x/sync/singleflight" "k8s.io/apimachinery/pkg/api/validation" "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/apimachinery/pkg/util/validation/field" "k8s.io/apiserver/pkg/admission" + "k8s.io/apiserver/pkg/cel/library" + "k8s.io/apiserver/pkg/warning" + "k8s.io/klog/v2" + "k8s.io/utils/lru" + + "github.com/google/cel-go/checker" configv1 "github.com/openshift/api/config/v1" + authenticationcel "k8s.io/apiserver/pkg/authentication/cel" crvalidation "k8s.io/kubernetes/openshift-kube-apiserver/admission/customresourcevalidation" ) const PluginName = "config.openshift.io/ValidateAuthentication" +const ( + wholeResourceExcessiveCostThreshold = 100000000 + excessiveCompileDuration = time.Second + costlyExpressionWarningCount = 3 + + // This is the default KAS request header size limit in bytes. + // Because JWTs are only limited in size by the maximum request header size, + // we can use this fixed value to make pessimistic size estimates by assuming + // that the inputs were decoded from base64-encoded JSON. + // + // This isn't very precise, but can still be used to provide + // end-users a signal that they are potentially doing very expensive + // operations with CEL expressions whose cost is dependent + // on the size of the input. + fixedSize = 1 << 20 +) + // Register registers a plugin func Register(plugins *admission.Plugins) { plugins.Register(PluginName, func(config io.Reader) (admission.Interface, error) { @@ -25,7 +54,17 @@ func Register(plugins *admission.Plugins) { configv1.GroupVersion.WithResource("authentications").GroupResource(): true, }, map[schema.GroupVersionKind]crvalidation.ObjectValidator{ - configv1.GroupVersion.WithKind("Authentication"): authenticationV1{}, + configv1.GroupVersion.WithKind("Authentication"): authenticationV1{ + cel: &celStore{ + compiledStore: lru.New(100), + compilingGroup: new(singleflight.Group), + compiler: authenticationcel.NewDefaultCompiler(), + sizeEstimator: &fixedSizeEstimator{ + size: fixedSize, + }, + timerFactory: &excessiveCompileTimerFactory{}, + }, + }, }) }) } @@ -46,21 +85,45 @@ func toAuthenticationV1(uncastObj runtime.Object) (*configv1.Authentication, fie return obj, nil } -type authenticationV1 struct{} +type celStore struct { + compilingGroup *singleflight.Group + compiledStore *lru.Cache + compiler authenticationcel.Compiler + sizeEstimator checker.CostEstimator + timerFactory TimerFactory +} + +type TimerFactory interface { + Timer(time.Duration, func()) Timer +} + +type Timer interface { + Stop() bool +} + +type excessiveCompileTimerFactory struct{} + +func (ectf *excessiveCompileTimerFactory) Timer(duration time.Duration, do func()) Timer { + return time.AfterFunc(duration, do) +} + +type authenticationV1 struct { + cel *celStore +} -func (authenticationV1) ValidateCreate(_ context.Context, uncastObj runtime.Object) field.ErrorList { +func (a authenticationV1) ValidateCreate(ctx context.Context, uncastObj runtime.Object) field.ErrorList { obj, errs := toAuthenticationV1(uncastObj) if len(errs) > 0 { return errs } errs = append(errs, validation.ValidateObjectMeta(&obj.ObjectMeta, false, crvalidation.RequireNameCluster, field.NewPath("metadata"))...) - errs = append(errs, validateAuthenticationSpecCreate(obj.Spec)...) + errs = append(errs, validateAuthenticationSpecCreate(ctx, obj.Spec, a.cel)...) return errs } -func (authenticationV1) ValidateUpdate(_ context.Context, uncastObj runtime.Object, uncastOldObj runtime.Object) field.ErrorList { +func (a authenticationV1) ValidateUpdate(ctx context.Context, uncastObj runtime.Object, uncastOldObj runtime.Object) field.ErrorList { obj, errs := toAuthenticationV1(uncastObj) if len(errs) > 0 { return errs @@ -71,7 +134,7 @@ func (authenticationV1) ValidateUpdate(_ context.Context, uncastObj runtime.Obje } errs = append(errs, validation.ValidateObjectMetaUpdate(&obj.ObjectMeta, &oldObj.ObjectMeta, field.NewPath("metadata"))...) - errs = append(errs, validateAuthenticationSpecUpdate(obj.Spec, oldObj.Spec)...) + errs = append(errs, validateAuthenticationSpecUpdate(ctx, obj.Spec, oldObj.Spec, a.cel)...) return errs } @@ -92,15 +155,15 @@ func (authenticationV1) ValidateStatusUpdate(_ context.Context, uncastObj runtim return errs } -func validateAuthenticationSpecCreate(spec configv1.AuthenticationSpec) field.ErrorList { - return validateAuthenticationSpec(spec) +func validateAuthenticationSpecCreate(ctx context.Context, spec configv1.AuthenticationSpec, cel *celStore) field.ErrorList { + return validateAuthenticationSpec(ctx, spec, cel) } -func validateAuthenticationSpecUpdate(newspec, oldspec configv1.AuthenticationSpec) field.ErrorList { - return validateAuthenticationSpec(newspec) +func validateAuthenticationSpecUpdate(ctx context.Context, newspec, oldspec configv1.AuthenticationSpec, cel *celStore) field.ErrorList { + return validateAuthenticationSpec(ctx, newspec, cel) } -func validateAuthenticationSpec(spec configv1.AuthenticationSpec) field.ErrorList { +func validateAuthenticationSpec(ctx context.Context, spec configv1.AuthenticationSpec, cel *celStore) field.ErrorList { errs := field.ErrorList{} specField := field.NewPath("spec") @@ -121,14 +184,238 @@ func validateAuthenticationSpec(spec configv1.AuthenticationSpec) field.ErrorLis spec.WebhookTokenAuthenticator, fmt.Sprintf("this field cannot be set with the %q .spec.type", spec.Type), )) } - } errs = append(errs, crvalidation.ValidateConfigMapReference(specField.Child("oauthMetadata"), spec.OAuthMetadata, false)...) + // Perform External OIDC Provider related validations + // ---------------- + + // There is currently no guarantee that these fields are not set when the spec.Type is != OIDC. + // To ensure we are enforcing approriate admission validations at all times, just always iterate through the list + // of OIDC Providers and perform the validations. + // If/when the openshift/api admission validations are updated to enforce that this field is not configured + // when Type != OIDC, this loop should be a no-op due to an empty list. + for i, provider := range spec.OIDCProviders { + errs = append(errs, validateOIDCProvider(ctx, specField.Child("oidcProviders").Index(i), cel, provider)...) + } + // ---------------- + return errs } func validateAuthenticationStatus(status configv1.AuthenticationStatus) field.ErrorList { return crvalidation.ValidateConfigMapReference(field.NewPath("status", "integratedOAuthMetadata"), status.IntegratedOAuthMetadata, false) } + +type costRecorder struct { + Recordings []costRecording +} + +func (cr *costRecorder) AddRecording(field *field.Path, cost uint64) { + cr.Recordings = append(cr.Recordings, costRecording{ + Field: field, + Cost: cost, + }) +} + +type costRecording struct { + Field *field.Path + Cost uint64 +} + +func validateOIDCProvider(ctx context.Context, path *field.Path, cel *celStore, provider configv1.OIDCProvider) field.ErrorList { + costRecorder := &costRecorder{} + + errs := validateClaimMappings(ctx, path, cel, costRecorder, provider.ClaimMappings) + + var totalCELExpressionCost uint64 = 0 + + for _, recording := range costRecorder.Recordings { + totalCELExpressionCost = addCost(totalCELExpressionCost, recording.Cost) + } + + if totalCELExpressionCost > wholeResourceExcessiveCostThreshold { + costlyExpressions := getNMostCostlyExpressions(costlyExpressionWarningCount, costRecorder.Recordings...) + warn := fmt.Sprintf("runtime cost of all CEL expressions exceeds %d points. top %d most costly expressions: %v", wholeResourceExcessiveCostThreshold, len(costlyExpressions), costlyExpressions) + warning.AddWarning(ctx, "", warn) + klog.Warning(warn) + } + + return errs +} + +// addCost adds a cost value to a total value, +// returning the resulting value. +// addCost handles integer overflow errors +// by just always returning the maximum uint64 +// value if an overflow would occur. +func addCost(total, cost uint64) uint64 { + if total > math.MaxUint64-cost { + return math.MaxUint64 + } + + return total + cost +} + +func getNMostCostlyExpressions(n int, records ...costRecording) []costRecording { + // sort in descending order of cost + slices.SortFunc(records, func(a, b costRecording) int { + return cmp.Compare(a.Cost, b.Cost) + }) + slices.Reverse(records) + + // safely get the N most expensive cost records + out := []costRecording{} + if len(records) > n { + out = records[:n] + } else { + out = records + } + + return out +} + +func validateClaimMappings(ctx context.Context, path *field.Path, cel *celStore, costRecorder *costRecorder, claimMappings configv1.TokenClaimMappings) field.ErrorList { + path = path.Child("claimMappings") + + out := field.ErrorList{} + + out = append(out, validateUIDClaimMapping(ctx, path, cel, costRecorder, claimMappings.UID)...) + out = append(out, validateExtraClaimMapping(ctx, path, cel, costRecorder, claimMappings.Extra...)...) + + return out +} + +func validateUIDClaimMapping(ctx context.Context, path *field.Path, cel *celStore, costRecorder *costRecorder, uid *configv1.TokenClaimOrExpressionMapping) field.ErrorList { + if uid == nil { + return nil + } + + out := field.ErrorList{} + + if uid.Expression != "" { + childPath := path.Child("uid", "expression") + + out = append(out, validateCELExpression(ctx, cel, costRecorder, childPath, &authenticationcel.ClaimMappingExpression{ + Expression: uid.Expression, + })...) + } + + return out +} + +func validateExtraClaimMapping(ctx context.Context, path *field.Path, cel *celStore, costRecorder *costRecorder, extras ...configv1.ExtraMapping) field.ErrorList { + out := field.ErrorList{} + for i, extra := range extras { + out = append(out, validateExtra(ctx, path.Child("extra").Index(i), cel, costRecorder, extra)...) + } + + return out +} + +func validateExtra(ctx context.Context, path *field.Path, cel *celStore, costRecorder *costRecorder, extra configv1.ExtraMapping) field.ErrorList { + childPath := path.Child("valueExpression") + + return validateCELExpression(ctx, cel, costRecorder, childPath, &authenticationcel.ExtraMappingExpression{ + Key: extra.Key, + Expression: extra.ValueExpression, + }) +} + +type celCompileResult struct { + err error + cost uint64 +} + +type panickedErr struct { + error +} + +func validateCELExpression(ctx context.Context, cel *celStore, costRecorder *costRecorder, path *field.Path, accessor authenticationcel.ExpressionAccessor) field.ErrorList { + // if context has been canceled, don't try to compile any expressions + if err := ctx.Err(); err != nil { + return field.ErrorList{field.InternalError(path, err)} + } + + result, err, _ := cel.compilingGroup.Do(accessor.GetExpression(), func() (interface{}, error) { + // if the expression is not currently being compiled, it might have already been compiled + if val, ok := cel.compiledStore.Get(accessor.GetExpression()); ok { + res, ok := val.(celCompileResult) + if !ok { + return nil, fmt.Errorf("expected return value from cache of compiled expressions to be of type celCompileResult but was %T", val) + } + + return res, nil + } + + // expression is not currently being compiled, and has not been compiled before (or has been long enough since it was last compiled that we dropped it). + // Let's compile it. + + // Asynchronously handle excessive compilation time so we + // can still log a warning in the event the process has died + // before compilation of the expression has finished. + warningChan := make(chan string, 1) + timer := cel.timerFactory.Timer(excessiveCompileDuration, func() { + defer close(warningChan) + warn := fmt.Sprintf("cel expression %q took excessively long to compile (%s)", accessor.GetExpression(), excessiveCompileDuration) + klog.Warning(warn) + warningChan <- warn + }) + + compRes, compErr := cel.compiler.CompileClaimsExpression(accessor) + cost, err := checker.Cost(compRes.AST.NativeRep(), &library.CostEstimator{ + SizeEstimator: cel.sizeEstimator, + }) + if err != nil { + klog.Errorf("unable to estimate cost for expression %q: %v. Defaulting cost to %d", accessor.GetExpression(), err, fixedSize) + cost = checker.CostEstimate{Max: fixedSize} + } + + res := celCompileResult{ + err: compErr, + cost: cost.Max, + } + + timer.Stop() + + // check if we received a warning. If not, continue + select { + case warn := <-warningChan: + warning.AddWarning(ctx, "", warn) + default: + } + + cel.compiledStore.Add(accessor.GetExpression(), res) + + return res, nil + }) + if err != nil { + return field.ErrorList{field.InternalError(path, fmt.Errorf("running compilation of expression %q: %v", accessor.GetExpression(), err))} + } + + compileRes, ok := result.(celCompileResult) + if !ok { + return field.ErrorList{field.InternalError(path, fmt.Errorf("expected result to be of type celCompileResult, but got %T", result))} + } + + costRecorder.AddRecording(path, compileRes.cost) + + if compileRes.err != nil { + return field.ErrorList{field.Invalid(path, accessor.GetExpression(), compileRes.err.Error())} + } + + return nil +} + +type fixedSizeEstimator struct { + size uint64 +} + +func (fcse *fixedSizeEstimator) EstimateSize(element checker.AstNode) *checker.SizeEstimate { + return &checker.SizeEstimate{Min: fcse.size, Max: fcse.size} +} + +func (fcse *fixedSizeEstimator) EstimateCallCost(function, overloadID string, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate { + return nil +} diff --git a/openshift-kube-apiserver/admission/customresourcevalidation/authentication/validate_authentication_test.go b/openshift-kube-apiserver/admission/customresourcevalidation/authentication/validate_authentication_test.go index d93f3f67f6fe9..c509e47d0e18e 100644 --- a/openshift-kube-apiserver/admission/customresourcevalidation/authentication/validate_authentication_test.go +++ b/openshift-kube-apiserver/admission/customresourcevalidation/authentication/validate_authentication_test.go @@ -1,10 +1,20 @@ package authentication import ( + "context" + "errors" + "fmt" + "strings" "testing" + "time" configv1 "github.com/openshift/api/config/v1" + "golang.org/x/sync/singleflight" + "k8s.io/apimachinery/pkg/util/rand" "k8s.io/apimachinery/pkg/util/validation/field" + authenticationcel "k8s.io/apiserver/pkg/authentication/cel" + "k8s.io/apiserver/pkg/warning" + "k8s.io/utils/lru" ) func TestFailValidateAuthenticationSpec(t *testing.T) { @@ -49,10 +59,50 @@ func TestFailValidateAuthenticationSpec(t *testing.T) { errorType: field.ErrorTypeInvalid, errorField: "spec.webhookTokenAuthenticator", }, + "invalid UID CEL expression": { + spec: configv1.AuthenticationSpec{ + Type: "OIDC", + OIDCProviders: []configv1.OIDCProvider{ + { + ClaimMappings: configv1.TokenClaimMappings{ + UID: &configv1.TokenClaimOrExpressionMapping{ + Expression: "!@^#&(!^@(*#&(", + }, + }, + }, + }, + }, + errorType: field.ErrorTypeInvalid, + errorField: "spec.oidcProviders[0].claimMappings.uid.expression", + }, + "invalid Extra CEL expression": { + spec: configv1.AuthenticationSpec{ + Type: "OIDC", + OIDCProviders: []configv1.OIDCProvider{ + { + ClaimMappings: configv1.TokenClaimMappings{ + Extra: []configv1.ExtraMapping{ + { + Key: "foo/bar", + ValueExpression: "!@*(&#^(!@*)&^&", + }, + }, + }, + }, + }, + }, + errorType: field.ErrorTypeInvalid, + errorField: "spec.oidcProviders[0].claimMappings.extra[0].valueExpression", + }, } for tcName, tc := range errorCases { - errs := validateAuthenticationSpec(tc.spec) + errs := validateAuthenticationSpec(context.TODO(), tc.spec, &celStore{ + compiler: authenticationcel.NewDefaultCompiler(), + compilingGroup: new(singleflight.Group), + compiledStore: lru.New(100), + timerFactory: &excessiveCompileTimerFactory{}, + }) if (len(errs) > 0) != (len(tc.errorType) != 0) { t.Errorf("'%s': expected failure: %t, got: %t", tcName, len(tc.errorType) != 0, len(errs) > 0) } @@ -109,10 +159,42 @@ func TestSucceedValidateAuthenticationSpec(t *testing.T) { {KubeConfig: configv1.SecretNameReference{Name: "thisisawebhook33"}}, }, }, + "valid uid CEL expression": { + Type: "OIDC", + OIDCProviders: []configv1.OIDCProvider{ + { + ClaimMappings: configv1.TokenClaimMappings{ + UID: &configv1.TokenClaimOrExpressionMapping{ + Expression: "claims.uid", + }, + }, + }, + }, + }, + "valid Extra CEL expression": { + Type: "OIDC", + OIDCProviders: []configv1.OIDCProvider{ + { + ClaimMappings: configv1.TokenClaimMappings{ + Extra: []configv1.ExtraMapping{ + { + Key: "foo/bar", + ValueExpression: "claims.roles", + }, + }, + }, + }, + }, + }, } for tcName, s := range successCases { - errs := validateAuthenticationSpec(s) + errs := validateAuthenticationSpec(context.TODO(), s, &celStore{ + compiler: authenticationcel.NewDefaultCompiler(), + compilingGroup: new(singleflight.Group), + compiledStore: lru.New(100), + timerFactory: &excessiveCompileTimerFactory{}, + }) if len(errs) != 0 { t.Errorf("'%s': expected success, but failed: %v", tcName, errs.ToAggregate().Error()) } @@ -175,5 +257,321 @@ func TestSucceedValidateAuthenticationStatus(t *testing.T) { t.Errorf("'%s': expected success, but failed: %v", tcName, errs.ToAggregate().Error()) } } +} + +func TestValidateCELExpression(t *testing.T) { + type testcase struct { + name string + cel func() *celStore + ctx func() context.Context + shouldErr bool + shouldWarn bool + } + + expression := &authenticationcel.ClaimMappingExpression{ + Expression: `["foo", "bar"].exists(x, x == "foo")`, + } + + testcases := []testcase{ + { + name: "does not return a warning when excessive compilation timer is not triggered", + cel: func() *celStore { + return &celStore{ + compiler: &mockCompiler{ + err: nil, + }, + compilingGroup: new(singleflight.Group), + compiledStore: lru.New(1), + timerFactory: &mockTimerFactory{ + trigger: false, + }, + } + }, + ctx: func() context.Context { return context.TODO() }, + }, + { + name: "returns a warning when excessive compilation timer is triggered", + cel: func() *celStore { + return &celStore{ + compiler: &mockCompiler{ + err: nil, + }, + compilingGroup: new(singleflight.Group), + compiledStore: lru.New(1), + timerFactory: &mockTimerFactory{ + trigger: true, + }, + } + }, + ctx: func() context.Context { return context.TODO() }, + shouldWarn: true, + }, + { + name: "still returns error if excessive compilation timer is triggered and errors out", + cel: func() *celStore { + return &celStore{ + compiler: &mockCompiler{ + err: errors.New("boom"), + }, + compilingGroup: new(singleflight.Group), + compiledStore: lru.New(1), + timerFactory: &mockTimerFactory{ + trigger: true, + }, + } + }, + ctx: func() context.Context { return context.TODO() }, + shouldWarn: true, + shouldErr: true, + }, + { + name: "returns an error if the context has been canceled", + cel: func() *celStore { + return &celStore{ + compiler: &mockCompiler{ + err: nil, + }, + compilingGroup: new(singleflight.Group), + compiledStore: lru.New(1), + timerFactory: &mockTimerFactory{ + trigger: false, + }, + } + }, + ctx: func() context.Context { + ctx, cancel := context.WithCancel(context.TODO()) + cancel() + return ctx + }, + shouldErr: true, + }, + { + name: "returns already compiled expression results if the expression has been compiled before", + cel: func() *celStore { + compiledLRU := lru.New(1) + res := celCompileResult{ + err: errors.New("boom"), + } + compiledLRU.Add(expression.Expression, res) + + return &celStore{ + compiler: nil, // should never end up calling this + compilingGroup: new(singleflight.Group), + compiledStore: compiledLRU, + timerFactory: &mockTimerFactory{ + trigger: false, + }, + } + }, + ctx: func() context.Context { return context.TODO() }, + shouldErr: true, + shouldWarn: false, + }, + } + + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + warningRecorder := &mockWarningRecorder{} + ctx := warning.WithWarningRecorder(tc.ctx(), warningRecorder) + err := validateCELExpression(ctx, tc.cel(), &costRecorder{}, field.NewPath("^"), expression) + if tc.shouldErr != (err != nil) { + t.Fatalf("error expectation does not match actual. expected? %v . received: %v", tc.shouldErr, err) + } + + if tc.shouldWarn != (len(warningRecorder.warnings) > 0) { + t.Fatalf("warning expectation does not match actual. expected? %v . received: %v", tc.shouldWarn, warningRecorder.warnings) + } + }) + } +} + +type mockCompiler struct { + receiver chan error + err error + useDelegate bool + delegate authenticationcel.Compiler +} + +func (mc *mockCompiler) CompileClaimsExpression(expressionAccessor authenticationcel.ExpressionAccessor) (authenticationcel.CompilationResult, error) { + if mc.receiver != nil { + err := <-mc.receiver + return authenticationcel.CompilationResult{}, err + } + return authenticationcel.CompilationResult{}, mc.err +} + +func (mc *mockCompiler) CompileUserExpression(expressionAccessor authenticationcel.ExpressionAccessor) (authenticationcel.CompilationResult, error) { + return authenticationcel.CompilationResult{}, mc.err +} + +type mockTimerFactory struct { + trigger bool +} + +func (mct *mockTimerFactory) Timer(_ time.Duration, do func()) Timer { + if mct.trigger { + do() + return &mockTimer{done: true} + } + + return &mockTimer{done: false} +} + +type mockTimer struct { + done bool +} + +func (mt *mockTimer) Stop() bool { + return mt.done +} + +func TestValidateCELExpressionDeduplicatesWork(t *testing.T) { + receiver := make(chan error) + cel := &celStore{ + compiler: &mockCompiler{ + receiver: receiver, + err: errors.New("boom"), + }, + compilingGroup: new(singleflight.Group), + compiledStore: lru.New(1), + timerFactory: &mockTimerFactory{ + trigger: false, + }, + } + + expression := &authenticationcel.ClaimMappingExpression{ + Expression: `["foo", "bar"].exists(x, x == "foo")`, + } + + errListOneChan := make(chan field.ErrorList) + errListTwoChan := make(chan field.ErrorList) + + go func() { + errListOneChan <- validateCELExpression(context.TODO(), cel, &costRecorder{}, field.NewPath("^"), expression) + }() + + go func() { + errListTwoChan <- validateCELExpression(context.TODO(), cel, &costRecorder{}, field.NewPath("^"), expression) + }() + + for range 100 { + select { + case <-errListOneChan: + t.Fatal("expected first call to validateCELExpression to hang until specified but it did not") + return + + case <-errListTwoChan: + t.Fatal("expected second call to validateCELExpression to hang until specified but it did not") + return + default: + } + } + + receiver <- errors.New("boom") + + errOne := <-errListOneChan + errTwo := <-errListTwoChan + + if errOne[0].Error() != errTwo[0].Error() { + t.Fatalf("expected the same result from both calls to validateCELExpression, but got different results. first call: %v | second call: %v", errOne[0], errTwo[0]) + } +} + +func TestValidAuthenticationSpecWithExcessivelyLongCELExpressionCompileTime(t *testing.T) { + // Create an expression that takes excessively long to compile + // but would not blow the top off the entire resource runtime cost estimation + // warning threshold + var sb strings.Builder + sb.WriteString(`["foo","bar"]`) + const toappend = `.map(x, [x+x,x+x])` + for 4096-sb.Len() >= len(toappend) { + sb.WriteString(toappend) + } + expr := sb.String() + + authn := configv1.AuthenticationSpec{ + Type: "OIDC", + OIDCProviders: []configv1.OIDCProvider{ + { + ClaimMappings: configv1.TokenClaimMappings{ + UID: &configv1.TokenClaimOrExpressionMapping{ + Expression: expr, + }, + }, + }, + }, + } + + warningRecorder := &mockWarningRecorder{} + ctx := warning.WithWarningRecorder(context.TODO(), warningRecorder) + + errs := validateAuthenticationSpec(ctx, authn, &celStore{ + compiler: authenticationcel.NewDefaultCompiler(), + compilingGroup: new(singleflight.Group), + compiledStore: lru.New(100), + timerFactory: &excessiveCompileTimerFactory{}, + }) + + if len(errs) > 0 { + t.Fatalf("should not have received any errors, but got: %v", errs.ToAggregate()) + } + + if len(warningRecorder.warnings) != 1 { + t.Fatalf("expected to receive one warning about excessively long cel compilation time, got: %v", warningRecorder.warnings) + } + + if !strings.Contains(warningRecorder.warnings[0], "took excessively long to compile") { + t.Fatalf("expected warning to mention excessively long compile time but instead got: %s", warningRecorder.warnings[0]) + } +} + +func TestValidAuthenticationSpecWithExcessiveCELExpressionRuntimeCost(t *testing.T) { + extras := []configv1.ExtraMapping{} + for range 5 { + extras = append(extras, configv1.ExtraMapping{ + Key: fmt.Sprintf("test.io/%s", rand.String(8)), + ValueExpression: "claims.map(x, x+x)", + }) + } + + authn := configv1.AuthenticationSpec{ + Type: "OIDC", + OIDCProviders: []configv1.OIDCProvider{ + { + ClaimMappings: configv1.TokenClaimMappings{ + Extra: extras, + }, + }, + }, + } + + warningRecorder := &mockWarningRecorder{} + ctx := warning.WithWarningRecorder(context.TODO(), warningRecorder) + + errs := validateAuthenticationSpec(ctx, authn, &celStore{ + compiler: authenticationcel.NewDefaultCompiler(), + compilingGroup: new(singleflight.Group), + compiledStore: lru.New(100), + timerFactory: &excessiveCompileTimerFactory{}, + }) + + if len(errs) > 0 { + t.Fatalf("should not have received any errors, but got: %v", errs.ToAggregate()) + } + + if len(warningRecorder.warnings) != 1 { + t.Fatalf("expected to receive one warning about excessive runtime cost, got: %v", warningRecorder.warnings) + } + + if !strings.Contains(warningRecorder.warnings[0], "runtime cost of all CEL expressions exceeds") { + t.Fatalf("expected warning to mention excessive runtime cost but instead got: %s", warningRecorder.warnings[0]) + } +} + +type mockWarningRecorder struct { + warnings []string +} +func (mwr *mockWarningRecorder) AddWarning(agent, text string) { + mwr.warnings = append(mwr.warnings, text) }