Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add gorilla/mux path middleware #226

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/gorilla/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
"github.com/prometheus/client_golang/prometheus/promhttp"
metrics "github.com/slok/go-http-metrics/metrics/prometheus"
"github.com/slok/go-http-metrics/middleware"
"github.com/slok/go-http-metrics/middleware/std"
muxmiddleware "github.com/slok/go-http-metrics/middleware/std"
)

const (
Expand All @@ -28,7 +28,7 @@ func main() {

// Create our router with the metrics middleware.
r := mux.NewRouter()
r.Use(std.HandlerProvider("", mdlw))
r.Use(muxmiddleware.HandlerProvider("", mdlw))

// Add paths.
r.Methods("GET").Path("/").HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand Down
42 changes: 42 additions & 0 deletions middleware/mux/example_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package mux_test

import (
"log"
"net/http"

"github.com/prometheus/client_golang/prometheus/promhttp"

metrics "github.com/slok/go-http-metrics/metrics/prometheus"
"github.com/slok/go-http-metrics/middleware"
muxmiddleware "github.com/slok/go-http-metrics/middleware/mux"
)

// MuxMiddleware shows how you would create a default middleware factory and use it
// to create a Gorilla Mux `http.Handler` compatible middleware.
func Example_muxMiddleware() {
// Create our middleware factory with the default settings.
mdlw := middleware.New(middleware.Config{
Recorder: metrics.NewRecorder(metrics.Config{}),
})

// Create our handler.
myHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("hello world!"))
})

// Wrap our handler with the middleware.
h := muxmiddleware.Handler("", mdlw, myHandler)

// Serve metrics from the default prometheus registry.
log.Printf("serving metrics at: %s", ":8081")
go func() {
_ = http.ListenAndServe(":8081", promhttp.Handler())
}()

// Serve our handler.
log.Printf("listening at: %s", ":8080")
if err := http.ListenAndServe(":8080", h); err != nil {
log.Panicf("error while serving: %s", err)
}
}
102 changes: 102 additions & 0 deletions middleware/mux/mux.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
// Package mux is a helper package to get a Gorilla Mux compatible middleware.
package mux

import (
"bufio"
"context"
"errors"
"net"
"net/http"

"github.com/gorilla/mux"
"github.com/slok/go-http-metrics/middleware"
)

// Handler returns an measuring standard http.Handler.
func Handler(handlerID string, m middleware.Middleware, h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
wi := &responseWriterInterceptor{
statusCode: http.StatusOK,
ResponseWriter: w,
}
reporter := &muxReporter{
w: wi,
r: r,
}

m.Measure(handlerID, reporter, func() {
h.ServeHTTP(wi, r)
})
})
}

// HandlerProvider is a helper method that returns a handler provider. This kind of
// provider is a defacto standard in some frameworks (e.g: Gorilla, Chi...).
func HandlerProvider(handlerID string, m middleware.Middleware) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return Handler(handlerID, m, next)
}
}

type muxReporter struct {
w *responseWriterInterceptor
r *http.Request
}

func (m *muxReporter) Method() string { return m.r.Method }

func (m *muxReporter) Context() context.Context { return m.r.Context() }

func (m *muxReporter) URLPath() string {
path, err := mux.CurrentRoute(m.r).GetPathTemplate()
if err != nil {
return m.r.URL.Path
}
return path
}

func (m *muxReporter) StatusCode() int { return m.w.statusCode }

func (m *muxReporter) BytesWritten() int64 { return int64(m.w.bytesWritten) }

// responseWriterInterceptor is a simple wrapper to intercept set data on a
// ResponseWriter.
type responseWriterInterceptor struct {
http.ResponseWriter
statusCode int
bytesWritten int
}

func (w *responseWriterInterceptor) WriteHeader(statusCode int) {
w.statusCode = statusCode
w.ResponseWriter.WriteHeader(statusCode)
}

func (w *responseWriterInterceptor) Write(p []byte) (int, error) {
w.bytesWritten += len(p)
return w.ResponseWriter.Write(p)
}

func (w *responseWriterInterceptor) Hijack() (net.Conn, *bufio.ReadWriter, error) {
h, ok := w.ResponseWriter.(http.Hijacker)
if !ok {
return nil, nil, errors.New("type assertion failed http.ResponseWriter not a http.Hijacker")
}
return h.Hijack()
}

func (w *responseWriterInterceptor) Flush() {
f, ok := w.ResponseWriter.(http.Flusher)
if !ok {
return
}

f.Flush()
}

// Check interface implementations.
var (
_ http.ResponseWriter = &responseWriterInterceptor{}
_ http.Hijacker = &responseWriterInterceptor{}
_ http.Flusher = &responseWriterInterceptor{}
)
163 changes: 163 additions & 0 deletions middleware/mux/mux_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
package mux_test

import (
"io"
"net/http"
"net/http/httptest"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"

"github.com/gorilla/mux"
mmetrics "github.com/slok/go-http-metrics/internal/mocks/metrics"
"github.com/slok/go-http-metrics/metrics"
"github.com/slok/go-http-metrics/middleware"
muxmiddleware "github.com/slok/go-http-metrics/middleware/mux"
)

func TestMiddleware(t *testing.T) {
tests := map[string]struct {
handlerID string
config middleware.Config
req func() *http.Request
mock func(m *mmetrics.Recorder)
handler func() http.Handler
expRespCode int
expRespBody string
}{
"A default HTTP middleware should call the recorder to measure.": {
req: func() *http.Request {
return httptest.NewRequest(http.MethodPost, "/test", nil)
},
mock: func(m *mmetrics.Recorder) {
expHTTPReqProps := metrics.HTTPReqProperties{
ID: "/test",
Service: "",
Method: "POST",
Code: "202",
}
m.On("ObserveHTTPRequestDuration", mock.Anything, expHTTPReqProps, mock.Anything).Once()
m.On("ObserveHTTPResponseSize", mock.Anything, expHTTPReqProps, int64(15)).Once()

expHTTPProps := metrics.HTTPProperties{
ID: "/test",
Service: "",
}
m.On("AddInflightRequests", mock.Anything, expHTTPProps, 1).Once()
m.On("AddInflightRequests", mock.Anything, expHTTPProps, -1).Once()
},
handler: func() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(202)
w.Write([]byte("Я бэтмен")) // nolint: errcheck
})
},
expRespCode: 202,
expRespBody: "Я бэтмен",
},
}

for name, test := range tests {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)

// Mocks.
mr := &mmetrics.Recorder{}
test.mock(mr)

// Create our mux instance with the middleware.
test.config.Recorder = mr
m := middleware.New(test.config)
h := muxmiddleware.Handler(test.handlerID, m, test.handler())
r := mux.NewRouter()
r.Handle("/test", h)

// Make the request.
resp := httptest.NewRecorder()
r.ServeHTTP(resp, test.req())

// Check.
mr.AssertExpectations(t)
assert.Equal(test.expRespCode, resp.Result().StatusCode)
gotBody, err := io.ReadAll(resp.Result().Body)
require.NoError(err)
assert.Equal(test.expRespBody, string(gotBody))
})
}
}

func TestProvider(t *testing.T) {
tests := map[string]struct {
handlerID string
config middleware.Config
req func() *http.Request
mock func(m *mmetrics.Recorder)
handler func() http.Handler
expRespCode int
expRespBody string
}{
"A default HTTP middleware should call the recorder to measure.": {
req: func() *http.Request {
return httptest.NewRequest(http.MethodPost, "/test", nil)
},
mock: func(m *mmetrics.Recorder) {
expHTTPReqProps := metrics.HTTPReqProperties{
ID: "/test",
Service: "",
Method: "POST",
Code: "202",
}
m.On("ObserveHTTPRequestDuration", mock.Anything, expHTTPReqProps, mock.Anything).Once()
m.On("ObserveHTTPResponseSize", mock.Anything, expHTTPReqProps, int64(15)).Once()

expHTTPProps := metrics.HTTPProperties{
ID: "/test",
Service: "",
}
m.On("AddInflightRequests", mock.Anything, expHTTPProps, 1).Once()
m.On("AddInflightRequests", mock.Anything, expHTTPProps, -1).Once()
},
handler: func() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(202)
w.Write([]byte("Я бэтмен")) // nolint: errcheck
})
},
expRespCode: 202,
expRespBody: "Я бэтмен",
},
}

for name, test := range tests {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)

// Mocks.
mr := &mmetrics.Recorder{}
test.mock(mr)

// Create our mux instance with the middleware.
test.config.Recorder = mr
m := middleware.New(test.config)
provider := muxmiddleware.HandlerProvider(test.handlerID, m)
h := provider(test.handler())
r := mux.NewRouter()
r.Handle("/test", h)

// Make the request.
resp := httptest.NewRecorder()
r.ServeHTTP(resp, test.req())

// Check.
mr.AssertExpectations(t)
assert.Equal(test.expRespCode, resp.Result().StatusCode)
gotBody, err := io.ReadAll(resp.Result().Body)
require.NoError(err)
assert.Equal(test.expRespBody, string(gotBody))
})
}
}