diff --git a/middleware/proxy/proxy.go b/middleware/proxy/proxy.go index 110c3fe8..071a841f 100644 --- a/middleware/proxy/proxy.go +++ b/middleware/proxy/proxy.go @@ -44,7 +44,7 @@ var ConfigDefault = Config{ // New is deprecated func New(config Config) fiber.Handler { - fmt.Println("proxy.New is deprecated, please us proxy.Balancer instead") + fmt.Println("proxy.New is deprecated, please use proxy.Balancer instead") return Balancer(config) } @@ -58,7 +58,9 @@ func Balancer(config Config) fiber.Handler { cfg.Next = ConfigDefault.Next } if len(cfg.Servers) == 0 { - panic("Servers cannot be empty") + return func(c *fiber.Ctx) (err error) { + panic("Servers cannot be empty") + } } client := fasthttp.Client{ @@ -97,20 +99,16 @@ func Balancer(config Config) fiber.Handler { } req.SetRequestURI(cfg.Servers[counter] + utils.UnsafeString(req.RequestURI())) - counter = (counter + 1) % len(cfg.Servers) // Forward request if err = client.Do(req, res); err != nil { - fmt.Println(err) return err } // Don't proxy "Connection" header res.Header.Del(fiber.HeaderConnection) - //fmt.Println(string(res.Header.ContentType())) - // Modify response if cfg.ModifyResponse != nil { if err = cfg.ModifyResponse(c); err != nil { diff --git a/middleware/proxy/proxy_test.go b/middleware/proxy/proxy_test.go index 7ddb3159..f5eb78a0 100644 --- a/middleware/proxy/proxy_test.go +++ b/middleware/proxy/proxy_test.go @@ -1,148 +1,161 @@ package proxy -// // go test -run Test_Proxy_Empty_Host -// func Test_Proxy_Empty_Host(t *testing.T) { -// app := fiber.New() -// app.Use(New( -// Config{Hosts: ""}, -// )) +import ( + "io/ioutil" + "net/http/httptest" + "strings" + "testing" + "time" -// resp, err := app.Test(httptest.NewRequest("GET", "/", nil)) -// utils.AssertEqual(t, nil, err) -// utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode) -// } + "github.com/gofiber/fiber/v2" + "github.com/gofiber/fiber/v2/middleware/recover" + "github.com/gofiber/fiber/v2/utils" +) + +// go test -run Test_Proxy_Empty_Host +func Test_Proxy_Empty_Upstream_Servers(t *testing.T) { + app := fiber.New() + app.Use(recover.New(), Balancer(Config{Servers: []string{}})) + + resp, err := app.Test(httptest.NewRequest("GET", "/", nil)) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, fiber.StatusInternalServerError, resp.StatusCode) +} // // go test -run Test_Proxy_Next -// func Test_Proxy_Next(t *testing.T) { -// app := fiber.New() -// app.Use(New(Config{ -// Hosts: "next", -// Next: func(_ *fiber.Ctx) bool { -// return true -// }, -// })) +func Test_Proxy_Next(t *testing.T) { + app := fiber.New() + app.Use(New(Config{ + Servers: []string{"localhost"}, + Next: func(_ *fiber.Ctx) bool { + return true + }, + })) -// resp, err := app.Test(httptest.NewRequest("GET", "/", nil)) -// utils.AssertEqual(t, nil, err) -// utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode) -// } + resp, err := app.Test(httptest.NewRequest("GET", "/", nil)) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode) +} // // go test -run Test_Proxy -// func Test_Proxy(t *testing.T) { -// target := fiber.New(fiber.Config{ -// DisableStartupMessage: true, -// }) +func Test_Proxy(t *testing.T) { + target := fiber.New() -// target.Get("/", func(c *fiber.Ctx) error { -// return c.SendStatus(fiber.StatusTeapot) -// }) + target.Get("/", func(c *fiber.Ctx) error { + return c.SendStatus(fiber.StatusTeapot) + }) -// go func() { -// utils.AssertEqual(t, nil, target.Listen(":3001")) -// }() + go func() { + utils.AssertEqual(t, nil, target.Listen(":3001")) + }() -// time.Sleep(time.Second) + time.Sleep(time.Second) -// resp, err := target.Test(httptest.NewRequest("GET", "/", nil), 2000) -// utils.AssertEqual(t, nil, err) -// utils.AssertEqual(t, fiber.StatusTeapot, resp.StatusCode) + resp, err := target.Test(httptest.NewRequest("GET", "/", nil), 2000) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, fiber.StatusTeapot, resp.StatusCode) -// app := fiber.New() + app := fiber.New() -// host := "localhost:3001" + app.Use(New(Config{Servers: []string{"localhost:3001"}})) -// app.Use(New(Config{ -// Hosts: host, -// })) + req := httptest.NewRequest("GET", "/", nil) + req.Host = "localhost:3001" + resp, err = app.Test(req) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, fiber.StatusTeapot, resp.StatusCode) +} -// req := httptest.NewRequest("GET", "/", nil) -// req.Host = host -// resp, err = app.Test(req) -// utils.AssertEqual(t, nil, err) -// utils.AssertEqual(t, fiber.StatusTeapot, resp.StatusCode) -// } +// go test -run Test_Proxy_Do_With_Error +func Test_Proxy_Do_With_Error(t *testing.T) { + app := fiber.New() -// // go test -run Test_Proxy_Before_With_Error -// func Test_Proxy_Before_With_Error(t *testing.T) { -// app := fiber.New() + app.Use(New(Config{Servers: []string{"localhost:90000"}})) -// errStr := "error after Before" + resp, err := app.Test(httptest.NewRequest("GET", "/", nil)) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, fiber.StatusInternalServerError, resp.StatusCode) -// app.Use( -// New(Config{ -// Hosts: "host", -// Before: func(c *fiber.Ctx) error { -// return fmt.Errorf(errStr) -// }, -// })) + b, err := ioutil.ReadAll(resp.Body) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, true, strings.Contains(string(b), "127.0.0.1:90000")) +} -// resp, err := app.Test(httptest.NewRequest("GET", "/", nil)) -// utils.AssertEqual(t, nil, err) -// utils.AssertEqual(t, fiber.StatusInternalServerError, resp.StatusCode) +func Test_Proxy_Forward(t *testing.T) { + app := fiber.New() -// b, err := ioutil.ReadAll(resp.Body) -// utils.AssertEqual(t, nil, err) -// utils.AssertEqual(t, errStr, string(b)) -// } + target := fiber.New(fiber.Config{DisableStartupMessage: true}) + go func() { + utils.AssertEqual(t, nil, target.Listen("localhost:50001")) + }() + target.Get("/", func(c *fiber.Ctx) error { + return c.SendString("forwarded") + }) -// // go test -run Test_Proxy_After_With_Error -// func Test_Proxy_After_With_Error(t *testing.T) { -// target := fiber.New(fiber.Config{ -// DisableStartupMessage: true, -// }) + app.Use(Forward("http://localhost:50001")) -// target.Get("/", func(c *fiber.Ctx) error { -// return c.SendStatus(fiber.StatusTeapot) -// }) + resp, err := app.Test(httptest.NewRequest("GET", "/", nil)) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode) -// go func() { -// utils.AssertEqual(t, nil, target.Listen(":3002")) -// }() + b, err := ioutil.ReadAll(resp.Body) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, "forwarded", string(b)) +} -// time.Sleep(time.Second) +func Test_Proxy_Modify_Response(t *testing.T) { + target := fiber.New(fiber.Config{DisableStartupMessage: true}) + go func() { + utils.AssertEqual(t, nil, target.Listen("localhost:50002")) + }() -// resp, err := target.Test(httptest.NewRequest("GET", "/", nil), 2000) -// utils.AssertEqual(t, nil, err) -// utils.AssertEqual(t, fiber.StatusTeapot, resp.StatusCode) + app := fiber.New() + app.Use(Balancer(Config{ + Servers: []string{"localhost:50002"}, + ModifyResponse: func(c *fiber.Ctx) error { + c.Response().SetStatusCode(fiber.StatusOK) + return c.SendString("modified response") + }, + })) -// app := fiber.New() + target.Get("/", func(c *fiber.Ctx) error { + return c.SendString("not modified") + }) -// host := "localhost:3001" -// errStr := "error after After" + resp, err := app.Test(httptest.NewRequest("GET", "/", nil)) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode) -// app.Use(New(Config{ -// Hosts: host, -// After: func(ctx *fiber.Ctx) error { -// utils.AssertEqual(t, fiber.StatusTeapot, ctx.Response().StatusCode()) -// return fmt.Errorf(errStr) -// }, -// })) + b, err := ioutil.ReadAll(resp.Body) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, "modified response", string(b)) +} -// req := httptest.NewRequest("GET", "/", nil) -// req.Host = host -// resp, err = app.Test(req) -// utils.AssertEqual(t, nil, err) -// utils.AssertEqual(t, fiber.StatusInternalServerError, resp.StatusCode) +func Test_Proxy_Modify_Request(t *testing.T) { + target := fiber.New(fiber.Config{DisableStartupMessage: true}) + go func() { + utils.AssertEqual(t, nil, target.Listen("localhost:50003")) + }() -// b, err := ioutil.ReadAll(resp.Body) -// utils.AssertEqual(t, nil, err) -// utils.AssertEqual(t, errStr, string(b)) -// } + app := fiber.New() + app.Use(Balancer(Config{ + Servers: []string{"localhost:50003"}, + ModifyRequest: func(c *fiber.Ctx) error { + c.Request().SetBody([]byte("modified request")) + return nil + }, + })) -// // go test -run Test_Proxy_Do_With_Error -// func Test_Proxy_Do_With_Error(t *testing.T) { -// app := fiber.New() + target.Get("/", func(c *fiber.Ctx) error { + b := c.Request().Body() + return c.SendString(string(b)) + }) -// app.Use( -// New(Config{ -// Hosts: "localhost:90000", -// })) + resp, err := app.Test(httptest.NewRequest("GET", "/", nil)) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode) -// resp, err := app.Test(httptest.NewRequest("GET", "/", nil)) -// utils.AssertEqual(t, nil, err) -// utils.AssertEqual(t, fiber.StatusInternalServerError, resp.StatusCode) - -// b, err := ioutil.ReadAll(resp.Body) -// utils.AssertEqual(t, nil, err) -// utils.AssertEqual(t, true, strings.Contains(string(b), "127.0.0.1:90000")) -// } + b, err := ioutil.ReadAll(resp.Body) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, "modified request", string(b)) +}