Skip to content

Commit

Permalink
feat(hosterrorscache): add Remove and MarkFailedOrRemove methods (#…
Browse files Browse the repository at this point in the history
…5984)

* feat(hosterrorscache): add `Remove` and `MarkFailedOrRemove` methods

and also deprecating `MarkFailed`

Signed-off-by: Dwi Siswanto <[email protected]>

* refactor(*): unwraps `hosterrorscache\.MarkFailed` invocation

Signed-off-by: Dwi Siswanto <[email protected]>

* feat(hosterrorscache): add sync in `Check` and `MarkFailedOrRemove` methods

* test(hosterrorscache): add concurrent test for `Check` method

* refactor(hosterrorscache): do NOT change `MarkFailed` behavior

Signed-off-by: Dwi Siswanto <[email protected]>

* feat(*): use `MarkFailedOrRemove` explicitly

Signed-off-by: Dwi Siswanto <[email protected]>

---------

Signed-off-by: Dwi Siswanto <[email protected]>
  • Loading branch information
dwisiswant0 authored Jan 31, 2025
1 parent 5a52e93 commit 052fd8b
Show file tree
Hide file tree
Showing 8 changed files with 185 additions and 86 deletions.
6 changes: 3 additions & 3 deletions pkg/core/workflow_execute.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,10 @@ func (e *Engine) runWorkflowStep(template *workflows.WorkflowTemplate, ctx *scan
firstMatched = true
}
}
if w.Options.HostErrorsCache != nil {
w.Options.HostErrorsCache.MarkFailedOrRemove(w.Options.ProtocolType.String(), ctx.Input, err)
}
if err != nil {
if w.Options.HostErrorsCache != nil {
w.Options.HostErrorsCache.MarkFailed(w.Options.ProtocolType.String(), ctx.Input, err)
}
if len(template.Executers) == 1 {
mainErr = err
} else {
Expand Down
128 changes: 98 additions & 30 deletions pkg/protocols/common/hosterrorscache/hosterrorscache.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package hosterrorscache

import (
"errors"
"net"
"net/url"
"regexp"
Expand All @@ -20,10 +21,12 @@ import (
// CacheInterface defines the signature of the hosterrorscache so that
// users of Nuclei as embedded lib may implement their own cache
type CacheInterface interface {
SetVerbose(verbose bool) // log verbosely
Close() // close the cache
Check(protoType string, ctx *contextargs.Context) bool // return true if the host should be skipped
MarkFailed(protoType string, ctx *contextargs.Context, err error) // record a failure (and cause) for the host
SetVerbose(verbose bool) // log verbosely
Close() // close the cache
Check(protoType string, ctx *contextargs.Context) bool // return true if the host should be skipped
Remove(ctx *contextargs.Context) // remove a host from the cache
MarkFailed(protoType string, ctx *contextargs.Context, err error) // record a failure (and cause) for the host
MarkFailedOrRemove(protoType string, ctx *contextargs.Context, err error) // record a failure (and cause) for the host or remove it
}

var (
Expand All @@ -47,16 +50,20 @@ type cacheItem struct {
errors atomic.Int32
isPermanentErr bool
cause error // optional cause
mu sync.Mutex
}

const DefaultMaxHostsCount = 10000

// New returns a new host max errors cache
func New(maxHostError, maxHostsCount int, trackError []string) *Cache {
gc := gcache.New[string, *cacheItem](maxHostsCount).
ARC().
Build()
return &Cache{failedTargets: gc, MaxHostError: maxHostError, TrackError: trackError}
gc := gcache.New[string, *cacheItem](maxHostsCount).ARC().Build()

return &Cache{
failedTargets: gc,
MaxHostError: maxHostError,
TrackError: trackError,
}
}

// SetVerbose sets the cache to log at verbose level
Expand Down Expand Up @@ -118,47 +125,108 @@ func (c *Cache) NormalizeCacheValue(value string) string {
func (c *Cache) Check(protoType string, ctx *contextargs.Context) bool {
finalValue := c.GetKeyFromContext(ctx, nil)

existingCacheItem, err := c.failedTargets.GetIFPresent(finalValue)
cache, err := c.failedTargets.GetIFPresent(finalValue)
if err != nil {
return false
}
if existingCacheItem.isPermanentErr {

cache.mu.Lock()
defer cache.mu.Unlock()

if cache.isPermanentErr {
// skipping permanent errors is expected so verbose instead of info
gologger.Verbose().Msgf("Skipped %s from target list as found unresponsive permanently: %s", finalValue, existingCacheItem.cause)
gologger.Verbose().Msgf("Skipped %s from target list as found unresponsive permanently: %s", finalValue, cache.cause)
return true
}

if existingCacheItem.errors.Load() >= int32(c.MaxHostError) {
existingCacheItem.Do(func() {
gologger.Info().Msgf("Skipped %s from target list as found unresponsive %d times", finalValue, existingCacheItem.errors.Load())
if cache.errors.Load() >= int32(c.MaxHostError) {
cache.Do(func() {
gologger.Info().Msgf("Skipped %s from target list as found unresponsive %d times", finalValue, cache.errors.Load())
})
return true
}

return false
}

// Remove removes a host from the cache
func (c *Cache) Remove(ctx *contextargs.Context) {
key := c.GetKeyFromContext(ctx, nil)
_ = c.failedTargets.Remove(key) // remove even the cache is not present
}

// MarkFailed marks a host as failed previously
//
// Deprecated: Use MarkFailedOrRemove instead.
func (c *Cache) MarkFailed(protoType string, ctx *contextargs.Context, err error) {
if !c.checkError(protoType, err) {
if err == nil {
return
}
finalValue := c.GetKeyFromContext(ctx, err)
existingCacheItem, err := c.failedTargets.GetIFPresent(finalValue)
if err != nil || existingCacheItem == nil {
newItem := &cacheItem{errors: atomic.Int32{}}
newItem.errors.Store(1)
if errkit.IsKind(err, errkit.ErrKindNetworkPermanent) {
// skip this address altogether
// permanent errors are always permanent hence this is created once
// and never updated so no need to synchronize
newItem.isPermanentErr = true
newItem.cause = err
}
_ = c.failedTargets.Set(finalValue, newItem)

c.MarkFailedOrRemove(protoType, ctx, err)
}

// MarkFailedOrRemove marks a host as failed previously or removes it
func (c *Cache) MarkFailedOrRemove(protoType string, ctx *contextargs.Context, err error) {
if err != nil && !c.checkError(protoType, err) {
return
}
existingCacheItem.errors.Add(1)
_ = c.failedTargets.Set(finalValue, existingCacheItem)

if err == nil {
// Remove the host from cache
//
// NOTE(dwisiswant0): The decision was made to completely remove the
// cached entry for the host instead of simply decrementing the error
// count (using `(atomic.Int32).Swap` to update the value to `N-1`).
// This approach was chosen because the error handling logic operates
// concurrently, and decrementing the count could lead to UB (unexpected
// behavior) even when the error is `nil`.
//
// To clarify, consider the following scenario where the error
// encountered does NOT belong to the permanent network error category
// (`errkit.ErrKindNetworkPermanent`):
//
// 1. Iteration 1: A timeout error occurs, and the error count for the
// host is incremented.
// 2. Iteration 2: Another timeout error is encountered, leading to
// another increment in the host's error count.
// 3. Iteration 3: A third timeout error happens, which increments the
// error count further. At this point, the host is flagged as
// unresponsive.
// 4. Iteration 4: The host becomes reachable (no error or a transient
// issue resolved). Instead of performing a no-op and leaving the
// host in the cache, the host entry is removed entirely to reset its
// state.
// 5. Iteration 5: A subsequent timeout error occurs after the host was
// removed and re-added to the cache. The error count is reset and
// starts from 1 again.
//
// This removal strategy ensures the cache is updated dynamically to
// reflect the current state of the host without persisting stale or
// irrelevant error counts that could interfere with future error
// handling and tracking logic.
c.Remove(ctx)

return
}

cacheKey := c.GetKeyFromContext(ctx, err)
cache, cacheErr := c.failedTargets.GetIFPresent(cacheKey)
if errors.Is(cacheErr, gcache.KeyNotFoundError) {
cache = &cacheItem{errors: atomic.Int32{}}
}

cache.mu.Lock()
defer cache.mu.Unlock()

if errkit.IsKind(err, errkit.ErrKindNetworkPermanent) {
cache.isPermanentErr = true
}

cache.cause = err
cache.errors.Add(1)

_ = c.failedTargets.Set(cacheKey, cache)
}

// GetKeyFromContext returns the key for the cache from the context
Expand Down
78 changes: 62 additions & 16 deletions pkg/protocols/common/hosterrorscache/hosterrorscache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package hosterrorscache

import (
"context"
"fmt"
"errors"
"sync"
"sync/atomic"
"testing"
Expand All @@ -17,28 +17,40 @@ const (

func TestCacheCheck(t *testing.T) {
cache := New(3, DefaultMaxHostsCount, nil)
err := errors.New("net/http: timeout awaiting response headers")

t.Run("increment host error", func(t *testing.T) {
ctx := newCtxArgs(t.Name())
for i := 1; i < 3; i++ {
cache.MarkFailed(protoType, ctx, err)
got := cache.Check(protoType, ctx)
require.Falsef(t, got, "got %v in iteration %d", got, i)
}
})

for i := 0; i < 100; i++ {
cache.MarkFailed(protoType, newCtxArgs("test"), fmt.Errorf("could not resolve host"))
got := cache.Check(protoType, newCtxArgs("test"))
if i < 2 {
// till 3 the host is not flagged to skip
require.False(t, got)
} else {
// above 3 it must remain flagged to skip
require.True(t, got)
t.Run("flagged", func(t *testing.T) {
ctx := newCtxArgs(t.Name())
for i := 1; i <= 3; i++ {
cache.MarkFailed(protoType, ctx, err)
}
}

value := cache.Check(protoType, newCtxArgs("test"))
require.Equal(t, true, value, "could not get checked value")
got := cache.Check(protoType, ctx)
require.True(t, got)
})

t.Run("mark failed or remove", func(t *testing.T) {
ctx := newCtxArgs(t.Name())
cache.MarkFailedOrRemove(protoType, ctx, nil) // nil error should remove the host from cache
got := cache.Check(protoType, ctx)
require.False(t, got)
})
}

func TestTrackErrors(t *testing.T) {
cache := New(3, DefaultMaxHostsCount, []string{"custom error"})

for i := 0; i < 100; i++ {
cache.MarkFailed(protoType, newCtxArgs("custom"), fmt.Errorf("got: nested: custom error"))
cache.MarkFailed(protoType, newCtxArgs("custom"), errors.New("got: nested: custom error"))
got := cache.Check(protoType, newCtxArgs("custom"))
if i < 2 {
// till 3 the host is not flagged to skip
Expand Down Expand Up @@ -74,6 +86,20 @@ func TestCacheItemDo(t *testing.T) {
require.Equal(t, count, 1)
}

func TestRemove(t *testing.T) {
cache := New(3, DefaultMaxHostsCount, nil)
ctx := newCtxArgs(t.Name())
err := errors.New("net/http: timeout awaiting response headers")

for i := 0; i < 100; i++ {
cache.MarkFailed(protoType, ctx, err)
}

require.True(t, cache.Check(protoType, ctx))
cache.Remove(ctx)
require.False(t, cache.Check(protoType, ctx))
}

func TestCacheMarkFailed(t *testing.T) {
cache := New(3, DefaultMaxHostsCount, nil)

Expand All @@ -90,7 +116,7 @@ func TestCacheMarkFailed(t *testing.T) {

for _, test := range tests {
normalizedCacheValue := cache.GetKeyFromContext(newCtxArgs(test.host), nil)
cache.MarkFailed(protoType, newCtxArgs(test.host), fmt.Errorf("no address found for host"))
cache.MarkFailed(protoType, newCtxArgs(test.host), errors.New("no address found for host"))
failedTarget, err := cache.failedTargets.Get(normalizedCacheValue)
require.Nil(t, err)
require.NotNil(t, failedTarget)
Expand Down Expand Up @@ -126,7 +152,7 @@ func TestCacheMarkFailedConcurrent(t *testing.T) {
wg.Add(1)
go func() {
defer wg.Done()
cache.MarkFailed(protoType, newCtxArgs(currentTest.host), fmt.Errorf("could not resolve host"))
cache.MarkFailed(protoType, newCtxArgs(currentTest.host), errors.New("net/http: timeout awaiting response headers"))
}()
}
}
Expand All @@ -144,6 +170,26 @@ func TestCacheMarkFailedConcurrent(t *testing.T) {
}
}

func TestCacheCheckConcurrent(t *testing.T) {
cache := New(3, DefaultMaxHostsCount, nil)
ctx := newCtxArgs(t.Name())

wg := sync.WaitGroup{}
for i := 1; i <= 100; i++ {
wg.Add(1)
i := i
go func() {
defer wg.Done()
cache.MarkFailed(protoType, ctx, errors.New("no address found for host"))
if i >= 3 {
got := cache.Check(protoType, ctx)
require.True(t, got)
}
}()
}
wg.Wait()
}

func newCtxArgs(value string) *contextargs.Context {
ctx := contextargs.NewWithInput(context.TODO(), value)
return ctx
Expand Down
26 changes: 7 additions & 19 deletions pkg/protocols/http/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,11 +149,8 @@ func (request *Request) executeRaceRequest(input *contextargs.Context, previous

// look for unresponsive hosts and cancel inflight requests as well
spmHandler.SetOnResultCallback(func(err error) {
if err == nil {
return
}
// marks thsi host as unresponsive if applicable
request.markUnresponsiveAddress(input, err)
request.markHostError(input, err)
if request.isUnresponsiveAddress(input) {
// stop all inflight requests
spmHandler.Cancel()
Expand Down Expand Up @@ -234,11 +231,8 @@ func (request *Request) executeParallelHTTP(input *contextargs.Context, dynamicV

// look for unresponsive hosts and cancel inflight requests as well
spmHandler.SetOnResultCallback(func(err error) {
if err == nil {
return
}
// marks thsi host as unresponsive if applicable
request.markUnresponsiveAddress(input, err)
request.markHostError(input, err)
if request.isUnresponsiveAddress(input) {
// stop all inflight requests
spmHandler.Cancel()
Expand Down Expand Up @@ -378,11 +372,8 @@ func (request *Request) executeTurboHTTP(input *contextargs.Context, dynamicValu

// look for unresponsive hosts and cancel inflight requests as well
spmHandler.SetOnResultCallback(func(err error) {
if err == nil {
return
}
// marks thsi host as unresponsive if applicable
request.markUnresponsiveAddress(input, err)
request.markHostError(input, err)
if request.isUnresponsiveAddress(input) {
// stop all inflight requests
spmHandler.Cancel()
Expand Down Expand Up @@ -551,12 +542,12 @@ func (request *Request) ExecuteWithResults(input *contextargs.Context, dynamicVa
}
if execReqErr != nil {
// if applicable mark the host as unresponsive
request.markUnresponsiveAddress(updatedInput, execReqErr)
requestErr = errorutil.NewWithErr(execReqErr).Msgf("got err while executing %v", generatedHttpRequest.URL())
request.options.Progress.IncrementFailedRequestsBy(1)
} else {
request.options.Progress.IncrementRequests()
}
request.markHostError(updatedInput, execReqErr)

// If this was a match, and we want to stop at first match, skip all further requests.
shouldStopAtFirstMatch := generatedHttpRequest.original.options.Options.StopAtFirstMatch || generatedHttpRequest.original.options.StopAtFirstMatch || request.StopAtFirstMatch
Expand Down Expand Up @@ -1199,13 +1190,10 @@ func (request *Request) newContext(input *contextargs.Context) context.Context {
return input.Context()
}

// markUnresponsiveAddress checks if the error is a unreponsive host error and marks it
func (request *Request) markUnresponsiveAddress(input *contextargs.Context, err error) {
if err == nil {
return
}
// markHostError checks if the error is a unreponsive host error and marks it
func (request *Request) markHostError(input *contextargs.Context, err error) {
if request.options.HostErrorsCache != nil {
request.options.HostErrorsCache.MarkFailed(request.options.ProtocolType.String(), input, err)
request.options.HostErrorsCache.MarkFailedOrRemove(request.options.ProtocolType.String(), input, err)
}
}

Expand Down
Loading

0 comments on commit 052fd8b

Please sign in to comment.