fiber/middleware/adaptor/adaptor_test.go

637 lines
16 KiB
Go

//nolint:contextcheck, revive // Much easier to just ignore memory leaks in tests
package adaptor
import (
"bytes"
"context"
"fmt"
"io"
"net"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"github.com/gofiber/fiber/v3"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/valyala/fasthttp"
)
func Test_HTTPHandler(t *testing.T) {
expectedMethod := fiber.MethodPost
expectedProto := "HTTP/1.1"
expectedProtoMajor := 1
expectedProtoMinor := 1
expectedRequestURI := "/foo/bar?baz=123"
expectedBody := "body 123 foo bar baz"
expectedContentLength := len(expectedBody)
expectedHost := "foobar.com"
expectedRemoteAddr := "1.2.3.4:6789"
expectedHeader := map[string]string{
"Foo-Bar": "baz",
"Abc": "defg",
"XXX-Remote-Addr": "123.43.4543.345",
}
expectedURL, err := url.ParseRequestURI(expectedRequestURI)
require.NoError(t, err)
type contextKeyType string
expectedContextKey := contextKeyType("contextKey")
expectedContextValue := "contextValue"
callsCount := 0
nethttpH := func(w http.ResponseWriter, r *http.Request) {
callsCount++
assert.Equal(t, expectedMethod, r.Method, "Method")
assert.Equal(t, expectedProto, r.Proto, "Proto")
assert.Equal(t, expectedProtoMajor, r.ProtoMajor, "ProtoMajor")
assert.Equal(t, expectedProtoMinor, r.ProtoMinor, "ProtoMinor")
assert.Equal(t, expectedRequestURI, r.RequestURI, "RequestURI")
assert.Equal(t, expectedContentLength, int(r.ContentLength), "ContentLength")
assert.Empty(t, r.TransferEncoding, "TransferEncoding")
assert.Equal(t, expectedHost, r.Host, "Host")
assert.Equal(t, expectedRemoteAddr, r.RemoteAddr, "RemoteAddr")
body, err := io.ReadAll(r.Body)
assert.NoError(t, err)
assert.Equal(t, expectedBody, string(body), "Body")
assert.Equal(t, expectedURL, r.URL, "URL")
assert.Equal(t, expectedContextValue, r.Context().Value(expectedContextKey), "Context")
for k, expectedV := range expectedHeader {
v := r.Header.Get(k)
assert.Equal(t, expectedV, v, "Header")
}
w.Header().Set("Header1", "value1")
w.Header().Set("Header2", "value2")
w.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(w, "request body is %q", body) //nolint:errcheck // not needed
}
fiberH := HTTPHandlerFunc(http.HandlerFunc(nethttpH))
fiberH = setFiberContextValueMiddleware(fiberH, expectedContextKey, expectedContextValue)
var fctx fasthttp.RequestCtx
var req fasthttp.Request
req.Header.SetMethod(expectedMethod)
req.SetRequestURI(expectedRequestURI)
req.Header.SetHost(expectedHost)
req.BodyWriter().Write([]byte(expectedBody)) //nolint:errcheck // not needed
for k, v := range expectedHeader {
req.Header.Set(k, v)
}
remoteAddr, err := net.ResolveTCPAddr("tcp", expectedRemoteAddr)
require.NoError(t, err)
fctx.Init(&req, remoteAddr, &disableLogger{})
app := fiber.New()
ctx := app.AcquireCtx(&fctx)
defer app.ReleaseCtx(ctx)
err = fiberH(ctx)
require.NoError(t, err)
require.Equal(t, 1, callsCount, "callsCount")
resp := &fctx.Response
require.Equal(t, http.StatusBadRequest, resp.StatusCode(), "StatusCode")
require.Equal(t, "value1", string(resp.Header.Peek("Header1")), "Header1")
require.Equal(t, "value2", string(resp.Header.Peek("Header2")), "Header2")
expectedResponseBody := fmt.Sprintf("request body is %q", expectedBody)
require.Equal(t, expectedResponseBody, string(resp.Body()), "Body")
}
type contextKey string
func (c contextKey) String() string {
return "test-" + string(c)
}
var (
TestContextKey = contextKey("TestContextKey")
TestContextSecondKey = contextKey("TestContextSecondKey")
)
func Test_HTTPMiddleware(t *testing.T) {
const expectedHost = "foobar.com"
tests := []struct {
name string
url string
method string
statusCode int
}{
{
name: "Should return 200",
url: "/",
method: "POST",
statusCode: 200,
},
{
name: "Should return 405",
url: "/",
method: "GET",
statusCode: 405,
},
{
name: "Should return 400",
url: "/unknown",
method: "POST",
statusCode: 404,
},
}
nethttpMW := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
r = r.WithContext(context.WithValue(r.Context(), TestContextKey, "okay"))
r = r.WithContext(context.WithValue(r.Context(), TestContextSecondKey, "not_okay"))
r = r.WithContext(context.WithValue(r.Context(), TestContextSecondKey, "okay"))
next.ServeHTTP(w, r)
})
}
app := fiber.New()
app.Use(HTTPMiddleware(nethttpMW))
app.Post("/", func(c fiber.Ctx) error {
value := c.Context().Value(TestContextKey)
val, ok := value.(string)
if !ok {
t.Error("unexpected error on type-assertion")
}
if value != nil {
c.Set("context_okay", val)
}
value = c.Context().Value(TestContextSecondKey)
if value != nil {
val, ok := value.(string)
if !ok {
t.Error("unexpected error on type-assertion")
}
c.Set("context_second_okay", val)
}
return c.SendStatus(fiber.StatusOK)
})
for _, tt := range tests {
req, err := http.NewRequestWithContext(context.Background(), tt.method, tt.url, nil)
req.Host = expectedHost
require.NoError(t, err)
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, tt.statusCode, resp.StatusCode, "StatusCode")
}
req, err := http.NewRequestWithContext(context.Background(), fiber.MethodPost, "/", nil)
req.Host = expectedHost
require.NoError(t, err)
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, "okay", resp.Header.Get("context_okay"))
require.Equal(t, "okay", resp.Header.Get("context_second_okay"))
}
func Test_HTTPMiddlewareWithCookies(t *testing.T) {
const (
cookieHeader = "Cookie"
setCookieHeader = "Set-Cookie"
cookieOneName = "cookieOne"
cookieTwoName = "cookieTwo"
cookieOneValue = "valueCookieOne"
cookieTwoValue = "valueCookieTwo"
)
nethttpMW := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
next.ServeHTTP(w, r)
})
}
app := fiber.New()
app.Use(HTTPMiddleware(nethttpMW))
app.Post("/", func(c fiber.Ctx) error {
// RETURNING CURRENT COOKIES TO RESPONSE
var cookies []string = strings.Split(c.Get(cookieHeader), "; ")
for _, cookie := range cookies {
c.Set(setCookieHeader, cookie)
}
return c.SendStatus(fiber.StatusOK)
})
// Test case for POST request with cookies
t.Run("POST request with cookies", func(t *testing.T) {
req, err := http.NewRequestWithContext(context.Background(), fiber.MethodPost, "/", nil)
require.NoError(t, err)
req.AddCookie(&http.Cookie{Name: cookieOneName, Value: cookieOneValue})
req.AddCookie(&http.Cookie{Name: cookieTwoName, Value: cookieTwoValue})
resp, err := app.Test(req)
require.NoError(t, err)
cookies := resp.Cookies()
require.Len(t, cookies, 2)
for _, cookie := range cookies {
switch cookie.Name {
case cookieOneName:
require.Equal(t, cookieOneValue, cookie.Value)
case cookieTwoName:
require.Equal(t, cookieTwoValue, cookie.Value)
default:
t.Error("unexpected cookie key")
}
}
})
// New test case for GET request
t.Run("GET request", func(t *testing.T) {
req, err := http.NewRequestWithContext(context.Background(), fiber.MethodGet, "/", nil)
require.NoError(t, err)
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, http.StatusMethodNotAllowed, resp.StatusCode)
})
// New test case for request without cookies
t.Run("POST request without cookies", func(t *testing.T) {
req, err := http.NewRequestWithContext(context.Background(), fiber.MethodPost, "/", nil)
require.NoError(t, err)
resp, err := app.Test(req)
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
require.Empty(t, resp.Cookies())
})
}
func Test_FiberHandler(t *testing.T) {
testFiberToHandlerFunc(t, false)
}
func Test_FiberApp(t *testing.T) {
testFiberToHandlerFunc(t, false, fiber.New())
}
func Test_FiberHandlerDefaultPort(t *testing.T) {
testFiberToHandlerFunc(t, true)
}
func Test_FiberAppDefaultPort(t *testing.T) {
testFiberToHandlerFunc(t, true, fiber.New())
}
func testFiberToHandlerFunc(t *testing.T, checkDefaultPort bool, app ...*fiber.App) {
t.Helper()
expectedMethod := fiber.MethodPost
expectedRequestURI := "/foo/bar?baz=123"
expectedBody := "body 123 foo bar baz"
expectedContentLength := len(expectedBody)
expectedHost := "foobar.com"
expectedRemoteAddr := "1.2.3.4:6789"
if checkDefaultPort {
expectedRemoteAddr = "1.2.3.4:80"
}
expectedHeader := map[string]string{
"Foo-Bar": "baz",
"Abc": "defg",
"XXX-Remote-Addr": "123.43.4543.345",
}
expectedURL, err := url.ParseRequestURI(expectedRequestURI)
require.NoError(t, err)
callsCount := 0
fiberH := func(c fiber.Ctx) error {
callsCount++
require.Equal(t, expectedMethod, c.Method(), "Method")
require.Equal(t, expectedRequestURI, string(c.Context().RequestURI()), "RequestURI")
require.Equal(t, expectedContentLength, c.Context().Request.Header.ContentLength(), "ContentLength")
require.Equal(t, expectedHost, c.Hostname(), "Host")
require.Equal(t, expectedHost, string(c.Request().Header.Host()), "Host")
require.Equal(t, "http://"+expectedHost, c.BaseURL(), "BaseURL")
require.Equal(t, expectedRemoteAddr, c.Context().RemoteAddr().String(), "RemoteAddr")
body := string(c.Body())
require.Equal(t, expectedBody, body, "Body")
require.Equal(t, expectedURL.String(), c.OriginalURL(), "URL")
for k, expectedV := range expectedHeader {
v := c.Get(k)
require.Equal(t, expectedV, v, "Header")
}
c.Set("Header1", "value1")
c.Set("Header2", "value2")
c.Status(fiber.StatusBadRequest)
_, err := c.Write([]byte(fmt.Sprintf("request body is %q", body)))
return err
}
var handlerFunc http.HandlerFunc
if len(app) > 0 {
app[0].Post("/foo/bar", fiberH)
handlerFunc = FiberApp(app[0])
} else {
handlerFunc = FiberHandlerFunc(fiberH)
}
var r http.Request
r.Method = expectedMethod
r.Body = &netHTTPBody{b: []byte(expectedBody)}
r.RequestURI = expectedRequestURI
r.ContentLength = int64(expectedContentLength)
r.Host = expectedHost
r.RemoteAddr = expectedRemoteAddr
if checkDefaultPort {
r.RemoteAddr = "1.2.3.4"
}
hdr := make(http.Header)
for k, v := range expectedHeader {
hdr.Set(k, v)
}
r.Header = hdr
var w netHTTPResponseWriter
handlerFunc.ServeHTTP(&w, &r)
require.Equal(t, http.StatusBadRequest, w.StatusCode(), "StatusCode")
require.Equal(t, "value1", w.Header().Get("Header1"), "Header1")
require.Equal(t, "value2", w.Header().Get("Header2"), "Header2")
expectedResponseBody := fmt.Sprintf("request body is %q", expectedBody)
require.Equal(t, expectedResponseBody, string(w.body), "Body")
}
func setFiberContextValueMiddleware(next fiber.Handler, key, value any) fiber.Handler {
return func(c fiber.Ctx) error {
c.Locals(key, value)
return next(c)
}
}
func Test_FiberHandler_RequestNilBody(t *testing.T) {
expectedMethod := fiber.MethodGet
expectedRequestURI := "/foo/bar"
expectedContentLength := 0
callsCount := 0
fiberH := func(c fiber.Ctx) error {
callsCount++
require.Equal(t, expectedMethod, c.Method(), "Method")
require.Equal(t, expectedRequestURI, string(c.Context().RequestURI()), "RequestURI")
require.Equal(t, expectedContentLength, c.Context().Request.Header.ContentLength(), "ContentLength")
_, err := c.Write([]byte("request body is nil"))
return err
}
nethttpH := FiberHandler(fiberH)
var r http.Request
r.Method = expectedMethod
r.RequestURI = expectedRequestURI
var w netHTTPResponseWriter
nethttpH.ServeHTTP(&w, &r)
expectedResponseBody := "request body is nil"
require.Equal(t, expectedResponseBody, string(w.body), "Body")
}
type netHTTPBody struct {
b []byte
}
func (r *netHTTPBody) Read(p []byte) (int, error) {
if len(r.b) == 0 {
return 0, io.EOF
}
n := copy(p, r.b)
r.b = r.b[n:]
return n, nil
}
func (r *netHTTPBody) Close() error {
r.b = r.b[:0]
return nil
}
type netHTTPResponseWriter struct {
h http.Header
body []byte
statusCode int
}
func (w *netHTTPResponseWriter) StatusCode() int {
if w.statusCode == 0 {
return http.StatusOK
}
return w.statusCode
}
func (w *netHTTPResponseWriter) Header() http.Header {
if w.h == nil {
w.h = make(http.Header)
}
return w.h
}
func (w *netHTTPResponseWriter) WriteHeader(statusCode int) {
w.statusCode = statusCode
}
func (w *netHTTPResponseWriter) Write(p []byte) (int, error) {
w.body = append(w.body, p...)
return len(p), nil
}
func Test_ConvertRequest(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Get("/test", func(c fiber.Ctx) error {
httpReq, err := ConvertRequest(c, false)
if err != nil {
return err
}
return c.SendString("Request URL: " + httpReq.URL.String())
})
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/test?hello=world&another=test", nil))
require.NoError(t, err, "app.Test(req)")
require.Equal(t, http.StatusOK, resp.StatusCode, "Status code")
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, "Request URL: /test?hello=world&another=test", string(body))
}
// Benchmark for FiberHandlerFunc
func Benchmark_FiberHandlerFunc(b *testing.B) {
benchmarks := []struct {
name string
bodyContent []byte
}{
{
name: "No Content",
bodyContent: nil, // No body content case
},
{
name: "100KB",
bodyContent: make([]byte, 100*1024),
},
{
name: "500KB",
bodyContent: make([]byte, 500*1024),
},
{
name: "1MB",
bodyContent: make([]byte, 1*1024*1024),
},
{
name: "5MB",
bodyContent: make([]byte, 5*1024*1024),
},
{
name: "10MB",
bodyContent: make([]byte, 10*1024*1024),
},
{
name: "25MB",
bodyContent: make([]byte, 25*1024*1024),
},
{
name: "50MB",
bodyContent: make([]byte, 50*1024*1024),
},
}
fiberH := func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
}
handlerFunc := FiberHandlerFunc(fiberH)
for _, bm := range benchmarks {
b.Run(bm.name, func(b *testing.B) {
w := httptest.NewRecorder()
var bodyBuffer *bytes.Buffer
// Handle the "No Content" case where bodyContent is nil
if bm.bodyContent != nil {
bodyBuffer = bytes.NewBuffer(bm.bodyContent)
} else {
bodyBuffer = bytes.NewBuffer([]byte{}) // Empty buffer for no content
}
r := http.Request{
Method: http.MethodPost,
Body: nil,
}
// Replace the empty Body with our buffer
r.Body = io.NopCloser(bodyBuffer)
defer r.Body.Close() //nolint:errcheck // not needed
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
handlerFunc.ServeHTTP(w, &r)
}
})
}
}
func Benchmark_FiberHandlerFunc_Parallel(b *testing.B) {
benchmarks := []struct {
name string
bodyContent []byte
}{
{
name: "No Content",
bodyContent: nil, // No body content case
},
{
name: "100KB",
bodyContent: make([]byte, 100*1024),
},
{
name: "500KB",
bodyContent: make([]byte, 500*1024),
},
{
name: "1MB",
bodyContent: make([]byte, 1*1024*1024),
},
{
name: "5MB",
bodyContent: make([]byte, 5*1024*1024),
},
{
name: "10MB",
bodyContent: make([]byte, 10*1024*1024),
},
{
name: "25MB",
bodyContent: make([]byte, 25*1024*1024),
},
{
name: "50MB",
bodyContent: make([]byte, 50*1024*1024),
},
}
fiberH := func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
}
handlerFunc := FiberHandlerFunc(fiberH)
for _, bm := range benchmarks {
b.Run(bm.name, func(b *testing.B) {
var bodyBuffer *bytes.Buffer
// Handle the "No Content" case where bodyContent is nil
if bm.bodyContent != nil {
bodyBuffer = bytes.NewBuffer(bm.bodyContent)
} else {
bodyBuffer = bytes.NewBuffer([]byte{}) // Empty buffer for no content
}
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
w := httptest.NewRecorder()
r := http.Request{
Method: http.MethodPost,
Body: nil,
}
// Replace the empty Body with our buffer
r.Body = io.NopCloser(bodyBuffer)
defer r.Body.Close() //nolint:errcheck // not needed
for pb.Next() {
handlerFunc(w, &r)
}
})
})
}
}