Skip to content

Commit

Permalink
Merge pull request #1 from ivov/drop-chi
Browse files Browse the repository at this point in the history
refactor: Replace `chi` with Go standard library
  • Loading branch information
ivov authored Oct 16, 2024
2 parents 68bdc00 + 183a8b9 commit 66fe512
Show file tree
Hide file tree
Showing 7 changed files with 128 additions and 65 deletions.
1 change: 0 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ go 1.22.5
require (
github.com/felixge/httpsnoop v1.0.4
github.com/getsentry/sentry-go v0.28.1
github.com/go-chi/chi/v5 v5.1.0
github.com/golang-migrate/migrate/v4 v4.17.1
github.com/jmoiron/sqlx v1.4.0
github.com/mattn/go-sqlite3 v1.14.22
Expand Down
2 changes: 0 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@ github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
github.com/getsentry/sentry-go v0.28.1 h1:zzaSm/vHmGllRM6Tpx1492r0YDzauArdBfkJRtY6P5k=
github.com/getsentry/sentry-go v0.28.1/go.mod h1:1fQZ+7l7eeJ3wYi82q5Hg8GqAPgefRq+FP/QhafYVgg=
github.com/go-chi/chi/v5 v5.1.0 h1:acVI1TYaD+hhedDJ3r54HyA6sExp3HfXq7QWEEY/xMw=
github.com/go-chi/chi/v5 v5.1.0/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8=
github.com/go-errors/errors v1.4.2 h1:J6MZopCL4uSllY1OfXM374weqZFFItUbrImctkmUxIA=
github.com/go-errors/errors v1.4.2/go.mod h1:sIVyrIiJhuEF+Pj9Ebtd6P/rEYROXFi3BopGUQ5a5Og=
github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y=
Expand Down
29 changes: 14 additions & 15 deletions internal/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"net/http"
"sync"

"github.com/go-chi/chi/v5"
"github.com/ivov/n8n-shortlink/internal"
"github.com/ivov/n8n-shortlink/internal/config"
"github.com/ivov/n8n-shortlink/internal/log"
Expand All @@ -24,34 +23,34 @@ type API struct {

// Routes sets up middleware and routes on the API.
func (api *API) Routes() http.Handler {
r := chi.NewRouter()

api.SetupMiddleware(r)
r := http.NewServeMux()

static := internal.Static()

fileServer := http.FileServer(http.FS(static))
r.Handle("/static/*", http.StripPrefix("/static", fileServer))
r.Handle("GET /static/*", http.StripPrefix("/static", fileServer))

// /static/canvas.tmpl.html and /static/swagger.html blocked by reverse proxy

r.Get("/health", api.HandleGetHealth)
r.Get("/debug/vars", expvar.Handler().ServeHTTP) // blocked by reverse proxy
r.Get("/metrics", api.HandleGetMetrics) // blocked by reverse proxy
r.Get("/docs", func(w http.ResponseWriter, r *http.Request) {
r.HandleFunc("GET /health", api.HandleGetHealth)
r.HandleFunc("GET /debug/vars", expvar.Handler().ServeHTTP) // blocked by reverse proxy
r.HandleFunc("GET /metrics", api.HandleGetMetrics) // blocked by reverse proxy
r.HandleFunc("GET /docs", func(w http.ResponseWriter, r *http.Request) {
http.ServeFileFS(w, r, static, "swagger.html")
})
r.Get("/spec", func(w http.ResponseWriter, r *http.Request) {
r.HandleFunc("GET /spec", func(w http.ResponseWriter, r *http.Request) {
http.ServeFile(w, r, "openapi.yml")
})

r.Post("/shortlink", api.HandlePostShortlink)
r.Get("/{slug}/view", api.HandleGetSlug)
r.Get("/{slug}", api.HandleGetSlug)
r.HandleFunc("POST /shortlink", api.HandlePostShortlink)
r.HandleFunc("GET /{slug}/view", api.HandleGetSlug)
r.HandleFunc("GET /{slug}", api.HandleGetSlug)

r.Get("/", func(w http.ResponseWriter, r *http.Request) {
r.HandleFunc("GET /", func(w http.ResponseWriter, r *http.Request) {
http.ServeFileFS(w, r, static, "index.html")
})

return r
mw := api.SetupMiddleware()

return mw(r)
}
7 changes: 1 addition & 6 deletions internal/api/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ func TestAPI(t *testing.T) {
assert.Contains(t, string(body), "total_responses_sent")
assert.Contains(t, string(body), "in_flight_requests")
assert.Contains(t, string(body), "total_processing_time_ms")
assert.Contains(t, string(body), "responses_sent_by_status")
assert.Contains(t, string(body), "total_responses_sent_by_status")
})
})

Expand All @@ -231,7 +231,6 @@ func TestAPI(t *testing.T) {
// ------------------------

t.Run("base use cases", func(t *testing.T) {

t.Run("should create URL shortlink and redirect on retrieval", func(t *testing.T) {
candidate := entities.Shortlink{
Kind: "url",
Expand Down Expand Up @@ -341,7 +340,6 @@ func TestAPI(t *testing.T) {
// ------------------------

t.Run("custom slug", func(t *testing.T) {

t.Run("should create custom-slug shortlink and redirect on retrieval", func(t *testing.T) {
candidate := entities.Shortlink{
Slug: "my-custom-slug",
Expand Down Expand Up @@ -369,7 +367,6 @@ func TestAPI(t *testing.T) {
// ------------------------

t.Run("creation payload validation", func(t *testing.T) {

t.Run("should reject on invalid creation payload", func(t *testing.T) {
testCases := []struct {
name string
Expand Down Expand Up @@ -510,7 +507,6 @@ func TestAPI(t *testing.T) {
})

t.Run("rate limiting", func(t *testing.T) {

t.Run("should enforce rate limiting", func(t *testing.T) {
// enable rate limiting only for this test
originalConfig := *api.Config
Expand Down Expand Up @@ -559,7 +555,6 @@ func TestAPI(t *testing.T) {
})

t.Run("password protection", func(t *testing.T) {

t.Run("should store password-protected shortlink", func(t *testing.T) {
candidate := entities.Shortlink{
Kind: "url",
Expand Down
2 changes: 1 addition & 1 deletion internal/api/handle_get_metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ var (
})
responsesSentByStatus = prometheus.NewCounterVec(
prometheus.CounterOpts{
Name: "responses_sent_by_status",
Name: "total_responses_sent_by_status",
Help: "Total responses sent by HTTP status code",
},
[]string{"status"},
Expand Down
3 changes: 1 addition & 2 deletions internal/api/handle_get_{slug}.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,13 @@ import (
"net/http"
"strings"

"github.com/go-chi/chi/v5"
"github.com/ivov/n8n-shortlink/internal"
"github.com/ivov/n8n-shortlink/internal/errors"
)

// HandleGetSlug handles a GET /{slug} request by resolving a regular shortlink.
func (api *API) HandleGetSlug(w http.ResponseWriter, r *http.Request) {
slug := chi.URLParam(r, "slug")
slug := r.PathValue("slug")

shortlink, err := api.ShortlinkService.GetBySlug(slug)
if err != nil {
Expand Down
149 changes: 111 additions & 38 deletions internal/api/middleware.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package api

import (
"crypto/rand"
"encoding/hex"
"expvar"
"fmt"
"net/http"
Expand All @@ -10,25 +12,86 @@ import (
"time"

"github.com/felixge/httpsnoop"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
"github.com/ivov/n8n-shortlink/internal/log"
"github.com/tomasen/realip"
"golang.org/x/time/rate"
)

// ------------
// setup
// ------------

// MiddlewareFn is a function that wraps an http.Handler.
type MiddlewareFn func(http.Handler) http.Handler

func createStack(fns ...MiddlewareFn) MiddlewareFn {
return func(next http.Handler) http.Handler {
for _, fn := range fns {
next = fn(next)
}
return next
}
}

// SetupMiddleware sets up all middleware on the API.
func (api *API) SetupMiddleware(r *chi.Mux) {
r.Use(api.recoverPanic)
r.Use(middleware.RequestID)
r.Use(middleware.CleanPath)
r.Use(middleware.StripSlashes)
r.Use(api.addRelaxedCorsHeaders)
r.Use(api.addSecurityHeaders)
r.Use(api.addCacheHeadersForStaticFiles)
r.Use(api.rateLimit)
r.Use(api.logRequest)
r.Use(api.metrics)
func (api *API) SetupMiddleware() MiddlewareFn {
return createStack(
api.setRequestID,
api.logRequest,
api.recoverPanic,
api.addRelaxedCorsHeaders,
api.addSecurityHeaders,
api.addCacheHeadersForStaticFiles,
api.rateLimit,
api.metrics,
)
}

// ------------
// headers
// ------------

func (api *API) setRequestID(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
id := r.Header.Get("X-Request-ID")
if id == "" {
id = generateRequestID()
r.Header.Set("X-Request-ID", id)
}
next.ServeHTTP(w, r)
})
}

func generateRequestID() string {
bytes := make([]byte, 16) // 128 bits
_, err := rand.Read(bytes)
if err != nil {
return "unknown"
}
return hex.EncodeToString(bytes)
}

func (api *API) addRelaxedCorsHeaders(next http.Handler) http.Handler {
return http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Referer, User-Agent")

next.ServeHTTP(w, r)
},
)
}

func (api *API) addSecurityHeaders(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Frame-Options", "DENY")
w.Header().Set("X-Content-Type-Options", "nosniff")
w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
w.Header().Set("Strict-Transport-Security", "max-age=31536000")

next.ServeHTTP(w, r)
})
}

func (api *API) addCacheHeadersForStaticFiles(next http.Handler) http.Handler {
Expand All @@ -40,6 +103,10 @@ func (api *API) addCacheHeadersForStaticFiles(next http.Handler) http.Handler {
})
}

// ------------
// metrics
// ------------

func (api *API) metrics(next http.Handler) http.Handler {
totalRequestsReceived := expvar.NewInt("total_requests_received")
totalResponsesSent := expvar.NewInt("total_responses_sent")
Expand All @@ -57,16 +124,9 @@ func (api *API) metrics(next http.Handler) http.Handler {
})
}

func (api *API) addSecurityHeaders(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Frame-Options", "DENY")
w.Header().Set("X-Content-Type-Options", "nosniff")
w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
w.Header().Set("Strict-Transport-Security", "max-age=31536000")

next.ServeHTTP(w, r)
})
}
// ------------
// recover
// ------------

func (api *API) recoverPanic(next http.Handler) http.Handler {
return http.HandlerFunc(
Expand All @@ -84,6 +144,27 @@ func (api *API) recoverPanic(next http.Handler) http.Handler {
)
}

// ------------
// logging
// ------------

type wrappedResponseWriter struct {
http.ResponseWriter
statusCode int
bytesWritten int
}

func (rw *wrappedResponseWriter) WriteHeader(statusCode int) {
rw.ResponseWriter.WriteHeader(statusCode)
rw.statusCode = statusCode
}

func (rw *wrappedResponseWriter) Write(b []byte) (int, error) {
bytes, err := rw.ResponseWriter.Write(b)
rw.bytesWritten += bytes
return bytes, err
}

func isLogIgnored(path string) bool {
for _, denyExact := range []string{"/health", "/favicon.ico", "/metrics"} {
if path == denyExact || strings.HasPrefix(path, "/static/") {
Expand All @@ -101,7 +182,7 @@ func (api *API) logRequest(next http.Handler) http.Handler {
return
}

wrappedWriter := middleware.NewWrapResponseWriter(w, r.ProtoMajor)
wrappedWriter := &wrappedResponseWriter{ResponseWriter: w, statusCode: http.StatusOK}

start := time.Now()

Expand All @@ -121,9 +202,9 @@ func (api *API) logRequest(next http.Handler) http.Handler {
Path: r.URL.Path,
QueryString: r.URL.Query().Encode(),
Latency: int(time.Duration(time.Since(start).Milliseconds())),
Status: wrappedWriter.Status(),
SizeBytes: wrappedWriter.BytesWritten(),
RequestID: middleware.GetReqID(r.Context()),
Status: wrappedWriter.statusCode,
SizeBytes: wrappedWriter.bytesWritten,
RequestID: r.Header.Get("X-Request-ID"),
}

api.Logger.Info(
Expand All @@ -144,17 +225,9 @@ func (api *API) logRequest(next http.Handler) http.Handler {
)
}

func (api *API) addRelaxedCorsHeaders(next http.Handler) http.Handler {
return http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Referer, User-Agent")

next.ServeHTTP(w, r)
},
)
}
// ------------
// r-limit
// ------------

var exemptedFromRateLimiting = []string{"/", "/docs"}

Expand Down

0 comments on commit 66fe512

Please sign in to comment.