Feature: Add DoRedirects, DoTimeout and DoDeadline to Proxy middleware (#2332)

* Add support for DoRedirects

Signed-off-by: Juan Calderon-Perez <jgcalderonperez@protonmail.com>

* Fix linter issues

Signed-off-by: Juan Calderon-Perez <jgcalderonperez@protonmail.com>

* Add example to README

* Add support for DoDeadline and DoTimeout. Expand unit-tests

* Fix linter errors

Signed-off-by: Juan Calderon-Perez <jgcalderonperez@protonmail.com>

* Add examples for Proxy Middleware

---------

Signed-off-by: Juan Calderon-Perez <jgcalderonperez@protonmail.com>
pull/2344/head
Juan Calderon-Perez 2023-02-24 09:09:00 -05:00 committed by GitHub
parent b634ba0a58
commit dc038d8233
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 250 additions and 20 deletions

View File

@ -18,6 +18,12 @@ func Balancer(config Config) fiber.Handler
func Forward(addr string, clients ...*fasthttp.Client) fiber.Handler
// Do performs the given http request and fills the given http response.
func Do(c *fiber.Ctx, addr string, clients ...*fasthttp.Client) error
// DoRedirects performs the given http request and fills the given http response while following up to maxRedirectsCount redirects.
func DoRedirects(c *fiber.Ctx, addr string, maxRedirectsCount int, clients ...*fasthttp.Client) error
// DoDeadline performs the given request and waits for response until the given deadline.
func DoDeadline(c *fiber.Ctx, addr string, deadline time.Time, clients ...*fasthttp.Client) error
// DoTimeout performs the given request and waits for response during the given timeout duration.
func DoTimeout(c *fiber.Ctx, addr string, timeout time.Duration, clients ...*fasthttp.Client) error
// DomainForward the given http request based on the given domain and fills the given http response
func DomainForward(hostname string, addr string, clients ...*fasthttp.Client) fiber.Handler
// BalancerForward performs the given http request based round robin balancer and fills the given http response
@ -73,6 +79,36 @@ app.Get("/:id", func(c *fiber.Ctx) error {
return nil
})
// Make proxy requests while following redirects
app.Get("/proxy", func(c *fiber.Ctx) error {
if err := proxy.DoRedirects(c, "http://google.com", 3); err != nil {
return err
}
// Remove Server header from response
c.Response().Header.Del(fiber.HeaderServer)
return nil
})
// Make proxy requests and wait up to 5 seconds before timing out
app.Get("/proxy", func(c *fiber.Ctx) error {
if err := proxy.DoTimeout(c, "http://localhost:3000", time.Second * 5); err != nil {
return err
}
// Remove Server header from response
c.Response().Header.Del(fiber.HeaderServer)
return nil
})
// Make proxy requests, timeout a minute from now
app.Get("/proxy", func(c *fiber.Ctx) error {
if err := DoDeadline(c, "http://localhost", time.Now().Add(time.Minute)); err != nil {
return err
}
// Remove Server header from response
c.Response().Header.Del(fiber.HeaderServer)
return nil
})
// Minimal round robin balancer
app.Use(proxy.Balancer(proxy.Config{
Servers: []string{

View File

@ -7,6 +7,7 @@ import (
"net/url"
"strings"
"sync"
"time"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/utils"
@ -139,16 +140,53 @@ func Forward(addr string, clients ...*fasthttp.Client) fiber.Handler {
// Do performs the given http request and fills the given http response.
// This method can be used within a fiber.Handler
func Do(c *fiber.Ctx, addr string, clients ...*fasthttp.Client) error {
return doAction(c, addr, func(cli *fasthttp.Client, req *fasthttp.Request, resp *fasthttp.Response) error {
return cli.Do(req, resp)
}, clients...)
}
// DoRedirects performs the given http request and fills the given http response, following up to maxRedirectsCount redirects.
// When the redirect count exceeds maxRedirectsCount, ErrTooManyRedirects is returned.
// This method can be used within a fiber.Handler
func DoRedirects(c *fiber.Ctx, addr string, maxRedirectsCount int, clients ...*fasthttp.Client) error {
return doAction(c, addr, func(cli *fasthttp.Client, req *fasthttp.Request, resp *fasthttp.Response) error {
return cli.DoRedirects(req, resp, maxRedirectsCount)
}, clients...)
}
// DoDeadline performs the given request and waits for response until the given deadline.
// This method can be used within a fiber.Handler
func DoDeadline(c *fiber.Ctx, addr string, deadline time.Time, clients ...*fasthttp.Client) error {
return doAction(c, addr, func(cli *fasthttp.Client, req *fasthttp.Request, resp *fasthttp.Response) error {
return cli.DoDeadline(req, resp, deadline)
}, clients...)
}
// DoTimeout performs the given request and waits for response during the given timeout duration.
// This method can be used within a fiber.Handler
func DoTimeout(c *fiber.Ctx, addr string, timeout time.Duration, clients ...*fasthttp.Client) error {
return doAction(c, addr, func(cli *fasthttp.Client, req *fasthttp.Request, resp *fasthttp.Response) error {
return cli.DoTimeout(req, resp, timeout)
}, clients...)
}
func doAction(
c *fiber.Ctx,
addr string,
action func(cli *fasthttp.Client, req *fasthttp.Request, resp *fasthttp.Response) error,
clients ...*fasthttp.Client,
) error {
var cli *fasthttp.Client
// set local or global client
if len(clients) != 0 {
// Set local client
cli = clients[0]
} else {
// Set global client
lock.RLock()
cli = client
lock.RUnlock()
}
req := c.Request()
res := c.Response()
originalURL := utils.CopyString(c.OriginalURL())
@ -157,14 +195,13 @@ func Do(c *fiber.Ctx, addr string, clients ...*fasthttp.Client) error {
copiedURL := utils.CopyString(addr)
req.SetRequestURI(copiedURL)
// NOTE: if req.isTLS is true, SetRequestURI keeps the scheme as https.
// issue reference:
// https://github.com/gofiber/fiber/issues/1762
// Reference: https://github.com/gofiber/fiber/issues/1762
if scheme := getScheme(utils.UnsafeBytes(copiedURL)); len(scheme) > 0 {
req.URI().SetSchemeBytes(scheme)
}
req.Header.Del(fiber.HeaderConnection)
if err := cli.Do(req, res); err != nil {
if err := action(cli, req, res); err != nil {
return err
}
res.Header.Del(fiber.HeaderConnection)

View File

@ -2,6 +2,7 @@ package proxy
import (
"crypto/tls"
"errors"
"io"
"net"
"net/http/httptest"
@ -48,6 +49,19 @@ func Test_Proxy_Empty_Upstream_Servers(t *testing.T) {
app.Use(Balancer(Config{Servers: []string{}}))
}
// go test -run Test_Proxy_Empty_Config
func Test_Proxy_Empty_Config(t *testing.T) {
t.Parallel()
defer func() {
if r := recover(); r != nil {
utils.AssertEqual(t, "Servers cannot be empty", r)
}
}()
app := fiber.New()
app.Use(New(Config{}))
}
// go test -run Test_Proxy_Next
func Test_Proxy_Next(t *testing.T) {
t.Parallel()
@ -345,24 +359,167 @@ func Test_Proxy_Buffer_Size_Response(t *testing.T) {
// go test -race -run Test_Proxy_Do_RestoreOriginalURL
func Test_Proxy_Do_RestoreOriginalURL(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Get("/proxy", func(c *fiber.Ctx) error {
return c.SendString("ok")
_, addr := createProxyTestServer(t, func(c *fiber.Ctx) error {
return c.SendString("proxied")
})
app.Get("/test", func(c *fiber.Ctx) error {
originalURL := utils.CopyString(c.OriginalURL())
if err := Do(c, "/proxy"); err != nil {
return err
}
utils.AssertEqual(t, originalURL, c.OriginalURL())
return c.SendString("ok")
})
_, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil))
// This test requires multiple requests due to zero allocation used in fiber
_, err2 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil))
app := fiber.New()
app.Get("/test", func(c *fiber.Ctx) error {
return Do(c, "http://"+addr)
})
resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil))
utils.AssertEqual(t, nil, err1)
utils.AssertEqual(t, nil, err2)
utils.AssertEqual(t, "/test", resp.Request.URL.String())
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
body, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "proxied", string(body))
}
// go test -race -run Test_Proxy_Do_WithRealURL
func Test_Proxy_Do_WithRealURL(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Get("/test", func(c *fiber.Ctx) error {
return Do(c, "https://www.google.com")
})
resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil))
utils.AssertEqual(t, nil, err1)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
utils.AssertEqual(t, "/test", resp.Request.URL.String())
body, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, true, strings.Contains(string(body), "https://www.google.com/"))
}
// go test -race -run Test_Proxy_Do_WithRedirect
func Test_Proxy_Do_WithRedirect(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Get("/test", func(c *fiber.Ctx) error {
return Do(c, "https://google.com")
})
resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil))
utils.AssertEqual(t, nil, err1)
body, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, true, strings.Contains(string(body), "https://www.google.com/"))
utils.AssertEqual(t, 301, resp.StatusCode)
}
// go test -race -run Test_Proxy_DoRedirects_RestoreOriginalURL
func Test_Proxy_DoRedirects_RestoreOriginalURL(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Get("/test", func(c *fiber.Ctx) error {
return DoRedirects(c, "http://google.com", 1)
})
resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil))
utils.AssertEqual(t, nil, err1)
_, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
utils.AssertEqual(t, "/test", resp.Request.URL.String())
}
// go test -race -run Test_Proxy_DoRedirects_TooManyRedirects
func Test_Proxy_DoRedirects_TooManyRedirects(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Get("/test", func(c *fiber.Ctx) error {
return DoRedirects(c, "http://google.com", 0)
})
resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil))
utils.AssertEqual(t, nil, err1)
body, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "too many redirects detected when doing the request", string(body))
utils.AssertEqual(t, fiber.StatusInternalServerError, resp.StatusCode)
utils.AssertEqual(t, "/test", resp.Request.URL.String())
}
// go test -race -run Test_Proxy_DoTimeout_RestoreOriginalURL
func Test_Proxy_DoTimeout_RestoreOriginalURL(t *testing.T) {
t.Parallel()
_, addr := createProxyTestServer(t, func(c *fiber.Ctx) error {
return c.SendString("proxied")
})
app := fiber.New()
app.Get("/test", func(c *fiber.Ctx) error {
return DoTimeout(c, "http://"+addr, time.Second)
})
resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil))
utils.AssertEqual(t, nil, err1)
body, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "proxied", string(body))
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
utils.AssertEqual(t, "/test", resp.Request.URL.String())
}
// go test -race -run Test_Proxy_DoTimeout_Timeout
func Test_Proxy_DoTimeout_Timeout(t *testing.T) {
t.Parallel()
_, addr := createProxyTestServer(t, func(c *fiber.Ctx) error {
time.Sleep(time.Second * 5)
return c.SendString("proxied")
})
app := fiber.New()
app.Get("/test", func(c *fiber.Ctx) error {
return DoTimeout(c, "http://"+addr, time.Second)
})
_, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil))
utils.AssertEqual(t, errors.New("test: timeout error 1000ms"), err1)
}
// go test -race -run Test_Proxy_DoDeadline_RestoreOriginalURL
func Test_Proxy_DoDeadline_RestoreOriginalURL(t *testing.T) {
t.Parallel()
_, addr := createProxyTestServer(t, func(c *fiber.Ctx) error {
return c.SendString("proxied")
})
app := fiber.New()
app.Get("/test", func(c *fiber.Ctx) error {
return DoDeadline(c, "http://"+addr, time.Now().Add(time.Second))
})
resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil))
utils.AssertEqual(t, nil, err1)
body, err := io.ReadAll(resp.Body)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, "proxied", string(body))
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
utils.AssertEqual(t, "/test", resp.Request.URL.String())
}
// go test -race -run Test_Proxy_DoDeadline_PastDeadline
func Test_Proxy_DoDeadline_PastDeadline(t *testing.T) {
t.Parallel()
_, addr := createProxyTestServer(t, func(c *fiber.Ctx) error {
time.Sleep(time.Second * 5)
return c.SendString("proxied")
})
app := fiber.New()
app.Get("/test", func(c *fiber.Ctx) error {
return DoDeadline(c, "http://"+addr, time.Now().Add(time.Second))
})
_, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil))
utils.AssertEqual(t, errors.New("test: timeout error 1000ms"), err1)
}
// go test -race -run Test_Proxy_Do_HTTP_Prefix_URL