Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions validator/core/walk.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@ type Events struct {
directiveList []func(walker *Walker, directives []*ast.Directive)
value []func(walker *Walker, value *ast.Value)
variable []func(walker *Walker, variable *ast.VariableDefinition)

// StopOnFirstError indicates whether to stop traversal on the first error.
StopOnFirstError bool

// Stopped indicates traversal should stop early. This is set by validators
// that wish to abort walking once an error has been encountered.
Stopped bool
}

func (o *Events) OnOperation(f func(walker *Walker, operation *ast.OperationDefinition)) {
Expand Down Expand Up @@ -76,6 +83,9 @@ type Walker struct {
}

func (w *Walker) walk() {
if w.Observers != nil && w.Observers.Stopped {
return
}
for _, child := range w.Document.Operations {
w.validatedFragmentSpreads = make(map[string]bool)
w.walkOperation(child)
Expand All @@ -87,6 +97,9 @@ func (w *Walker) walk() {
}

func (w *Walker) walkOperation(operation *ast.OperationDefinition) {
if w.Observers != nil && w.Observers.Stopped {
return
}
w.CurrentOperation = operation
for _, varDef := range operation.VariableDefinitions {
varDef.Definition = w.Schema.Types[varDef.Type.Name()]
Expand Down Expand Up @@ -130,6 +143,9 @@ func (w *Walker) walkOperation(operation *ast.OperationDefinition) {
}

func (w *Walker) walkFragment(it *ast.FragmentDefinition) {
if w.Observers != nil && w.Observers.Stopped {
return
}
def := w.Schema.Types[it.TypeCondition]

it.Definition = def
Expand All @@ -143,6 +159,9 @@ func (w *Walker) walkFragment(it *ast.FragmentDefinition) {
}

func (w *Walker) walkDirectives(parentDef *ast.Definition, directives []*ast.Directive, location ast.DirectiveLocation) {
if w.Observers != nil && w.Observers.Stopped {
return
}
for _, dir := range directives {
def := w.Schema.Directives[dir.Name]
dir.Definition = def
Expand All @@ -169,6 +188,9 @@ func (w *Walker) walkDirectives(parentDef *ast.Definition, directives []*ast.Dir
}

func (w *Walker) walkValue(value *ast.Value) {
if w.Observers != nil && w.Observers.Stopped {
return
}
if value.Kind == ast.Variable && w.CurrentOperation != nil {
value.VariableDefinition = w.CurrentOperation.VariableDefinitions.ForName(value.Raw)
if value.VariableDefinition != nil {
Expand Down Expand Up @@ -207,6 +229,9 @@ func (w *Walker) walkValue(value *ast.Value) {
}

func (w *Walker) walkArgument(argDef *ast.ArgumentDefinition, arg *ast.Argument) {
if w.Observers != nil && w.Observers.Stopped {
return
}
if argDef != nil {
arg.Value.ExpectedType = argDef.Type
arg.Value.ExpectedTypeHasDefault = argDef.DefaultValue != nil && argDef.DefaultValue.Kind != ast.NullValue
Expand All @@ -217,12 +242,18 @@ func (w *Walker) walkArgument(argDef *ast.ArgumentDefinition, arg *ast.Argument)
}

func (w *Walker) walkSelectionSet(parentDef *ast.Definition, it ast.SelectionSet) {
if w.Observers != nil && w.Observers.Stopped {
return
}
for _, child := range it {
w.walkSelection(parentDef, child)
}
}

func (w *Walker) walkSelection(parentDef *ast.Definition, it ast.Selection) {
if w.Observers != nil && w.Observers.Stopped {
return
}
switch it := it.(type) {
case *ast.Field:
var def *ast.FieldDefinition
Expand Down
12 changes: 11 additions & 1 deletion validator/validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,10 @@ func Validate(schema *Schema, doc *QueryDocument, rules ...Rule) gqlerror.List {
}

func ValidateWithRules(schema *Schema, doc *QueryDocument, rules *validatorrules.Rules) gqlerror.List {
return ValidateWithRulesAndStopOnFirstError(schema, doc, rules, false)
}

func ValidateWithRulesAndStopOnFirstError(schema *Schema, doc *QueryDocument, rules *validatorrules.Rules, stopOnFirstError bool) gqlerror.List {
if rules == nil {
rules = validatorrules.NewDefaultRules()
}
Expand All @@ -132,7 +136,9 @@ func ValidateWithRules(schema *Schema, doc *QueryDocument, rules *validatorrules
if len(errs) > 0 {
return errs
}
observers := &core.Events{}
observers := &core.Events{
StopOnFirstError: stopOnFirstError,
}

var currentRules []Rule // nolint:prealloc // would require extra local refs for len
for name, ruleFunc := range rules.GetInner() {
Expand All @@ -150,6 +156,10 @@ func ValidateWithRules(schema *Schema, doc *QueryDocument, rules *validatorrules
o(err)
}
errs = append(errs, err)

if observers.StopOnFirstError {
observers.Stopped = true
}
})
}

Expand Down
Loading