diff --git a/validator/core/walk.go b/validator/core/walk.go index 5324508..7e58dd3 100644 --- a/validator/core/walk.go +++ b/validator/core/walk.go @@ -17,6 +17,10 @@ type Events struct { directiveList []func(walker *Walker, directives []*ast.Directive) value []func(walker *Walker, value *ast.Value) variable []func(walker *Walker, variable *ast.VariableDefinition) + + // Stopped indicates traversal should stop early. This is set by validators + // that wish to abort walking once an error has been encountered. + Stopped bool } func (o *Events) OnOperation(f func(walker *Walker, operation *ast.OperationDefinition)) { @@ -76,6 +80,9 @@ type Walker struct { } func (w *Walker) walk() { + if w.Observers != nil && w.Observers.Stopped { + return + } for _, child := range w.Document.Operations { w.validatedFragmentSpreads = make(map[string]bool) w.walkOperation(child) @@ -87,6 +94,9 @@ func (w *Walker) walk() { } func (w *Walker) walkOperation(operation *ast.OperationDefinition) { + if w.Observers != nil && w.Observers.Stopped { + return + } w.CurrentOperation = operation for _, varDef := range operation.VariableDefinitions { varDef.Definition = w.Schema.Types[varDef.Type.Name()] @@ -130,6 +140,9 @@ func (w *Walker) walkOperation(operation *ast.OperationDefinition) { } func (w *Walker) walkFragment(it *ast.FragmentDefinition) { + if w.Observers != nil && w.Observers.Stopped { + return + } def := w.Schema.Types[it.TypeCondition] it.Definition = def @@ -143,6 +156,9 @@ func (w *Walker) walkFragment(it *ast.FragmentDefinition) { } func (w *Walker) walkDirectives(parentDef *ast.Definition, directives []*ast.Directive, location ast.DirectiveLocation) { + if w.Observers != nil && w.Observers.Stopped { + return + } for _, dir := range directives { def := w.Schema.Directives[dir.Name] dir.Definition = def @@ -169,6 +185,9 @@ func (w *Walker) walkDirectives(parentDef *ast.Definition, directives []*ast.Dir } func (w *Walker) walkValue(value *ast.Value) { + if w.Observers != nil && w.Observers.Stopped { + return + } if value.Kind == ast.Variable && w.CurrentOperation != nil { value.VariableDefinition = w.CurrentOperation.VariableDefinitions.ForName(value.Raw) if value.VariableDefinition != nil { @@ -207,6 +226,9 @@ func (w *Walker) walkValue(value *ast.Value) { } func (w *Walker) walkArgument(argDef *ast.ArgumentDefinition, arg *ast.Argument) { + if w.Observers != nil && w.Observers.Stopped { + return + } if argDef != nil { arg.Value.ExpectedType = argDef.Type arg.Value.ExpectedTypeHasDefault = argDef.DefaultValue != nil && argDef.DefaultValue.Kind != ast.NullValue @@ -217,12 +239,18 @@ func (w *Walker) walkArgument(argDef *ast.ArgumentDefinition, arg *ast.Argument) } func (w *Walker) walkSelectionSet(parentDef *ast.Definition, it ast.SelectionSet) { + if w.Observers != nil && w.Observers.Stopped { + return + } for _, child := range it { w.walkSelection(parentDef, child) } } func (w *Walker) walkSelection(parentDef *ast.Definition, it ast.Selection) { + if w.Observers != nil && w.Observers.Stopped { + return + } switch it := it.(type) { case *ast.Field: var def *ast.FieldDefinition diff --git a/validator/validator.go b/validator/validator.go index 1214ed1..d241f4e 100644 --- a/validator/validator.go +++ b/validator/validator.go @@ -118,6 +118,10 @@ func Validate(schema *Schema, doc *QueryDocument, rules ...Rule) gqlerror.List { } func ValidateWithRules(schema *Schema, doc *QueryDocument, rules *validatorrules.Rules) gqlerror.List { + return ValidateWithRulesAndMaximumErrors(schema, doc, rules, 0) +} + +func ValidateWithRulesAndMaximumErrors(schema *Schema, doc *QueryDocument, rules *validatorrules.Rules, maximumErrors int) gqlerror.List { if rules == nil { rules = validatorrules.NewDefaultRules() } @@ -129,6 +133,9 @@ func ValidateWithRules(schema *Schema, doc *QueryDocument, rules *validatorrules if doc == nil { errs = append(errs, gqlerror.Errorf("cannot validate as QueryDocument is nil")) } + if maximumErrors < 0 { + errs = append(errs, gqlerror.Errorf("maximumErrors cannot be negative")) + } if len(errs) > 0 { return errs } @@ -150,6 +157,10 @@ func ValidateWithRules(schema *Schema, doc *QueryDocument, rules *validatorrules o(err) } errs = append(errs, err) + + if maximumErrors > 0 && len(errs) >= maximumErrors { + observers.Stopped = true + } }) } diff --git a/validator/validator_test.go b/validator/validator_test.go index 6befa03..a9482d3 100644 --- a/validator/validator_test.go +++ b/validator/validator_test.go @@ -318,3 +318,215 @@ func TestRemoveRule(t *testing.T) { // no error validator.RemoveRule("Rule that should no longer exist") } + +func TestValidateWithRulesAndMaximumErrors(t *testing.T) { + t.Run("maximumErrors limits error count", func(t *testing.T) { + s := gqlparser.MustLoadSchema( + &ast.Source{ + Name: "graph/schema.graphqls", + Input: ` + type Query { + field1: String! + field2: String! + field3: String! + } + `, BuiltIn: false}, + ) + + q, err := parser.ParseQuery(&ast.Source{ + Name: "SomeQuery", + Input: ` + query { + field1 + field2 + field3 + } + `}) + require.NoError(t, err) + + // Create a rule that generates errors for each field + errorRule := validator.Rule{ + Name: "ErrorRule", + RuleFunc: func(observers *validator.Events, addError validator.AddErrFunc) { + observers.OnField(func(walker *validator.Walker, field *ast.Field) { + addError(validator.Message("Error for field %s", field.Name)) + }) + }, + } + + rules := rules.NewRules(errorRule) + errList := validator.ValidateWithRulesAndMaximumErrors(s, q, rules, 2) + + // Should only return 2 errors even though there are 3 fields + require.Len(t, errList, 2) + }) + + t.Run("maximumErrors zero means no limit", func(t *testing.T) { + s := gqlparser.MustLoadSchema( + &ast.Source{ + Name: "graph/schema.graphqls", + Input: ` + type Query { + field1: String! + field2: String! + field3: String! + } + `, BuiltIn: false}, + ) + + q, err := parser.ParseQuery(&ast.Source{ + Name: "SomeQuery", + Input: ` + query { + field1 + field2 + field3 + } + `}) + require.NoError(t, err) + + // Create a rule that generates errors for each field + errorRule := validator.Rule{ + Name: "ErrorRule", + RuleFunc: func(observers *validator.Events, addError validator.AddErrFunc) { + observers.OnField(func(walker *validator.Walker, field *ast.Field) { + addError(validator.Message("Error for field %s", field.Name)) + }) + }, + } + + rules := rules.NewRules(errorRule) + errList := validator.ValidateWithRulesAndMaximumErrors(s, q, rules, 0) + + // Should return all errors when maximumErrors is 0 + require.Len(t, errList, 3) + }) + + t.Run("negative maximumErrors returns error", func(t *testing.T) { + s := gqlparser.MustLoadSchema( + &ast.Source{ + Name: "graph/schema.graphqls", + Input: ` + type Query { + field1: String! + } + `, BuiltIn: false}, + ) + + q, err := parser.ParseQuery(&ast.Source{ + Name: "SomeQuery", + Input: ` + query { + field1 + } + `}) + require.NoError(t, err) + + errList := validator.ValidateWithRulesAndMaximumErrors(s, q, nil, -1) + + // Should return an error about negative maximumErrors + require.Len(t, errList, 1) + require.Contains(t, errList[0].Message, "maximumErrors cannot be negative") + }) + + t.Run("maximumErrors stops traversal early", func(t *testing.T) { + s := gqlparser.MustLoadSchema( + &ast.Source{ + Name: "graph/schema.graphqls", + Input: ` + type Query { + field1: String! + field2: String! + field3: String! + field4: String! + field5: String! + } + `, BuiltIn: false}, + ) + + q, err := parser.ParseQuery(&ast.Source{ + Name: "SomeQuery", + Input: ` + query { + field1 + field2 + field3 + field4 + field5 + } + `}) + require.NoError(t, err) + + fieldCount := 0 + // Create a rule that generates errors and counts fields + errorRule := validator.Rule{ + Name: "ErrorRule", + RuleFunc: func(observers *validator.Events, addError validator.AddErrFunc) { + observers.OnField(func(walker *validator.Walker, field *ast.Field) { + fieldCount++ + addError(validator.Message("Error for field %s", field.Name)) + }) + }, + } + + rules := rules.NewRules(errorRule) + errList := validator.ValidateWithRulesAndMaximumErrors(s, q, rules, 2) + + // Should only return 2 errors + require.Len(t, errList, 2) + // Should have stopped traversal early after exactly 2 fields processed + require.Equal(t, 2, fieldCount) + }) + + t.Run("maximumErrors with multiple rules", func(t *testing.T) { + s := gqlparser.MustLoadSchema( + &ast.Source{ + Name: "graph/schema.graphqls", + Input: ` + type Query { + field1: String! + field2: String! + } + `, BuiltIn: false}, + ) + + q, err := parser.ParseQuery(&ast.Source{ + Name: "SomeQuery", + Input: ` + query { + field1 + field2 + } + `}) + require.NoError(t, err) + + // Create two rules that each generate errors and count fields + fieldCount := 0 + rule1 := validator.Rule{ + Name: "Rule1", + RuleFunc: func(observers *validator.Events, addError validator.AddErrFunc) { + observers.OnField(func(walker *validator.Walker, field *ast.Field) { + fieldCount++ + addError(validator.Message("Rule1 error for field %s", field.Name)) + }) + }, + } + rule2 := validator.Rule{ + Name: "Rule2", + RuleFunc: func(observers *validator.Events, addError validator.AddErrFunc) { + observers.OnField(func(walker *validator.Walker, field *ast.Field) { + fieldCount++ + addError(validator.Message("Rule2 error for field %s", field.Name)) + }) + }, + } + + rules := rules.NewRules(rule1, rule2) + errList := validator.ValidateWithRulesAndMaximumErrors(s, q, rules, 3) + + // Although we set maximumErrors to 3, we expect 4 errors here (2 rules × 2 fields). + // The limit is evaluated after the batch is processed, allowing a final overflow. + require.Equal(t, 4, fieldCount) + require.Equal(t, 4, len(errList)) + }) +} diff --git a/validator/walk_test.go b/validator/walk_test.go index d92b885..168e158 100644 --- a/validator/walk_test.go +++ b/validator/walk_test.go @@ -50,3 +50,23 @@ func TestWalkInlineFragment(t *testing.T) { require.True(t, called) } + +func TestWalkStoppedEarly(t *testing.T) { + schema, err := LoadSchema(Prelude, &ast.Source{Input: "type Query { name: String, age: Int }\n schema { query: Query }"}) + require.NoError(t, err) + query, err := parser.ParseQuery(&ast.Source{Input: "{ name age }"}) + require.NoError(t, err) + + fieldCount := 0 + observers := &Events{} + observers.OnField(func(walker *Walker, field *ast.Field) { + fieldCount++ + if fieldCount == 1 { + observers.Stopped = true + } + }) + + Walk(schema, query, observers) + + require.Equal(t, 1, fieldCount) +}