fiber/middleware/proxy/proxy.go

144 lines
3.3 KiB
Go

package proxy
import (
"bytes"
"crypto/tls"
"net/url"
"strings"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/utils"
"github.com/valyala/fasthttp"
)
// Balancer creates a load balancer among multiple upstream servers
func Balancer(config Config) fiber.Handler {
// Set default config
cfg := configDefault(config)
// Load balanced client
var lbc fasthttp.LBClient
// Set timeout
lbc.Timeout = cfg.Timeout
// Scheme must be provided, falls back to http
// TODO add https support
for _, server := range cfg.Servers {
if !strings.HasPrefix(server, "http") {
server = "http://" + server
}
u, err := url.Parse(server)
if err != nil {
panic(err)
}
client := &fasthttp.HostClient{
NoDefaultUserAgentHeader: true,
DisablePathNormalizing: true,
Addr: u.Host,
ReadBufferSize: config.ReadBufferSize,
WriteBufferSize: config.WriteBufferSize,
TLSConfig: config.TlsConfig,
}
lbc.Clients = append(lbc.Clients, client)
}
// Return new handler
return func(c fiber.Ctx) (err error) {
// Don't execute middleware if Next returns true
if cfg.Next != nil && cfg.Next(c) {
return c.Next()
}
// Set request and response
req := c.Request()
res := c.Response()
// Don't proxy "Connection" header
req.Header.Del(fiber.HeaderConnection)
// Modify request
if cfg.ModifyRequest != nil {
if err = cfg.ModifyRequest(c); err != nil {
return err
}
}
req.SetRequestURI(utils.UnsafeString(req.RequestURI()))
// Forward request
if err = lbc.Do(req, res); err != nil {
return err
}
// Don't proxy "Connection" header
res.Header.Del(fiber.HeaderConnection)
// Modify response
if cfg.ModifyResponse != nil {
if err = cfg.ModifyResponse(c); err != nil {
return err
}
}
// Return nil to end proxying if no error
return nil
}
}
var client = fasthttp.Client{
NoDefaultUserAgentHeader: true,
DisablePathNormalizing: true,
}
// WithTlsConfig update http client with a user specified tls.config
// This function should be called before Do and Forward.
func WithTlsConfig(tlsConfig *tls.Config) {
client.TLSConfig = tlsConfig
}
// Forward performs the given http request and fills the given http response.
// This method will return an fiber.Handler
func Forward(addr string) fiber.Handler {
return func(c fiber.Ctx) error {
return Do(c, addr)
}
}
// Do performs the given http request and fills the given http response.
// This method can be used within a fiber.Handler
func Do(c fiber.Ctx, addr string) error {
req := c.Request()
res := c.Response()
originalURL := utils.CopyString(c.OriginalURL())
defer req.SetRequestURI(originalURL)
copiedURL := utils.CopyString(addr)
req.SetRequestURI(copiedURL)
// NOTE: if req.isTLS is true, SetRequestURI keeps the scheme as https.
// issue reference:
// https://github.com/gofiber/fiber/issues/1762
if scheme := getScheme(utils.UnsafeBytes(copiedURL)); len(scheme) > 0 {
req.URI().SetSchemeBytes(scheme)
}
req.Header.Del(fiber.HeaderConnection)
if err := client.Do(req, res); err != nil {
return err
}
res.Header.Del(fiber.HeaderConnection)
return nil
}
func getScheme(uri []byte) []byte {
i := bytes.IndexByte(uri, '/')
if i < 1 || uri[i-1] != ':' || i == len(uri)-1 || uri[i+1] != '/' {
return nil
}
return uri[:i-1]
}