Refactor proxy middleware and increase coverage

pull/924/head
Konstantinos Lypitkas 2020-10-14 00:48:09 +03:00
parent d4e604f1a3
commit 3d59648b7d
2 changed files with 131 additions and 120 deletions

View File

@ -44,7 +44,7 @@ var ConfigDefault = Config{
// New is deprecated // New is deprecated
func New(config Config) fiber.Handler { func New(config Config) fiber.Handler {
fmt.Println("proxy.New is deprecated, please us proxy.Balancer instead") fmt.Println("proxy.New is deprecated, please use proxy.Balancer instead")
return Balancer(config) return Balancer(config)
} }
@ -58,8 +58,10 @@ func Balancer(config Config) fiber.Handler {
cfg.Next = ConfigDefault.Next cfg.Next = ConfigDefault.Next
} }
if len(cfg.Servers) == 0 { if len(cfg.Servers) == 0 {
return func(c *fiber.Ctx) (err error) {
panic("Servers cannot be empty") panic("Servers cannot be empty")
} }
}
client := fasthttp.Client{ client := fasthttp.Client{
NoDefaultUserAgentHeader: true, NoDefaultUserAgentHeader: true,
@ -97,20 +99,16 @@ func Balancer(config Config) fiber.Handler {
} }
req.SetRequestURI(cfg.Servers[counter] + utils.UnsafeString(req.RequestURI())) req.SetRequestURI(cfg.Servers[counter] + utils.UnsafeString(req.RequestURI()))
counter = (counter + 1) % len(cfg.Servers) counter = (counter + 1) % len(cfg.Servers)
// Forward request // Forward request
if err = client.Do(req, res); err != nil { if err = client.Do(req, res); err != nil {
fmt.Println(err)
return err return err
} }
// Don't proxy "Connection" header // Don't proxy "Connection" header
res.Header.Del(fiber.HeaderConnection) res.Header.Del(fiber.HeaderConnection)
//fmt.Println(string(res.Header.ContentType()))
// Modify response // Modify response
if cfg.ModifyResponse != nil { if cfg.ModifyResponse != nil {
if err = cfg.ModifyResponse(c); err != nil { if err = cfg.ModifyResponse(c); err != nil {

View File

@ -1,148 +1,161 @@
package proxy package proxy
// // go test -run Test_Proxy_Empty_Host import (
// func Test_Proxy_Empty_Host(t *testing.T) { "io/ioutil"
// app := fiber.New() "net/http/httptest"
// app.Use(New( "strings"
// Config{Hosts: ""}, "testing"
// )) "time"
// resp, err := app.Test(httptest.NewRequest("GET", "/", nil)) "github.com/gofiber/fiber/v2"
// utils.AssertEqual(t, nil, err) "github.com/gofiber/fiber/v2/middleware/recover"
// utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode) "github.com/gofiber/fiber/v2/utils"
// } )
// go test -run Test_Proxy_Empty_Host
func Test_Proxy_Empty_Upstream_Servers(t *testing.T) {
app := fiber.New()
app.Use(recover.New(), Balancer(Config{Servers: []string{}}))
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusInternalServerError, resp.StatusCode)
}
// // go test -run Test_Proxy_Next // // go test -run Test_Proxy_Next
// func Test_Proxy_Next(t *testing.T) { func Test_Proxy_Next(t *testing.T) {
// app := fiber.New() app := fiber.New()
// app.Use(New(Config{ app.Use(New(Config{
// Hosts: "next", Servers: []string{"localhost"},
// Next: func(_ *fiber.Ctx) bool { Next: func(_ *fiber.Ctx) bool {
// return true return true
// }, },
// })) }))
// resp, err := app.Test(httptest.NewRequest("GET", "/", nil)) resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
// utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
// utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode) utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
// } }
// // go test -run Test_Proxy // // go test -run Test_Proxy
// func Test_Proxy(t *testing.T) { func Test_Proxy(t *testing.T) {
// target := fiber.New(fiber.Config{ target := fiber.New()
// DisableStartupMessage: true,
// })
// target.Get("/", func(c *fiber.Ctx) error { target.Get("/", func(c *fiber.Ctx) error {
// return c.SendStatus(fiber.StatusTeapot) return c.SendStatus(fiber.StatusTeapot)
// }) })
// go func() { go func() {
// utils.AssertEqual(t, nil, target.Listen(":3001")) utils.AssertEqual(t, nil, target.Listen(":3001"))
// }() }()
// time.Sleep(time.Second) time.Sleep(time.Second)
// resp, err := target.Test(httptest.NewRequest("GET", "/", nil), 2000) resp, err := target.Test(httptest.NewRequest("GET", "/", nil), 2000)
// utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
// utils.AssertEqual(t, fiber.StatusTeapot, resp.StatusCode) utils.AssertEqual(t, fiber.StatusTeapot, resp.StatusCode)
// app := fiber.New() app := fiber.New()
// host := "localhost:3001" app.Use(New(Config{Servers: []string{"localhost:3001"}}))
// app.Use(New(Config{ req := httptest.NewRequest("GET", "/", nil)
// Hosts: host, req.Host = "localhost:3001"
// })) resp, err = app.Test(req)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusTeapot, resp.StatusCode)
}
// req := httptest.NewRequest("GET", "/", nil) // go test -run Test_Proxy_Do_With_Error
// req.Host = host func Test_Proxy_Do_With_Error(t *testing.T) {
// resp, err = app.Test(req) app := fiber.New()
// utils.AssertEqual(t, nil, err)
// utils.AssertEqual(t, fiber.StatusTeapot, resp.StatusCode)
// }
// // go test -run Test_Proxy_Before_With_Error app.Use(New(Config{Servers: []string{"localhost:90000"}}))
// func Test_Proxy_Before_With_Error(t *testing.T) {
// app := fiber.New()
// errStr := "error after Before" resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusInternalServerError, resp.StatusCode)
// app.Use( b, err := ioutil.ReadAll(resp.Body)
// New(Config{ utils.AssertEqual(t, nil, err)
// Hosts: "host", utils.AssertEqual(t, true, strings.Contains(string(b), "127.0.0.1:90000"))
// Before: func(c *fiber.Ctx) error { }
// return fmt.Errorf(errStr)
// },
// }))
// resp, err := app.Test(httptest.NewRequest("GET", "/", nil)) func Test_Proxy_Forward(t *testing.T) {
// utils.AssertEqual(t, nil, err) app := fiber.New()
// utils.AssertEqual(t, fiber.StatusInternalServerError, resp.StatusCode)
// b, err := ioutil.ReadAll(resp.Body) target := fiber.New(fiber.Config{DisableStartupMessage: true})
// utils.AssertEqual(t, nil, err) go func() {
// utils.AssertEqual(t, errStr, string(b)) utils.AssertEqual(t, nil, target.Listen("localhost:50001"))
// } }()
target.Get("/", func(c *fiber.Ctx) error {
return c.SendString("forwarded")
})
// // go test -run Test_Proxy_After_With_Error app.Use(Forward("http://localhost:50001"))
// func Test_Proxy_After_With_Error(t *testing.T) {
// target := fiber.New(fiber.Config{
// DisableStartupMessage: true,
// })
// target.Get("/", func(c *fiber.Ctx) error { resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
// return c.SendStatus(fiber.StatusTeapot) utils.AssertEqual(t, nil, err)
// }) utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
// go func() { b, err := ioutil.ReadAll(resp.Body)
// utils.AssertEqual(t, nil, target.Listen(":3002")) utils.AssertEqual(t, nil, err)
// }() utils.AssertEqual(t, "forwarded", string(b))
}
// time.Sleep(time.Second) func Test_Proxy_Modify_Response(t *testing.T) {
target := fiber.New(fiber.Config{DisableStartupMessage: true})
go func() {
utils.AssertEqual(t, nil, target.Listen("localhost:50002"))
}()
// resp, err := target.Test(httptest.NewRequest("GET", "/", nil), 2000) app := fiber.New()
// utils.AssertEqual(t, nil, err) app.Use(Balancer(Config{
// utils.AssertEqual(t, fiber.StatusTeapot, resp.StatusCode) Servers: []string{"localhost:50002"},
ModifyResponse: func(c *fiber.Ctx) error {
c.Response().SetStatusCode(fiber.StatusOK)
return c.SendString("modified response")
},
}))
// app := fiber.New() target.Get("/", func(c *fiber.Ctx) error {
return c.SendString("not modified")
})
// host := "localhost:3001" resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
// errStr := "error after After" utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
// app.Use(New(Config{ b, err := ioutil.ReadAll(resp.Body)
// Hosts: host, utils.AssertEqual(t, nil, err)
// After: func(ctx *fiber.Ctx) error { utils.AssertEqual(t, "modified response", string(b))
// utils.AssertEqual(t, fiber.StatusTeapot, ctx.Response().StatusCode()) }
// return fmt.Errorf(errStr)
// },
// }))
// req := httptest.NewRequest("GET", "/", nil) func Test_Proxy_Modify_Request(t *testing.T) {
// req.Host = host target := fiber.New(fiber.Config{DisableStartupMessage: true})
// resp, err = app.Test(req) go func() {
// utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, target.Listen("localhost:50003"))
// utils.AssertEqual(t, fiber.StatusInternalServerError, resp.StatusCode) }()
// b, err := ioutil.ReadAll(resp.Body) app := fiber.New()
// utils.AssertEqual(t, nil, err) app.Use(Balancer(Config{
// utils.AssertEqual(t, errStr, string(b)) Servers: []string{"localhost:50003"},
// } ModifyRequest: func(c *fiber.Ctx) error {
c.Request().SetBody([]byte("modified request"))
return nil
},
}))
// // go test -run Test_Proxy_Do_With_Error target.Get("/", func(c *fiber.Ctx) error {
// func Test_Proxy_Do_With_Error(t *testing.T) { b := c.Request().Body()
// app := fiber.New() return c.SendString(string(b))
})
// app.Use( resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
// New(Config{ utils.AssertEqual(t, nil, err)
// Hosts: "localhost:90000", utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
// }))
// resp, err := app.Test(httptest.NewRequest("GET", "/", nil)) b, err := ioutil.ReadAll(resp.Body)
// utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
// utils.AssertEqual(t, fiber.StatusInternalServerError, resp.StatusCode) utils.AssertEqual(t, "modified request", string(b))
}
// b, err := ioutil.ReadAll(resp.Body)
// utils.AssertEqual(t, nil, err)
// utils.AssertEqual(t, true, strings.Contains(string(b), "127.0.0.1:90000"))
// }