From 01a58abcff9b01434b57af8f77cea337e86d3947 Mon Sep 17 00:00:00 2001 From: Cory Koch Date: Sun, 10 Nov 2024 23:08:05 -0500 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=A5=20feat:=20Add=20Support=20for=20Re?= =?UTF-8?q?moving=20Routes=20(#3230)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add new methods named RemoveRoute and RemoveRouteByName. * Update register method to prevent duplicate routes. * Clean up tests * Update docs --- docs/api/app.md | 53 +++++++++ docs/whats_new.md | 8 ++ router.go | 76 ++++++++++++- router_test.go | 269 +++++++++++++++++++++++++++++++++++++++++++--- 4 files changed, 392 insertions(+), 14 deletions(-) diff --git a/docs/api/app.md b/docs/api/app.md index 6582159c0f..955dbd143b 100644 --- a/docs/api/app.md +++ b/docs/api/app.md @@ -761,3 +761,56 @@ func main() { ``` In this example, a new route is defined and then `RebuildTree()` is called to ensure the new route is registered and available. + + +## RemoveRoute + +This method removes a route by path. You must call the `RebuildTree()` method after the remove in to ensure the route is removed. + +```go title="Signature" +func (app *App) RemoveRoute(path string, methods ...string) +``` + +This method removes a route by name +```go title="Signature" +func (app *App) RemoveRouteByName(name string, methods ...string) +``` + +```go title="Example" +package main + +import ( + "log" + + "github.com/gofiber/fiber/v3" +) + +func main() { + app := fiber.New() + + app.Get("/api/feature-a", func(c Ctx) error { + app.RemoveRoute("/api/feature", MethodGet) + app.RebuildTree() + // Redefine route + app.Get("/api/feature", func(c Ctx) error { + return c.SendString("Testing feature-a") + }) + + app.RebuildTree() + return c.SendStatus(http.StatusOK) + }) + app.Get("/api/feature-b", func(c Ctx) error { + app.RemoveRoute("/api/feature", MethodGet) + app.RebuildTree() + // Redefine route + app.Get("/api/feature", func(c Ctx) error { + return c.SendString("Testing feature-b") + }) + + app.RebuildTree() + return c.SendStatus(http.StatusOK) + }) + + log.Fatal(app.Listen(":3000")) +} +``` diff --git a/docs/whats_new.md b/docs/whats_new.md index bfc6f25c29..e0575dd8c7 100644 --- a/docs/whats_new.md +++ b/docs/whats_new.md @@ -941,6 +941,14 @@ In this example, a new route is defined, and `RebuildTree()` is called to ensure Note: Use this method with caution. It is **not** thread-safe and can be very performance-intensive. Therefore, it should be used sparingly and primarily in development mode. It should not be invoke concurrently. +## RemoveRoute + +- **RemoveRoute**: Removes route by path + +- **RemoveRouteByName**: Removes route by name + +For more details, refer to the [app documentation](./api/app.md#removeroute): + ### 🧠 Context Fiber v3 introduces several new features and changes to the Ctx interface, enhancing its functionality and flexibility. diff --git a/router.go b/router.go index 2091cfc6cb..cfe83adf87 100644 --- a/router.go +++ b/router.go @@ -8,6 +8,7 @@ import ( "errors" "fmt" "html" + "slices" "sort" "strings" "sync/atomic" @@ -302,6 +303,13 @@ func (app *App) register(methods []string, pathRaw string, group *Group, handler if method != methodUse && app.methodInt(method) == -1 { panic(fmt.Sprintf("add: invalid http method %s\n", method)) } + + // Duplicate Route Handling + if app.routeExists(method, pathRaw) { + matchPathFunc := func(r *Route) bool { return r.Path == pathRaw } + app.deleteRoute([]string{method}, matchPathFunc) + } + // is mounted app isMount := group != nil && group.app != app // A route requires atleast one ctx handler @@ -375,6 +383,72 @@ func (app *App) register(methods []string, pathRaw string, group *Group, handler } } +func (app *App) routeExists(method, pathRaw string) bool { + pathToCheck := pathRaw + if !app.config.CaseSensitive { + pathToCheck = utils.ToLower(pathToCheck) + } + + return slices.ContainsFunc(app.stack[app.methodInt(method)], func(r *Route) bool { + routePath := r.path + if !app.config.CaseSensitive { + routePath = utils.ToLower(routePath) + } + + return routePath == pathToCheck + }) +} + +// RemoveRoute is used to remove a route from the stack by path. +// This only needs to be called to remove a route, route registration prevents duplicate routes. +// You should call RebuildTree after using this to ensure consistency of the tree. +func (app *App) RemoveRoute(path string, methods ...string) { + pathMatchFunc := func(r *Route) bool { return r.Path == path } + app.deleteRoute(methods, pathMatchFunc) +} + +// RemoveRouteByName is used to remove a route from the stack by name. +// This only needs to be called to remove a route, route registration prevents duplicate routes. +// You should call RebuildTree after using this to ensure consistency of the tree. +func (app *App) RemoveRouteByName(name string, methods ...string) { + matchFunc := func(r *Route) bool { return r.Name == name } + app.deleteRoute(methods, matchFunc) +} + +func (app *App) deleteRoute(methods []string, matchFunc func(r *Route) bool) { + app.mutex.Lock() + defer app.mutex.Unlock() + + for _, method := range methods { + // Uppercase HTTP methods + method = utils.ToUpper(method) + + // Get unique HTTP method identifier + m := app.methodInt(method) + if m == -1 { + continue // Skip invalid HTTP methods + } + + // Find the index of the route to remove + index := slices.IndexFunc(app.stack[m], matchFunc) + if index == -1 { + continue // Route not found + } + + route := app.stack[m][index] + + // Decrement global handler count + atomic.AddUint32(&app.handlersCount, ^uint32(len(route.Handlers)-1)) //nolint:gosec // Not a concern + // Decrement global route position + atomic.AddUint32(&app.routesCount, ^uint32(0)) + + // Remove route from tree stack + app.stack[m] = slices.Delete(app.stack[m], index, index+1) + } + + app.routesRefreshed = true +} + func (app *App) addRoute(method string, route *Route, isMounted ...bool) { app.mutex.Lock() defer app.mutex.Unlock() @@ -415,7 +489,7 @@ func (app *App) addRoute(method string, route *Route, isMounted ...bool) { // This method is useful when you want to register routes dynamically after the app has started. // It is not recommended to use this method on production environments because rebuilding // the tree is performance-intensive and not thread-safe in runtime. Since building the tree -// is only done in the startupProcess of the app, this method does not makes sure that the +// is only done in the startupProcess of the app, this method does not make sure that the // routeTree is being safely changed, as it would add a great deal of overhead in the request. // Latest benchmark results showed a degradation from 82.79 ns/op to 94.48 ns/op and can be found in: // https://github.com/gofiber/fiber/issues/2769#issuecomment-2227385283 diff --git a/router_test.go b/router_test.go index 5509039c66..a1d9c4da5f 100644 --- a/router_test.go +++ b/router_test.go @@ -12,6 +12,10 @@ import ( "net/http" "net/http/httptest" "os" + "reflect" + "runtime" + "strings" + "sync" "testing" "github.com/gofiber/utils/v2" @@ -369,31 +373,270 @@ func Test_Router_NotFound_HTML_Inject(t *testing.T) { require.Equal(t, "Cannot DELETE /does/not/exist<script>alert('foo');</script>", string(c.Response.Body())) } +func registerTreeManipulationRoutes(app *App, middleware ...func(Ctx) error) { + app.Get("/test", func(c Ctx) error { + app.Get("/dynamically-defined", func(c Ctx) error { + return c.SendStatus(StatusOK) + }) + + app.RebuildTree() + + return c.SendStatus(StatusOK) + }, middleware...) +} + +func verifyRequest(tb testing.TB, app *App, path string, expectedStatus int) *http.Response { + tb.Helper() + + resp, err := app.Test(httptest.NewRequest(MethodGet, path, nil)) + require.NoError(tb, err, "app.Test(req)") + require.Equal(tb, expectedStatus, resp.StatusCode, "Status code") + + return resp +} + +func verifyRouteHandlerCounts(tb testing.TB, app *App, expectedRoutesCount int) { + tb.Helper() + + // this is taken from listen.go's printRoutesMessage app method + var routes []RouteMessage + for _, routeStack := range app.stack { + for _, route := range routeStack { + routeMsg := RouteMessage{ + name: route.Name, + method: route.Method, + path: route.Path, + } + + for _, handler := range route.Handlers { + routeMsg.handlers += runtime.FuncForPC(reflect.ValueOf(handler).Pointer()).Name() + " " + } + + routes = append(routes, routeMsg) + } + } + + for _, route := range routes { + require.Equal(tb, expectedRoutesCount, strings.Count(route.handlers, " ")) + } +} + +func verifyThereAreNoRoutes(tb testing.TB, app *App) { + tb.Helper() + + require.Equal(tb, uint32(0), app.handlersCount) + require.Equal(tb, uint32(0), app.routesCount) + verifyRouteHandlerCounts(tb, app, 0) +} + func Test_App_Rebuild_Tree(t *testing.T) { t.Parallel() app := New() - app.Get("/test", func(c Ctx) error { - app.Get("/dynamically-defined", func(c Ctx) error { - return c.SendStatus(http.StatusOK) + registerTreeManipulationRoutes(app) + + verifyRequest(t, app, "/dynamically-defined", StatusNotFound) + verifyRequest(t, app, "/test", StatusOK) + verifyRequest(t, app, "/dynamically-defined", StatusOK) +} + +func Test_App_Remove_Route_A_B_Feature_Testing(t *testing.T) { + t.Parallel() + app := New() + + app.Get("/api/feature-a", func(c Ctx) error { + app.RemoveRoute("/api/feature", MethodGet) + app.RebuildTree() + // Redefine route + app.Get("/api/feature", func(c Ctx) error { + return c.SendString("Testing feature-a") }) app.RebuildTree() + return c.SendStatus(StatusOK) + }) + app.Get("/api/feature-b", func(c Ctx) error { + app.RemoveRoute("/api/feature", MethodGet) + app.RebuildTree() + // Redefine route + app.Get("/api/feature", func(c Ctx) error { + return c.SendString("Testing feature-b") + }) - return c.SendStatus(http.StatusOK) + app.RebuildTree() + return c.SendStatus(StatusOK) }) - resp, err := app.Test(httptest.NewRequest(MethodGet, "/dynamically-defined", nil)) - require.NoError(t, err, "app.Test(req)") - require.Equal(t, http.StatusNotFound, resp.StatusCode, "Status code") + verifyRequest(t, app, "/api/feature-a", StatusOK) - resp, err = app.Test(httptest.NewRequest(MethodGet, "/test", nil)) - require.NoError(t, err, "app.Test(req)") - require.Equal(t, http.StatusOK, resp.StatusCode, "Status code") + resp := verifyRequest(t, app, "/api/feature", StatusOK) + require.Equal(t, "Testing feature-a", resp, "Response Message") - resp, err = app.Test(httptest.NewRequest(MethodGet, "/dynamically-defined", nil)) - require.NoError(t, err, "app.Test(req)") - require.Equal(t, http.StatusOK, resp.StatusCode, "Status code") + resp = verifyRequest(t, app, "/api/feature-b", StatusOK) + require.Equal(t, "Testing feature-b", resp, "Response Message") +} + +func Test_App_Remove_Route_By_Name(t *testing.T) { + t.Parallel() + app := New() + + app.Get("/api/test", func(c Ctx) error { + return c.SendStatus(StatusOK) + }).Name("test") + + app.RemoveRouteByName("test", MethodGet) + app.RebuildTree() + + verifyRequest(t, app, "/test", StatusNotFound) + verifyThereAreNoRoutes(t, app) +} + +func Test_App_Remove_Route_By_Name_Non_Existing_Route(t *testing.T) { + t.Parallel() + app := New() + + app.RemoveRouteByName("test", MethodGet) + app.RebuildTree() + + verifyThereAreNoRoutes(t, app) +} + +func Test_App_Remove_Route_Nested(t *testing.T) { + t.Parallel() + app := New() + + api := app.Group("/api") + + v1 := api.Group("/v1") + v1.Get("/test", func(c Ctx) error { + return c.SendStatus(StatusOK) + }) + + verifyRequest(t, app, "/api/v1/test", StatusOK) + app.RemoveRoute("/api/v1/test", MethodGet) + + verifyThereAreNoRoutes(t, app) +} + +func Test_App_Remove_Route_Parameterized(t *testing.T) { + t.Parallel() + app := New() + + app.Get("/test/:id", func(c Ctx) error { + return c.SendStatus(StatusOK) + }) + verifyRequest(t, app, "/test/:id", StatusOK) + app.RemoveRoute("/test/:id", MethodGet) + + verifyThereAreNoRoutes(t, app) +} + +func Test_App_Remove_Route(t *testing.T) { + t.Parallel() + app := New() + + app.Get("/test", func(c Ctx) error { + return c.SendStatus(StatusOK) + }) + + app.RemoveRoute("/test", MethodGet) + app.RebuildTree() + + verifyRequest(t, app, "/test", StatusNotFound) +} + +func Test_App_Remove_Route_Non_Existing_Route(t *testing.T) { + t.Parallel() + app := New() + + app.RemoveRoute("/test", MethodGet, MethodHead) + app.RebuildTree() + + verifyThereAreNoRoutes(t, app) +} + +func Test_App_Remove_Route_Concurrent(t *testing.T) { + t.Parallel() + app := New() + + // Add test route + app.Get("/test", func(c Ctx) error { + return c.SendStatus(StatusOK) + }) + + // Concurrently remove and add routes + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + app.RemoveRoute("/test", MethodGet) + app.Get("/test", func(c Ctx) error { + return c.SendStatus(StatusOK) + }) + }() + } + wg.Wait() + + // Verify final state + app.RebuildTree() + verifyRequest(t, app, "/test", StatusOK) +} + +func Test_App_Route_Registration_Prevent_Duplicate(t *testing.T) { + t.Parallel() + app := New() + + registerTreeManipulationRoutes(app) + registerTreeManipulationRoutes(app) + + verifyRequest(t, app, "/dynamically-defined", StatusNotFound) + require.Equal(t, uint32(1), app.handlersCount) + + verifyRequest(t, app, "/test", StatusOK) + require.Equal(t, uint32(2), app.handlersCount) + + verifyRequest(t, app, "/dynamically-defined", StatusOK) + require.Equal(t, uint32(2), app.handlersCount) + + verifyRequest(t, app, "/test", StatusOK) + require.Equal(t, uint32(2), app.handlersCount) + + verifyRequest(t, app, "/dynamically-defined", StatusOK) + require.Equal(t, uint32(2), app.handlersCount) + require.Equal(t, uint32(2), app.routesCount) + + verifyRouteHandlerCounts(t, app, 1) +} + +func Test_Route_Registration_Prevent_Duplicate_With_Middleware(t *testing.T) { + t.Parallel() + app := New() + + middleware := func(c Ctx) error { + return c.Next() + } + + registerTreeManipulationRoutes(app, middleware) + registerTreeManipulationRoutes(app) + + verifyRequest(t, app, "/dynamically-defined", StatusNotFound) + require.Equal(t, uint32(2), app.handlersCount) + + verifyRequest(t, app, "/test", StatusOK) + require.Equal(t, uint32(3), app.handlersCount) + + verifyRequest(t, app, "/dynamically-defined", StatusOK) + require.Equal(t, uint32(3), app.handlersCount) + + verifyRequest(t, app, "/test", StatusOK) + require.Equal(t, uint32(3), app.handlersCount) + + verifyRequest(t, app, "/dynamically-defined", StatusOK) + require.Equal(t, uint32(3), app.handlersCount) + require.Equal(t, uint32(2), app.routesCount) + + verifyRouteHandlerCounts(t, app, 1) } //////////////////////////////////////////////