package rewrite import ( "context" "fmt" "io" "net/http" "testing" "github.com/gofiber/fiber/v3" "github.com/stretchr/testify/require" ) func Test_New(t *testing.T) { // Test with no config m := New() if m == nil { t.Error("Expected middleware to be returned, got nil") } // Test with config m = New(Config{ Rules: map[string]string{ "/old": "/new", }, }) if m == nil { t.Error("Expected middleware to be returned, got nil") } // Test with full config m = New(Config{ Next: func(fiber.Ctx) bool { return true }, Rules: map[string]string{ "/old": "/new", }, }) if m == nil { t.Error("Expected middleware to be returned, got nil") } } func Test_Rewrite(t *testing.T) { // Case 1: Next function always returns true app := fiber.New() app.Use(New(Config{ Next: func(fiber.Ctx) bool { return true }, Rules: map[string]string{ "/old": "/new", }, })) app.Get("/old", func(c fiber.Ctx) error { return c.SendString("Rewrite Successful") }) req, err := http.NewRequestWithContext(context.Background(), fiber.MethodGet, "/old", nil) require.NoError(t, err) resp, err := app.Test(req) require.NoError(t, err) body, err := io.ReadAll(resp.Body) require.NoError(t, err) bodyString := string(body) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) require.Equal(t, "Rewrite Successful", bodyString) // Case 2: Next function always returns false app = fiber.New() app.Use(New(Config{ Next: func(fiber.Ctx) bool { return false }, Rules: map[string]string{ "/old": "/new", }, })) app.Get("/new", func(c fiber.Ctx) error { return c.SendString("Rewrite Successful") }) req, err = http.NewRequestWithContext(context.Background(), fiber.MethodGet, "/old", nil) require.NoError(t, err) resp, err = app.Test(req) require.NoError(t, err) body, err = io.ReadAll(resp.Body) require.NoError(t, err) bodyString = string(body) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) require.Equal(t, "Rewrite Successful", bodyString) // Case 3: check for captured tokens in rewrite rule app = fiber.New() app.Use(New(Config{ Rules: map[string]string{ "/users/*/orders/*": "/user/$1/order/$2", }, })) app.Get("/user/:userID/order/:orderID", func(c fiber.Ctx) error { return c.SendString(fmt.Sprintf("User ID: %s, Order ID: %s", c.Params("userID"), c.Params("orderID"))) }) req, err = http.NewRequestWithContext(context.Background(), fiber.MethodGet, "/users/123/orders/456", nil) require.NoError(t, err) resp, err = app.Test(req) require.NoError(t, err) body, err = io.ReadAll(resp.Body) require.NoError(t, err) bodyString = string(body) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) require.Equal(t, "User ID: 123, Order ID: 456", bodyString) // Case 4: Send non-matching request, handled by default route app = fiber.New() app.Use(New(Config{ Rules: map[string]string{ "/users/*/orders/*": "/user/$1/order/$2", }, })) app.Get("/user/:userID/order/:orderID", func(c fiber.Ctx) error { return c.SendString(fmt.Sprintf("User ID: %s, Order ID: %s", c.Params("userID"), c.Params("orderID"))) }) app.Use(func(c fiber.Ctx) error { return c.SendStatus(fiber.StatusOK) }) req, err = http.NewRequestWithContext(context.Background(), fiber.MethodGet, "/not-matching-any-rule", nil) require.NoError(t, err) resp, err = app.Test(req) require.NoError(t, err) body, err = io.ReadAll(resp.Body) require.NoError(t, err) bodyString = string(body) require.NoError(t, err) require.Equal(t, fiber.StatusOK, resp.StatusCode) require.Equal(t, "OK", bodyString) // Case 4: Send non-matching request, with no default route app = fiber.New() app.Use(New(Config{ Rules: map[string]string{ "/users/*/orders/*": "/user/$1/order/$2", }, })) app.Get("/user/:userID/order/:orderID", func(c fiber.Ctx) error { return c.SendString(fmt.Sprintf("User ID: %s, Order ID: %s", c.Params("userID"), c.Params("orderID"))) }) req, err = http.NewRequestWithContext(context.Background(), fiber.MethodGet, "/not-matching-any-rule", nil) require.NoError(t, err) resp, err = app.Test(req) require.NoError(t, err) require.Equal(t, fiber.StatusNotFound, resp.StatusCode) }