From 48f7cbaddfc87d31b4a569405b9ec6ecbe79beed Mon Sep 17 00:00:00 2001 From: sashayakovtseva Date: Thu, 17 Jun 2021 18:18:41 +0300 Subject: [PATCH] Support custom labels --- go.mod | 4 +- internal/mocks/doc.go | 5 +- internal/mocks/metrics/Recorder.go | 2 +- .../mocks/middleware/CustomLabelReporter.go | 102 +++++++++++ internal/mocks/middleware/Reporter.go | 2 +- metrics/metrics.go | 6 + metrics/prometheus/prometheus.go | 39 +++- metrics/prometheus/prometheus_test.go | 36 +++- middleware/fasthttp/example_test.go | 89 +++++++++- middleware/middleware.go | 25 ++- middleware/middleware_test.go | 168 ++++++++++++------ 11 files changed, 395 insertions(+), 83 deletions(-) create mode 100644 internal/mocks/middleware/CustomLabelReporter.go diff --git a/go.mod b/go.mod index 69f130a..a2ca9b0 100644 --- a/go.mod +++ b/go.mod @@ -1,5 +1,7 @@ module github.com/slok/go-http-metrics +go 1.15 + require ( contrib.go.opencensus.io/exporter/prometheus v0.3.0 github.com/emicklei/go-restful v2.15.0+incompatible @@ -17,5 +19,3 @@ require ( go.opencensus.io v0.23.0 goji.io v2.0.2+incompatible ) - -go 1.15 diff --git a/internal/mocks/doc.go b/internal/mocks/doc.go index 74a6c32..ea51b4d 100644 --- a/internal/mocks/doc.go +++ b/internal/mocks/doc.go @@ -3,5 +3,6 @@ Package mocks will have all the mocks of the library. */ package mocks // import "github.com/slok/go-http-metrics/internal/mocks" -//go:generate mockery -output ./metrics -outpkg metrics -dir ../../metrics -name Recorder -//go:generate mockery -output ./middleware -outpkg middleware -dir ../../middleware -name Reporter +//go:generate mockery --output ./metrics --outpkg metrics --dir ../../metrics --name Recorder +//go:generate mockery --output ./middleware --outpkg middleware --dir ../../middleware --name Reporter +//go:generate mockery --output ./middleware --outpkg middleware --dir ../../middleware --name CustomLabelReporter diff --git a/internal/mocks/metrics/Recorder.go b/internal/mocks/metrics/Recorder.go index 2e04256..a297b2c 100644 --- a/internal/mocks/metrics/Recorder.go +++ b/internal/mocks/metrics/Recorder.go @@ -1,4 +1,4 @@ -// Code generated by mockery v1.0.0. DO NOT EDIT. +// Code generated by mockery v2.8.0. DO NOT EDIT. package metrics diff --git a/internal/mocks/middleware/CustomLabelReporter.go b/internal/mocks/middleware/CustomLabelReporter.go new file mode 100644 index 0000000..79141dd --- /dev/null +++ b/internal/mocks/middleware/CustomLabelReporter.go @@ -0,0 +1,102 @@ +// Code generated by mockery v2.8.0. DO NOT EDIT. + +package middleware + +import ( + context "context" + + mock "github.com/stretchr/testify/mock" +) + +// CustomLabelReporter is an autogenerated mock type for the CustomLabelReporter type +type CustomLabelReporter struct { + mock.Mock +} + +// BytesWritten provides a mock function with given fields: +func (_m *CustomLabelReporter) BytesWritten() int64 { + ret := _m.Called() + + var r0 int64 + if rf, ok := ret.Get(0).(func() int64); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(int64) + } + + return r0 +} + +// Context provides a mock function with given fields: +func (_m *CustomLabelReporter) Context() context.Context { + ret := _m.Called() + + var r0 context.Context + if rf, ok := ret.Get(0).(func() context.Context); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(context.Context) + } + } + + return r0 +} + +// CustomLabels provides a mock function with given fields: +func (_m *CustomLabelReporter) CustomLabels() []string { + ret := _m.Called() + + var r0 []string + if rf, ok := ret.Get(0).(func() []string); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } + } + + return r0 +} + +// Method provides a mock function with given fields: +func (_m *CustomLabelReporter) Method() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// StatusCode provides a mock function with given fields: +func (_m *CustomLabelReporter) StatusCode() int { + ret := _m.Called() + + var r0 int + if rf, ok := ret.Get(0).(func() int); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(int) + } + + return r0 +} + +// URLPath provides a mock function with given fields: +func (_m *CustomLabelReporter) URLPath() string { + ret := _m.Called() + + var r0 string + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} diff --git a/internal/mocks/middleware/Reporter.go b/internal/mocks/middleware/Reporter.go index 7ed1742..b618092 100644 --- a/internal/mocks/middleware/Reporter.go +++ b/internal/mocks/middleware/Reporter.go @@ -1,4 +1,4 @@ -// Code generated by mockery v1.0.0. DO NOT EDIT. +// Code generated by mockery v2.8.0. DO NOT EDIT. package middleware diff --git a/metrics/metrics.go b/metrics/metrics.go index f4fa46a..d17b5d5 100644 --- a/metrics/metrics.go +++ b/metrics/metrics.go @@ -16,6 +16,9 @@ type HTTPReqProperties struct { Method string // Code is the response of the request. Code string + + // CustomLabels hold values of the custom labels, if any. + CustomLabels []string } // HTTPProperties are the metric properties for the global server metrics. @@ -24,6 +27,9 @@ type HTTPProperties struct { Service string // ID is the id of the request handler. ID string + + // CustomLabels hold values of the custom labels, if any. + CustomLabels []string } // Recorder knows how to record and measure the metrics. This diff --git a/metrics/prometheus/prometheus.go b/metrics/prometheus/prometheus.go index 41e95b5..4ed0689 100644 --- a/metrics/prometheus/prometheus.go +++ b/metrics/prometheus/prometheus.go @@ -30,6 +30,9 @@ type Config struct { MethodLabel string // ServiceLabel is the name that will be set to the service label, by default is `service`. ServiceLabel string + + // CustomLabels hold names of the custom labels, if any. + CustomLabels []string } func (c *Config) defaults() { @@ -73,6 +76,24 @@ type recorder struct { func NewRecorder(cfg Config) metrics.Recorder { cfg.defaults() + perReqLabels := append( + []string{ + cfg.ServiceLabel, + cfg.HandlerIDLabel, + cfg.MethodLabel, + cfg.StatusCodeLabel, + }, + cfg.CustomLabels..., + ) + + serviceLabels := append( + []string{ + cfg.ServiceLabel, + cfg.HandlerIDLabel, + }, + cfg.CustomLabels..., + ) + r := &recorder{ httpRequestDurHistogram: prometheus.NewHistogramVec(prometheus.HistogramOpts{ Namespace: cfg.Prefix, @@ -80,7 +101,7 @@ func NewRecorder(cfg Config) metrics.Recorder { Name: "request_duration_seconds", Help: "The latency of the HTTP requests.", Buckets: cfg.DurationBuckets, - }, []string{cfg.ServiceLabel, cfg.HandlerIDLabel, cfg.MethodLabel, cfg.StatusCodeLabel}), + }, perReqLabels), httpResponseSizeHistogram: prometheus.NewHistogramVec(prometheus.HistogramOpts{ Namespace: cfg.Prefix, @@ -88,14 +109,14 @@ func NewRecorder(cfg Config) metrics.Recorder { Name: "response_size_bytes", Help: "The size of the HTTP responses.", Buckets: cfg.SizeBuckets, - }, []string{cfg.ServiceLabel, cfg.HandlerIDLabel, cfg.MethodLabel, cfg.StatusCodeLabel}), + }, perReqLabels), httpRequestsInflight: prometheus.NewGaugeVec(prometheus.GaugeOpts{ Namespace: cfg.Prefix, Subsystem: "http", Name: "requests_inflight", Help: "The number of inflight requests being handled at the same time.", - }, []string{cfg.ServiceLabel, cfg.HandlerIDLabel}), + }, serviceLabels), } cfg.Registry.MustRegister( @@ -108,13 +129,19 @@ func NewRecorder(cfg Config) metrics.Recorder { } func (r recorder) ObserveHTTPRequestDuration(_ context.Context, p metrics.HTTPReqProperties, duration time.Duration) { - r.httpRequestDurHistogram.WithLabelValues(p.Service, p.ID, p.Method, p.Code).Observe(duration.Seconds()) + lvs := []string{p.Service, p.ID, p.Method, p.Code} + lvs = append(lvs, p.CustomLabels...) + r.httpRequestDurHistogram.WithLabelValues(lvs...).Observe(duration.Seconds()) } func (r recorder) ObserveHTTPResponseSize(_ context.Context, p metrics.HTTPReqProperties, sizeBytes int64) { - r.httpResponseSizeHistogram.WithLabelValues(p.Service, p.ID, p.Method, p.Code).Observe(float64(sizeBytes)) + lvs := []string{p.Service, p.ID, p.Method, p.Code} + lvs = append(lvs, p.CustomLabels...) + r.httpResponseSizeHistogram.WithLabelValues(lvs...).Observe(float64(sizeBytes)) } func (r recorder) AddInflightRequests(_ context.Context, p metrics.HTTPProperties, quantity int) { - r.httpRequestsInflight.WithLabelValues(p.Service, p.ID).Add(float64(quantity)) + lvs := []string{p.Service, p.ID} + lvs = append(lvs, p.CustomLabels...) + r.httpRequestsInflight.WithLabelValues(lvs...).Add(float64(quantity)) } diff --git a/metrics/prometheus/prometheus_test.go b/metrics/prometheus/prometheus_test.go index bbc8be8..936d80d 100644 --- a/metrics/prometheus/prometheus_test.go +++ b/metrics/prometheus/prometheus_test.go @@ -159,7 +159,7 @@ func TestPrometheusRecorder(t *testing.T) { }, }, { - name: "Using a custom labels in the configuration should measure with those labels.", + name: "Using a custom label names in the configuration should measure with those labels.", config: libprometheus.Config{ HandlerIDLabel: "route_id", StatusCodeLabel: "status_code", @@ -186,12 +186,38 @@ func TestPrometheusRecorder(t *testing.T) { `http_request_duration_seconds_count{http_method="GET",http_service="svc1",route_id="test1",status_code="200"} 2`, }, }, + { + name: "Using a custom labels in the configuration should measure with those labels.", + config: libprometheus.Config{ + DurationBuckets: []float64{1, 10}, + CustomLabels: []string{"user_id"}, + }, + recordMetrics: func(r metrics.Recorder) { + r.ObserveHTTPRequestDuration(context.TODO(), metrics.HTTPReqProperties{ + Service: "svc1", + ID: "test1", + Method: http.MethodGet, + Code: "200", + CustomLabels: []string{"userVIP"}, + }, 6*time.Second) + r.AddInflightRequests(context.TODO(), metrics.HTTPProperties{ + Service: "svc1", + ID: "test1", + CustomLabels: []string{"userVIP"}, + }, 1) + }, + expMetrics: []string{ + `http_request_duration_seconds_bucket{code="200",handler="test1",method="GET",service="svc1",user_id="userVIP",le="1"} 0`, + `http_request_duration_seconds_bucket{code="200",handler="test1",method="GET",service="svc1",user_id="userVIP",le="10"} 1`, + `http_request_duration_seconds_bucket{code="200",handler="test1",method="GET",service="svc1",user_id="userVIP",le="+Inf"} 1`, + `http_request_duration_seconds_count{code="200",handler="test1",method="GET",service="svc1",user_id="userVIP"} 1`, + `http_requests_inflight{handler="test1",service="svc1",user_id="userVIP"} 1`, + }, + }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - assert := assert.New(t) - reg := prometheus.NewRegistry() test.config.Registry = reg mrecorder := libprometheus.NewRecorder(test.config) @@ -205,10 +231,10 @@ func TestPrometheusRecorder(t *testing.T) { resp := rec.Result() // Check all metrics are present. - if assert.Equal(http.StatusOK, resp.StatusCode) { + if assert.Equal(t, http.StatusOK, resp.StatusCode) { body, _ := ioutil.ReadAll(resp.Body) for _, expMetric := range test.expMetrics { - assert.Contains(string(body), expMetric, "metric not present on the result") + assert.Contains(t, string(body), expMetric, "metric not present on the result") } } }) diff --git a/middleware/fasthttp/example_test.go b/middleware/fasthttp/example_test.go index ac565e0..8cd1d9f 100644 --- a/middleware/fasthttp/example_test.go +++ b/middleware/fasthttp/example_test.go @@ -1,33 +1,73 @@ package fasthttp_test import ( + "context" + "fmt" "log" "net/http" + "github.com/fasthttp/router" "github.com/prometheus/client_golang/prometheus/promhttp" - metrics "github.com/slok/go-http-metrics/metrics/prometheus" + promMetrics "github.com/slok/go-http-metrics/metrics/prometheus" "github.com/slok/go-http-metrics/middleware" fasthttpMiddleware "github.com/slok/go-http-metrics/middleware/fasthttp" "github.com/valyala/fasthttp" ) +func handleHello(rCtx *fasthttp.RequestCtx) { + userID, ok := rCtx.UserValue("user_id").(string) + if !ok { + userID = "unknown" + } + + rCtx.SetStatusCode(fasthttp.StatusOK) + rCtx.SetBodyString(fmt.Sprintf("Hello, %s!", userID)) +} + // FasthttpMiddleware shows how you would create a default middleware // factory and use it to create a fasthttp compatible middleware. func Example_fasthttpMiddleware() { // Create our middleware factory with the default settings. mdlw := middleware.New(middleware.Config{ - Recorder: metrics.NewRecorder(metrics.Config{}), + Recorder: promMetrics.NewRecorder(promMetrics.Config{}), }) - // Add our handler and middleware - h := func(rCtx *fasthttp.RequestCtx) { - rCtx.SetStatusCode(fasthttp.StatusOK) - rCtx.SetBodyString("OK") + // Create our fasthttp instance. + srv := &fasthttp.Server{ + Handler: fasthttpMiddleware.Handler("", mdlw, handleHello), } - // Create our fasthttp instance. + // Serve metrics from the default prometheus registry. + log.Printf("serving metrics at: %s", ":8081") + go func() { + _ = http.ListenAndServe(":8081", promhttp.Handler()) + }() + + // Serve our handler. + log.Printf("listening at: %s", ":8080") + if err := srv.ListenAndServe(":8080"); err != nil { + log.Panicf("error while serving: %s", err) + } +} + +func Example_fasthttpCustomLabels() { + mdlw := middleware.New(middleware.Config{ + Recorder: promMetrics.NewRecorder(promMetrics.Config{ + CustomLabels: []string{"user_id"}, + }), + }) + + mux := router.New() + mux.GET("/{user_id}", + func(c *fasthttp.RequestCtx) { + mdlw.Measure("/hello", userIDReporter{c}, func() { + handleHello(c) + }) + }, + ) + srv := &fasthttp.Server{ - Handler: fasthttpMiddleware.Handler("", mdlw, h), + Handler: mux.Handler, } // Serve metrics from the default prometheus registry. @@ -42,3 +82,36 @@ func Example_fasthttpMiddleware() { log.Panicf("error while serving: %s", err) } } + +type userIDReporter struct { + c *fasthttp.RequestCtx +} + +func (r userIDReporter) Method() string { + return string(r.c.Method()) +} + +func (r userIDReporter) Context() context.Context { + return r.c +} + +func (r userIDReporter) URLPath() string { + return string(r.c.Path()) +} + +func (r userIDReporter) StatusCode() int { + return r.c.Response.StatusCode() +} + +func (r userIDReporter) BytesWritten() int64 { + return int64(len(r.c.Response.Body())) +} + +func (r userIDReporter) CustomLabels() []string { + userID, ok := r.c.UserValue("user_id").(string) + if !ok { + return nil + } + + return []string{userID} +} diff --git a/middleware/middleware.go b/middleware/middleware.go index 079e77c..b6d0b77 100644 --- a/middleware/middleware.go +++ b/middleware/middleware.go @@ -67,6 +67,11 @@ func New(cfg Config) Middleware { func (m Middleware) Measure(handlerID string, reporter Reporter, next func()) { ctx := reporter.Context() + var customLabels []string + if cr, ok := reporter.(CustomLabelReporter); ok { + customLabels = cr.CustomLabels() + } + // If there isn't predefined handler ID we // set that ID as the URL path. hid := handlerID @@ -77,9 +82,11 @@ func (m Middleware) Measure(handlerID string, reporter Reporter, next func()) { // Measure inflights if required. if !m.cfg.DisableMeasureInflight { props := metrics.HTTPProperties{ - Service: m.cfg.Service, - ID: hid, + Service: m.cfg.Service, + ID: hid, + CustomLabels: customLabels, } + m.cfg.Recorder.AddInflightRequests(ctx, props, 1) defer m.cfg.Recorder.AddInflightRequests(ctx, props, -1) } @@ -100,10 +107,11 @@ func (m Middleware) Measure(handlerID string, reporter Reporter, next func()) { } props := metrics.HTTPReqProperties{ - Service: m.cfg.Service, - ID: hid, - Method: reporter.Method(), - Code: code, + Service: m.cfg.Service, + ID: hid, + Method: reporter.Method(), + Code: code, + CustomLabels: customLabels, } m.cfg.Recorder.ObserveHTTPRequestDuration(ctx, props, duration) @@ -126,3 +134,8 @@ type Reporter interface { StatusCode() int BytesWritten() int64 } + +type CustomLabelReporter interface { + Reporter + CustomLabels() []string +} diff --git a/middleware/middleware_test.go b/middleware/middleware_test.go index 0e70e68..2255274 100644 --- a/middleware/middleware_test.go +++ b/middleware/middleware_test.go @@ -14,19 +14,23 @@ import ( ) func TestMiddlewareMeasure(t *testing.T) { - tests := map[string]struct { + tests := []struct { + name string handlerID string - config func() middleware.Config - mock func(mrec *mockmetrics.Recorder, mrep *mockmiddleware.Reporter) + config middleware.Config + recorder func() metrics.Recorder + setup func() (metrics.Recorder, middleware.Reporter, func(t *testing.T)) }{ - "Having default config with service, it should measure the metrics.": { + { + name: "Having default config with service, it should measure the metrics.", handlerID: "test01", - config: func() middleware.Config { - return middleware.Config{ - Service: "svc1", - } + config: middleware.Config{ + Service: "svc1", }, - mock: func(mrec *mockmetrics.Recorder, mrep *mockmiddleware.Reporter) { + setup: func() (metrics.Recorder, middleware.Reporter, func(t *testing.T)) { + mrec := &mockmetrics.Recorder{} + mrep := &mockmiddleware.Reporter{} + // Reporter mocks. mrep.On("Context").Once().Return(context.TODO()) mrep.On("StatusCode").Once().Return(418) @@ -41,15 +45,61 @@ func TestMiddlewareMeasure(t *testing.T) { mrec.On("AddInflightRequests", mock.Anything, expProps, -1).Once() mrec.On("ObserveHTTPRequestDuration", mock.Anything, expRepProps, mock.Anything).Once() mrec.On("ObserveHTTPResponseSize", mock.Anything, expRepProps, int64(42)).Once() + + return mrec, mrep, func(t *testing.T) { + mrec.AssertExpectations(t) + mrep.AssertExpectations(t) + } }, }, + { + name: "Custom labels should work", + handlerID: "test01", + config: middleware.Config{ + Service: "svc1", + }, + setup: func() (metrics.Recorder, middleware.Reporter, func(t *testing.T)) { + mrec := &mockmetrics.Recorder{} + mrep := &mockmiddleware.CustomLabelReporter{} - "Without having handler ID, it should measure the metrics using the request path.": { - handlerID: "", - config: func() middleware.Config { - return middleware.Config{} + mrep.On("Context").Once().Return(context.TODO()) + mrep.On("StatusCode").Once().Return(418) + mrep.On("Method").Once().Return("PATCH") + mrep.On("BytesWritten").Once().Return(int64(42)) + mrep.On("CustomLabels").Once().Return([]string{"user_VIP"}) + + expProps := metrics.HTTPProperties{ + Service: "svc1", + ID: "test01", + CustomLabels: []string{"user_VIP"}, + } + expRepProps := metrics.HTTPReqProperties{ + Service: "svc1", + ID: "test01", + Method: "PATCH", + Code: "418", + CustomLabels: []string{"user_VIP"}, + } + + mrec.On("AddInflightRequests", mock.Anything, expProps, 1).Once() + mrec.On("AddInflightRequests", mock.Anything, expProps, -1).Once() + mrec.On("ObserveHTTPRequestDuration", mock.Anything, expRepProps, mock.Anything).Once() + mrec.On("ObserveHTTPResponseSize", mock.Anything, expRepProps, int64(42)).Once() + + return mrec, mrep, func(t *testing.T) { + mrec.AssertExpectations(t) + mrep.AssertExpectations(t) + } }, - mock: func(mrec *mockmetrics.Recorder, mrep *mockmiddleware.Reporter) { + }, + { + name: "Without having handler ID, it should measure the metrics using the request path.", + handlerID: "", + config: middleware.Config{}, + setup: func() (metrics.Recorder, middleware.Reporter, func(t *testing.T)) { + mrec := &mockmetrics.Recorder{} + mrep := &mockmiddleware.Reporter{} + // Reporter mocks. mrep.On("URLPath").Once().Return("/test/01") mrep.On("Context").Once().Return(context.TODO()) @@ -64,17 +114,23 @@ func TestMiddlewareMeasure(t *testing.T) { mrec.On("AddInflightRequests", mock.Anything, mock.Anything, mock.Anything).Once() mrec.On("ObserveHTTPRequestDuration", mock.Anything, expRepProps, mock.Anything).Once() mrec.On("ObserveHTTPResponseSize", mock.Anything, expRepProps, mock.Anything).Once() + + return mrec, mrep, func(t *testing.T) { + mrec.AssertExpectations(t) + mrep.AssertExpectations(t) + } }, }, - - "Having grouped status code, it should measure the metrics using grouped status codes.": { + { + name: "Having grouped status code, it should measure the metrics using grouped status codes.", handlerID: "test01", - config: func() middleware.Config { - return middleware.Config{ - GroupedStatus: true, - } + config: middleware.Config{ + GroupedStatus: true, }, - mock: func(mrec *mockmetrics.Recorder, mrep *mockmiddleware.Reporter) { + setup: func() (metrics.Recorder, middleware.Reporter, func(t *testing.T)) { + mrec := &mockmetrics.Recorder{} + mrep := &mockmiddleware.Reporter{} + // Reporter mocks. mrep.On("Context").Once().Return(context.TODO()) mrep.On("StatusCode").Once().Return(418) @@ -88,17 +144,23 @@ func TestMiddlewareMeasure(t *testing.T) { mrec.On("AddInflightRequests", mock.Anything, mock.Anything, mock.Anything).Once() mrec.On("ObserveHTTPRequestDuration", mock.Anything, expRepProps, mock.Anything).Once() mrec.On("ObserveHTTPResponseSize", mock.Anything, expRepProps, mock.Anything).Once() + + return mrec, mrep, func(t *testing.T) { + mrec.AssertExpectations(t) + mrep.AssertExpectations(t) + } }, }, - - "Disabling inflight requests measuring, it shouldn't measure inflight metrics.": { + { + name: "Disabling inflight requests measuring, it shouldn't measure inflight metrics.", handlerID: "test01", - config: func() middleware.Config { - return middleware.Config{ - DisableMeasureInflight: true, - } + config: middleware.Config{ + DisableMeasureInflight: true, }, - mock: func(mrec *mockmetrics.Recorder, mrep *mockmiddleware.Reporter) { + setup: func() (metrics.Recorder, middleware.Reporter, func(t *testing.T)) { + mrec := &mockmetrics.Recorder{} + mrep := &mockmiddleware.Reporter{} + // Reporter mocks. mrep.On("Context").Once().Return(context.TODO()) mrep.On("StatusCode").Once().Return(418) @@ -110,17 +172,23 @@ func TestMiddlewareMeasure(t *testing.T) { mrec.On("ObserveHTTPRequestDuration", mock.Anything, expRepProps, mock.Anything).Once() mrec.On("ObserveHTTPResponseSize", mock.Anything, expRepProps, mock.Anything).Once() + + return mrec, mrep, func(t *testing.T) { + mrec.AssertExpectations(t) + mrep.AssertExpectations(t) + } }, }, - - "Disabling size measuring, it shouldn't measure size metrics.": { + { + name: "Disabling size measuring, it shouldn't measure size metrics.", handlerID: "test01", - config: func() middleware.Config { - return middleware.Config{ - DisableMeasureSize: true, - } + config: middleware.Config{ + DisableMeasureSize: true, }, - mock: func(mrec *mockmetrics.Recorder, mrep *mockmiddleware.Reporter) { + setup: func() (metrics.Recorder, middleware.Reporter, func(t *testing.T)) { + mrec := &mockmetrics.Recorder{} + mrep := &mockmiddleware.Reporter{} + // Reporter mocks. mrep.On("Context").Once().Return(context.TODO()) mrep.On("StatusCode").Once().Return(418) @@ -132,31 +200,27 @@ func TestMiddlewareMeasure(t *testing.T) { mrec.On("AddInflightRequests", mock.Anything, mock.Anything, mock.Anything).Once() mrec.On("AddInflightRequests", mock.Anything, mock.Anything, mock.Anything).Once() mrec.On("ObserveHTTPRequestDuration", mock.Anything, expRepProps, mock.Anything).Once() + + return mrec, mrep, func(t *testing.T) { + mrec.AssertExpectations(t) + mrep.AssertExpectations(t) + } }, }, } - for name, test := range tests { - t.Run(name, func(t *testing.T) { - assert := assert.New(t) - - // Mocks. - mrec := &mockmetrics.Recorder{} - mrep := &mockmiddleware.Reporter{} - test.mock(mrec, mrep) + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + mrec, mrep, cleanup := tc.setup() - // Execute. - config := test.config() - config.Recorder = mrec // Set mocked recorder. - mdlw := middleware.New(config) + tc.config.Recorder = mrec + mdlw := middleware.New(tc.config) calledNext := false - mdlw.Measure(test.handlerID, mrep, func() { calledNext = true }) + mdlw.Measure(tc.handlerID, mrep, func() { calledNext = true }) - // Check. - mrec.AssertExpectations(t) - mrep.AssertExpectations(t) - assert.True(calledNext) + cleanup(t) + assert.True(t, calledNext) }) } }