Skip to content

Commit 8ff25d5

Browse files
Demétrio de Castro Menezes NetoDemétrio de Castro Menezes Neto
Demétrio de Castro Menezes Neto
authored and
Demétrio de Castro Menezes Neto
committed
Add configurable preflight status code for CORS middleware
1 parent de44c53 commit 8ff25d5

File tree

2 files changed

+56
-6
lines changed

2 files changed

+56
-6
lines changed

middleware/cors.go

+20-6
Original file line numberDiff line numberDiff line change
@@ -107,13 +107,23 @@ type CORSConfig struct {
107107
//
108108
// See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Max-Age
109109
MaxAge int `yaml:"max_age"`
110+
111+
// PreflightStatusCode determines the status code to be returned on a
112+
// successful preflight request.
113+
//
114+
// Optional. Default value is http.StatusNoContent(204)
115+
//
116+
// See also: https://fetch.spec.whatwg.org/#ref-for-ok-status
117+
// See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Methods/OPTIONS#preflighted_requests_in_cors
118+
PreflightStatusCode int `yaml:"preflight_status_code"`
110119
}
111120

112121
// DefaultCORSConfig is the default CORS middleware config.
113122
var DefaultCORSConfig = CORSConfig{
114-
Skipper: DefaultSkipper,
115-
AllowOrigins: []string{"*"},
116-
AllowMethods: []string{http.MethodGet, http.MethodHead, http.MethodPut, http.MethodPatch, http.MethodPost, http.MethodDelete},
123+
Skipper: DefaultSkipper,
124+
AllowOrigins: []string{"*"},
125+
AllowMethods: []string{http.MethodGet, http.MethodHead, http.MethodPut, http.MethodPatch, http.MethodPost, http.MethodDelete},
126+
PreflightStatusCode: http.StatusNoContent,
117127
}
118128

119129
// CORS returns a Cross-Origin Resource Sharing (CORS) middleware.
@@ -147,6 +157,10 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
147157
config.AllowMethods = DefaultCORSConfig.AllowMethods
148158
}
149159

160+
if config.PreflightStatusCode == 0 {
161+
config.PreflightStatusCode = DefaultCORSConfig.PreflightStatusCode
162+
}
163+
150164
allowOriginPatterns := make([]*regexp.Regexp, 0, len(config.AllowOrigins))
151165
for _, origin := range config.AllowOrigins {
152166
if origin == "*" {
@@ -214,7 +228,7 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
214228
if !preflight {
215229
return next(c)
216230
}
217-
return c.NoContent(http.StatusNoContent)
231+
return c.NoContent(config.PreflightStatusCode)
218232
}
219233

220234
if config.AllowOriginFunc != nil {
@@ -264,7 +278,7 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
264278
if !preflight {
265279
return echo.ErrUnauthorized
266280
}
267-
return c.NoContent(http.StatusNoContent)
281+
return c.NoContent(config.PreflightStatusCode)
268282
}
269283

270284
res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowOrigin)
@@ -301,7 +315,7 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
301315
if config.MaxAge != 0 {
302316
res.Header().Set(echo.HeaderAccessControlMaxAge, maxAge)
303317
}
304-
return c.NoContent(http.StatusNoContent)
318+
return c.NoContent(config.PreflightStatusCode)
305319
}
306320
}
307321
}

middleware/cors_test.go

+36
Original file line numberDiff line numberDiff line change
@@ -683,3 +683,39 @@ func Test_allowOriginFunc(t *testing.T) {
683683
}
684684
}
685685
}
686+
687+
func TestCORSWithConfig_PreflightStatusCode(t *testing.T) {
688+
tests := []struct {
689+
name string
690+
mw echo.MiddlewareFunc
691+
expectedStatusCode int
692+
}{
693+
{
694+
name: "ok, preflight with default config returns http.StatusNoContent (204)",
695+
mw: CORS(),
696+
expectedStatusCode: http.StatusNoContent,
697+
},
698+
{
699+
name: "ok, preflight returning http.StatusOK (200)",
700+
mw: CORSWithConfig(CORSConfig{
701+
PreflightStatusCode: http.StatusOK,
702+
}),
703+
expectedStatusCode: http.StatusOK,
704+
},
705+
}
706+
e := echo.New()
707+
708+
for _, tc := range tests {
709+
req := httptest.NewRequest(http.MethodOptions, "/", nil)
710+
rec := httptest.NewRecorder()
711+
712+
c := e.NewContext(req, rec)
713+
714+
cors := tc.mw(echo.NotFoundHandler)
715+
err := cors(c)
716+
717+
assert.NoError(t, err)
718+
assert.Equal(t, rec.Result().StatusCode, tc.expectedStatusCode)
719+
720+
}
721+
}

0 commit comments

Comments
 (0)