Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added: Endware functionality #51

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
148 changes: 127 additions & 21 deletions chain.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
// Package alice provides a convenient way to chain http handlers.
package alice

import "net/http"
import (
"net/http"
)

// A constructor for a piece of middleware.
// Some middleware use this constructor out of the box,
@@ -14,39 +16,65 @@ type Constructor func(http.Handler) http.Handler
// the same set of constructors in the same order.
type Chain struct {
constructors []Constructor
endwares []Endware
}

// New creates a new chain,
// memorizing the given list of middleware constructors.
// New serves no other function,
// constructors are only called upon a call to Then().
func New(constructors ...Constructor) Chain {
return Chain{append(([]Constructor)(nil), constructors...)}
return Chain{append(([]Constructor)(nil), constructors...), ([]Endware)(nil)}
}

// Then chains the middleware and returns the final http.Handler.
// New(m1, m2, m3).Then(h)
// endwareHandler represents a handler that has been modified
// to execute endwares afterwards. This is a helper for Then()
// because if we just wrap it in an anonymous
// http.HandlerFunc(func(w http.ResponseWriter, r *http.Request)))
// there is a stack overflow
type endwareHandler struct {
handler http.Handler
endwares []Endware
}

// ServeHTTP serves the main endwareHandler's handler as well as
// calling all of the individual endwares afterwards.
func (eh endwareHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
eh.handler.ServeHTTP(w, r)
for _, endware := range eh.endwares {
endware.ServeHTTP(w, r)
}
}

// Then chains the middleware and endwares and returns the final http.Handler.
// New(m1, m2, m3).Finally(e1, e2, e3).Then(h)
// is equivalent to:
// m1(m2(m3(h)))
// When the request comes in, it will be passed to m1, then m2, then m3
// and finally, the given handler
// (assuming every middleware calls the following one).
// followed by:
// e1(e2(e3()))
// When the request comes in, it will be passed to m1, then m2, then m3,
// then the given handler (who serves the response), then e1, e2, e3
// (assuming every middleware/endwares calls the following one).
//
// A chain can be safely reused by calling Then() several times.
// stdStack := alice.New(ratelimitHandler, csrfHandler)
// stdStack := alice.New(ratelimitHandler, csrfHandler).Finally(loggingHandler)
// indexPipe = stdStack.Then(indexHandler)
// authPipe = stdStack.Then(authHandler)
// Note that constructors are called on every call to Then()
// and thus several instances of the same middleware will be created
// Note that constructors and endwares are called on every call to Then()
// and thus several instances of the same middleware/endwares will be created
// when a chain is reused in this way.
// For proper middleware, this should cause no problems.
// For proper middleware/endwares, this should cause no problems.
//
// Then() treats nil as http.DefaultServeMux.
func (c Chain) Then(h http.Handler) http.Handler {
if h == nil {
h = http.DefaultServeMux
}

if len(c.endwares) > 0 {
h = endwareHandler{h, c.endwares}
}

for i := range c.constructors {
h = c.constructors[len(c.constructors)-1-i](h)
}
@@ -73,6 +101,7 @@ func (c Chain) ThenFunc(fn http.HandlerFunc) http.Handler {
// as the last ones in the request flow.
//
// Append returns a new chain, leaving the original one untouched.
// The new chain will have the original chain's endwares.
//
// stdChain := alice.New(m1, m2)
// extChain := stdChain.Append(m3, m4)
@@ -83,7 +112,7 @@ func (c Chain) Append(constructors ...Constructor) Chain {
newCons = append(newCons, c.constructors...)
newCons = append(newCons, constructors...)

return Chain{newCons}
return New(newCons...).AppendEndware(c.endwares...)
}

// Extend extends a chain by adding the specified chain
@@ -92,21 +121,98 @@ func (c Chain) Append(constructors ...Constructor) Chain {
// Extend returns a new chain, leaving the original one untouched.
//
// stdChain := alice.New(m1, m2)
// ext1Chain := alice.New(m3, m4)
// ext1Chain := alice.New(m3, m4).Finally(e1, e2)
// ext2Chain := stdChain.Extend(ext1Chain)
// // requests in stdChain go m1 -> m2
// // requests in ext1Chain go m3 -> m4
// // requests in ext2Chain go m1 -> m2 -> m3 -> m4
// // requests in stdChain go m1 -> m2 -> handler
// // requests in ext1Chain go m3 -> m4 -> handler -> e1 -> e2
// // requests in ext2Chain go m1 -> m2 -> m3 -> m4 -> handler -> e1 -> e2
//
// Another example:
// aHtmlAfterNosurf := alice.New(m2)
// logRequestChain := aHtmlAfterNosurf.Finally(e1)
// aHtml := alice.New(m1, func(h http.Handler) http.Handler {
// csrf := nosurf.New(h)
// csrf.SetFailureHandler(aHtmlAfterNosurf.ThenFunc(csrfFail))
// csrf.SetFailureHandler(logRequestChain.ThenFunc(csrfFail))
// return csrf
// }).Extend(aHtmlAfterNosurf)
// // requests to aHtml hitting nosurfs success handler go m1 -> nosurf -> m2 -> target-handler
// // requests to aHtml hitting nosurfs failure handler go m1 -> nosurf -> m2 -> csrfFail
// }).Extend(logRequestChain)
// // requests to aHtml hitting nosurfs success handler go:
// m1 -> nosurf -> m2 -> target-handler -> e1
// // requests to aHtml hitting nosurfs failure handler go:
// m1 -> nosurf -> m2 -> csrfFail -> e1
func (c Chain) Extend(chain Chain) Chain {
return c.Append(chain.constructors...)
return c.
Append(chain.constructors...).
AppendEndware(chain.endwares...)
}

// Endware is functionality executed after a the main handler is called
// and response has been sent to the requester. Like middleware,
// values from the request or response can be accessed. This will not
// let you access values from the request or the response that can no longer be used.
// e.g. re-reading a request body, re-setting the response headers, etc.
type Endware http.Handler

// Finally creates a new chain with the original chain's
// constructors and endwares, as well as the provided endwares.
// Endwares are executed after both the constructors and
// the Then() handler are called.
func (c Chain) Finally(endwares ...Endware) Chain {
newEnds := make([]Endware, 0, len(c.endwares)+len(endwares))
newEnds = append(newEnds, c.endwares...)
newEnds = append(newEnds, endwares...)

newC := New(c.constructors...)
newC.endwares = newEnds
return newC
}

// FinallyFuncs works identically to Finally, but takes HandlerFuncs
// instead of Endwares.
//
// The following two statements are equivalent:
// c.Finally(http.HandlerFunc(fn1), http.HandlerFunc(fn2))
// c.FinallyFuncs(fn1, fn2)
//
// FinallyFuncs provides all the guarantees of Finally.
func (c Chain) FinallyFuncs(fns ...func(w http.ResponseWriter, r *http.Request)) Chain {
// convert each http.HandlerFunc into an Endware
endwares := make([]Endware, len(fns))
for i, fn := range fns {
endwares[i] = http.HandlerFunc(fn)
}

return c.Finally(endwares...)
}

// AppendEndware extends a chain, adding the specified endwares
// as the last ones in the request flow.
//
// AppendEndware returns a new chain, leaving the original one untouched.
// The new chain will have the original chain's constructors.
//
// stdChain := alice.New(m1).Finally(e1, e2)
// extChain := stdChain.AppendEndware(e3, e4)
// // requests in stdHandler go m1 -> handler -> e1 -> e2
// // requests in extHandler go m1 -> handler -> e1 -> e2 -> e3 -> e4
func (c Chain) AppendEndware(endwares ...Endware) Chain {
return New(c.constructors...).Finally(append(c.endwares, endwares...)...)
}

// AppendEndwareFuncs works identically to AppendEndware, but takes HandlerFuncs
// instead of Endwares.
//
// The following two statements are equivalent:
// c.AppendEndware(http.HandlerFunc(fn1), http.HandlerFunc(fn2))
// c.AppendEndwareFuncs(fn1, fn2)
//
// AppendEndwareFuncs provides all the guarantees of AppendEndware.
func (c Chain) AppendEndwareFuncs(fns ...func(w http.ResponseWriter, r *http.Request)) Chain {
// convert each http.HandlerFunc into an Endware
endwares := make([]Endware, len(fns))
for i, fn := range fns {
endwares[i] = http.HandlerFunc(fn)
}

return c.AppendEndware(endwares...)

}
148 changes: 136 additions & 12 deletions chain_test.go
Original file line number Diff line number Diff line change
@@ -20,6 +20,12 @@ func tagMiddleware(tag string) Constructor {
}
}

func tagEndware(tag string) Endware {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(tag))
})
}

// Not recommended (https://golang.org/pkg/reflect/#Value.Pointer),
// but the best we can do.
func funcsEqual(f1, f2 interface{}) bool {
@@ -51,20 +57,40 @@ func TestNew(t *testing.T) {
}
}

func TestFinally(t *testing.T) {
e1 := tagEndware("e1\n")
e2 := tagEndware("e2\n")

slice := []Endware{e1, e2}

chain := New().Finally(slice...)
for k := range slice {
if !funcsEqual(chain.endwares[k], slice[k]) {
t.Error("Finally does not add endwares correctly")
}
}
}

func TestThenWorksWithNoMiddleware(t *testing.T) {
if !funcsEqual(New().Then(testApp), testApp) {
t.Error("Then does not work with no middleware")
}
}

func TestThenWorksWithNoEndware(t *testing.T) {
if !funcsEqual(New().Finally().Then(testApp), testApp) {
t.Error("Then does not work with no endware")
}
}

func TestThenTreatsNilAsDefaultServeMux(t *testing.T) {
if New().Then(nil) != http.DefaultServeMux {
if New().Finally().Then(nil) != http.DefaultServeMux {
t.Error("Then does not treat nil as DefaultServeMux")
}
}

func TestThenFuncTreatsNilAsDefaultServeMux(t *testing.T) {
if New().ThenFunc(nil) != http.DefaultServeMux {
if New().Finally().ThenFunc(nil) != http.DefaultServeMux {
t.Error("ThenFunc does not treat nil as DefaultServeMux")
}
}
@@ -73,7 +99,7 @@ func TestThenFuncConstructsHandlerFunc(t *testing.T) {
fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
})
chained := New().ThenFunc(fn)
chained := New().Finally().ThenFunc(fn)
rec := httptest.NewRecorder()

chained.ServeHTTP(rec, (*http.Request)(nil))
@@ -87,8 +113,11 @@ func TestThenOrdersHandlersCorrectly(t *testing.T) {
t1 := tagMiddleware("t1\n")
t2 := tagMiddleware("t2\n")
t3 := tagMiddleware("t3\n")
e1 := tagEndware("e1\n")
e2 := tagEndware("e2\n")
e3 := tagEndware("e3\n")

chained := New(t1, t2, t3).Then(testApp)
chained := New(t1, t2, t3).Finally(e1, e2, e3).Then(testApp)

w := httptest.NewRecorder()
r, err := http.NewRequest("GET", "/", nil)
@@ -98,7 +127,7 @@ func TestThenOrdersHandlersCorrectly(t *testing.T) {

chained.ServeHTTP(w, r)

if w.Body.String() != "t1\nt2\nt3\napp\n" {
if w.Body.String() != "t1\nt2\nt3\napp\ne1\ne2\ne3\n" {
t.Error("Then does not order handlers correctly")
}
}
@@ -110,9 +139,15 @@ func TestAppendAddsHandlersCorrectly(t *testing.T) {
if len(chain.constructors) != 2 {
t.Error("chain should have 2 constructors")
}
if len(chain.endwares) != 0 {
t.Error("chain should have 0 endwares")
}
if len(newChain.constructors) != 4 {
t.Error("newChain should have 4 constructors")
}
if len(newChain.endwares) != 0 {
t.Error("newChain should have 0 endwares")
}

chained := newChain.Then(testApp)

@@ -129,29 +164,114 @@ func TestAppendAddsHandlersCorrectly(t *testing.T) {
}
}

func TestAppendEndwareAddsHandlersCorrectly(t *testing.T) {
chain := New(tagMiddleware("t1\n")).Finally(tagEndware("e1\n"), tagEndware("e2\n"))
newChain := chain.AppendEndware(tagEndware("e3\n"), tagEndware("e4\n"))

if len(chain.constructors) != 1 {
t.Error("chain should have 1 constructor")
}
if len(chain.endwares) != 2 {
t.Error("chain should have 2 endwares")
}
if len(newChain.constructors) != 1 {
t.Error("newChain should have 1 constructor")
}
if len(newChain.endwares) != 4 {
t.Error("newChain should have 4 endwares")
}

chained := newChain.Then(testApp)

w := httptest.NewRecorder()
r, err := http.NewRequest("GET", "/", nil)
if err != nil {
t.Fatal(err)
}

chained.ServeHTTP(w, r)

if w.Body.String() != "t1\napp\ne1\ne2\ne3\ne4\n" {
t.Error("AppendEndware does not add handlers correctly")
}
}

func TestAppendRespectsImmutability(t *testing.T) {
chain := New(tagMiddleware(""))
chain := New(tagMiddleware("")).Finally(tagEndware(""))
newChain := chain.Append(tagMiddleware(""))

if &chain.constructors[0] == &newChain.constructors[0] {
t.Error("Apppend does not respect immutability")
t.Error("Append does not respect constructor immutability")
}

if &chain.endwares[0] == &newChain.endwares[0] {
t.Error("Append does not respect endware immutability")
}
}

func TestAppendEndwareRespectsImmutability(t *testing.T) {
chain := New(tagMiddleware("")).Finally(tagEndware(""))
newChain := chain.AppendEndware(tagEndware(""))

if &chain.constructors[0] == &newChain.constructors[0] {
t.Error("AppendEndware does not respect constructor immutability")
}

if &chain.endwares[0] == &newChain.endwares[0] {
t.Error("AppendEndware does not respect endware immutability")
}
}

func TestExtendsRespectsImmutability(t *testing.T) {
chain := New(tagMiddleware("")).Finally(tagEndware(""))
newChain := New(tagMiddleware("")).Finally(tagEndware(""))
newChain = chain.Extend(newChain)

// chain.constructors[0] should have the same functionality as
// newChain.constructors[1], but check both anyways
if &chain.constructors[0] == &newChain.constructors[0] {
t.Error("Extends does not respect constructor immutability")
}

if &chain.constructors[0] == &newChain.constructors[1] {
t.Error("Extends does not respect constructor immutability")
}

if &chain.endwares[0] == &newChain.endwares[0] {
t.Error("Extends does not respect endware immutability")
}

if &chain.endwares[0] == &newChain.endwares[1] {
t.Error("Extends does not respect endware immutability")
}
}

func TestExtendAddsHandlersCorrectly(t *testing.T) {
chain1 := New(tagMiddleware("t1\n"), tagMiddleware("t2\n"))
chain2 := New(tagMiddleware("t3\n"), tagMiddleware("t4\n"))
chain2 := New(tagMiddleware("t3\n"), tagMiddleware("t4\n")).
Finally(tagEndware("e1\n"), tagEndware("e2\n"))
newChain := chain1.Extend(chain2)

if len(chain1.constructors) != 2 {
t.Error("chain1 should contain 2 constructors")
}
if len(chain1.endwares) != 0 {
t.Error("chain1 should contain 0 endwares")
}

if len(chain2.constructors) != 2 {
t.Error("chain2 should contain 2 constructors")
}
if len(chain2.endwares) != 2 {
t.Error("chain2 should contain 2 endwares")
}

if len(newChain.constructors) != 4 {
t.Error("newChain should contain 4 constructors")
}
if len(newChain.endwares) != 2 {
t.Error("newChain should contain 2 endwares")
}

chained := newChain.Then(testApp)

@@ -163,16 +283,20 @@ func TestExtendAddsHandlersCorrectly(t *testing.T) {

chained.ServeHTTP(w, r)

if w.Body.String() != "t1\nt2\nt3\nt4\napp\n" {
if w.Body.String() != "t1\nt2\nt3\nt4\napp\ne1\ne2\n" {
t.Error("Extend does not add handlers in correctly")
}
}

func TestExtendRespectsImmutability(t *testing.T) {
chain := New(tagMiddleware(""))
newChain := chain.Extend(New(tagMiddleware("")))
chain := New(tagMiddleware("")).Finally(tagEndware(""))
newChain := chain.Extend(New())

if &chain.constructors[0] == &newChain.constructors[0] {
t.Error("Extend does not respect immutability")
t.Error("Extend does not respect immutability for constructors")
}

if &chain.endwares[0] == &newChain.endwares[0] {
t.Error("Extend does not respect immutability for endwares")
}
}