feat: support to set client in proxy mw (#2117)

optimize: add WithClient
pull/2131/head
kinggo 2022-09-28 20:27:58 +08:00 committed by GitHub
parent 66d5b195c5
commit 8e8ad95079
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 165 additions and 38 deletions

View File

@ -7,11 +7,12 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"github.com/gofiber/fiber/v2/internal/gopsutil/common"
"os/exec" "os/exec"
"regexp" "regexp"
"strconv" "strconv"
"strings" "strings"
"github.com/gofiber/fiber/v2/internal/gopsutil/common"
) )
var ( var (

View File

@ -13,8 +13,8 @@ Proxy middleware for [Fiber](https://github.com/gofiber/fiber) that allows you t
```go ```go
func Balancer(config Config) fiber.Handler func Balancer(config Config) fiber.Handler
func Forward(addr string) fiber.Handler func Forward(addr string, clients ...*fasthttp.Client) fiber.Handler
func Do(c *fiber.Ctx, addr string) error func Do(c *fiber.Ctx, addr string, clients ...*fasthttp.Client) error
``` ```
### Examples ### Examples
@ -37,9 +37,21 @@ proxy.WithTlsConfig(&tls.Config{
InsecureSkipVerify: true, InsecureSkipVerify: true,
}) })
// if you need to use global self-custom client, you should use proxy.WithClient.
proxy.WithClient(&fasthttp.Client{
NoDefaultUserAgentHeader: true,
DisablePathNormalizing: true,
})
// Forward to url // Forward to url
app.Get("/gif", proxy.Forward("https://i.imgur.com/IWaBepg.gif")) app.Get("/gif", proxy.Forward("https://i.imgur.com/IWaBepg.gif"))
// Forward to url with local custom client
app.Get("/gif", proxy.Forward("https://i.imgur.com/IWaBepg.gif", &fasthttp.Client{
NoDefaultUserAgentHeader: true,
DisablePathNormalizing: true,
}))
// Make request within handler // Make request within handler
app.Get("/:id", func(c *fiber.Ctx) error { app.Get("/:id", func(c *fiber.Ctx) error {
url := "https://i.imgur.com/"+c.Params("id")+".gif" url := "https://i.imgur.com/"+c.Params("id")+".gif"
@ -120,8 +132,13 @@ type Config struct {
// Per-connection buffer size for responses' writing. // Per-connection buffer size for responses' writing.
WriteBufferSize int WriteBufferSize int
// tls config for the http client // tls config for the http client.
TlsConfig *tls.Config TlsConfig *tls.Config
// Client is custom client when client config is complex.
// Note that Servers, Timeout, WriteBufferSize, ReadBufferSize and TlsConfig
// will not be used if the client are set.
Client *fasthttp.LBClient
} }
``` ```

View File

@ -47,8 +47,13 @@ type Config struct {
// Per-connection buffer size for responses' writing. // Per-connection buffer size for responses' writing.
WriteBufferSize int WriteBufferSize int
// tls config for the http client // tls config for the http client.
TlsConfig *tls.Config TlsConfig *tls.Config
// Client is custom client when client config is complex.
// Note that Servers, Timeout, WriteBufferSize, ReadBufferSize and TlsConfig
// will not be used if the client are set.
Client *fasthttp.LBClient
} }
// ConfigDefault is the default config // ConfigDefault is the default config
@ -75,7 +80,7 @@ func configDefault(config ...Config) Config {
} }
// Set default values // Set default values
if len(cfg.Servers) == 0 { if len(cfg.Servers) == 0 && cfg.Client == nil {
panic("Servers cannot be empty") panic("Servers cannot be empty")
} }
return cfg return cfg

View File

@ -24,34 +24,39 @@ func Balancer(config Config) fiber.Handler {
cfg := configDefault(config) cfg := configDefault(config)
// Load balanced client // Load balanced client
var lbc fasthttp.LBClient var lbc = &fasthttp.LBClient{}
// Set timeout // Note that Servers, Timeout, WriteBufferSize, ReadBufferSize and TlsConfig
lbc.Timeout = cfg.Timeout // will not be used if the client are set.
if config.Client == nil {
// Set timeout
lbc.Timeout = cfg.Timeout
// Scheme must be provided, falls back to http
for _, server := range cfg.Servers {
if !strings.HasPrefix(server, "http") {
server = "http://" + server
}
// Scheme must be provided, falls back to http u, err := url.Parse(server)
// TODO add https support if err != nil {
for _, server := range cfg.Servers { panic(err)
if !strings.HasPrefix(server, "http") { }
server = "http://" + server
client := &fasthttp.HostClient{
NoDefaultUserAgentHeader: true,
DisablePathNormalizing: true,
Addr: u.Host,
ReadBufferSize: config.ReadBufferSize,
WriteBufferSize: config.WriteBufferSize,
TLSConfig: config.TlsConfig,
}
lbc.Clients = append(lbc.Clients, client)
} }
} else {
u, err := url.Parse(server) // Set custom client
if err != nil { lbc = config.Client
panic(err)
}
client := &fasthttp.HostClient{
NoDefaultUserAgentHeader: true,
DisablePathNormalizing: true,
Addr: u.Host,
ReadBufferSize: config.ReadBufferSize,
WriteBufferSize: config.WriteBufferSize,
TLSConfig: config.TlsConfig,
}
lbc.Clients = append(lbc.Clients, client)
} }
// Return new handler // Return new handler
@ -97,28 +102,43 @@ func Balancer(config Config) fiber.Handler {
} }
} }
var client = fasthttp.Client{ var client = &fasthttp.Client{
NoDefaultUserAgentHeader: true, NoDefaultUserAgentHeader: true,
DisablePathNormalizing: true, DisablePathNormalizing: true,
} }
// WithTlsConfig update http client with a user specified tls.config // WithTlsConfig update http client with a user specified tls.config
// This function should be called before Do and Forward. // This function should be called before Do and Forward.
// Deprecated: use WithClient instead.
func WithTlsConfig(tlsConfig *tls.Config) { func WithTlsConfig(tlsConfig *tls.Config) {
client.TLSConfig = tlsConfig client.TLSConfig = tlsConfig
} }
// WithClient sets the global proxy client.
// This function should be called before Do and Forward.
func WithClient(cli *fasthttp.Client) {
client = cli
}
// Forward performs the given http request and fills the given http response. // Forward performs the given http request and fills the given http response.
// This method will return an fiber.Handler // This method will return an fiber.Handler
func Forward(addr string) fiber.Handler { func Forward(addr string, clients ...*fasthttp.Client) fiber.Handler {
return func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error {
return Do(c, addr) return Do(c, addr, clients...)
} }
} }
// Do performs the given http request and fills the given http response. // Do performs the given http request and fills the given http response.
// This method can be used within a fiber.Handler // This method can be used within a fiber.Handler
func Do(c *fiber.Ctx, addr string) error { func Do(c *fiber.Ctx, addr string, clients ...*fasthttp.Client) error {
var cli *fasthttp.Client
if len(clients) != 0 {
// Set local client
cli = clients[0]
} else {
// Set global client
cli = client
}
req := c.Request() req := c.Request()
res := c.Response() res := c.Response()
originalURL := utils.CopyString(c.OriginalURL()) originalURL := utils.CopyString(c.OriginalURL())
@ -134,7 +154,7 @@ func Do(c *fiber.Ctx, addr string) error {
} }
req.Header.Del(fiber.HeaderConnection) req.Header.Del(fiber.HeaderConnection)
if err := client.Do(req, res); err != nil { if err := cli.Do(req, res); err != nil {
return err return err
} }
res.Header.Del(fiber.HeaderConnection) res.Header.Del(fiber.HeaderConnection)

View File

@ -13,6 +13,7 @@ import (
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/internal/tlstest" "github.com/gofiber/fiber/v2/internal/tlstest"
"github.com/gofiber/fiber/v2/utils" "github.com/gofiber/fiber/v2/utils"
"github.com/valyala/fasthttp"
) )
func createProxyTestServer(handler fiber.Handler, t *testing.T) (*fiber.App, string) { func createProxyTestServer(handler fiber.Handler, t *testing.T) (*fiber.App, string) {
@ -364,6 +365,7 @@ func Test_Proxy_Do_RestoreOriginalURL(t *testing.T) {
utils.AssertEqual(t, nil, err2) utils.AssertEqual(t, nil, err2)
} }
// go test -race -run Test_Proxy_Do_HTTP_Prefix_URL
func Test_Proxy_Do_HTTP_Prefix_URL(t *testing.T) { func Test_Proxy_Do_HTTP_Prefix_URL(t *testing.T) {
t.Parallel() t.Parallel()
@ -390,3 +392,85 @@ func Test_Proxy_Do_HTTP_Prefix_URL(t *testing.T) {
utils.AssertEqual(t, nil, err) utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "hello world", string(s)) utils.AssertEqual(t, "hello world", string(s))
} }
// go test -race -run Test_Proxy_Forward_Global_Client
func Test_Proxy_Forward_Global_Client(t *testing.T) {
t.Parallel()
ln, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0")
utils.AssertEqual(t, nil, err)
WithClient(&fasthttp.Client{
NoDefaultUserAgentHeader: true,
DisablePathNormalizing: true,
})
app := fiber.New(fiber.Config{DisableStartupMessage: true})
app.Get("/test_global_client", func(c *fiber.Ctx) error {
return c.SendString("test_global_client")
})
addr := ln.Addr().String()
app.Use(Forward("http://" + addr + "/test_global_client"))
go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }()
code, body, errs := fiber.Get("http://" + addr).String()
utils.AssertEqual(t, 0, len(errs))
utils.AssertEqual(t, fiber.StatusOK, code)
utils.AssertEqual(t, "test_global_client", body)
}
// go test -race -run Test_Proxy_Forward_Local_Client
func Test_Proxy_Forward_Local_Client(t *testing.T) {
t.Parallel()
ln, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0")
utils.AssertEqual(t, nil, err)
app := fiber.New(fiber.Config{DisableStartupMessage: true})
app.Get("/test_local_client", func(c *fiber.Ctx) error {
return c.SendString("test_local_client")
})
addr := ln.Addr().String()
app.Use(Forward("http://"+addr+"/test_local_client", &fasthttp.Client{
NoDefaultUserAgentHeader: true,
DisablePathNormalizing: true,
Dial: func(addr string) (net.Conn, error) {
return fasthttp.Dial(addr)
},
}))
go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }()
code, body, errs := fiber.Get("http://" + addr).String()
utils.AssertEqual(t, 0, len(errs))
utils.AssertEqual(t, fiber.StatusOK, code)
utils.AssertEqual(t, "test_local_client", body)
}
// go test -run Test_ProxyBalancer_Custom_Client
func Test_ProxyBalancer_Custom_Client(t *testing.T) {
t.Parallel()
target, addr := createProxyTestServer(
func(c *fiber.Ctx) error { return c.SendStatus(fiber.StatusTeapot) }, t,
)
resp, err := target.Test(httptest.NewRequest("GET", "/", nil), 2000)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusTeapot, resp.StatusCode)
app := fiber.New(fiber.Config{DisableStartupMessage: true})
app.Use(Balancer(Config{Client: &fasthttp.LBClient{
Clients: []fasthttp.BalancingClient{
&fasthttp.HostClient{
NoDefaultUserAgentHeader: true,
DisablePathNormalizing: true,
Addr: addr,
},
},
Timeout: time.Second,
}}))
req := httptest.NewRequest("GET", "/", nil)
req.Host = addr
resp, err = app.Test(req)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusTeapot, resp.StatusCode)
}