diff --git a/middleware/timeout/timeout.go b/middleware/timeout/timeout.go index 9e6caa4295..1bd2739c28 100644 --- a/middleware/timeout/timeout.go +++ b/middleware/timeout/timeout.go @@ -1,10 +1,11 @@ package timeout import ( - "context" "errors" + "time" "github.com/gofiber/fiber/v3" + utils "github.com/gofiber/utils/v2" ) // New enforces a timeout for each incoming request. It replaces the request's @@ -21,42 +22,34 @@ func New(h fiber.Handler, config ...Config) fiber.Handler { timeout := cfg.Timeout if timeout <= 0 { - return runHandler(ctx, h, cfg) + return h(ctx) } - parent := ctx.Context() - tCtx, cancel := context.WithTimeout(parent, timeout) - ctx.SetContext(tCtx) - defer func() { - cancel() - ctx.SetContext(parent) + err := make(chan error, 1) + go func() { + err <- h(ctx) }() - - err := runHandler(ctx, h, cfg) - - if errors.Is(tCtx.Err(), context.DeadlineExceeded) && err == nil { + select { + case err := <-err: + if err != nil && (len(cfg.Errors) > 0 && isCustomError(err, cfg.Errors)) { + if cfg.OnTimeout != nil { + if toErr := cfg.OnTimeout(ctx); toErr != nil { + return toErr + } + } + return fiber.ErrRequestTimeout + } + return err + case <-time.After(timeout): if cfg.OnTimeout != nil { - return cfg.OnTimeout(ctx) + err := cfg.OnTimeout(ctx) + ctx.RequestCtx().TimeoutErrorWithResponse(&ctx.RequestCtx().Response) + return err } + ctx.RequestCtx().TimeoutErrorWithCode(utils.StatusMessage(fiber.StatusRequestTimeout), fiber.StatusRequestTimeout) return fiber.ErrRequestTimeout } - return err - } -} - -// runHandler executes the handler and returns fiber.ErrRequestTimeout if it -// sees a deadline exceeded error or one of the custom "timeout-like" errors. -func runHandler(c fiber.Ctx, h fiber.Handler, cfg Config) error { - err := h(c) - if err != nil && (errors.Is(err, context.DeadlineExceeded) || (len(cfg.Errors) > 0 && isCustomError(err, cfg.Errors))) { - if cfg.OnTimeout != nil { - if toErr := cfg.OnTimeout(c); toErr != nil { - return toErr - } - } - return fiber.ErrRequestTimeout } - return err } // isCustomError checks whether err matches any error in errList using errors.Is. diff --git a/middleware/timeout/timeout_test.go b/middleware/timeout/timeout_test.go index 98530fc3a4..22b42f59b9 100644 --- a/middleware/timeout/timeout_test.go +++ b/middleware/timeout/timeout_test.go @@ -21,6 +21,21 @@ var ( errUnrelated = errors.New("unmatched error") ) +// runHandler executes the handler and returns fiber.ErrRequestTimeout if it +// sees a deadline exceeded error or one of the custom "timeout-like" errors. +func runHandler(c fiber.Ctx, h fiber.Handler, cfg Config) error { + err := h(c) + if err != nil && (len(cfg.Errors) > 0 && isCustomError(err, cfg.Errors)) { + if cfg.OnTimeout != nil { + if toErr := cfg.OnTimeout(c); toErr != nil { + return toErr + } + } + return fiber.ErrRequestTimeout + } + return err +} + // sleepWithContext simulates a task that takes `d` time, but returns `te` if the context is canceled. func sleepWithContext(ctx context.Context, d time.Duration, te error) error { timer := time.NewTimer(d) @@ -155,11 +170,9 @@ func TestTimeout_CustomHandler(t *testing.T) { app := fiber.New() called := 0 - app.Get("/custom-handler", New(func(c fiber.Ctx) error { - if err := sleepWithContext(c.Context(), 100*time.Millisecond, context.DeadlineExceeded); err != nil { - return err - } - return c.SendString("should not reach") + app.Get("/custom-handler", New(func(_ fiber.Ctx) error { + time.Sleep(100 * time.Millisecond) + return context.DeadlineExceeded }, Config{ Timeout: 20 * time.Millisecond, OnTimeout: func(c fiber.Ctx) error { @@ -175,19 +188,6 @@ func TestTimeout_CustomHandler(t *testing.T) { require.Equal(t, 1, called) } -// TestRunHandler_DefaultOnTimeout ensures context.DeadlineExceeded triggers ErrRequestTimeout. -func TestRunHandler_DefaultOnTimeout(t *testing.T) { - app := fiber.New() - ctx := app.AcquireCtx(&fasthttp.RequestCtx{}) - defer app.ReleaseCtx(ctx) - - err := runHandler(ctx, func(_ fiber.Ctx) error { - return context.DeadlineExceeded - }, Config{}) - - require.Equal(t, fiber.ErrRequestTimeout, err) -} - // TestRunHandler_CustomOnTimeout verifies that a custom error and OnTimeout handler are used. func TestRunHandler_CustomOnTimeout(t *testing.T) { app := fiber.New()